diff --git a/README.md b/README.md index 2ea99b8f1..89cf1ef36 100644 --- a/README.md +++ b/README.md @@ -618,6 +618,9 @@ push_dataset_to_hub: # repo path # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` # if not set. dataset_processes: # defaults to os.cpu_count() if not set +# Keep dataset in memory while preprocessing +# Only needed if cached dataset is taking too much storage +dataset_keep_in_memory: # push checkpoints to hub hub_model_id: # repo path to push finetuned model # how to push checkpoints to hub diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index c9181bdab..837b0d674 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -24,6 +24,8 @@ class TokenizedPromptDataset(Dataset): Args: prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data. dataset (dataset.Dataset): Dataset with text files. + process_count (int): Number of processes to use for tokenizing. + keep_in_memory (bool): Whether to keep the tokenized dataset in memory. """ def __init__( # pylint: disable=super-init-not-called @@ -31,10 +33,12 @@ class TokenizedPromptDataset(Dataset): prompt_tokenizer: PromptTokenizingStrategy, dataset: IterableDataset, process_count: Optional[int] = None, + keep_in_memory: Optional[bool] = False, **kwargs, ): self.prompt_tokenizer = prompt_tokenizer self.process_count = process_count + self.keep_in_memory = keep_in_memory super().__init__( self.process(dataset).data, **kwargs, @@ -42,11 +46,8 @@ class TokenizedPromptDataset(Dataset): def process(self, dataset): features = dataset.features.keys() - num_proc = ( - min(64, self.process_count) - if self.process_count - else min(64, os.cpu_count()) - ) + num_proc = min(64, self.process_count if self.process_count else os.cpu_count()) + map_kwargs = {} if self.prompt_tokenizer.supports_batched: map_kwargs["batched"] = True @@ -55,7 +56,7 @@ class TokenizedPromptDataset(Dataset): self.prompt_tokenizer.tokenize_prompt, num_proc=num_proc, remove_columns=features, - keep_in_memory=True, + keep_in_memory=self.keep_in_memory, **map_kwargs, ) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 15ae8d5a5..1eff82694 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -588,6 +588,11 @@ def get_dataset_wrapper( dataset_wrapper = None dataset_prompter = None + ds_kwargs = { + "process_count": cfg.dataset_processes, + "keep_in_memory": cfg.dataset_keep_in_memory is True, + } + if ( "input_ids" in dataset.features and "attention_mask" in dataset.features @@ -604,14 +609,14 @@ def get_dataset_wrapper( dataset_wrapper = TokenizedPromptDataset( ds_strategy, dataset, - process_count=cfg.dataset_processes, + **ds_kwargs, ) elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): dataset_prompter = UnsupportedPrompter() dataset_wrapper = TokenizedPromptDataset( ds_strategy, dataset, - process_count=cfg.dataset_processes, + **ds_kwargs, ) elif d_base_type == "alpaca": dataset_prompter = AlpacaPrompter(d_prompt_style) @@ -624,7 +629,7 @@ def get_dataset_wrapper( ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, - process_count=cfg.dataset_processes, + **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "explainchoice": @@ -638,7 +643,7 @@ def get_dataset_wrapper( ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, - process_count=cfg.dataset_processes, + **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "concisechoice": @@ -652,7 +657,7 @@ def get_dataset_wrapper( ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, - process_count=cfg.dataset_processes, + **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "summarizetldr": @@ -666,7 +671,7 @@ def get_dataset_wrapper( ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, - process_count=cfg.dataset_processes, + **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "jeopardy": @@ -680,7 +685,7 @@ def get_dataset_wrapper( ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, - process_count=cfg.dataset_processes, + **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "oasst": @@ -694,7 +699,7 @@ def get_dataset_wrapper( ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, - process_count=cfg.dataset_processes, + **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "gpteacher": @@ -708,7 +713,7 @@ def get_dataset_wrapper( ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, - process_count=cfg.dataset_processes, + **ds_kwargs, ) dataset_wrapper = ds_wrapper elif d_base_type == "reflection": @@ -722,7 +727,7 @@ def get_dataset_wrapper( ds_wrapper = TokenizedPromptDataset( ds_strategy, dataset, - process_count=cfg.dataset_processes, + **ds_kwargs, ) dataset_wrapper = ds_wrapper else: