update to be deprecated evaluation_strategy (#1682) [skip ci]
* update to be deprecated evaluation_strategy and c4 dataset * chore: lint * remap eval strategy to new config and add tests
This commit is contained in:
@@ -726,7 +726,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"eval_strategy": "epoch",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
@@ -734,14 +734,14 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||
ValueError, match=r".*eval_strategy and eval_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "no",
|
||||
"eval_strategy": "no",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
@@ -749,14 +749,14 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||
ValueError, match=r".*eval_strategy and eval_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
"eval_strategy": "steps",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
@@ -767,7 +767,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
"eval_strategy": "steps",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
@@ -790,7 +790,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "no",
|
||||
"eval_strategy": "no",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
@@ -801,7 +801,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"eval_strategy": "epoch",
|
||||
"val_set_size": 0,
|
||||
}
|
||||
)
|
||||
@@ -810,7 +810,7 @@ class TestValidation(BaseValidation):
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||
match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
@@ -826,7 +826,7 @@ class TestValidation(BaseValidation):
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||
match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
@@ -856,7 +856,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"eval_strategy": "epoch",
|
||||
"val_set_size": 0.01,
|
||||
}
|
||||
)
|
||||
@@ -1095,6 +1095,24 @@ class TestValidation(BaseValidation):
|
||||
assert new_cfg["dpo_beta"] is None
|
||||
assert len(self._caplog.records) == 1
|
||||
|
||||
def test_eval_strategy_remap(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.eval_strategy == "steps"
|
||||
assert (
|
||||
"evaluation_strategy is deprecated, use eval_strategy instead"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
|
||||
class TestValidationCheckModelConfig(BaseValidation):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user