Fix(cfg): Add validation for save_strategy and eval_strategy (#633)
* Fix(cfg): Check save_strategy cfg conflict with save_steps * Fix(cfg): Check evaluation_strategy cfg conflict with eval_steps * chore: add extra check for steps only
This commit is contained in:
@@ -397,3 +397,171 @@ class ValidationTest(unittest.TestCase):
|
||||
for record in self._caplog.records
|
||||
)
|
||||
assert cfg.datasets[0].type == "sharegpt:load_role"
|
||||
|
||||
def test_no_conflict_save_strategy(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_strategy": "epoch",
|
||||
"save_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*save_strategy and save_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_strategy": "no",
|
||||
"save_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*save_strategy and save_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_strategy": "steps",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_strategy": "steps",
|
||||
"save_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_strategy": "no",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_no_conflict_eval_strategy(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "no",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "no",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"val_set_size": 0,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"eval_steps": 10,
|
||||
"val_set_size": 0,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"val_set_size": 0,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"eval_steps": 10,
|
||||
"val_set_size": 0.01,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"val_set_size": 0.01,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user