From 19cf0bda99b0957dd4ccd2152d27faa84f6f58a8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 8 Jul 2023 12:13:39 -0400 Subject: [PATCH] params are adam_*, not adamw_* --- src/axolotl/utils/validation.py | 2 +- tests/test_validation.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 43b4b1d16..40dfb84a9 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -87,7 +87,7 @@ def validate_config(cfg): "You probably want to disable group_by_length as it will force a streamed dataset to download completely." ) - if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and ( + if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and ( not cfg.optimizer or "adamw" not in cfg.optimizer ): logging.warning("adamw hyperparameters found, but no adamw optimizer set") diff --git a/tests/test_validation.py b/tests/test_validation.py index d39a4618e..88c97f0b7 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -268,7 +268,7 @@ class ValidationTest(unittest.TestCase): cfg = DictDefault( { "optimizer": None, - "adamw_epsilon": 0.0001, + "adam_epsilon": 0.0001, } ) @@ -283,7 +283,7 @@ class ValidationTest(unittest.TestCase): cfg = DictDefault( { "optimizer": "adafactor", - "adamw_beta1": 0.0001, + "adam_beta1": 0.0001, } ) @@ -298,9 +298,9 @@ class ValidationTest(unittest.TestCase): cfg = DictDefault( { "optimizer": "adamw_bnb_8bit", - "adamw_beta1": 0.0001, - "adamw_beta2": 0.0001, - "adamw_epsilon": 0.0001, + "adam_beta1": 0.9, + "adam_beta2": 0.99, + "adam_epsilon": 0.0001, } )