From cb9d3af5c00e0189f95c03d64efdc283aec54679 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 15 Jun 2023 09:39:42 -0400 Subject: [PATCH] add validation and tests for adamw hyperparam --- src/axolotl/utils/validation.py | 5 ++++ tests/test_validation.py | 42 +++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 298d36c4e..2e0da69b3 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -87,6 +87,11 @@ def validate_config(cfg): "You probably want to disable group_by_length as it will force a streamed dataset to download completely." ) + if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and ( + not cfg.optimizer or "adamw" not in cfg.optimizer + ): + logging.warning("adamw hyperparameters found, but no adamw optimizer set") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index dba54586e..cc6d29a23 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -263,3 +263,45 @@ class ValidationTest(unittest.TestCase): with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) + + def test_adamw_hyperparams(self): + cfg = DictDefault( + { + "optimizer": None, + "adamw_epsilon": 0.0001, + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "adamw hyperparameters found, but no adamw optimizer set" + in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "optimizer": "adafactor", + "adamw_beta1": 0.0001, + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "adamw hyperparameters found, but no adamw optimizer set" + in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "optimizer": "adamw_bnb_8bit", + "adamw_beta1": 0.0001, + "adamw_beta2": 0.0001, + "adamw_epsilon": 0.0001, + } + ) + + validate_config(cfg)