hqq integration

This commit is contained in:
Sunny Liu
2025-04-16 16:27:09 -04:00
committed by Sung Ching Liu
parent 7651550850
commit 99730ce40a
3 changed files with 87 additions and 16 deletions

View File

@@ -36,6 +36,7 @@ from transformers import (
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
GPTQConfig,
HqqConfig,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
@@ -852,23 +853,23 @@ class ModelLoader:
if (
self.cfg.adapter in ["qlora", "lora"]
and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
and getattr(self.model_config.quantization_config, "quant_method")
in ["gptq", "awq", "bitsandbytes", "hqq"]
):
if self.model_config.quantization_config["quant_method"] == "gptq":
self.model_kwargs["quantization_config"] = GPTQConfig(
**self.model_config.quantization_config
)
elif self.model_config.quantization_config["quant_method"] == "awq":
self.model_kwargs["quantization_config"] = AwqConfig(
**self.model_config.quantization_config
)
elif (
self.model_config.quantization_config["quant_method"] == "bitsandbytes"
):
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
quant_config_class_dict = {
"gptq": GPTQConfig,
"awq": AwqConfig,
"bitsandbytes": BitsAndBytesConfig,
"hqq": HqqConfig,
}
quant_config_class = quant_config_class_dict[
getattr(self.model_config.quantization_config, "quant_method")
]
self.model_kwargs["quantization_config"] = quant_config_class(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
@@ -903,6 +904,13 @@ class ModelLoader:
**bnb_config,
)
elif self.cfg.hqq_nbits:
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
self.model_kwargs["quantization_config"] = HqqConfig(
get_hqq_quant_config_kwargs(self.cfg)
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None)

View File

@@ -44,6 +44,7 @@ from axolotl.utils.schemas.model import (
)
from axolotl.utils.schemas.multimodal import MultiModalConfig
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
from axolotl.utils.schemas.quant import HQQConfig
from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig
@@ -59,6 +60,7 @@ class AxolotlInputConfig(
ModelOutputConfig,
LoraConfig,
ReLoRAConfig,
HQQConfig,
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
@@ -377,6 +379,18 @@ class AxolotlInputConfig(
raise ValueError(f"Only one of {', '.join(fields)} must be set")
return data
@model_validator(mode="before")
@classmethod
def check_hqq_config_redundancy(cls, data):
if (data.get("load_in_4bit") or data.get("load_in_8bit")) and data.get(
"hqq_nbits"
):
raise ValueError(
"Can't simultaneously set `hqq` configurations and `load_in_4bit` / `load_in_8bit`"
)
return data
@model_validator(mode="before")
@classmethod
def check_batch_size_fields(cls, data):

View File

@@ -0,0 +1,49 @@
""" "
Takes care of quantization configuration
"""
from typing import Literal
from pydantic import BaseModel, model_validator
class HQQConfig(BaseModel):
"""HQQ configuration subset"""
hqq_nbits: Literal[8, 4, 3, 2, 1] | None = None
hqq_group_size: int | None = None
hqq_target_module: list[str] | None = None
@model_validator(mode="before")
@classmethod
def check_hqq_config_fields(cls, data):
fields = ("hqq_nbits", "hqq_group_size")
non_empty_count = sum(1 for field in fields if data.get(field))
if non_empty_count == 1 or (
data.get("'hqq_target_module") and non_empty_count < 2
):
raise ValueError(
"If using HQQ, must set both `hqq_nbits` and `hqq_group_size`"
)
def get_hqq_quant_config_kwargs(cfg):
# If no target module is specified, then target the whole model
if cfg.hqq_module_name is None:
return {
"nbits": cfg.hqq_nbits,
"group_size": cfg.hqq_group_size,
}
hqq_target_module = cfg.hqq_target_module
if not isinstance(cfg.hqq_target_module, list):
hqq_target_module = [hqq_target_module]
hqq_quant_config_kwargs = {"dynamic_config": {}}
for module in hqq_target_module:
hqq_quant_config_kwargs["dynamic_config"][module] = {
"nbits": cfg.hqq_nbits,
"group_size": cfg.hqq_group_size,
}
return hqq_quant_config_kwargs