add validation and tests for adamw hyperparam
This commit is contained in:
@@ -87,6 +87,11 @@ def validate_config(cfg):
|
|||||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and (
|
||||||
|
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||||
|
):
|
||||||
|
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -263,3 +263,45 @@ class ValidationTest(unittest.TestCase):
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match=regex_exp):
|
with pytest.raises(ValueError, match=regex_exp):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_adamw_hyperparams(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"optimizer": None,
|
||||||
|
"adamw_epsilon": 0.0001,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"adamw hyperparameters found, but no adamw optimizer set"
|
||||||
|
in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"optimizer": "adafactor",
|
||||||
|
"adamw_beta1": 0.0001,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"adamw hyperparameters found, but no adamw optimizer set"
|
||||||
|
in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"adamw_beta1": 0.0001,
|
||||||
|
"adamw_beta2": 0.0001,
|
||||||
|
"adamw_epsilon": 0.0001,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user