include HQQLinear in find target_linear
This commit is contained in:
committed by
Sung Ching Liu
parent
8a5ad8aee3
commit
ac24eba2ac
@@ -854,13 +854,13 @@ class ModelLoader:
|
|||||||
self.cfg.adapter in ["qlora", "lora"]
|
self.cfg.adapter in ["qlora", "lora"]
|
||||||
and hasattr(self.model_config, "quantization_config")
|
and hasattr(self.model_config, "quantization_config")
|
||||||
and self.model_config.quantization_config["quant_method"]
|
and self.model_config.quantization_config["quant_method"]
|
||||||
in ["gptq", "awq", "bitsandbytes", "hqq"]
|
in ["gptq", "awq", "bitsandbytes"]
|
||||||
|
and not self.cfg.hqq
|
||||||
):
|
):
|
||||||
quant_config_class_dict = {
|
quant_config_class_dict = {
|
||||||
"gptq": GPTQConfig,
|
"gptq": GPTQConfig,
|
||||||
"awq": AwqConfig,
|
"awq": AwqConfig,
|
||||||
"bitsandbytes": BitsAndBytesConfig,
|
"bitsandbytes": BitsAndBytesConfig,
|
||||||
"hqq": HqqConfig,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
quant_config_class = quant_config_class_dict[
|
quant_config_class = quant_config_class_dict[
|
||||||
@@ -904,7 +904,7 @@ class ModelLoader:
|
|||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.cfg.hqq:
|
if self.cfg.hqq:
|
||||||
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
|
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
|
||||||
|
|
||||||
self.model_kwargs["quantization_config"] = HqqConfig(
|
self.model_kwargs["quantization_config"] = HqqConfig(
|
||||||
@@ -1471,7 +1471,16 @@ def load_llama_adapter(model, cfg):
|
|||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
from hqq.core.peft import HQQLinearLoRA
|
||||||
|
from hqq.core.quantize import HQQLinear
|
||||||
|
|
||||||
|
cls = (
|
||||||
|
bnb.nn.Linear4bit,
|
||||||
|
bnb.nn.Linear8bitLt,
|
||||||
|
torch.nn.Linear,
|
||||||
|
HQQLinear,
|
||||||
|
HQQLinearLoRA,
|
||||||
|
)
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (
|
if (
|
||||||
|
|||||||
Reference in New Issue
Block a user