From c4910da015bd842f1136e53f4577272f6335a4cd Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Mon, 21 Apr 2025 22:17:08 -0400 Subject: [PATCH] update more tests + better hqq validation --- src/axolotl/utils/config/__init__.py | 7 +---- src/axolotl/utils/models.py | 1 + src/axolotl/utils/schemas/peft.py | 26 ++++++++++++++++--- .../e2e/patched/test_lora_llama_multipack.py | 3 +++ 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 95afae5c7..7e9b4a952 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -243,16 +243,11 @@ def normalize_config(cfg): elif cfg.quantization.bits == 4: cfg.load_in_4bit = True - elif cfg.quantization.backend == "gptq": + if cfg.quantization.backend == "gptq": cfg.gptq = True elif cfg.quantization.backend == "hqq": cfg.hqq = True - if cfg.hqq and not cfg.quantization.hqq_config: - raise ValueError( - "If using HQQ, must set `hqq_config` to a list of HQQConfig objects" - ) - def normalize_cfg_datasets(cfg): """ diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5e39f242e..53cb0413c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -838,6 +838,7 @@ class ModelLoader: (not self.cfg.quantization) and (not self.cfg.load_in_8bit) and (not self.cfg.load_in_4bit) + and not self.cfg.gptq ): return self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index 3cedec8b2..0fc50b783 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -87,24 +87,24 @@ class LoraConfig(BaseModel): def validate_qlora(self): if self.adapter == "qlora": if self.merge_lora: - if self.quantization.bits == 8: + if self.quantization.bits == 8 or self.cfg.load_in_8bit: raise ValueError("Can't merge qlora if loaded in 8bit") if self.quantization.backend == "gptq": raise ValueError("Can't merge qlora if using gptq") - if self.quantization.bits == 4: + if self.quantization.bits == 4 or self.cfg.load_in_4bit: raise ValueError("Can't merge qlora if loaded in 4bit") else: if self.quantization: - if self.quantization.bits == 8: + if self.quantization.bits == 8 or self.cfg.load_in_8bit: raise ValueError("Can't load qlora in 8bit") if self.quantization.backend == "gptq": raise ValueError("Can't load qlora if using gptq") - if not self.quantization.bits == 4: + if not self.quantization.bits == 4 or self.cfg.load_in_4bit: raise ValueError("Require quantization.bits <= 4 for qlora") return self @@ -123,6 +123,24 @@ class LoraConfig(BaseModel): data["lora_dropout"] = 0.0 return data + @model_validator(mode="before") + @classmethod + def validate_hqq(cls, data): + if ( + data.get("quantization") + and data.get("quantization").get("backend") == "hqq" + ): + if not data.get("quantization").get("hqq_config"): + raise ValueError( + "If using HQQ, must set `hqq_config` under `quantization`" + ) + + if data.get("load_in_4bit") or data.get("load_in_8bit"): + raise ValueError( + "If using HQQ quantization, please remove load_in_4bit or load_in_8bit" + ) + return data + class ReLoRAConfig(BaseModel): """ReLoRA configuration subset""" diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index e544eb4fd..3e1a6d45a 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -89,6 +89,9 @@ class TestLoraLlama(unittest.TestCase): "sequence_len": 1024, "sample_packing": True, "flash_attention": True, + "quantization": { + "backend": "gptq", + }, "load_in_8bit": True, "adapter": "lora", "gptq": True,