setup defaults for dataloader to ensure GPU is kept busy (#2632) [skip ci]
This commit is contained in:
@@ -1114,3 +1114,17 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
raise ValueError("QAT is not supported on torch version < 2.6.0")
|
raise ValueError("QAT is not supported on torch version < 2.6.0")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def default_dataloader_opts(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("dataloader_num_workers") is None
|
||||||
|
and data.get("dataloader_pin_memory") is None
|
||||||
|
and data.get("dataloader_prefetch_factor") is None
|
||||||
|
):
|
||||||
|
data["dataloader_num_workers"] = data.get("capabilities").get("n_gpu", 1)
|
||||||
|
data["dataloader_pin_memory"] = True
|
||||||
|
data["dataloader_prefetch_factor"] = 256
|
||||||
|
|
||||||
|
return data
|
||||||
|
|||||||
@@ -1690,3 +1690,18 @@ class TestValidationMLflow(BaseValidation):
|
|||||||
assert new_cfg.use_mlflow is True
|
assert new_cfg.use_mlflow is True
|
||||||
|
|
||||||
os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)
|
os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataloaderValidation(BaseValidation):
|
||||||
|
"""
|
||||||
|
tests for dataloader_* sane defaults
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_dataloader_auto_defaults(self, minimal_cfg):
|
||||||
|
cfg = minimal_cfg
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg, {"n_gpu": 8}, {"torch_version": "2.6.0"})
|
||||||
|
|
||||||
|
assert new_cfg.dataloader_num_workers == 8
|
||||||
|
assert new_cfg.dataloader_pin_memory is True
|
||||||
|
assert new_cfg.dataloader_prefetch_factor == 256
|
||||||
|
|||||||
Reference in New Issue
Block a user