convert exponential notation lr to floats (#771)

This commit is contained in:
Wing Lian
2023-10-22 15:37:03 -04:00
committed by GitHub
parent 32eeeb5b64
commit ca84cca2c0
2 changed files with 42 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
"""
Test classes for checking functionality of the cfg normalization
"""
import unittest
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
class NormalizeConfigTestCase(unittest.TestCase):
"""
test class for normalize_config checks
"""
def _get_base_cfg(self):
return DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
def test_lr_as_float(self):
cfg = (
self._get_base_cfg()
| DictDefault( # pylint: disable=unsupported-binary-operation
{
"learning_rate": "5e-5",
}
)
)
normalize_config(cfg)
assert cfg.learning_rate == 0.00005