ADD: warning if hub_model_id ist set but not any save strategy (#1202)
* warning if hub model id set but no save * add warning * move the warning * add test * allow more public methods for tests for now * fix tests --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -340,6 +340,11 @@ def validate_config(cfg):
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
|
||||
LOG.warning(
|
||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.model_revision:
|
||||
raise ValueError(
|
||||
"model_revision is not supported for GPTQ models. "
|
||||
|
||||
@@ -26,6 +26,7 @@ class BaseValidation(unittest.TestCase):
|
||||
self._caplog = caplog
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class ValidationTest(BaseValidation):
|
||||
"""
|
||||
Test the validation module
|
||||
@@ -698,6 +699,22 @@ class ValidationTest(BaseValidation):
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_hub_model_id_save_value_warns(self):
|
||||
cfg = DictDefault({"hub_model_id": "test"})
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert (
|
||||
"set without any models being saved" in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_hub_model_id_save_value(self):
|
||||
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
|
||||
class ValidationCheckModelConfig(BaseValidation):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user