From 759cefb74154d3bdbedd9bf21f9388b1ffced4bf Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 7 Jul 2025 10:10:58 -0400 Subject: [PATCH] setup defaults for dataloader to ensure GPU is kept busy (#2632) [skip ci] --- src/axolotl/utils/schemas/config.py | 14 ++++++++++++++ tests/patched/test_validation.py | 15 +++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 6481202c7..21b69824f 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1114,3 +1114,17 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): raise ValueError("QAT is not supported on torch version < 2.6.0") return data + + @model_validator(mode="before") + @classmethod + def default_dataloader_opts(cls, data): + if ( + data.get("dataloader_num_workers") is None + and data.get("dataloader_pin_memory") is None + and data.get("dataloader_prefetch_factor") is None + ): + data["dataloader_num_workers"] = data.get("capabilities").get("n_gpu", 1) + data["dataloader_pin_memory"] = True + data["dataloader_prefetch_factor"] = 256 + + return data diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 2c28a71ea..55e25daf7 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -1690,3 +1690,18 @@ class TestValidationMLflow(BaseValidation): assert new_cfg.use_mlflow is True os.environ.pop("MLFLOW_EXPERIMENT_NAME", None) + + +class TestDataloaderValidation(BaseValidation): + """ + tests for dataloader_* sane defaults + """ + + def test_dataloader_auto_defaults(self, minimal_cfg): + cfg = minimal_cfg + + new_cfg = validate_config(cfg, {"n_gpu": 8}, {"torch_version": "2.6.0"}) + + assert new_cfg.dataloader_num_workers == 8 + assert new_cfg.dataloader_pin_memory is True + assert new_cfg.dataloader_prefetch_factor == 256