diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 9fa83b9af..c165bc97b 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -119,6 +119,9 @@ def normalize_config(cfg): or (cfg.model_type and "mistral" in cfg.model_type.lower()) ) + if isinstance(cfg.learning_rate, str): + cfg.learning_rate = float(cfg.learning_rate) + log_gpu_memory_usage(LOG, "baseline", cfg.device) diff --git a/tests/test_normalize_config.py b/tests/test_normalize_config.py new file mode 100644 index 000000000..01b8c162c --- /dev/null +++ b/tests/test_normalize_config.py @@ -0,0 +1,39 @@ +""" +Test classes for checking functionality of the cfg normalization +""" +import unittest + +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + + +class NormalizeConfigTestCase(unittest.TestCase): + """ + test class for normalize_config checks + """ + + def _get_base_cfg(self): + return DictDefault( + { + "base_model": "JackFram/llama-68m", + "base_model_config": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + } + ) + + def test_lr_as_float(self): + cfg = ( + self._get_base_cfg() + | DictDefault( # pylint: disable=unsupported-binary-operation + { + "learning_rate": "5e-5", + } + ) + ) + + normalize_config(cfg) + + assert cfg.learning_rate == 0.00005