include HQQLinear in find target_linear

This commit is contained in:
Sunny Liu
2025-04-20 12:48:14 -04:00
committed by Sung Ching Liu
parent 8a5ad8aee3
commit ac24eba2ac

View File

@@ -854,13 +854,13 @@ class ModelLoader:
self.cfg.adapter in ["qlora", "lora"]
and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes", "hqq"]
in ["gptq", "awq", "bitsandbytes"]
and not self.cfg.hqq
):
quant_config_class_dict = {
"gptq": GPTQConfig,
"awq": AwqConfig,
"bitsandbytes": BitsAndBytesConfig,
"hqq": HqqConfig,
}
quant_config_class = quant_config_class_dict[
@@ -904,7 +904,7 @@ class ModelLoader:
**bnb_config,
)
elif self.cfg.hqq:
if self.cfg.hqq:
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
self.model_kwargs["quantization_config"] = HqqConfig(
@@ -1471,7 +1471,16 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
from hqq.core.peft import HQQLinearLoRA
from hqq.core.quantize import HQQLinear
cls = (
bnb.nn.Linear4bit,
bnb.nn.Linear8bitLt,
torch.nn.Linear,
HQQLinear,
HQQLinearLoRA,
)
lora_module_names = set()
for name, module in model.named_modules():
if (