feat:add support dataset_num_processes (#3129) [skip ci]

* feat:add support dataset_num_processes

* chore

* required changes

* requested chnages

* required chnages

* required changes

* required changes

* elif get_default_process_count()

* add:del data

* Update cicd/Dockerfile.jinja

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update cicd/single_gpu.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
VED
2025-10-13 15:48:12 +05:30
committed by GitHub
parent 143dea4753
commit cd856b45b1
18 changed files with 57 additions and 34 deletions

View File

@@ -491,6 +491,7 @@ class TrainerBuilderBase(abc.ABC):
"dion_momentum",
"dion_rank_fraction",
"dion_rank_multiple_of",
"dataset_num_proc",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
@@ -514,9 +515,6 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len

View File

@@ -113,7 +113,7 @@ def _map_dataset(
dataset = dataset.map(
ds_transform_fn,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Mapping RL Dataset",
**map_kwargs,
@@ -234,7 +234,7 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)

View File

@@ -409,7 +409,7 @@ def save_preprocessed_dataset(
) -> None:
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
num_workers = cfg.dataset_processes or get_default_process_count()
num_workers = cfg.dataset_num_proc or get_default_process_count()
if isinstance(dataset, IterableDataset):
ds_from_iter = Dataset.from_generator(
functools.partial(_generate_from_iterable_dataset, dataset),

View File

@@ -223,7 +223,7 @@ def handle_long_seq_in_dataset(
filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}

View File

@@ -80,7 +80,7 @@ def get_dataset_wrapper(
"""
# Common parameters for dataset wrapping
dataset_kwargs: dict[str, Any] = {
"process_count": cfg.dataset_processes,
"process_count": cfg.dataset_num_proc,
"keep_in_memory": cfg.dataset_keep_in_memory is True,
}

View File

@@ -4,6 +4,8 @@ import os
def get_default_process_count():
if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"):
return int(axolotl_dataset_num_proc)
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
return int(axolotl_dataset_processes)
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):

View File

@@ -234,6 +234,7 @@ class AxolotlInputConfig(
)
dataset_processes: int | None = Field(
default=None,
deprecated="Use `dataset_num_proc` instead. This parameter will be removed in a future version.",
json_schema_extra={
"description": (
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
@@ -241,6 +242,16 @@ class AxolotlInputConfig(
)
},
)
dataset_num_proc: int | None = Field(
default=None,
json_schema_extra={
"description": (
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
)
},
)
dataset_exact_deduplication: bool | None = Field(
default=None,
json_schema_extra={
@@ -1314,10 +1325,22 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before")
@classmethod
def default_dataset_processes(cls, data):
if data.get("dataset_processes") is None:
data["dataset_processes"] = get_default_process_count()
def default_dataset_num_proc(cls, data):
if data.get("dataset_processes") is not None:
if data.get("dataset_num_proc") is None:
data["dataset_num_proc"] = data["dataset_processes"]
LOG.warning(
"dataset_processes is deprecated and will be removed in a future version. "
"Please use dataset_num_proc instead."
)
else:
LOG.warning(
"Both dataset_processes and dataset_num_proc are set. "
"Using dataset_num_proc and ignoring dataset_processes."
)
del data["dataset_processes"]
elif data.get("dataset_num_proc") is None:
data["dataset_num_proc"] = get_default_process_count()
return data
@model_validator(mode="before")

View File

@@ -278,7 +278,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
prior_len = None
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
@@ -318,7 +318,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length",
)
@@ -335,7 +335,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
@@ -344,7 +344,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
@@ -469,7 +469,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially,
drop_last=True,
num_processes=cfg.dataset_processes,
num_processes=cfg.dataset_prcoesses,
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
)