hqq integration
This commit is contained in:
committed by
Sung Ching Liu
parent
7651550850
commit
99730ce40a
@@ -36,6 +36,7 @@ from transformers import (
|
|||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
Gemma3ForConditionalGeneration,
|
Gemma3ForConditionalGeneration,
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
|
HqqConfig,
|
||||||
Llama4ForConditionalGeneration,
|
Llama4ForConditionalGeneration,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
Mistral3ForConditionalGeneration,
|
Mistral3ForConditionalGeneration,
|
||||||
@@ -852,23 +853,23 @@ class ModelLoader:
|
|||||||
if (
|
if (
|
||||||
self.cfg.adapter in ["qlora", "lora"]
|
self.cfg.adapter in ["qlora", "lora"]
|
||||||
and hasattr(self.model_config, "quantization_config")
|
and hasattr(self.model_config, "quantization_config")
|
||||||
and self.model_config.quantization_config["quant_method"]
|
and getattr(self.model_config.quantization_config, "quant_method")
|
||||||
in ["gptq", "awq", "bitsandbytes"]
|
in ["gptq", "awq", "bitsandbytes", "hqq"]
|
||||||
):
|
):
|
||||||
if self.model_config.quantization_config["quant_method"] == "gptq":
|
quant_config_class_dict = {
|
||||||
self.model_kwargs["quantization_config"] = GPTQConfig(
|
"gptq": GPTQConfig,
|
||||||
**self.model_config.quantization_config
|
"awq": AwqConfig,
|
||||||
)
|
"bitsandbytes": BitsAndBytesConfig,
|
||||||
elif self.model_config.quantization_config["quant_method"] == "awq":
|
"hqq": HqqConfig,
|
||||||
self.model_kwargs["quantization_config"] = AwqConfig(
|
}
|
||||||
**self.model_config.quantization_config
|
|
||||||
)
|
quant_config_class = quant_config_class_dict[
|
||||||
elif (
|
getattr(self.model_config.quantization_config, "quant_method")
|
||||||
self.model_config.quantization_config["quant_method"] == "bitsandbytes"
|
]
|
||||||
):
|
self.model_kwargs["quantization_config"] = quant_config_class(
|
||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
**self.model_config.quantization_config
|
||||||
**self.model_config.quantization_config
|
)
|
||||||
)
|
|
||||||
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
@@ -903,6 +904,13 @@ class ModelLoader:
|
|||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.cfg.hqq_nbits:
|
||||||
|
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
|
||||||
|
|
||||||
|
self.model_kwargs["quantization_config"] = HqqConfig(
|
||||||
|
get_hqq_quant_config_kwargs(self.cfg)
|
||||||
|
)
|
||||||
|
|
||||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||||
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
||||||
self.model_kwargs.pop("load_in_8bit", None)
|
self.model_kwargs.pop("load_in_8bit", None)
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from axolotl.utils.schemas.model import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||||
|
from axolotl.utils.schemas.quant import HQQConfig
|
||||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||||
from axolotl.utils.schemas.trl import TRLConfig
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
from axolotl.utils.schemas.vllm import VllmConfig
|
from axolotl.utils.schemas.vllm import VllmConfig
|
||||||
@@ -59,6 +60,7 @@ class AxolotlInputConfig(
|
|||||||
ModelOutputConfig,
|
ModelOutputConfig,
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
ReLoRAConfig,
|
ReLoRAConfig,
|
||||||
|
HQQConfig,
|
||||||
HyperparametersConfig,
|
HyperparametersConfig,
|
||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
@@ -377,6 +379,18 @@ class AxolotlInputConfig(
|
|||||||
raise ValueError(f"Only one of {', '.join(fields)} must be set")
|
raise ValueError(f"Only one of {', '.join(fields)} must be set")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_hqq_config_redundancy(cls, data):
|
||||||
|
if (data.get("load_in_4bit") or data.get("load_in_8bit")) and data.get(
|
||||||
|
"hqq_nbits"
|
||||||
|
):
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Can't simultaneously set `hqq` configurations and `load_in_4bit` / `load_in_8bit`"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_batch_size_fields(cls, data):
|
def check_batch_size_fields(cls, data):
|
||||||
|
|||||||
49
src/axolotl/utils/schemas/quant.py
Normal file
49
src/axolotl/utils/schemas/quant.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
""" "
|
||||||
|
Takes care of quantization configuration
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class HQQConfig(BaseModel):
|
||||||
|
"""HQQ configuration subset"""
|
||||||
|
|
||||||
|
hqq_nbits: Literal[8, 4, 3, 2, 1] | None = None
|
||||||
|
hqq_group_size: int | None = None
|
||||||
|
hqq_target_module: list[str] | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_hqq_config_fields(cls, data):
|
||||||
|
fields = ("hqq_nbits", "hqq_group_size")
|
||||||
|
non_empty_count = sum(1 for field in fields if data.get(field))
|
||||||
|
if non_empty_count == 1 or (
|
||||||
|
data.get("'hqq_target_module") and non_empty_count < 2
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"If using HQQ, must set both `hqq_nbits` and `hqq_group_size`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hqq_quant_config_kwargs(cfg):
|
||||||
|
|
||||||
|
# If no target module is specified, then target the whole model
|
||||||
|
if cfg.hqq_module_name is None:
|
||||||
|
return {
|
||||||
|
"nbits": cfg.hqq_nbits,
|
||||||
|
"group_size": cfg.hqq_group_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
hqq_target_module = cfg.hqq_target_module
|
||||||
|
if not isinstance(cfg.hqq_target_module, list):
|
||||||
|
hqq_target_module = [hqq_target_module]
|
||||||
|
|
||||||
|
hqq_quant_config_kwargs = {"dynamic_config": {}}
|
||||||
|
for module in hqq_target_module:
|
||||||
|
hqq_quant_config_kwargs["dynamic_config"][module] = {
|
||||||
|
"nbits": cfg.hqq_nbits,
|
||||||
|
"group_size": cfg.hqq_group_size,
|
||||||
|
}
|
||||||
|
return hqq_quant_config_kwargs
|
||||||
Reference in New Issue
Block a user