include HQQLinear in find target_linear
This commit is contained in:
committed by
Sung Ching Liu
parent
8a5ad8aee3
commit
ac24eba2ac
@@ -854,13 +854,13 @@ class ModelLoader:
|
||||
self.cfg.adapter in ["qlora", "lora"]
|
||||
and hasattr(self.model_config, "quantization_config")
|
||||
and self.model_config.quantization_config["quant_method"]
|
||||
in ["gptq", "awq", "bitsandbytes", "hqq"]
|
||||
in ["gptq", "awq", "bitsandbytes"]
|
||||
and not self.cfg.hqq
|
||||
):
|
||||
quant_config_class_dict = {
|
||||
"gptq": GPTQConfig,
|
||||
"awq": AwqConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"hqq": HqqConfig,
|
||||
}
|
||||
|
||||
quant_config_class = quant_config_class_dict[
|
||||
@@ -904,7 +904,7 @@ class ModelLoader:
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
elif self.cfg.hqq:
|
||||
if self.cfg.hqq:
|
||||
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
|
||||
|
||||
self.model_kwargs["quantization_config"] = HqqConfig(
|
||||
@@ -1471,7 +1471,16 @@ def load_llama_adapter(model, cfg):
|
||||
|
||||
|
||||
def find_all_linear_names(model):
|
||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||
from hqq.core.peft import HQQLinearLoRA
|
||||
from hqq.core.quantize import HQQLinear
|
||||
|
||||
cls = (
|
||||
bnb.nn.Linear4bit,
|
||||
bnb.nn.Linear8bitLt,
|
||||
torch.nn.Linear,
|
||||
HQQLinear,
|
||||
HQQLinearLoRA,
|
||||
)
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user