diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d7105daba..5c54c7998 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -36,6 +36,7 @@ from transformers import ( BitsAndBytesConfig, Gemma3ForConditionalGeneration, GPTQConfig, + HqqConfig, Llama4ForConditionalGeneration, LlavaForConditionalGeneration, Mistral3ForConditionalGeneration, @@ -852,23 +853,23 @@ class ModelLoader: if ( self.cfg.adapter in ["qlora", "lora"] and hasattr(self.model_config, "quantization_config") - and self.model_config.quantization_config["quant_method"] - in ["gptq", "awq", "bitsandbytes"] + and getattr(self.model_config.quantization_config, "quant_method") + in ["gptq", "awq", "bitsandbytes", "hqq"] ): - if self.model_config.quantization_config["quant_method"] == "gptq": - self.model_kwargs["quantization_config"] = GPTQConfig( - **self.model_config.quantization_config - ) - elif self.model_config.quantization_config["quant_method"] == "awq": - self.model_kwargs["quantization_config"] = AwqConfig( - **self.model_config.quantization_config - ) - elif ( - self.model_config.quantization_config["quant_method"] == "bitsandbytes" - ): - self.model_kwargs["quantization_config"] = BitsAndBytesConfig( - **self.model_config.quantization_config - ) + quant_config_class_dict = { + "gptq": GPTQConfig, + "awq": AwqConfig, + "bitsandbytes": BitsAndBytesConfig, + "hqq": HqqConfig, + } + + quant_config_class = quant_config_class_dict[ + getattr(self.model_config.quantization_config, "quant_method") + ] + self.model_kwargs["quantization_config"] = quant_config_class( + **self.model_config.quantization_config + ) + elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: bnb_config = { "load_in_4bit": True, @@ -903,6 +904,13 @@ class ModelLoader: **bnb_config, ) + elif self.cfg.hqq_nbits: + from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs + + self.model_kwargs["quantization_config"] = HqqConfig( + get_hqq_quant_config_kwargs(self.cfg) + ) + # no longer needed per https://github.com/huggingface/transformers/pull/26610 if "quantization_config" in self.model_kwargs or self.cfg.gptq: self.model_kwargs.pop("load_in_8bit", None) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 732ae60cf..b3625ee25 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -44,6 +44,7 @@ from axolotl.utils.schemas.model import ( ) from axolotl.utils.schemas.multimodal import MultiModalConfig from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig +from axolotl.utils.schemas.quant import HQQConfig from axolotl.utils.schemas.training import HyperparametersConfig from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.vllm import VllmConfig @@ -59,6 +60,7 @@ class AxolotlInputConfig( ModelOutputConfig, LoraConfig, ReLoRAConfig, + HQQConfig, HyperparametersConfig, WandbConfig, MLFlowConfig, @@ -377,6 +379,18 @@ class AxolotlInputConfig( raise ValueError(f"Only one of {', '.join(fields)} must be set") return data + @model_validator(mode="before") + @classmethod + def check_hqq_config_redundancy(cls, data): + if (data.get("load_in_4bit") or data.get("load_in_8bit")) and data.get( + "hqq_nbits" + ): + + raise ValueError( + "Can't simultaneously set `hqq` configurations and `load_in_4bit` / `load_in_8bit`" + ) + return data + @model_validator(mode="before") @classmethod def check_batch_size_fields(cls, data): diff --git a/src/axolotl/utils/schemas/quant.py b/src/axolotl/utils/schemas/quant.py new file mode 100644 index 000000000..62c33d99b --- /dev/null +++ b/src/axolotl/utils/schemas/quant.py @@ -0,0 +1,49 @@ +""" " +Takes care of quantization configuration +""" + +from typing import Literal + +from pydantic import BaseModel, model_validator + + +class HQQConfig(BaseModel): + """HQQ configuration subset""" + + hqq_nbits: Literal[8, 4, 3, 2, 1] | None = None + hqq_group_size: int | None = None + hqq_target_module: list[str] | None = None + + @model_validator(mode="before") + @classmethod + def check_hqq_config_fields(cls, data): + fields = ("hqq_nbits", "hqq_group_size") + non_empty_count = sum(1 for field in fields if data.get(field)) + if non_empty_count == 1 or ( + data.get("'hqq_target_module") and non_empty_count < 2 + ): + raise ValueError( + "If using HQQ, must set both `hqq_nbits` and `hqq_group_size`" + ) + + +def get_hqq_quant_config_kwargs(cfg): + + # If no target module is specified, then target the whole model + if cfg.hqq_module_name is None: + return { + "nbits": cfg.hqq_nbits, + "group_size": cfg.hqq_group_size, + } + + hqq_target_module = cfg.hqq_target_module + if not isinstance(cfg.hqq_target_module, list): + hqq_target_module = [hqq_target_module] + + hqq_quant_config_kwargs = {"dynamic_config": {}} + for module in hqq_target_module: + hqq_quant_config_kwargs["dynamic_config"][module] = { + "nbits": cfg.hqq_nbits, + "group_size": cfg.hqq_group_size, + } + return hqq_quant_config_kwargs