ADD: warning if hub_model_id ist set but not any save strategy (#1202)
* warning if hub model id set but no save * add warning * move the warning * add test * allow more public methods for tests for now * fix tests --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -340,6 +340,11 @@ def validate_config(cfg):
|
|||||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
|
||||||
|
LOG.warning(
|
||||||
|
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.gptq and cfg.model_revision:
|
if cfg.gptq and cfg.model_revision:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"model_revision is not supported for GPTQ models. "
|
"model_revision is not supported for GPTQ models. "
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class BaseValidation(unittest.TestCase):
|
|||||||
self._caplog = caplog
|
self._caplog = caplog
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=too-many-public-methods
|
||||||
class ValidationTest(BaseValidation):
|
class ValidationTest(BaseValidation):
|
||||||
"""
|
"""
|
||||||
Test the validation module
|
Test the validation module
|
||||||
@@ -698,6 +699,22 @@ class ValidationTest(BaseValidation):
|
|||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value_warns(self):
|
||||||
|
cfg = DictDefault({"hub_model_id": "test"})
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert (
|
||||||
|
"set without any models being saved" in self._caplog.records[0].message
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_hub_model_id_save_value(self):
|
||||||
|
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert len(self._caplog.records) == 0
|
||||||
|
|
||||||
|
|
||||||
class ValidationCheckModelConfig(BaseValidation):
|
class ValidationCheckModelConfig(BaseValidation):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user