diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fb9fb4bf6..0c4f1cf80 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -904,7 +904,7 @@ class ModelLoader: **bnb_config, ) - elif self.cfg.hqq_nbits: + elif self.cfg.use_hqq: from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs self.model_kwargs["quantization_config"] = HqqConfig( @@ -1044,6 +1044,9 @@ class ModelLoader: config=self.model_config, ) else: + if self.cfg.use_hqq: + # if using hqq, we need to set device_map to gpu otherwise the loading get stuck + self.model_kwargs["device_map"] = "auto" self.model = self.auto_model_loader.from_pretrained( self.base_model, config=self.model_config, @@ -1198,7 +1201,7 @@ class ModelLoader: if ( not skip_prepare_model_for_kbit_training and self.cfg.adapter in ["lora", "qlora"] - and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq_nbits) + and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.use_hqq) ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") self.model = prepare_model_for_kbit_training( diff --git a/src/axolotl/utils/schemas/quant.py b/src/axolotl/utils/schemas/quant.py index 9297c1640..0b9d2dd0f 100644 --- a/src/axolotl/utils/schemas/quant.py +++ b/src/axolotl/utils/schemas/quant.py @@ -2,50 +2,69 @@ Takes care of quantization configuration """ -from typing import Literal +from typing import Annotated -from pydantic import BaseModel, model_validator +from annotated_types import MinLen +from pydantic import BaseModel, Field, model_validator class HQQConfig(BaseModel): """HQQ configuration subset""" - hqq_nbits: Literal[8, 4, 3, 2, 1] | None = None - hqq_group_size: int | None = None - hqq_target_module: list[str] | None = None + n_bits: int | None = Field(default=None) + group_size: int | None = Field(default=None) + target_modules: list[str] | str | None = Field( + default=None, + json_schema_extra={ + "description": "Target modules for HQQ quantization. If not specified, the whole model will be quantized." + }, + ) + + +class QuantizationConfig(BaseModel): + """Over all Quantization configuration subset""" + + # We will use this class as base future refactoring of all quantization configs + use_hqq: bool = False + hqq_config: Annotated[list[HQQConfig], MinLen(1)] | None = None @model_validator(mode="before") @classmethod - def check_hqq_config_fields(cls, data): - fields = ("hqq_nbits", "hqq_group_size") - non_empty_count = sum(1 for field in fields if data.get(field)) - if non_empty_count == 1 or ( - data.get("'hqq_target_module") and non_empty_count < 2 - ): + def check_hqq_config(cls, data): + if data.get("use_hqq") and not data.get("hqq_config"): raise ValueError( - "If using HQQ, must set both `hqq_nbits` and `hqq_group_size`" + "If using HQQ, must set `hqq_config` to a list of HQQConfig objects" ) + if data.get("hqq_config") and len(data.get("hqq_config")) > 1: + for hqq_config in data.get("hqq_config"): + if hqq_config.get("target_modules") is None: + raise ValueError( + "If using HQQ, `target_modules` must be specified for each HQQConfig object" + ) + return data def get_hqq_quant_config_kwargs(cfg): # If no target module is specified, then target the whole model - if cfg.hqq_target_module is None: + if len(cfg.hqq_config) == 1 and cfg.hqq_config[0].target_modules is None: return { - "nbits": cfg.hqq_nbits, - "group_size": cfg.hqq_group_size, + "nbits": cfg.hqq_config[0].n_bits, + "group_size": cfg.hqq_config[0].group_size, } - hqq_target_module = cfg.hqq_target_module - if not isinstance(cfg.hqq_target_module, list): - hqq_target_module = [hqq_target_module] - hqq_quant_config_kwargs = {"dynamic_config": {}} - for module in hqq_target_module: - hqq_quant_config_kwargs["dynamic_config"][module] = { - "nbits": cfg.hqq_nbits, - "group_size": cfg.hqq_group_size, - } + for hqq_config in cfg.hqq_config: + target_modules = hqq_config.target_modules + if not isinstance(target_modules, list): + target_modules = [target_modules] + + for module in target_modules: + hqq_quant_config_kwargs["dynamic_config"][module] = { + "nbits": hqq_config.hqq_nbits, + "group_size": hqq_config.hqq_group_size, + } + return hqq_quant_config_kwargs