diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 42cbe52c1..cdbe47b8f 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1314,6 +1314,7 @@ class AxolotlInputConfig( and data.get("gradient_checkpointing_kwargs", {}) and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is False + and data.get("deepspeed", "") is not None and "zero3" in data.get("deepspeed", "") ): # may result in: diff --git a/tests/test_validation.py b/tests/test_validation.py index f3f4d18ab..491f230c3 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -68,6 +68,53 @@ class TestValidation(BaseValidation): assert cfg.train_on_inputs is False assert cfg.weight_decay is None + def test_zero3_qlora_use_reentrant_false(self, minimal_cfg): + test_cfg = DictDefault( + { + "deepspeed": "deepspeed_configs/zero3_bf16.json", + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": False}, + "load_in_4bit": True, + "adapter": "qlora", + } + | minimal_cfg + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(test_cfg) + assert ( + "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values" + in self._caplog.records[0].message + ) + + def test_deepspeed_empty(self, minimal_cfg): + test_cfg = DictDefault( + { + "deepspeed": "", + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": False}, + "load_in_4bit": True, + "adapter": "qlora", + } + | minimal_cfg + ) + + _ = validate_config(test_cfg) + + def test_deepspeed_not_set(self, minimal_cfg): + test_cfg = DictDefault( + { + "deepspeed": None, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": False}, + "load_in_4bit": True, + "adapter": "qlora", + } + | minimal_cfg + ) + + _ = validate_config(test_cfg) + def test_datasets_min_length(self): cfg = DictDefault( {