diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index c27849d83..3bc01fc7f 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -340,6 +340,11 @@ def validate_config(cfg): "push_to_hub_model_id is deprecated. Please use hub_model_id instead." ) + if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch): + LOG.warning( + "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch." + ) + if cfg.gptq and cfg.model_revision: raise ValueError( "model_revision is not supported for GPTQ models. " diff --git a/tests/test_validation.py b/tests/test_validation.py index 5c3641f65..d73ae34eb 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -26,6 +26,7 @@ class BaseValidation(unittest.TestCase): self._caplog = caplog +# pylint: disable=too-many-public-methods class ValidationTest(BaseValidation): """ Test the validation module @@ -698,6 +699,22 @@ class ValidationTest(BaseValidation): ): validate_config(cfg) + def test_hub_model_id_save_value_warns(self): + cfg = DictDefault({"hub_model_id": "test"}) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert ( + "set without any models being saved" in self._caplog.records[0].message + ) + + def test_hub_model_id_save_value(self): + cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert len(self._caplog.records) == 0 + class ValidationCheckModelConfig(BaseValidation): """