hqq integration

This commit is contained in:
Sunny Liu
2025-04-16 16:27:09 -04:00
committed by Sung Ching Liu
parent 7651550850
commit 99730ce40a
3 changed files with 87 additions and 16 deletions

View File

@@ -36,6 +36,7 @@ from transformers import (
BitsAndBytesConfig, BitsAndBytesConfig,
Gemma3ForConditionalGeneration, Gemma3ForConditionalGeneration,
GPTQConfig, GPTQConfig,
HqqConfig,
Llama4ForConditionalGeneration, Llama4ForConditionalGeneration,
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration, Mistral3ForConditionalGeneration,
@@ -852,23 +853,23 @@ class ModelLoader:
if ( if (
self.cfg.adapter in ["qlora", "lora"] self.cfg.adapter in ["qlora", "lora"]
and hasattr(self.model_config, "quantization_config") and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"] and getattr(self.model_config.quantization_config, "quant_method")
in ["gptq", "awq", "bitsandbytes"] in ["gptq", "awq", "bitsandbytes", "hqq"]
): ):
if self.model_config.quantization_config["quant_method"] == "gptq": quant_config_class_dict = {
self.model_kwargs["quantization_config"] = GPTQConfig( "gptq": GPTQConfig,
**self.model_config.quantization_config "awq": AwqConfig,
) "bitsandbytes": BitsAndBytesConfig,
elif self.model_config.quantization_config["quant_method"] == "awq": "hqq": HqqConfig,
self.model_kwargs["quantization_config"] = AwqConfig( }
**self.model_config.quantization_config
) quant_config_class = quant_config_class_dict[
elif ( getattr(self.model_config.quantization_config, "quant_method")
self.model_config.quantization_config["quant_method"] == "bitsandbytes" ]
): self.model_kwargs["quantization_config"] = quant_config_class(
self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **self.model_config.quantization_config
**self.model_config.quantization_config )
)
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = { bnb_config = {
"load_in_4bit": True, "load_in_4bit": True,
@@ -903,6 +904,13 @@ class ModelLoader:
**bnb_config, **bnb_config,
) )
elif self.cfg.hqq_nbits:
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
self.model_kwargs["quantization_config"] = HqqConfig(
get_hqq_quant_config_kwargs(self.cfg)
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610 # no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq: if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None) self.model_kwargs.pop("load_in_8bit", None)

View File

@@ -44,6 +44,7 @@ from axolotl.utils.schemas.model import (
) )
from axolotl.utils.schemas.multimodal import MultiModalConfig from axolotl.utils.schemas.multimodal import MultiModalConfig
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
from axolotl.utils.schemas.quant import HQQConfig
from axolotl.utils.schemas.training import HyperparametersConfig from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig from axolotl.utils.schemas.vllm import VllmConfig
@@ -59,6 +60,7 @@ class AxolotlInputConfig(
ModelOutputConfig, ModelOutputConfig,
LoraConfig, LoraConfig,
ReLoRAConfig, ReLoRAConfig,
HQQConfig,
HyperparametersConfig, HyperparametersConfig,
WandbConfig, WandbConfig,
MLFlowConfig, MLFlowConfig,
@@ -377,6 +379,18 @@ class AxolotlInputConfig(
raise ValueError(f"Only one of {', '.join(fields)} must be set") raise ValueError(f"Only one of {', '.join(fields)} must be set")
return data return data
@model_validator(mode="before")
@classmethod
def check_hqq_config_redundancy(cls, data):
if (data.get("load_in_4bit") or data.get("load_in_8bit")) and data.get(
"hqq_nbits"
):
raise ValueError(
"Can't simultaneously set `hqq` configurations and `load_in_4bit` / `load_in_8bit`"
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_batch_size_fields(cls, data): def check_batch_size_fields(cls, data):

View File

@@ -0,0 +1,49 @@
""" "
Takes care of quantization configuration
"""
from typing import Literal
from pydantic import BaseModel, model_validator
class HQQConfig(BaseModel):
"""HQQ configuration subset"""
hqq_nbits: Literal[8, 4, 3, 2, 1] | None = None
hqq_group_size: int | None = None
hqq_target_module: list[str] | None = None
@model_validator(mode="before")
@classmethod
def check_hqq_config_fields(cls, data):
fields = ("hqq_nbits", "hqq_group_size")
non_empty_count = sum(1 for field in fields if data.get(field))
if non_empty_count == 1 or (
data.get("'hqq_target_module") and non_empty_count < 2
):
raise ValueError(
"If using HQQ, must set both `hqq_nbits` and `hqq_group_size`"
)
def get_hqq_quant_config_kwargs(cfg):
# If no target module is specified, then target the whole model
if cfg.hqq_module_name is None:
return {
"nbits": cfg.hqq_nbits,
"group_size": cfg.hqq_group_size,
}
hqq_target_module = cfg.hqq_target_module
if not isinstance(cfg.hqq_target_module, list):
hqq_target_module = [hqq_target_module]
hqq_quant_config_kwargs = {"dynamic_config": {}}
for module in hqq_target_module:
hqq_quant_config_kwargs["dynamic_config"][module] = {
"nbits": cfg.hqq_nbits,
"group_size": cfg.hqq_group_size,
}
return hqq_quant_config_kwargs