diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index c165bc97b..9260577db 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -360,6 +360,12 @@ def validate_config(cfg): "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false." ) + if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit): + raise ValueError( + "load_in_8bit and load_in_4bit are not supported without setting an adapter." + "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index 98c91da49..dbd030da3 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -606,3 +606,46 @@ class ValidationTest(unittest.TestCase): ) validate_config(cfg) + + def test_load_in_x_bit_without_adapter(self): + cfg = DictDefault( + { + "load_in_4bit": True, + } + ) + + with pytest.raises( + ValueError, + match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "load_in_8bit": True, + } + ) + + with pytest.raises( + ValueError, + match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*", + ): + validate_config(cfg) + + cfg = DictDefault( + { + "load_in_4bit": True, + "adapter": "qlora", + } + ) + + validate_config(cfg) + + cfg = DictDefault( + { + "load_in_8bit": True, + "adapter": "lora", + } + ) + + validate_config(cfg)