From 9f2bb188a4e21b72802807164f451a311e4d581b Mon Sep 17 00:00:00 2001 From: Varun Gumma <45076943+VarunGumma@users.noreply.github.com> Date: Thu, 17 Jul 2025 19:17:58 +0530 Subject: [PATCH] Improve Dataset Processing Multiprocessing, Sharding, and Qwen Tokenizer Bug Fix. (#2918) * Added a feature to save prepared dataset in specified shards, removed limiter on multiprocessing during tokenization, and a bug fix of qwen tokenizer * removed limiters and fixed config variable name * black lint * chore: lint * feat: update handling of dataset_processes --------- Co-authored-by: NanoCode012 --- src/axolotl/core/datasets/chat.py | 7 +------ src/axolotl/datasets.py | 7 ++----- src/axolotl/loaders/tokenizer.py | 3 ++- src/axolotl/utils/config/__init__.py | 2 -- src/axolotl/utils/data/shared.py | 17 +++++++++++++---- src/axolotl/utils/schemas/config.py | 28 ++++++++++++++++++++++++---- 6 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/axolotl/core/datasets/chat.py b/src/axolotl/core/datasets/chat.py index 724f12866..a4dc300d9 100644 --- a/src/axolotl/core/datasets/chat.py +++ b/src/axolotl/core/datasets/chat.py @@ -2,7 +2,6 @@ chat dataset module """ -import os from typing import Callable, Optional, Union from datasets import Dataset @@ -41,14 +40,10 @@ class TokenizedChatDataset(Dataset): ) return ex.tokenized(model_transform) - process_or_cpu_count: int = ( - process_count or os.cpu_count() # type: ignore[assignment] - ) - num_proc = min(32, process_or_cpu_count) features = data.features.keys() tokenized_data = data.map( map_fn, - num_proc=num_proc, + num_proc=process_count, keep_in_memory=keep_in_memory, remove_columns=features, desc="Tokenizing Chats", diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 7c112c59e..c9d006ac8 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -1,7 +1,5 @@ """Module containing Dataset functionality""" -import os - import torch from datasets import Dataset, IterableDataset @@ -46,7 +44,6 @@ class TokenizedPromptDataset(Dataset): def process(self, dataset): features = dataset.features.keys() - num_proc = min(64, self.process_count if self.process_count else os.cpu_count()) map_kwargs = {} if self.prompt_tokenizer.supports_batched: @@ -59,13 +56,13 @@ class TokenizedPromptDataset(Dataset): ): dataset = dataset.filter( self.prompt_tokenizer.filter_rows, - num_proc=num_proc, + num_proc=self.process_count, desc="Strategy Filtering Rows", ) return dataset.map( self.prompt_tokenizer.tokenize_prompt, - num_proc=num_proc, + num_proc=self.process_count, remove_columns=features, keep_in_memory=self.keep_in_memory, desc="Tokenizing Prompts", diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 9fdb7d5cc..2f0ccbcbb 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -188,7 +188,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: tokenizer.padding_side = "left" # Qwen base only has single token, so we need to set the special tokens - if cfg.is_qwen_derived_model: + # the following check is for Qwen1 base models + if cfg.is_qwen_derived_model and hasattr(tokenizer, "eod_id"): token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"] for attr_name in token_ids: if getattr(tokenizer, attr_name) is None: diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index aaa203e82..c9613c39b 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -148,8 +148,6 @@ def normalize_config(cfg): f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations." ) - cfg.dataset_processes = cfg.dataset_processes or os.cpu_count() - if not cfg.base_model_config: cfg.base_model_config = cfg.base_model diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index c30459d5b..3a3657240 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -410,9 +410,8 @@ def save_preprocessed_dataset( ) -> None: """Save preprocessed dataset to disk and optionally push to the HF Hub.""" prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash) + num_workers = cfg.dataset_processes if isinstance(dataset, IterableDataset): - num_workers = cfg.dataset_processes - ds_from_iter = Dataset.from_generator( functools.partial(_generate_from_iterable_dataset, dataset), features=dataset.features, @@ -423,10 +422,20 @@ def save_preprocessed_dataset( "num_workers": [num_workers] * num_workers, }, ) - ds_from_iter.save_to_disk(str(prepared_ds_path)) + ds_from_iter.save_to_disk( + str(prepared_ds_path), + num_proc=num_workers, + max_shard_size=None, + num_shards=cfg.num_dataset_shards_to_save, + ) else: os.makedirs(prepared_ds_path, exist_ok=True) - dataset.save_to_disk(str(prepared_ds_path)) + dataset.save_to_disk( + str(prepared_ds_path), + num_proc=num_workers, + max_shard_size=None, + num_shards=cfg.num_dataset_shards_to_save, + ) if cfg.push_dataset_to_hub: LOG.info( "Pushing merged prepared dataset to Huggingface hub at " diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 06212a27f..d3fb0b14c 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -193,6 +193,12 @@ class AxolotlInputConfig( json_schema_extra={"description": "Index of shard to use for whole dataset"}, ) skip_prepare_dataset: bool | None = False + num_dataset_shards_to_save: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of shards to save the prepared dataset" + }, + ) pretraining_dataset: ( Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None @@ -203,11 +209,12 @@ class AxolotlInputConfig( }, ) dataset_processes: int | None = Field( - default=min( - int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count() - ), # type: ignore[type-var] + default=None, json_schema_extra={ - "description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set." + "description": ( + "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n" + "For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT." + ) }, ) dataset_exact_deduplication: bool | None = Field( @@ -1199,3 +1206,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): data["dataloader_prefetch_factor"] = 256 return data + + @model_validator(mode="before") + @classmethod + def default_dataset_processes(cls, data): + if data.get("dataset_processes") is None: + if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"): + data["dataset_processes"] = int(axolotl_dataset_processes) + elif runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"): + data["dataset_processes"] = int(runpod_cpu_count) + else: + data["dataset_processes"] = os.cpu_count() + + return data