47 lines
1.2 KiB
Python
47 lines
1.2 KiB
Python
"""
|
|
Test classes for checking functionality of the cfg normalization
|
|
"""
|
|
import unittest
|
|
|
|
from axolotl.utils.config import normalize_config
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
class NormalizeConfigTestCase(unittest.TestCase):
|
|
"""
|
|
test class for normalize_config checks
|
|
"""
|
|
|
|
def _get_base_cfg(self):
|
|
return DictDefault(
|
|
{
|
|
"base_model": "JackFram/llama-68m",
|
|
"base_model_config": "JackFram/llama-68m",
|
|
"tokenizer_type": "LlamaTokenizer",
|
|
"num_epochs": 1,
|
|
"micro_batch_size": 1,
|
|
"gradient_accumulation_steps": 1,
|
|
}
|
|
)
|
|
|
|
def test_lr_as_float(self):
|
|
cfg = (
|
|
self._get_base_cfg()
|
|
| DictDefault( # pylint: disable=unsupported-binary-operation
|
|
{
|
|
"learning_rate": "5e-5",
|
|
}
|
|
)
|
|
)
|
|
|
|
normalize_config(cfg)
|
|
|
|
assert cfg.learning_rate == 0.00005
|
|
|
|
def test_base_model_config_set_when_empty(self):
|
|
cfg = self._get_base_cfg()
|
|
del cfg.base_model_config
|
|
normalize_config(cfg)
|
|
|
|
assert cfg.base_model_config == cfg.base_model
|