feat:add support dataset_num_processes (#3129) [skip ci]

* feat:add support dataset_num_processes

* chore

* required changes

* requested chnages

* required chnages

* required changes

* required changes

* elif get_default_process_count()

* add:del data

* Update cicd/Dockerfile.jinja

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update cicd/single_gpu.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
VED
2025-10-13 15:48:12 +05:30
committed by GitHub
parent 143dea4753
commit cd856b45b1
18 changed files with 57 additions and 34 deletions

View File

@@ -9,7 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_PROCESSES="8"
ENV AXOLOTL_DATASET_NUM_PROC="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm

View File

@@ -65,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
sp_env = os.environ.copy()
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
# Propagate errors from subprocess.
try:

View File

@@ -13,7 +13,7 @@ datasets:
val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model
dataset_prepared_path: temp_debug/axolotl_outputs/data
dataset_processes: 1
dataset_num_proc: 1
sequence_len: 4096
sample_packing: false

View File

@@ -29,7 +29,7 @@ While debugging it's helpful to simplify your test scenario as much as possible.
1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`.
1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing:
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
- Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`.
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
```yaml
@@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"-m", "axolotl.cli.train", "dev_chat_template.yml",
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_processes=1", // limits data preprocessing to one process
"--dataset_num_proc=1", // limits data preprocessing to one process
"--max_steps=1", // limits training to just one step
"--batch_size=1", // minimizes batch size
"--micro_batch_size=1", // minimizes batch size

View File

@@ -491,6 +491,7 @@ class TrainerBuilderBase(abc.ABC):
"dion_momentum",
"dion_rank_fraction",
"dion_rank_multiple_of",
"dataset_num_proc",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
@@ -514,9 +515,6 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
# max_length is not used in CausalTrainer
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len

View File

@@ -113,7 +113,7 @@ def _map_dataset(
dataset = dataset.map(
ds_transform_fn,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Mapping RL Dataset",
**map_kwargs,
@@ -234,7 +234,7 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)

View File

@@ -409,7 +409,7 @@ def save_preprocessed_dataset(
) -> None:
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
num_workers = cfg.dataset_processes or get_default_process_count()
num_workers = cfg.dataset_num_proc or get_default_process_count()
if isinstance(dataset, IterableDataset):
ds_from_iter = Dataset.from_generator(
functools.partial(_generate_from_iterable_dataset, dataset),

View File

@@ -223,7 +223,7 @@ def handle_long_seq_in_dataset(
filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}

View File

@@ -80,7 +80,7 @@ def get_dataset_wrapper(
"""
# Common parameters for dataset wrapping
dataset_kwargs: dict[str, Any] = {
"process_count": cfg.dataset_processes,
"process_count": cfg.dataset_num_proc,
"keep_in_memory": cfg.dataset_keep_in_memory is True,
}

View File

@@ -4,6 +4,8 @@ import os
def get_default_process_count():
if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"):
return int(axolotl_dataset_num_proc)
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
return int(axolotl_dataset_processes)
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):

View File

@@ -234,6 +234,7 @@ class AxolotlInputConfig(
)
dataset_processes: int | None = Field(
default=None,
deprecated="Use `dataset_num_proc` instead. This parameter will be removed in a future version.",
json_schema_extra={
"description": (
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
@@ -241,6 +242,16 @@ class AxolotlInputConfig(
)
},
)
dataset_num_proc: int | None = Field(
default=None,
json_schema_extra={
"description": (
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
)
},
)
dataset_exact_deduplication: bool | None = Field(
default=None,
json_schema_extra={
@@ -1314,10 +1325,22 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before")
@classmethod
def default_dataset_processes(cls, data):
if data.get("dataset_processes") is None:
data["dataset_processes"] = get_default_process_count()
def default_dataset_num_proc(cls, data):
if data.get("dataset_processes") is not None:
if data.get("dataset_num_proc") is None:
data["dataset_num_proc"] = data["dataset_processes"]
LOG.warning(
"dataset_processes is deprecated and will be removed in a future version. "
"Please use dataset_num_proc instead."
)
else:
LOG.warning(
"Both dataset_processes and dataset_num_proc are set. "
"Using dataset_num_proc and ignoring dataset_processes."
)
del data["dataset_processes"]
elif data.get("dataset_num_proc") is None:
data["dataset_num_proc"] = get_default_process_count()
return data
@model_validator(mode="before")

View File

@@ -278,7 +278,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
prior_len = None
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
@@ -318,7 +318,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length",
)
@@ -335,7 +335,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
@@ -344,7 +344,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
@@ -469,7 +469,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially,
drop_last=True,
num_processes=cfg.dataset_processes,
num_processes=cfg.dataset_prcoesses,
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
)

View File

@@ -440,7 +440,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
]
else:
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
cfg["dataset_processes"] = 4
cfg["dataset_num_proc"] = 4
if cfg_string == "grpo_cfg":
rewards_dir = tmp_path / "rewards_test"

View File

@@ -69,7 +69,7 @@ class TestActivationCheckpointing:
"save_safetensors": True,
"gradient_checkpointing": gradient_checkpointing,
"save_first_step": False,
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)

View File

@@ -29,7 +29,7 @@ class TestPretrainLlama:
"sequence_len": 1024,
"sample_packing": sample_packing,
"pretrain_multipack_attn": pretrain_multipack_attn,
"dataset_processes": 1,
"dataset_num_proc": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},

View File

@@ -141,7 +141,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
@@ -180,7 +180,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
@@ -219,7 +219,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
@@ -252,7 +252,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
@@ -285,7 +285,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
@@ -370,7 +370,7 @@ class TestDatasetPreparation:
"rl": "dpo",
"chat_template": "llama3",
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
@@ -471,7 +471,7 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)

View File

@@ -210,7 +210,7 @@ class TestDeduplicateRLDataset:
ALPACA_MESSAGES_CONFIG_REVISION,
ALPACA_MESSAGES_CONFIG_REVISION,
],
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)
yield fixture

View File

@@ -55,7 +55,7 @@ class TestPacking(unittest.TestCase):
"type": "alpaca",
},
],
"dataset_processes": 4,
"dataset_num_proc": 4,
"num_epochs": 1,
"max_steps": 20,
"save_steps": 10,