From af29d81f80e42afbeb38daa3829d86b53b2aa7f7 Mon Sep 17 00:00:00 2001 From: JohanWork <39947546+JohanWork@users.noreply.github.com> Date: Fri, 26 Jan 2024 16:38:55 +0100 Subject: [PATCH] ADD: warning if hub_model_id ist set but not any save strategy (#1202) * warning if hub model id set but no save * add warning * move the warning * add test * allow more public methods for tests for now * fix tests --------- Co-authored-by: Wing Lian --- src/axolotl/utils/config.py | 5 +++++ tests/test_validation.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index c27849d83..3bc01fc7f 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -340,6 +340,11 @@ def validate_config(cfg): "push_to_hub_model_id is deprecated. Please use hub_model_id instead." ) + if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch): + LOG.warning( + "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch." + ) + if cfg.gptq and cfg.model_revision: raise ValueError( "model_revision is not supported for GPTQ models. " diff --git a/tests/test_validation.py b/tests/test_validation.py index 5c3641f65..d73ae34eb 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -26,6 +26,7 @@ class BaseValidation(unittest.TestCase): self._caplog = caplog +# pylint: disable=too-many-public-methods class ValidationTest(BaseValidation): """ Test the validation module @@ -698,6 +699,22 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) + def test_hub_model_id_save_value_warns(self): + cfg = DictDefault({"hub_model_id": "test"}) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert ( + "set without any models being saved" in self._caplog.records[0].message + ) + + def test_hub_model_id_save_value(self): + cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 0 + class ValidationCheckModelConfig(BaseValidation): """