From 7570446596dfab69763c06ac7885b9c42c08aa24 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 17 Jan 2024 11:02:41 -0500 Subject: [PATCH] Preprocess dataset size fix (#1131) * overwrite cache on preprocess step * don't cache the TokenizedPromptDataset at all * load_from_cache_file no longer needed --- src/axolotl/cli/preprocess.py | 1 + src/axolotl/datasets.py | 6 +++++- src/axolotl/utils/data.py | 40 ++++++++++++++++++++++++++--------- src/axolotl/utils/trainer.py | 22 ++++++++++++++----- 4 files changed, 53 insertions(+), 16 deletions(-) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 2c2709519..76b655afb 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -25,6 +25,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs): # pylint: disable=duplicate-code print_axolotl_text_art() parsed_cfg = load_cfg(config, **kwargs) + parsed_cfg.is_preprocess = True check_accelerate_default_config() check_user_token() parser = transformers.HfArgumentParser((PreprocessCliArgs)) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index d5362643f..c9181bdab 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -35,7 +35,10 @@ class TokenizedPromptDataset(Dataset): ): self.prompt_tokenizer = prompt_tokenizer self.process_count = process_count - super().__init__(self.process(dataset).data, **kwargs) + super().__init__( + self.process(dataset).data, + **kwargs, + ) def process(self, dataset): features = dataset.features.keys() @@ -52,6 +55,7 @@ class TokenizedPromptDataset(Dataset): self.prompt_tokenizer.tokenize_prompt, num_proc=num_proc, remove_columns=features, + keep_in_memory=True, **map_kwargs, ) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 8ef3a7f78..d65e19ab4 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -594,12 +594,16 @@ def get_dataset_wrapper( ) dataset_prompter = UnsupportedPrompter() dataset_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes + ds_strategy, + dataset, + process_count=cfg.dataset_processes, ) elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): dataset_prompter = UnsupportedPrompter() dataset_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes + ds_strategy, + dataset, + process_count=cfg.dataset_processes, ) elif d_base_type == "alpaca": dataset_prompter = AlpacaPrompter(d_prompt_style) @@ -610,7 +614,9 @@ def get_dataset_wrapper( cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes + ds_strategy, + dataset, + process_count=cfg.dataset_processes, ) dataset_wrapper = ds_wrapper elif d_base_type == "explainchoice": @@ -622,7 +628,9 @@ def get_dataset_wrapper( cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes + ds_strategy, + dataset, + process_count=cfg.dataset_processes, ) dataset_wrapper = ds_wrapper elif d_base_type == "concisechoice": @@ -634,7 +642,9 @@ def get_dataset_wrapper( cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes + ds_strategy, + dataset, + process_count=cfg.dataset_processes, ) dataset_wrapper = ds_wrapper elif d_base_type == "summarizetldr": @@ -646,7 +656,9 @@ def get_dataset_wrapper( cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes + ds_strategy, + dataset, + process_count=cfg.dataset_processes, ) dataset_wrapper = ds_wrapper elif d_base_type == "jeopardy": @@ -658,7 +670,9 @@ def get_dataset_wrapper( cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes + ds_strategy, + dataset, + process_count=cfg.dataset_processes, ) dataset_wrapper = ds_wrapper elif d_base_type == "oasst": @@ -670,7 +684,9 @@ def get_dataset_wrapper( cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes + ds_strategy, + dataset, + process_count=cfg.dataset_processes, ) dataset_wrapper = ds_wrapper elif d_base_type == "gpteacher": @@ -682,7 +698,9 @@ def get_dataset_wrapper( cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes + ds_strategy, + dataset, + process_count=cfg.dataset_processes, ) dataset_wrapper = ds_wrapper elif d_base_type == "reflection": @@ -694,7 +712,9 @@ def get_dataset_wrapper( cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset( - ds_strategy, dataset, process_count=cfg.dataset_processes + ds_strategy, + dataset, + process_count=cfg.dataset_processes, ) dataset_wrapper = ds_wrapper else: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 5588e768f..3fc244605 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -111,27 +111,39 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): with zero_first(is_main_process()): if cfg.group_by_length: train_dataset = train_dataset.map( - add_length, num_proc=cfg.dataset_processes + add_length, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, ) if cfg.sample_packing: train_dataset = train_dataset.map( - add_position_ids, num_proc=cfg.dataset_processes + add_position_ids, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, ) if cfg.eval_sample_packing is not False: if eval_dataset: eval_dataset = eval_dataset.map( - add_position_ids, num_proc=cfg.dataset_processes + add_position_ids, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, ) if cfg.group_by_length or cfg.sample_packing: max_input_len = np.max(get_dataset_lengths(train_dataset)) LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) - train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes) + train_dataset = train_dataset.filter( + drop_long, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + ) if eval_dataset: eval_dataset = eval_dataset.filter( - drop_long, num_proc=cfg.dataset_processes + drop_long, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, ) # Phi doesn't want the attention_mask feature when training