diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 298d36c4e..2e0da69b3 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -87,6 +87,11 @@ def validate_config(cfg): "You probably want to disable group_by_length as it will force a streamed dataset to download completely." ) + if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and ( + not cfg.optimizer or "adamw" not in cfg.optimizer + ): + logging.warning("adamw hyperparameters found, but no adamw optimizer set") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index dba54586e..cc6d29a23 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -263,3 +263,45 @@ class ValidationTest(unittest.TestCase): with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) + + def test_adamw_hyperparams(self): + cfg = DictDefault( + { + "optimizer": None, + "adamw_epsilon": 0.0001, + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "adamw hyperparameters found, but no adamw optimizer set" + in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "optimizer": "adafactor", + "adamw_beta1": 0.0001, + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "adamw hyperparameters found, but no adamw optimizer set" + in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "optimizer": "adamw_bnb_8bit", + "adamw_beta1": 0.0001, + "adamw_beta2": 0.0001, + "adamw_epsilon": 0.0001, + } + ) + + validate_config(cfg)