Improve Dataset Processing Multiprocessing, Sharding, and Qwen Tokenizer Bug Fix. (#2918)
* Added a feature to save prepared dataset in specified shards, removed limiter on multiprocessing during tokenization, and a bug fix of qwen tokenizer * removed limiters and fixed config variable name * black lint * chore: lint * feat: update handling of dataset_processes --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
chat dataset module
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from datasets import Dataset
|
||||
@@ -41,14 +40,10 @@ class TokenizedChatDataset(Dataset):
|
||||
)
|
||||
return ex.tokenized(model_transform)
|
||||
|
||||
process_or_cpu_count: int = (
|
||||
process_count or os.cpu_count() # type: ignore[assignment]
|
||||
)
|
||||
num_proc = min(32, process_or_cpu_count)
|
||||
features = data.features.keys()
|
||||
tokenized_data = data.map(
|
||||
map_fn,
|
||||
num_proc=num_proc,
|
||||
num_proc=process_count,
|
||||
keep_in_memory=keep_in_memory,
|
||||
remove_columns=features,
|
||||
desc="Tokenizing Chats",
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Module containing Dataset functionality"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
|
||||
@@ -46,7 +44,6 @@ class TokenizedPromptDataset(Dataset):
|
||||
|
||||
def process(self, dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
||||
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
@@ -59,13 +56,13 @@ class TokenizedPromptDataset(Dataset):
|
||||
):
|
||||
dataset = dataset.filter(
|
||||
self.prompt_tokenizer.filter_rows,
|
||||
num_proc=num_proc,
|
||||
num_proc=self.process_count,
|
||||
desc="Strategy Filtering Rows",
|
||||
)
|
||||
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
num_proc=self.process_count,
|
||||
remove_columns=features,
|
||||
keep_in_memory=self.keep_in_memory,
|
||||
desc="Tokenizing Prompts",
|
||||
|
||||
@@ -188,7 +188,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Qwen base only has single token, so we need to set the special tokens
|
||||
if cfg.is_qwen_derived_model:
|
||||
# the following check is for Qwen1 base models
|
||||
if cfg.is_qwen_derived_model and hasattr(tokenizer, "eod_id"):
|
||||
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
||||
for attr_name in token_ids:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
|
||||
@@ -148,8 +148,6 @@ def normalize_config(cfg):
|
||||
f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations."
|
||||
)
|
||||
|
||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
||||
|
||||
if not cfg.base_model_config:
|
||||
cfg.base_model_config = cfg.base_model
|
||||
|
||||
|
||||
@@ -410,9 +410,8 @@ def save_preprocessed_dataset(
|
||||
) -> None:
|
||||
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
||||
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
||||
num_workers = cfg.dataset_processes
|
||||
if isinstance(dataset, IterableDataset):
|
||||
num_workers = cfg.dataset_processes
|
||||
|
||||
ds_from_iter = Dataset.from_generator(
|
||||
functools.partial(_generate_from_iterable_dataset, dataset),
|
||||
features=dataset.features,
|
||||
@@ -423,10 +422,20 @@ def save_preprocessed_dataset(
|
||||
"num_workers": [num_workers] * num_workers,
|
||||
},
|
||||
)
|
||||
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
||||
ds_from_iter.save_to_disk(
|
||||
str(prepared_ds_path),
|
||||
num_proc=num_workers,
|
||||
max_shard_size=None,
|
||||
num_shards=cfg.num_dataset_shards_to_save,
|
||||
)
|
||||
else:
|
||||
os.makedirs(prepared_ds_path, exist_ok=True)
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
dataset.save_to_disk(
|
||||
str(prepared_ds_path),
|
||||
num_proc=num_workers,
|
||||
max_shard_size=None,
|
||||
num_shards=cfg.num_dataset_shards_to_save,
|
||||
)
|
||||
if cfg.push_dataset_to_hub:
|
||||
LOG.info(
|
||||
"Pushing merged prepared dataset to Huggingface hub at "
|
||||
|
||||
@@ -193,6 +193,12 @@ class AxolotlInputConfig(
|
||||
json_schema_extra={"description": "Index of shard to use for whole dataset"},
|
||||
)
|
||||
skip_prepare_dataset: bool | None = False
|
||||
num_dataset_shards_to_save: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of shards to save the prepared dataset"
|
||||
},
|
||||
)
|
||||
|
||||
pretraining_dataset: (
|
||||
Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
|
||||
@@ -203,11 +209,12 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
dataset_processes: int | None = Field(
|
||||
default=min(
|
||||
int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()
|
||||
), # type: ignore[type-var]
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set."
|
||||
"description": (
|
||||
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
|
||||
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
|
||||
)
|
||||
},
|
||||
)
|
||||
dataset_exact_deduplication: bool | None = Field(
|
||||
@@ -1199,3 +1206,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
data["dataloader_prefetch_factor"] = 256
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def default_dataset_processes(cls, data):
|
||||
if data.get("dataset_processes") is None:
|
||||
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
|
||||
data["dataset_processes"] = int(axolotl_dataset_processes)
|
||||
elif runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
|
||||
data["dataset_processes"] = int(runpod_cpu_count)
|
||||
else:
|
||||
data["dataset_processes"] = os.cpu_count()
|
||||
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user