Improve Dataset Processing Multiprocessing, Sharding, and Qwen Tokenizer Bug Fix. (#2918)
* Added a feature to save prepared dataset in specified shards, removed limiter on multiprocessing during tokenization, and a bug fix of qwen tokenizer * removed limiters and fixed config variable name * black lint * chore: lint * feat: update handling of dataset_processes --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -2,7 +2,6 @@
|
|||||||
chat dataset module
|
chat dataset module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@@ -41,14 +40,10 @@ class TokenizedChatDataset(Dataset):
|
|||||||
)
|
)
|
||||||
return ex.tokenized(model_transform)
|
return ex.tokenized(model_transform)
|
||||||
|
|
||||||
process_or_cpu_count: int = (
|
|
||||||
process_count or os.cpu_count() # type: ignore[assignment]
|
|
||||||
)
|
|
||||||
num_proc = min(32, process_or_cpu_count)
|
|
||||||
features = data.features.keys()
|
features = data.features.keys()
|
||||||
tokenized_data = data.map(
|
tokenized_data = data.map(
|
||||||
map_fn,
|
map_fn,
|
||||||
num_proc=num_proc,
|
num_proc=process_count,
|
||||||
keep_in_memory=keep_in_memory,
|
keep_in_memory=keep_in_memory,
|
||||||
remove_columns=features,
|
remove_columns=features,
|
||||||
desc="Tokenizing Chats",
|
desc="Tokenizing Chats",
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
"""Module containing Dataset functionality"""
|
"""Module containing Dataset functionality"""
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
|
||||||
@@ -46,7 +44,6 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
|
|
||||||
def process(self, dataset):
|
def process(self, dataset):
|
||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
|
||||||
|
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
@@ -59,13 +56,13 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
):
|
):
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
self.prompt_tokenizer.filter_rows,
|
self.prompt_tokenizer.filter_rows,
|
||||||
num_proc=num_proc,
|
num_proc=self.process_count,
|
||||||
desc="Strategy Filtering Rows",
|
desc="Strategy Filtering Rows",
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
num_proc=num_proc,
|
num_proc=self.process_count,
|
||||||
remove_columns=features,
|
remove_columns=features,
|
||||||
keep_in_memory=self.keep_in_memory,
|
keep_in_memory=self.keep_in_memory,
|
||||||
desc="Tokenizing Prompts",
|
desc="Tokenizing Prompts",
|
||||||
|
|||||||
@@ -188,7 +188,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
# Qwen base only has single token, so we need to set the special tokens
|
# Qwen base only has single token, so we need to set the special tokens
|
||||||
if cfg.is_qwen_derived_model:
|
# the following check is for Qwen1 base models
|
||||||
|
if cfg.is_qwen_derived_model and hasattr(tokenizer, "eod_id"):
|
||||||
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
||||||
for attr_name in token_ids:
|
for attr_name in token_ids:
|
||||||
if getattr(tokenizer, attr_name) is None:
|
if getattr(tokenizer, attr_name) is None:
|
||||||
|
|||||||
@@ -148,8 +148,6 @@ def normalize_config(cfg):
|
|||||||
f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations."
|
f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations."
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
|
||||||
|
|
||||||
if not cfg.base_model_config:
|
if not cfg.base_model_config:
|
||||||
cfg.base_model_config = cfg.base_model
|
cfg.base_model_config = cfg.base_model
|
||||||
|
|
||||||
|
|||||||
@@ -410,9 +410,8 @@ def save_preprocessed_dataset(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
||||||
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
||||||
|
num_workers = cfg.dataset_processes
|
||||||
if isinstance(dataset, IterableDataset):
|
if isinstance(dataset, IterableDataset):
|
||||||
num_workers = cfg.dataset_processes
|
|
||||||
|
|
||||||
ds_from_iter = Dataset.from_generator(
|
ds_from_iter = Dataset.from_generator(
|
||||||
functools.partial(_generate_from_iterable_dataset, dataset),
|
functools.partial(_generate_from_iterable_dataset, dataset),
|
||||||
features=dataset.features,
|
features=dataset.features,
|
||||||
@@ -423,10 +422,20 @@ def save_preprocessed_dataset(
|
|||||||
"num_workers": [num_workers] * num_workers,
|
"num_workers": [num_workers] * num_workers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
ds_from_iter.save_to_disk(
|
||||||
|
str(prepared_ds_path),
|
||||||
|
num_proc=num_workers,
|
||||||
|
max_shard_size=None,
|
||||||
|
num_shards=cfg.num_dataset_shards_to_save,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
os.makedirs(prepared_ds_path, exist_ok=True)
|
os.makedirs(prepared_ds_path, exist_ok=True)
|
||||||
dataset.save_to_disk(str(prepared_ds_path))
|
dataset.save_to_disk(
|
||||||
|
str(prepared_ds_path),
|
||||||
|
num_proc=num_workers,
|
||||||
|
max_shard_size=None,
|
||||||
|
num_shards=cfg.num_dataset_shards_to_save,
|
||||||
|
)
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Pushing merged prepared dataset to Huggingface hub at "
|
"Pushing merged prepared dataset to Huggingface hub at "
|
||||||
|
|||||||
@@ -193,6 +193,12 @@ class AxolotlInputConfig(
|
|||||||
json_schema_extra={"description": "Index of shard to use for whole dataset"},
|
json_schema_extra={"description": "Index of shard to use for whole dataset"},
|
||||||
)
|
)
|
||||||
skip_prepare_dataset: bool | None = False
|
skip_prepare_dataset: bool | None = False
|
||||||
|
num_dataset_shards_to_save: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of shards to save the prepared dataset"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
pretraining_dataset: (
|
pretraining_dataset: (
|
||||||
Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
|
Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
|
||||||
@@ -203,11 +209,12 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
dataset_processes: int | None = Field(
|
dataset_processes: int | None = Field(
|
||||||
default=min(
|
default=None,
|
||||||
int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()
|
|
||||||
), # type: ignore[type-var]
|
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set."
|
"description": (
|
||||||
|
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
|
||||||
|
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
dataset_exact_deduplication: bool | None = Field(
|
dataset_exact_deduplication: bool | None = Field(
|
||||||
@@ -1199,3 +1206,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
data["dataloader_prefetch_factor"] = 256
|
data["dataloader_prefetch_factor"] = 256
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def default_dataset_processes(cls, data):
|
||||||
|
if data.get("dataset_processes") is None:
|
||||||
|
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
|
||||||
|
data["dataset_processes"] = int(axolotl_dataset_processes)
|
||||||
|
elif runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
|
||||||
|
data["dataset_processes"] = int(runpod_cpu_count)
|
||||||
|
else:
|
||||||
|
data["dataset_processes"] = os.cpu_count()
|
||||||
|
|
||||||
|
return data
|
||||||
|
|||||||
Reference in New Issue
Block a user