Feat: Update validate_config and add tests
This commit is contained in:
95
tests/test_validation.py
Normal file
95
tests/test_validation.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.validation import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class ValidationTest(unittest.TestCase):
|
||||
def test_load_4bit_deprecate(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"load_4bit": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_qlora(self):
|
||||
base_cfg = DictDefault(
|
||||
{
|
||||
"adapter": "qlora",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*8bit.*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
{
|
||||
"gptq": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*gptq.*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
{
|
||||
"load_in_4bit": False,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*4bit.*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_qlora_merge(self):
|
||||
base_cfg = DictDefault(
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"merge_lora": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*8bit.*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
{
|
||||
"gptq": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*gptq.*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = base_cfg | DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*4bit.*"):
|
||||
validate_config(cfg)
|
||||
Reference in New Issue
Block a user