limit num_proc when saving datasets to disk (#2948) [skip ci]
* limit num_proc when saving datasets to disk * enforce at least 1 in case it rounds down to 0, and sane divisor is at least 8 rows per worker to save * update fixtures with dataset processes since that should never be NoneType * improve reusability for tests
This commit is contained in:
@@ -25,6 +25,7 @@ from huggingface_hub.errors import (
|
|||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||||
|
from axolotl.utils.datasets import get_default_process_count
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -410,7 +411,7 @@ def save_preprocessed_dataset(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
||||||
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
||||||
num_workers = cfg.dataset_processes
|
num_workers = cfg.dataset_processes or get_default_process_count()
|
||||||
if isinstance(dataset, IterableDataset):
|
if isinstance(dataset, IterableDataset):
|
||||||
ds_from_iter = Dataset.from_generator(
|
ds_from_iter = Dataset.from_generator(
|
||||||
functools.partial(_generate_from_iterable_dataset, dataset),
|
functools.partial(_generate_from_iterable_dataset, dataset),
|
||||||
@@ -432,7 +433,7 @@ def save_preprocessed_dataset(
|
|||||||
os.makedirs(prepared_ds_path, exist_ok=True)
|
os.makedirs(prepared_ds_path, exist_ok=True)
|
||||||
dataset.save_to_disk(
|
dataset.save_to_disk(
|
||||||
str(prepared_ds_path),
|
str(prepared_ds_path),
|
||||||
num_proc=num_workers,
|
num_proc=min(max(1, len(dataset) // 8), num_workers),
|
||||||
max_shard_size=None,
|
max_shard_size=None,
|
||||||
num_shards=cfg.num_dataset_shards_to_save,
|
num_shards=cfg.num_dataset_shards_to_save,
|
||||||
)
|
)
|
||||||
|
|||||||
11
src/axolotl/utils/datasets.py
Normal file
11
src/axolotl/utils/datasets.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""helper functions for datasets"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_process_count():
|
||||||
|
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
|
||||||
|
return int(axolotl_dataset_processes)
|
||||||
|
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
|
||||||
|
return int(runpod_cpu_count)
|
||||||
|
return os.cpu_count()
|
||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from annotated_types import MinLen
|
from annotated_types import MinLen
|
||||||
@@ -15,6 +14,7 @@ from pydantic import (
|
|||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.utils.datasets import get_default_process_count
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.datasets import (
|
from axolotl.utils.schemas.datasets import (
|
||||||
DatasetConfig,
|
DatasetConfig,
|
||||||
@@ -1211,11 +1211,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def default_dataset_processes(cls, data):
|
def default_dataset_processes(cls, data):
|
||||||
if data.get("dataset_processes") is None:
|
if data.get("dataset_processes") is None:
|
||||||
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
|
data["dataset_processes"] = get_default_process_count()
|
||||||
data["dataset_processes"] = int(axolotl_dataset_processes)
|
|
||||||
elif runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
|
|
||||||
data["dataset_processes"] = int(runpod_cpu_count)
|
|
||||||
else:
|
|
||||||
data["dataset_processes"] = os.cpu_count()
|
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ def fixture_base_cfg():
|
|||||||
"ddp_timeout": 1800,
|
"ddp_timeout": 1800,
|
||||||
"ddp_bucket_cap_mb": 25,
|
"ddp_bucket_cap_mb": 25,
|
||||||
"ddp_broadcast_buffers": False,
|
"ddp_broadcast_buffers": False,
|
||||||
|
"dataset_processes": 4,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -440,6 +441,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
|
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
|
||||||
|
cfg["dataset_processes"] = 4
|
||||||
|
|
||||||
if cfg_string == "grpo_cfg":
|
if cfg_string == "grpo_cfg":
|
||||||
rewards_dir = tmp_path / "rewards_test"
|
rewards_dir = tmp_path / "rewards_test"
|
||||||
|
|||||||
@@ -141,6 +141,7 @@ class TestDatasetPreparation:
|
|||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"dataset_processes": 4,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -179,6 +180,7 @@ class TestDatasetPreparation:
|
|||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"dataset_processes": 4,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -217,6 +219,7 @@ class TestDatasetPreparation:
|
|||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"dataset_processes": 4,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -249,6 +252,7 @@ class TestDatasetPreparation:
|
|||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"dataset_processes": 4,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -281,6 +285,7 @@ class TestDatasetPreparation:
|
|||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"dataset_processes": 4,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -365,6 +370,7 @@ class TestDatasetPreparation:
|
|||||||
"rl": "dpo",
|
"rl": "dpo",
|
||||||
"chat_template": "llama3",
|
"chat_template": "llama3",
|
||||||
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
|
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
|
||||||
|
"dataset_processes": 4,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -466,6 +472,7 @@ class TestDatasetPreparation:
|
|||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"dataset_processes": 4,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -210,6 +210,7 @@ class TestDeduplicateRLDataset:
|
|||||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||||
],
|
],
|
||||||
|
"dataset_processes": 4,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
yield fixture
|
yield fixture
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ class TestPacking(unittest.TestCase):
|
|||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"dataset_processes": 4,
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 20,
|
"max_steps": 20,
|
||||||
"save_steps": 10,
|
"save_steps": 10,
|
||||||
|
|||||||
Reference in New Issue
Block a user