diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 94c9a67e3..6a1ddb66d 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -9,7 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}" ENV GITHUB_SHA="{{ GITHUB_SHA }}" ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}" ENV HF_HOME="{{ HF_HOME }}" -ENV AXOLOTL_DATASET_PROCESSES="8" +ENV AXOLOTL_DATASET_NUM_PROC="8" RUN apt-get update && \ apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm diff --git a/cicd/single_gpu.py b/cicd/single_gpu.py index 3bca5806f..cd73f60b8 100644 --- a/cicd/single_gpu.py +++ b/cicd/single_gpu.py @@ -65,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str): import subprocess # nosec sp_env = os.environ.copy() - sp_env["AXOLOTL_DATASET_PROCESSES"] = "8" + sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8" # Propagate errors from subprocess. try: diff --git a/devtools/dev_chat_template.yml b/devtools/dev_chat_template.yml index 27dc9be1a..32d5e56a0 100644 --- a/devtools/dev_chat_template.yml +++ b/devtools/dev_chat_template.yml @@ -13,7 +13,7 @@ datasets: val_set_size: 0 output_dir: temp_debug/axolotl_outputs/model dataset_prepared_path: temp_debug/axolotl_outputs/data -dataset_processes: 1 +dataset_num_proc: 1 sequence_len: 4096 sample_packing: false diff --git a/docs/debugging.qmd b/docs/debugging.qmd index bf3c6fe7e..04b4faa64 100644 --- a/docs/debugging.qmd +++ b/docs/debugging.qmd @@ -29,7 +29,7 @@ While debugging it's helpful to simplify your test scenario as much as possible. 1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`. 1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing: - Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`. - - Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`. + - Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`. 2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config): ```yaml @@ -101,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler "-m", "axolotl.cli.train", "dev_chat_template.yml", // The flags below simplify debugging by overriding the axolotl config // with the debugging tips above. Modify as needed. - "--dataset_processes=1", // limits data preprocessing to one process + "--dataset_num_proc=1", // limits data preprocessing to one process "--max_steps=1", // limits training to just one step "--batch_size=1", // minimizes batch size "--micro_batch_size=1", // minimizes batch size diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 3ad8012f9..8c86e335e 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -491,6 +491,7 @@ class TrainerBuilderBase(abc.ABC): "dion_momentum", "dion_rank_fraction", "dion_rank_multiple_of", + "dataset_num_proc", ]: if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: training_args_kwargs[arg] = getattr(self.cfg, arg) @@ -514,9 +515,6 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1 training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs - if self.cfg.dataset_processes: - training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes - # max_length is not used in CausalTrainer if self.cfg.reward_model or self.cfg.rl: training_args_kwargs["max_length"] = self.cfg.sequence_len diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index d371c9acb..f7a5ec04c 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -113,7 +113,7 @@ def _map_dataset( dataset = dataset.map( ds_transform_fn, - num_proc=cfg.dataset_processes, + num_proc=cfg.dataset_num_proc, load_from_cache_file=not cfg.is_preprocess, desc="Mapping RL Dataset", **map_kwargs, @@ -234,7 +234,7 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: prior_len = len(split_datasets[i]) split_datasets[i] = split_datasets[i].filter( drop_long, - num_proc=cfg.dataset_processes, + num_proc=cfg.dataset_num_proc, load_from_cache_file=not cfg.is_preprocess, desc="Dropping Long Sequences", ) diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 6b6e0e281..c9a91b829 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -409,7 +409,7 @@ def save_preprocessed_dataset( ) -> None: """Save preprocessed dataset to disk and optionally push to the HF Hub.""" prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash) - num_workers = cfg.dataset_processes or get_default_process_count() + num_workers = cfg.dataset_num_proc or get_default_process_count() if isinstance(dataset, IterableDataset): ds_from_iter = Dataset.from_generator( functools.partial(_generate_from_iterable_dataset, dataset), diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 445a65d6c..2d0ca9d0e 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -223,7 +223,7 @@ def handle_long_seq_in_dataset( filter_map_kwargs = {} if not isinstance(dataset, IterableDataset): - filter_map_kwargs["num_proc"] = cfg.dataset_processes + filter_map_kwargs["num_proc"] = cfg.dataset_num_proc filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess drop_long_kwargs = {} diff --git a/src/axolotl/utils/data/wrappers.py b/src/axolotl/utils/data/wrappers.py index cb9e2c6b4..3a10bde00 100644 --- a/src/axolotl/utils/data/wrappers.py +++ b/src/axolotl/utils/data/wrappers.py @@ -80,7 +80,7 @@ def get_dataset_wrapper( """ # Common parameters for dataset wrapping dataset_kwargs: dict[str, Any] = { - "process_count": cfg.dataset_processes, + "process_count": cfg.dataset_num_proc, "keep_in_memory": cfg.dataset_keep_in_memory is True, } diff --git a/src/axolotl/utils/datasets.py b/src/axolotl/utils/datasets.py index 93e1a2416..9b8a8e25a 100644 --- a/src/axolotl/utils/datasets.py +++ b/src/axolotl/utils/datasets.py @@ -4,6 +4,8 @@ import os def get_default_process_count(): + if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"): + return int(axolotl_dataset_num_proc) if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"): return int(axolotl_dataset_processes) if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"): diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 7cf8c3b4a..4d1d0aab2 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -234,6 +234,7 @@ class AxolotlInputConfig( ) dataset_processes: int | None = Field( default=None, + deprecated="Use `dataset_num_proc` instead. This parameter will be removed in a future version.", json_schema_extra={ "description": ( "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n" @@ -241,6 +242,16 @@ class AxolotlInputConfig( ) }, ) + dataset_num_proc: int | None = Field( + default=None, + json_schema_extra={ + "description": ( + "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n" + "For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT." + ) + }, + ) + dataset_exact_deduplication: bool | None = Field( default=None, json_schema_extra={ @@ -1314,10 +1325,22 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): @model_validator(mode="before") @classmethod - def default_dataset_processes(cls, data): - if data.get("dataset_processes") is None: - data["dataset_processes"] = get_default_process_count() - + def default_dataset_num_proc(cls, data): + if data.get("dataset_processes") is not None: + if data.get("dataset_num_proc") is None: + data["dataset_num_proc"] = data["dataset_processes"] + LOG.warning( + "dataset_processes is deprecated and will be removed in a future version. " + "Please use dataset_num_proc instead." + ) + else: + LOG.warning( + "Both dataset_processes and dataset_num_proc are set. " + "Using dataset_num_proc and ignoring dataset_processes." + ) + del data["dataset_processes"] + elif data.get("dataset_num_proc") is None: + data["dataset_num_proc"] = get_default_process_count() return data @model_validator(mode="before") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f2f8279f3..d97577d86 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -278,7 +278,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): prior_len = None filter_map_kwargs = {} if not isinstance(train_dataset, IterableDataset): - filter_map_kwargs["num_proc"] = cfg.dataset_processes + filter_map_kwargs["num_proc"] = cfg.dataset_num_proc filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess drop_long_kwargs = {} @@ -318,7 +318,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if cfg.group_by_length: train_dataset = train_dataset.map( add_length, - num_proc=cfg.dataset_processes, + num_proc=cfg.dataset_num_proc, load_from_cache_file=not cfg.is_preprocess, desc="Group By Length", ) @@ -335,7 +335,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): ) train_dataset = train_dataset.map( pose_fn, - num_proc=cfg.dataset_processes, + num_proc=cfg.dataset_num_proc, load_from_cache_file=not cfg.is_preprocess, desc="Add position_id column (PoSE)", ) @@ -344,7 +344,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset: eval_dataset = eval_dataset.map( pose_fn, - num_proc=cfg.dataset_processes, + num_proc=cfg.dataset_num_proc, load_from_cache_file=not cfg.is_preprocess, desc="Add position_id column (PoSE)", ) @@ -469,7 +469,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): bin_size=cfg.sample_packing_bin_size, sequential=cfg.sample_packing_sequentially, drop_last=True, - num_processes=cfg.dataset_processes, + num_processes=cfg.dataset_prcoesses, mp_start_method=cfg.sample_packing_mp_start_method or "fork", ) diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index 6428aa977..67481b2ad 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -440,7 +440,7 @@ def rand_reward_func(prompts, completions) -> list[float]: ] else: raise ValueError(f"Unhandled cfg_string: {cfg_string}") - cfg["dataset_processes"] = 4 + cfg["dataset_num_proc"] = 4 if cfg_string == "grpo_cfg": rewards_dir = tmp_path / "rewards_test" diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py index ddace8ef1..e8006c162 100644 --- a/tests/e2e/patched/test_activation_checkpointing.py +++ b/tests/e2e/patched/test_activation_checkpointing.py @@ -69,7 +69,7 @@ class TestActivationCheckpointing: "save_safetensors": True, "gradient_checkpointing": gradient_checkpointing, "save_first_step": False, - "dataset_processes": 4, + "dataset_num_proc": 4, } ) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index a041244e7..f0daa9dd6 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -29,7 +29,7 @@ class TestPretrainLlama: "sequence_len": 1024, "sample_packing": sample_packing, "pretrain_multipack_attn": pretrain_multipack_attn, - "dataset_processes": 1, + "dataset_num_proc": 1, "special_tokens": { "pad_token": "<|endoftext|>", }, diff --git a/tests/test_datasets.py b/tests/test_datasets.py index ea5ee368d..bd1c8f2c2 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -141,7 +141,7 @@ class TestDatasetPreparation: "type": "alpaca", }, ], - "dataset_processes": 4, + "dataset_num_proc": 4, } ) @@ -180,7 +180,7 @@ class TestDatasetPreparation: "type": "alpaca", }, ], - "dataset_processes": 4, + "dataset_num_proc": 4, } ) @@ -219,7 +219,7 @@ class TestDatasetPreparation: "type": "alpaca", }, ], - "dataset_processes": 4, + "dataset_num_proc": 4, } ) @@ -252,7 +252,7 @@ class TestDatasetPreparation: "type": "alpaca", }, ], - "dataset_processes": 4, + "dataset_num_proc": 4, } ) @@ -285,7 +285,7 @@ class TestDatasetPreparation: "type": "alpaca", }, ], - "dataset_processes": 4, + "dataset_num_proc": 4, } ) @@ -370,7 +370,7 @@ class TestDatasetPreparation: "rl": "dpo", "chat_template": "llama3", "datasets": [ALPACA_MESSAGES_CONFIG_REVISION], - "dataset_processes": 4, + "dataset_num_proc": 4, } ) @@ -471,7 +471,7 @@ class TestDatasetPreparation: "type": "alpaca", }, ], - "dataset_processes": 4, + "dataset_num_proc": 4, } ) diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 65deb5209..a519db525 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -210,7 +210,7 @@ class TestDeduplicateRLDataset: ALPACA_MESSAGES_CONFIG_REVISION, ALPACA_MESSAGES_CONFIG_REVISION, ], - "dataset_processes": 4, + "dataset_num_proc": 4, } ) yield fixture diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py index 64f314e2e..953d523af 100644 --- a/tests/test_packed_dataset.py +++ b/tests/test_packed_dataset.py @@ -55,7 +55,7 @@ class TestPacking(unittest.TestCase): "type": "alpaca", }, ], - "dataset_processes": 4, + "dataset_num_proc": 4, "num_epochs": 1, "max_steps": 20, "save_steps": 10,