add validation and tests for adamw hyperparam

This commit is contained in:
Wing Lian
2023-06-15 09:39:42 -04:00
parent c969f0a9dc
commit cb9d3af5c0
2 changed files with 47 additions and 0 deletions

View File

@@ -87,6 +87,11 @@ def validate_config(cfg):
"You probably want to disable group_by_length as it will force a streamed dataset to download completely." "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
) )
if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer
):
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -263,3 +263,45 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=regex_exp): with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg) validate_config(cfg)
def test_adamw_hyperparams(self):
cfg = DictDefault(
{
"optimizer": None,
"adamw_epsilon": 0.0001,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"optimizer": "adafactor",
"adamw_beta1": 0.0001,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"optimizer": "adamw_bnb_8bit",
"adamw_beta1": 0.0001,
"adamw_beta2": 0.0001,
"adamw_epsilon": 0.0001,
}
)
validate_config(cfg)