Feat: Update validate_config and add tests
This commit is contained in:
@@ -3,24 +3,39 @@ import logging
|
|||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
if cfg.load_4bit:
|
if cfg.load_4bit:
|
||||||
raise ValueError("cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq")
|
raise ValueError(
|
||||||
|
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.adapter == "qlora":
|
if cfg.adapter == "qlora":
|
||||||
if cfg.merge_lora:
|
if cfg.merge_lora:
|
||||||
# can't merge qlora if loaded in 8bit or 4bit
|
# can't merge qlora if loaded in 8bit or 4bit
|
||||||
assert cfg.load_in_8bit is not True
|
if cfg.load_in_8bit:
|
||||||
assert cfg.gptq is not True
|
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||||
assert cfg.load_in_4bit is not True
|
|
||||||
|
if cfg.gptq:
|
||||||
|
raise ValueError("Can't merge qlora if gptq")
|
||||||
|
|
||||||
|
if cfg.load_in_4bit:
|
||||||
|
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert cfg.load_in_8bit is not True
|
if cfg.load_in_8bit:
|
||||||
assert cfg.gptq is not True
|
raise ValueError("Can't load qlora in 8bit")
|
||||||
assert cfg.load_in_4bit is True
|
|
||||||
|
if cfg.gptq:
|
||||||
|
raise ValueError("Can't load qlora if gptq")
|
||||||
|
|
||||||
|
if not cfg.load_in_4bit:
|
||||||
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
|
|
||||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||||
logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||||
|
|
||||||
if cfg.trust_remote_code:
|
if cfg.trust_remote_code:
|
||||||
logging.warning("`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model.")
|
logging.warning(
|
||||||
|
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
|
|||||||
95
tests/test_validation.py
Normal file
95
tests/test_validation.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.validation import validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationTest(unittest.TestCase):
|
||||||
|
def test_load_4bit_deprecate(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"load_4bit": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_qlora(self):
|
||||||
|
base_cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"adapter": "qlora",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = base_cfg | DictDefault(
|
||||||
|
{
|
||||||
|
"load_in_8bit": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r".*8bit.*"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = base_cfg | DictDefault(
|
||||||
|
{
|
||||||
|
"gptq": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r".*gptq.*"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = base_cfg | DictDefault(
|
||||||
|
{
|
||||||
|
"load_in_4bit": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r".*4bit.*"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = base_cfg | DictDefault(
|
||||||
|
{
|
||||||
|
"load_in_4bit": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_qlora_merge(self):
|
||||||
|
base_cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"adapter": "qlora",
|
||||||
|
"merge_lora": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = base_cfg | DictDefault(
|
||||||
|
{
|
||||||
|
"load_in_8bit": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r".*8bit.*"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = base_cfg | DictDefault(
|
||||||
|
{
|
||||||
|
"gptq": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r".*gptq.*"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = base_cfg | DictDefault(
|
||||||
|
{
|
||||||
|
"load_in_4bit": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r".*4bit.*"):
|
||||||
|
validate_config(cfg)
|
||||||
Reference in New Issue
Block a user