ADD: warning hub model (#1301)
* update warning for save_strategy * update * clean up * update * Update test_validation.py * fix validation step * update * test_validation * update * fix * fix --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
@@ -383,9 +383,9 @@ def legacy_validate_config(cfg):
|
|||||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
|
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.gptq and cfg.revision_of_model:
|
if cfg.gptq and cfg.revision_of_model:
|
||||||
@@ -448,10 +448,14 @@ def legacy_validate_config(cfg):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||||
)
|
)
|
||||||
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
|
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||||
)
|
)
|
||||||
|
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||||
|
raise ValueError(
|
||||||
|
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||||
|
)
|
||||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||||
@@ -464,11 +468,6 @@ def legacy_validate_config(cfg):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
)
|
)
|
||||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
|
||||||
raise ValueError(
|
|
||||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.evaluation_strategy
|
cfg.evaluation_strategy
|
||||||
and cfg.eval_steps
|
and cfg.eval_steps
|
||||||
|
|||||||
@@ -780,11 +780,11 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_push_save(cls, data):
|
def check_push_save(cls, data):
|
||||||
if data.get("hub_model_id") and not (
|
if data.get("hub_model_id") and (
|
||||||
data.get("save_steps") or data.get("saves_per_epoch")
|
data.get("save_strategy") not in ["steps", "epoch", None]
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
@@ -1067,17 +1067,51 @@ class TestValidation(BaseValidation):
|
|||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_hub_model_id_save_value_warns(self, minimal_cfg):
|
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
|
||||||
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
with self._caplog.at_level(logging.WARNING):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
assert (
|
assert len(self._caplog.records) == 1
|
||||||
"set without any models being saved" in self._caplog.records[0].message
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_hub_model_id_save_value(self, minimal_cfg):
|
def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):
|
||||||
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
|
cfg = (
|
||||||
|
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 1
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_steps(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault({"hub_model_id": "test", "save_strategy": "steps"})
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 0
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_epochs(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 0
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_none(self, minimal_cfg):
|
||||||
|
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 0
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
|
||||||
|
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
||||||
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
with self._caplog.at_level(logging.WARNING):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user