Fix(cfg): Add validation for save_strategy and eval_strategy (#633)
* Fix(cfg): Check save_strategy cfg conflict with save_steps * Fix(cfg): Check evaluation_strategy cfg conflict with eval_steps * chore: add extra check for steps only
This commit is contained in:
@@ -296,6 +296,24 @@ def validate_config(cfg):
|
|||||||
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||||
"sharegpt_simple", "sharegpt"
|
"sharegpt_simple", "sharegpt"
|
||||||
)
|
)
|
||||||
|
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||||
|
raise ValueError(
|
||||||
|
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.evaluation_strategy
|
||||||
|
and cfg.eval_steps
|
||||||
|
and cfg.evaluation_strategy != "steps"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
|
||||||
|
raise ValueError(
|
||||||
|
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
|
|||||||
@@ -604,26 +604,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
"sample_packing_efficiency"
|
"sample_packing_efficiency"
|
||||||
] = cfg.sample_packing_eff_est
|
] = cfg.sample_packing_eff_est
|
||||||
|
|
||||||
if cfg.eval_steps and cfg.evaluation_strategy:
|
if cfg.eval_steps:
|
||||||
# assume if the user set both, they know what they're doing
|
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||||
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
|
|
||||||
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
||||||
|
elif cfg.evaluation_strategy:
|
||||||
|
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
|
||||||
elif cfg.val_set_size == 0:
|
elif cfg.val_set_size == 0:
|
||||||
# no eval set, so don't eval
|
# no eval set, so don't eval
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||||
elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]:
|
|
||||||
# if explicitly set for epoch, just set, and eval steps don't matter
|
|
||||||
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
|
|
||||||
elif cfg.eval_steps:
|
|
||||||
# steps isn't used w/ epochs
|
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
|
||||||
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
|
||||||
else:
|
else:
|
||||||
# we have an eval set, but no steps defined, default to use epoch
|
# we have an eval set, but no steps defined, default to use epoch
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||||
|
|
||||||
if cfg.save_steps:
|
if cfg.save_steps:
|
||||||
# save_steps implies save_strategy of steps
|
|
||||||
training_arguments_kwargs["save_strategy"] = "steps"
|
training_arguments_kwargs["save_strategy"] = "steps"
|
||||||
training_arguments_kwargs["save_steps"] = cfg.save_steps
|
training_arguments_kwargs["save_steps"] = cfg.save_steps
|
||||||
elif cfg.save_strategy:
|
elif cfg.save_strategy:
|
||||||
|
|||||||
@@ -397,3 +397,171 @@ class ValidationTest(unittest.TestCase):
|
|||||||
for record in self._caplog.records
|
for record in self._caplog.records
|
||||||
)
|
)
|
||||||
assert cfg.datasets[0].type == "sharegpt:load_role"
|
assert cfg.datasets[0].type == "sharegpt:load_role"
|
||||||
|
|
||||||
|
def test_no_conflict_save_strategy(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"save_strategy": "epoch",
|
||||||
|
"save_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=r".*save_strategy and save_steps mismatch.*"
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"save_strategy": "no",
|
||||||
|
"save_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=r".*save_strategy and save_steps mismatch.*"
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"save_strategy": "steps",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"save_strategy": "steps",
|
||||||
|
"save_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"save_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"save_strategy": "no",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_no_conflict_eval_strategy(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "no",
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "steps",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "steps",
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"eval_steps": 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "no",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"val_set_size": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"eval_steps": 10,
|
||||||
|
"val_set_size": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"val_set_size": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"eval_steps": 10,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "epoch",
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user