From ac24eba2acda6cba4368114a849bb82b6dc679af Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Sun, 20 Apr 2025 12:48:14 -0400 Subject: [PATCH] include HQQLinear in find target_linear --- src/axolotl/utils/models.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4c2a2ef5c..ac14e37c5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -854,13 +854,13 @@ class ModelLoader: self.cfg.adapter in ["qlora", "lora"] and hasattr(self.model_config, "quantization_config") and self.model_config.quantization_config["quant_method"] - in ["gptq", "awq", "bitsandbytes", "hqq"] + in ["gptq", "awq", "bitsandbytes"] + and not self.cfg.hqq ): quant_config_class_dict = { "gptq": GPTQConfig, "awq": AwqConfig, "bitsandbytes": BitsAndBytesConfig, - "hqq": HqqConfig, } quant_config_class = quant_config_class_dict[ @@ -904,7 +904,7 @@ class ModelLoader: **bnb_config, ) - elif self.cfg.hqq: + if self.cfg.hqq: from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs self.model_kwargs["quantization_config"] = HqqConfig( @@ -1471,7 +1471,16 @@ def load_llama_adapter(model, cfg): def find_all_linear_names(model): - cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) + from hqq.core.peft import HQQLinearLoRA + from hqq.core.quantize import HQQLinear + + cls = ( + bnb.nn.Linear4bit, + bnb.nn.Linear8bitLt, + torch.nn.Linear, + HQQLinear, + HQQLinearLoRA, + ) lora_module_names = set() for name, module in model.named_modules(): if (