hqq integration
This commit is contained in:
committed by
Sung Ching Liu
parent
7651550850
commit
99730ce40a
@@ -36,6 +36,7 @@ from transformers import (
|
||||
BitsAndBytesConfig,
|
||||
Gemma3ForConditionalGeneration,
|
||||
GPTQConfig,
|
||||
HqqConfig,
|
||||
Llama4ForConditionalGeneration,
|
||||
LlavaForConditionalGeneration,
|
||||
Mistral3ForConditionalGeneration,
|
||||
@@ -852,23 +853,23 @@ class ModelLoader:
|
||||
if (
|
||||
self.cfg.adapter in ["qlora", "lora"]
|
||||
and hasattr(self.model_config, "quantization_config")
|
||||
and self.model_config.quantization_config["quant_method"]
|
||||
in ["gptq", "awq", "bitsandbytes"]
|
||||
and getattr(self.model_config.quantization_config, "quant_method")
|
||||
in ["gptq", "awq", "bitsandbytes", "hqq"]
|
||||
):
|
||||
if self.model_config.quantization_config["quant_method"] == "gptq":
|
||||
self.model_kwargs["quantization_config"] = GPTQConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
elif self.model_config.quantization_config["quant_method"] == "awq":
|
||||
self.model_kwargs["quantization_config"] = AwqConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
elif (
|
||||
self.model_config.quantization_config["quant_method"] == "bitsandbytes"
|
||||
):
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
quant_config_class_dict = {
|
||||
"gptq": GPTQConfig,
|
||||
"awq": AwqConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"hqq": HqqConfig,
|
||||
}
|
||||
|
||||
quant_config_class = quant_config_class_dict[
|
||||
getattr(self.model_config.quantization_config, "quant_method")
|
||||
]
|
||||
self.model_kwargs["quantization_config"] = quant_config_class(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
|
||||
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
||||
bnb_config = {
|
||||
"load_in_4bit": True,
|
||||
@@ -903,6 +904,13 @@ class ModelLoader:
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
elif self.cfg.hqq_nbits:
|
||||
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
|
||||
|
||||
self.model_kwargs["quantization_config"] = HqqConfig(
|
||||
get_hqq_quant_config_kwargs(self.cfg)
|
||||
)
|
||||
|
||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
||||
self.model_kwargs.pop("load_in_8bit", None)
|
||||
|
||||
@@ -44,6 +44,7 @@ from axolotl.utils.schemas.model import (
|
||||
)
|
||||
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||
from axolotl.utils.schemas.quant import HQQConfig
|
||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
@@ -59,6 +60,7 @@ class AxolotlInputConfig(
|
||||
ModelOutputConfig,
|
||||
LoraConfig,
|
||||
ReLoRAConfig,
|
||||
HQQConfig,
|
||||
HyperparametersConfig,
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
@@ -377,6 +379,18 @@ class AxolotlInputConfig(
|
||||
raise ValueError(f"Only one of {', '.join(fields)} must be set")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_hqq_config_redundancy(cls, data):
|
||||
if (data.get("load_in_4bit") or data.get("load_in_8bit")) and data.get(
|
||||
"hqq_nbits"
|
||||
):
|
||||
|
||||
raise ValueError(
|
||||
"Can't simultaneously set `hqq` configurations and `load_in_4bit` / `load_in_8bit`"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_batch_size_fields(cls, data):
|
||||
|
||||
49
src/axolotl/utils/schemas/quant.py
Normal file
49
src/axolotl/utils/schemas/quant.py
Normal file
@@ -0,0 +1,49 @@
|
||||
""" "
|
||||
Takes care of quantization configuration
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class HQQConfig(BaseModel):
|
||||
"""HQQ configuration subset"""
|
||||
|
||||
hqq_nbits: Literal[8, 4, 3, 2, 1] | None = None
|
||||
hqq_group_size: int | None = None
|
||||
hqq_target_module: list[str] | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_hqq_config_fields(cls, data):
|
||||
fields = ("hqq_nbits", "hqq_group_size")
|
||||
non_empty_count = sum(1 for field in fields if data.get(field))
|
||||
if non_empty_count == 1 or (
|
||||
data.get("'hqq_target_module") and non_empty_count < 2
|
||||
):
|
||||
raise ValueError(
|
||||
"If using HQQ, must set both `hqq_nbits` and `hqq_group_size`"
|
||||
)
|
||||
|
||||
|
||||
def get_hqq_quant_config_kwargs(cfg):
|
||||
|
||||
# If no target module is specified, then target the whole model
|
||||
if cfg.hqq_module_name is None:
|
||||
return {
|
||||
"nbits": cfg.hqq_nbits,
|
||||
"group_size": cfg.hqq_group_size,
|
||||
}
|
||||
|
||||
hqq_target_module = cfg.hqq_target_module
|
||||
if not isinstance(cfg.hqq_target_module, list):
|
||||
hqq_target_module = [hqq_target_module]
|
||||
|
||||
hqq_quant_config_kwargs = {"dynamic_config": {}}
|
||||
for module in hqq_target_module:
|
||||
hqq_quant_config_kwargs["dynamic_config"][module] = {
|
||||
"nbits": cfg.hqq_nbits,
|
||||
"group_size": cfg.hqq_group_size,
|
||||
}
|
||||
return hqq_quant_config_kwargs
|
||||
Reference in New Issue
Block a user