convert exponential notation lr to floats (#771)
This commit is contained in:
@@ -119,6 +119,9 @@ def normalize_config(cfg):
|
|||||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(cfg.learning_rate, str):
|
||||||
|
cfg.learning_rate = float(cfg.learning_rate)
|
||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
39
tests/test_normalize_config.py
Normal file
39
tests/test_normalize_config.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""
|
||||||
|
Test classes for checking functionality of the cfg normalization
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeConfigTestCase(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
test class for normalize_config checks
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_base_cfg(self):
|
||||||
|
return DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"base_model_config": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lr_as_float(self):
|
||||||
|
cfg = (
|
||||||
|
self._get_base_cfg()
|
||||||
|
| DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
|
{
|
||||||
|
"learning_rate": "5e-5",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
normalize_config(cfg)
|
||||||
|
|
||||||
|
assert cfg.learning_rate == 0.00005
|
||||||
Reference in New Issue
Block a user