diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 3a3657240..bf7a30f48 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -25,6 +25,7 @@ from huggingface_hub.errors import ( from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 +from axolotl.utils.datasets import get_default_process_count from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -410,7 +411,7 @@ def save_preprocessed_dataset( ) -> None: """Save preprocessed dataset to disk and optionally push to the HF Hub.""" prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash) - num_workers = cfg.dataset_processes + num_workers = cfg.dataset_processes or get_default_process_count() if isinstance(dataset, IterableDataset): ds_from_iter = Dataset.from_generator( functools.partial(_generate_from_iterable_dataset, dataset), @@ -432,7 +433,7 @@ def save_preprocessed_dataset( os.makedirs(prepared_ds_path, exist_ok=True) dataset.save_to_disk( str(prepared_ds_path), - num_proc=num_workers, + num_proc=min(max(1, len(dataset) // 8), num_workers), max_shard_size=None, num_shards=cfg.num_dataset_shards_to_save, ) diff --git a/src/axolotl/utils/datasets.py b/src/axolotl/utils/datasets.py new file mode 100644 index 000000000..93e1a2416 --- /dev/null +++ b/src/axolotl/utils/datasets.py @@ -0,0 +1,11 @@ +"""helper functions for datasets""" + +import os + + +def get_default_process_count(): + if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"): + return int(axolotl_dataset_processes) + if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"): + return int(runpod_cpu_count) + return os.cpu_count() diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index d3fb0b14c..96e3a8a3e 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -2,7 +2,6 @@ # pylint: disable=too-many-lines -import os from typing import Annotated, Any, Literal from annotated_types import MinLen @@ -15,6 +14,7 @@ from pydantic import ( model_validator, ) +from axolotl.utils.datasets import get_default_process_count from axolotl.utils.logging import get_logger from axolotl.utils.schemas.datasets import ( DatasetConfig, @@ -1211,11 +1211,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): @classmethod def default_dataset_processes(cls, data): if data.get("dataset_processes") is None: - if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"): - data["dataset_processes"] = int(axolotl_dataset_processes) - elif runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"): - data["dataset_processes"] = int(runpod_cpu_count) - else: - data["dataset_processes"] = os.cpu_count() + data["dataset_processes"] = get_default_process_count() return data diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index 0053b4d27..040152beb 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -82,6 +82,7 @@ def fixture_base_cfg(): "ddp_timeout": 1800, "ddp_bucket_cap_mb": 25, "ddp_broadcast_buffers": False, + "dataset_processes": 4, } ) @@ -440,6 +441,7 @@ def rand_reward_func(prompts, completions) -> list[float]: ] else: raise ValueError(f"Unhandled cfg_string: {cfg_string}") + cfg["dataset_processes"] = 4 if cfg_string == "grpo_cfg": rewards_dir = tmp_path / "rewards_test" diff --git a/tests/test_datasets.py b/tests/test_datasets.py index f4730f0f1..719dfdc19 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -141,6 +141,7 @@ class TestDatasetPreparation: "type": "alpaca", }, ], + "dataset_processes": 4, } ) @@ -179,6 +180,7 @@ class TestDatasetPreparation: "type": "alpaca", }, ], + "dataset_processes": 4, } ) @@ -217,6 +219,7 @@ class TestDatasetPreparation: "type": "alpaca", }, ], + "dataset_processes": 4, } ) @@ -249,6 +252,7 @@ class TestDatasetPreparation: "type": "alpaca", }, ], + "dataset_processes": 4, } ) @@ -281,6 +285,7 @@ class TestDatasetPreparation: "type": "alpaca", }, ], + "dataset_processes": 4, } ) @@ -365,6 +370,7 @@ class TestDatasetPreparation: "rl": "dpo", "chat_template": "llama3", "datasets": [ALPACA_MESSAGES_CONFIG_REVISION], + "dataset_processes": 4, } ) @@ -466,6 +472,7 @@ class TestDatasetPreparation: "type": "alpaca", }, ], + "dataset_processes": 4, } ) diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 45a327a40..d97aad8ea 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -210,6 +210,7 @@ class TestDeduplicateRLDataset: ALPACA_MESSAGES_CONFIG_REVISION, ALPACA_MESSAGES_CONFIG_REVISION, ], + "dataset_processes": 4, } ) yield fixture diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py index 8b29eab21..699d5e6cc 100644 --- a/tests/test_packed_dataset.py +++ b/tests/test_packed_dataset.py @@ -99,6 +99,7 @@ class TestPacking(unittest.TestCase): "type": "alpaca", }, ], + "dataset_processes": 4, "num_epochs": 1, "max_steps": 20, "save_steps": 10,