From 52dd92a0cd559d7097c0e5c2224cdcdbdf083ce4 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 00:25:54 +0900 Subject: [PATCH] Feat: Update validate_config and add tests --- src/axolotl/utils/validation.py | 33 ++++++++---- tests/test_validation.py | 95 +++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 9 deletions(-) create mode 100644 tests/test_validation.py diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index bd759c3b7..97cde0677 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -3,24 +3,39 @@ import logging def validate_config(cfg): if cfg.load_4bit: - raise ValueError("cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq") + raise ValueError( + "cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq" + ) if cfg.adapter == "qlora": if cfg.merge_lora: # can't merge qlora if loaded in 8bit or 4bit - assert cfg.load_in_8bit is not True - assert cfg.gptq is not True - assert cfg.load_in_4bit is not True + if cfg.load_in_8bit: + raise ValueError("Can't merge qlora if loaded in 8bit") + + if cfg.gptq: + raise ValueError("Can't merge qlora if gptq") + + if cfg.load_in_4bit: + raise ValueError("Can't merge qlora if loaded in 4bit") + else: - assert cfg.load_in_8bit is not True - assert cfg.gptq is not True - assert cfg.load_in_4bit is True + if cfg.load_in_8bit: + raise ValueError("Can't load qlora in 8bit") + + if cfg.gptq: + raise ValueError("Can't load qlora if gptq") + + if not cfg.load_in_4bit: + raise ValueError("Require cfg.load_in_4bit to be True for qlora") if not cfg.load_in_8bit and cfg.adapter == "lora": logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") - + if cfg.trust_remote_code: - logging.warning("`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model.") + logging.warning( + "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." + ) # TODO # MPT 7b diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 000000000..e754b0ea7 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,95 @@ +import unittest + +import pytest + +from axolotl.utils.validation import validate_config +from axolotl.utils.dict import DictDefault + + +class ValidationTest(unittest.TestCase): + def test_load_4bit_deprecate(self): + cfg = DictDefault( + { + "load_4bit": True, + } + ) + + with pytest.raises(ValueError): + validate_config(cfg) + + def test_qlora(self): + base_cfg = DictDefault( + { + "adapter": "qlora", + } + ) + + cfg = base_cfg | DictDefault( + { + "load_in_8bit": True, + } + ) + + with pytest.raises(ValueError, match=r".*8bit.*"): + validate_config(cfg) + + cfg = base_cfg | DictDefault( + { + "gptq": True, + } + ) + + with pytest.raises(ValueError, match=r".*gptq.*"): + validate_config(cfg) + + cfg = base_cfg | DictDefault( + { + "load_in_4bit": False, + } + ) + + with pytest.raises(ValueError, match=r".*4bit.*"): + validate_config(cfg) + + cfg = base_cfg | DictDefault( + { + "load_in_4bit": True, + } + ) + + validate_config(cfg) + + def test_qlora_merge(self): + base_cfg = DictDefault( + { + "adapter": "qlora", + "merge_lora": True, + } + ) + + cfg = base_cfg | DictDefault( + { + "load_in_8bit": True, + } + ) + + with pytest.raises(ValueError, match=r".*8bit.*"): + validate_config(cfg) + + cfg = base_cfg | DictDefault( + { + "gptq": True, + } + ) + + with pytest.raises(ValueError, match=r".*gptq.*"): + validate_config(cfg) + + cfg = base_cfg | DictDefault( + { + "load_in_4bit": True, + } + ) + + with pytest.raises(ValueError, match=r".*4bit.*"): + validate_config(cfg)