Improve Dataset Processing Multiprocessing, Sharding, and Qwen Tokenizer Bug Fix. (#2918)

* Added a feature to save prepared dataset in specified shards, removed limiter on multiprocessing during tokenization, and a bug fix of qwen tokenizer

* removed limiters and fixed config variable name

* black lint

* chore: lint

* feat: update handling of dataset_processes

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Varun Gumma
2025-07-17 19:17:58 +05:30
committed by GitHub
parent 9dde9e1b71
commit 9f2bb188a4
6 changed files with 42 additions and 22 deletions

View File

@@ -2,7 +2,6 @@
chat dataset module
"""
import os
from typing import Callable, Optional, Union
from datasets import Dataset
@@ -41,14 +40,10 @@ class TokenizedChatDataset(Dataset):
)
return ex.tokenized(model_transform)
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(32, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,
num_proc=num_proc,
num_proc=process_count,
keep_in_memory=keep_in_memory,
remove_columns=features,
desc="Tokenizing Chats",

View File

@@ -1,7 +1,5 @@
"""Module containing Dataset functionality"""
import os
import torch
from datasets import Dataset, IterableDataset
@@ -46,7 +44,6 @@ class TokenizedPromptDataset(Dataset):
def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
@@ -59,13 +56,13 @@ class TokenizedPromptDataset(Dataset):
):
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
num_proc=num_proc,
num_proc=self.process_count,
desc="Strategy Filtering Rows",
)
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
num_proc=self.process_count,
remove_columns=features,
keep_in_memory=self.keep_in_memory,
desc="Tokenizing Prompts",

View File

@@ -188,7 +188,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
# the following check is for Qwen1 base models
if cfg.is_qwen_derived_model and hasattr(tokenizer, "eod_id"):
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:

View File

@@ -148,8 +148,6 @@ def normalize_config(cfg):
f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations."
)
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
if not cfg.base_model_config:
cfg.base_model_config = cfg.base_model

View File

@@ -410,9 +410,8 @@ def save_preprocessed_dataset(
) -> None:
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
num_workers = cfg.dataset_processes
if isinstance(dataset, IterableDataset):
num_workers = cfg.dataset_processes
ds_from_iter = Dataset.from_generator(
functools.partial(_generate_from_iterable_dataset, dataset),
features=dataset.features,
@@ -423,10 +422,20 @@ def save_preprocessed_dataset(
"num_workers": [num_workers] * num_workers,
},
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
ds_from_iter.save_to_disk(
str(prepared_ds_path),
num_proc=num_workers,
max_shard_size=None,
num_shards=cfg.num_dataset_shards_to_save,
)
else:
os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk(str(prepared_ds_path))
dataset.save_to_disk(
str(prepared_ds_path),
num_proc=num_workers,
max_shard_size=None,
num_shards=cfg.num_dataset_shards_to_save,
)
if cfg.push_dataset_to_hub:
LOG.info(
"Pushing merged prepared dataset to Huggingface hub at "

View File

@@ -193,6 +193,12 @@ class AxolotlInputConfig(
json_schema_extra={"description": "Index of shard to use for whole dataset"},
)
skip_prepare_dataset: bool | None = False
num_dataset_shards_to_save: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of shards to save the prepared dataset"
},
)
pretraining_dataset: (
Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
@@ -203,11 +209,12 @@ class AxolotlInputConfig(
},
)
dataset_processes: int | None = Field(
default=min(
int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()
), # type: ignore[type-var]
default=None,
json_schema_extra={
"description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set."
"description": (
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
)
},
)
dataset_exact_deduplication: bool | None = Field(
@@ -1199,3 +1206,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
data["dataloader_prefetch_factor"] = 256
return data
@model_validator(mode="before")
@classmethod
def default_dataset_processes(cls, data):
if data.get("dataset_processes") is None:
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
data["dataset_processes"] = int(axolotl_dataset_processes)
elif runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
data["dataset_processes"] = int(runpod_cpu_count)
else:
data["dataset_processes"] = os.cpu_count()
return data