convert exponential notation lr to floats (#771)
This commit is contained in:
39
tests/test_normalize_config.py
Normal file
39
tests/test_normalize_config.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Test classes for checking functionality of the cfg normalization
|
||||
"""
|
||||
import unittest
|
||||
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class NormalizeConfigTestCase(unittest.TestCase):
|
||||
"""
|
||||
test class for normalize_config checks
|
||||
"""
|
||||
|
||||
def _get_base_cfg(self):
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model_config": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
def test_lr_as_float(self):
|
||||
cfg = (
|
||||
self._get_base_cfg()
|
||||
| DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"learning_rate": "5e-5",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
assert cfg.learning_rate == 0.00005
|
||||
Reference in New Issue
Block a user