Preprocess dataset size fix (#1131)
* overwrite cache on preprocess step * don't cache the TokenizedPromptDataset at all * load_from_cache_file no longer needed
This commit is contained in:
@@ -25,6 +25,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
parsed_cfg.is_preprocess = True
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
||||||
|
|||||||
@@ -35,7 +35,10 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
):
|
):
|
||||||
self.prompt_tokenizer = prompt_tokenizer
|
self.prompt_tokenizer = prompt_tokenizer
|
||||||
self.process_count = process_count
|
self.process_count = process_count
|
||||||
super().__init__(self.process(dataset).data, **kwargs)
|
super().__init__(
|
||||||
|
self.process(dataset).data,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def process(self, dataset):
|
def process(self, dataset):
|
||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
@@ -52,6 +55,7 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
self.prompt_tokenizer.tokenize_prompt,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
num_proc=num_proc,
|
num_proc=num_proc,
|
||||||
remove_columns=features,
|
remove_columns=features,
|
||||||
|
keep_in_memory=True,
|
||||||
**map_kwargs,
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -594,12 +594,16 @@ def get_dataset_wrapper(
|
|||||||
)
|
)
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
ds_strategy,
|
||||||
|
dataset,
|
||||||
|
process_count=cfg.dataset_processes,
|
||||||
)
|
)
|
||||||
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
ds_strategy,
|
||||||
|
dataset,
|
||||||
|
process_count=cfg.dataset_processes,
|
||||||
)
|
)
|
||||||
elif d_base_type == "alpaca":
|
elif d_base_type == "alpaca":
|
||||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||||
@@ -610,7 +614,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
ds_strategy,
|
||||||
|
dataset,
|
||||||
|
process_count=cfg.dataset_processes,
|
||||||
)
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "explainchoice":
|
elif d_base_type == "explainchoice":
|
||||||
@@ -622,7 +628,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
ds_strategy,
|
||||||
|
dataset,
|
||||||
|
process_count=cfg.dataset_processes,
|
||||||
)
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "concisechoice":
|
elif d_base_type == "concisechoice":
|
||||||
@@ -634,7 +642,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
ds_strategy,
|
||||||
|
dataset,
|
||||||
|
process_count=cfg.dataset_processes,
|
||||||
)
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "summarizetldr":
|
elif d_base_type == "summarizetldr":
|
||||||
@@ -646,7 +656,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
ds_strategy,
|
||||||
|
dataset,
|
||||||
|
process_count=cfg.dataset_processes,
|
||||||
)
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "jeopardy":
|
elif d_base_type == "jeopardy":
|
||||||
@@ -658,7 +670,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
ds_strategy,
|
||||||
|
dataset,
|
||||||
|
process_count=cfg.dataset_processes,
|
||||||
)
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "oasst":
|
elif d_base_type == "oasst":
|
||||||
@@ -670,7 +684,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
ds_strategy,
|
||||||
|
dataset,
|
||||||
|
process_count=cfg.dataset_processes,
|
||||||
)
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "gpteacher":
|
elif d_base_type == "gpteacher":
|
||||||
@@ -682,7 +698,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
ds_strategy,
|
||||||
|
dataset,
|
||||||
|
process_count=cfg.dataset_processes,
|
||||||
)
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "reflection":
|
elif d_base_type == "reflection":
|
||||||
@@ -694,7 +712,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
ds_strategy,
|
||||||
|
dataset,
|
||||||
|
process_count=cfg.dataset_processes,
|
||||||
)
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -111,27 +111,39 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_length, num_proc=cfg.dataset_processes
|
add_length,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids, num_proc=cfg.dataset_processes
|
add_position_ids,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
)
|
)
|
||||||
if cfg.eval_sample_packing is not False:
|
if cfg.eval_sample_packing is not False:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids, num_proc=cfg.dataset_processes
|
add_position_ids,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.group_by_length or cfg.sample_packing:
|
if cfg.group_by_length or cfg.sample_packing:
|
||||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||||
|
|
||||||
train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes)
|
train_dataset = train_dataset.filter(
|
||||||
|
drop_long,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
)
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.filter(
|
eval_dataset = eval_dataset.filter(
|
||||||
drop_long, num_proc=cfg.dataset_processes
|
drop_long,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Phi doesn't want the attention_mask feature when training
|
# Phi doesn't want the attention_mask feature when training
|
||||||
|
|||||||
Reference in New Issue
Block a user