params are adam_*, not adamw_*

This commit is contained in:
Wing Lian
2023-07-08 12:13:39 -04:00
parent f74edd5b56
commit 19cf0bda99
2 changed files with 6 additions and 6 deletions

View File

@@ -87,7 +87,7 @@ def validate_config(cfg):
"You probably want to disable group_by_length as it will force a streamed dataset to download completely." "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
) )
if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and ( if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer not cfg.optimizer or "adamw" not in cfg.optimizer
): ):
logging.warning("adamw hyperparameters found, but no adamw optimizer set") logging.warning("adamw hyperparameters found, but no adamw optimizer set")

View File

@@ -268,7 +268,7 @@ class ValidationTest(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"optimizer": None, "optimizer": None,
"adamw_epsilon": 0.0001, "adam_epsilon": 0.0001,
} }
) )
@@ -283,7 +283,7 @@ class ValidationTest(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"optimizer": "adafactor", "optimizer": "adafactor",
"adamw_beta1": 0.0001, "adam_beta1": 0.0001,
} }
) )
@@ -298,9 +298,9 @@ class ValidationTest(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"optimizer": "adamw_bnb_8bit", "optimizer": "adamw_bnb_8bit",
"adamw_beta1": 0.0001, "adam_beta1": 0.9,
"adamw_beta2": 0.0001, "adam_beta2": 0.99,
"adamw_epsilon": 0.0001, "adam_epsilon": 0.0001,
} }
) )