limit num_proc when saving datasets to disk (#2948) [skip ci]

* limit num_proc when saving datasets to disk

* enforce at least 1 in case it rounds down to 0, and sane divisor is at least 8 rows per worker to save

* update fixtures with dataset processes since that should never be NoneType

* improve reusability for tests
This commit is contained in:
Wing Lian
2025-07-21 11:39:38 -04:00
committed by GitHub
parent 8e5f146701
commit db5f6f4693
7 changed files with 27 additions and 9 deletions

View File

@@ -25,6 +25,7 @@ from huggingface_hub.errors import (
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.datasets import get_default_process_count
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -410,7 +411,7 @@ def save_preprocessed_dataset(
) -> None:
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
num_workers = cfg.dataset_processes
num_workers = cfg.dataset_processes or get_default_process_count()
if isinstance(dataset, IterableDataset):
ds_from_iter = Dataset.from_generator(
functools.partial(_generate_from_iterable_dataset, dataset),
@@ -432,7 +433,7 @@ def save_preprocessed_dataset(
os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk(
str(prepared_ds_path),
num_proc=num_workers,
num_proc=min(max(1, len(dataset) // 8), num_workers),
max_shard_size=None,
num_shards=cfg.num_dataset_shards_to_save,
)

View File

@@ -0,0 +1,11 @@
"""helper functions for datasets"""
import os
def get_default_process_count():
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
return int(axolotl_dataset_processes)
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
return int(runpod_cpu_count)
return os.cpu_count()

View File

@@ -2,7 +2,6 @@
# pylint: disable=too-many-lines
import os
from typing import Annotated, Any, Literal
from annotated_types import MinLen
@@ -15,6 +14,7 @@ from pydantic import (
model_validator,
)
from axolotl.utils.datasets import get_default_process_count
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import (
DatasetConfig,
@@ -1211,11 +1211,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@classmethod
def default_dataset_processes(cls, data):
if data.get("dataset_processes") is None:
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
data["dataset_processes"] = int(axolotl_dataset_processes)
elif runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
data["dataset_processes"] = int(runpod_cpu_count)
else:
data["dataset_processes"] = os.cpu_count()
data["dataset_processes"] = get_default_process_count()
return data