From ba8e29c841e8cbdee50dc90515ff5b8b68cc86ab Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Sat, 19 Apr 2025 17:24:02 -0400 Subject: [PATCH] quantization config refactoring - better integration --- src/axolotl/utils/config/__init__.py | 17 ++++++++++ src/axolotl/utils/models.py | 10 +++--- src/axolotl/utils/schemas/config.py | 3 -- src/axolotl/utils/schemas/peft.py | 40 +++++++++++------------ src/axolotl/utils/schemas/quant.py | 47 ++++++++++++++++++++-------- 5 files changed, 74 insertions(+), 43 deletions(-) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index b527dce08..95afae5c7 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -236,6 +236,23 @@ def normalize_config(cfg): log_gpu_memory_usage(LOG, "baseline", cfg.device) + if cfg.quantization: + if cfg.quantization.backend in ["bnb"]: + if cfg.quantization.bits == 8: + cfg.load_in_8bit = True + elif cfg.quantization.bits == 4: + cfg.load_in_4bit = True + + elif cfg.quantization.backend == "gptq": + cfg.gptq = True + elif cfg.quantization.backend == "hqq": + cfg.hqq = True + + if cfg.hqq and not cfg.quantization.hqq_config: + raise ValueError( + "If using HQQ, must set `hqq_config` to a list of HQQConfig objects" + ) + def normalize_cfg_datasets(cfg): """ diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7c2877180..4c2a2ef5c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -887,8 +887,8 @@ class ModelLoader: # but deepspeed needs this still in bfloat16 bnb_config["bnb_4bit_quant_storage"] = torch.float32 - if self.cfg.bnb_config_kwargs: - bnb_config.update(self.cfg.bnb_config_kwargs) + if self.cfg.quantization.bnb_config_kwargs: + bnb_config.update(self.cfg.quantization.bnb_config_kwargs) self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, @@ -904,7 +904,7 @@ class ModelLoader: **bnb_config, ) - elif self.cfg.use_hqq: + elif self.cfg.hqq: from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs self.model_kwargs["quantization_config"] = HqqConfig( @@ -1044,7 +1044,7 @@ class ModelLoader: config=self.model_config, ) else: - if self.cfg.use_hqq: + if self.cfg.hqq: # if using hqq, we need to set device_map to gpu otherwise the loading get stuck self.model_kwargs["device_map"] = "auto" self.model = self.auto_model_loader.from_pretrained( @@ -1201,7 +1201,7 @@ class ModelLoader: if ( not skip_prepare_model_for_kbit_training and self.cfg.adapter in ["lora", "qlora"] - and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.use_hqq) + and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq) ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") self.model = prepare_model_for_kbit_training( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e1ccdde97..6455fbeb9 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -44,7 +44,6 @@ from axolotl.utils.schemas.model import ( ) from axolotl.utils.schemas.multimodal import MultiModalConfig from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig -from axolotl.utils.schemas.quant import QuantizationConfig from axolotl.utils.schemas.training import HyperparametersConfig from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.vllm import VllmConfig @@ -84,8 +83,6 @@ class AxolotlInputConfig( # optionally shrink the embeddings when the tokenizer vocab size is smaller shrink_embeddings: bool | None = None - quantization: QuantizationConfig | None = None - rl: RLType | None = None trl: TRLConfig | None = Field( default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index c05cc2ce6..98f94fcb2 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -1,9 +1,9 @@ """Pydantic models for PEFT-related configuration""" -from typing import Any - from pydantic import BaseModel, Field, field_validator, model_validator +from axolotl.utils.schemas.quant import QuantizationConfig + class LoftQConfig(BaseModel): """LoftQ configuration subset""" @@ -23,8 +23,11 @@ class PeftConfig(BaseModel): class LoraConfig(BaseModel): """Peft / LoRA configuration subset""" - load_in_8bit: bool | None = Field(default=False) - load_in_4bit: bool | None = Field(default=False) + quantization: QuantizationConfig | None = None + load_in_4bit: bool | None = None # for internal use + load_in_8bit: bool | None = None # for internal use + hqq: bool | None = None # for internal use + gptq: bool | None = None # for internal use adapter: str | None = None lora_model_dir: str | None = None @@ -50,8 +53,6 @@ class LoraConfig(BaseModel): }, ) lora_on_cpu: bool | None = None - gptq: bool | None = None - bnb_config_kwargs: dict[str, Any] | None = None loraplus_lr_ratio: float | None = Field( default=None, @@ -74,11 +75,11 @@ class LoraConfig(BaseModel): if ( not data.get("adapter") and not data.get("inference") - and (data.get("load_in_8bit") or data.get("load_in_4bit")) + and (data.get("quantization")) ): raise ValueError( - "load_in_8bit and load_in_4bit are not supported without setting an adapter for training." - "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." + "Quantization is not supported without setting an adapter for training." + "If you want to full finetune, please turn off Quantization." ) return data @@ -87,24 +88,19 @@ class LoraConfig(BaseModel): if self.adapter == "qlora": if self.merge_lora: # can't merge qlora if loaded in 8bit or 4bit - if self.load_in_8bit: - raise ValueError("Can't merge qlora if loaded in 8bit") + if self.quantization: + raise ValueError("Can't merge qlora if loaded in quantized model") - if self.gptq: - raise ValueError("Can't merge qlora if gptq") - - if self.load_in_4bit and not self.use_hqq: - raise ValueError("Can't merge qlora if loaded in 4bit") + if self.quantization.backend == "gptq": + raise ValueError("Can't merge qlora if using gptq") else: - if self.load_in_8bit: - raise ValueError("Can't load qlora in 8bit") + if self.quantization.bits >= 4: + raise ValueError("Can't load qlora in >4 bit") - if self.gptq: - raise ValueError("Can't load qlora if gptq") + if self.quantization.backend == "gptq": + raise ValueError("Can't load qlora if using gptq") - if not self.load_in_4bit and not self.use_hqq: - raise ValueError("Require cfg.load_in_4bit to be True for qlora") return self @field_validator("loraplus_lr_embedding") diff --git a/src/axolotl/utils/schemas/quant.py b/src/axolotl/utils/schemas/quant.py index e4a65ea80..56e0d99f7 100644 --- a/src/axolotl/utils/schemas/quant.py +++ b/src/axolotl/utils/schemas/quant.py @@ -11,7 +11,13 @@ from pydantic import BaseModel, Field, model_validator class HQQConfig(BaseModel): """HQQ configuration subset""" - nbits: Literal[8, 4, 3, 2, 1] + nbits: Literal[8, 4, 3, 2, 1] | None = Field( + default=None, + json_schema_extra={ + "description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1." + }, + ) + group_size: int = Field(default=64) target_modules: list[str] | str | None = Field( default=None, @@ -26,23 +32,21 @@ class QuantizationConfig(BaseModel): # We will use this class as base future refactoring of all quantization configs backend: Literal["bnb", "hqq", "gptq"] | None = None - bits: int | None = None - bnb_config: dict[str, Any] | None = None + bits: Literal[8, 4, 3, 2, 1] | None = None + bnb_config_kwargs: dict[str, Any] | None = None hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None @model_validator(mode="before") @classmethod def check_hqq_config(cls, data): - if data.get("use_hqq") and not data.get("hqq_config"): - raise ValueError( - "If using HQQ, must set `hqq_config` to a list of HQQConfig objects" - ) + if data.get("backend") == "hqq" and not data.get("hqq_config"): + raise ValueError("If using HQQ, must set `group_size` under `hqq_config`") if data.get("hqq_config") and len(data.get("hqq_config")) > 1: for hqq_config in data.get("hqq_config"): if hqq_config.get("target_modules") is None: raise ValueError( - "If using HQQ, `target_modules` must be specified for each HQQConfig object" + "For list of hqq configs, `target_modules` must be specified for each" ) return data @@ -51,21 +55,38 @@ class QuantizationConfig(BaseModel): def get_hqq_quant_config_kwargs(cfg): # If no target module is specified, then target the whole model - if len(cfg.hqq_config) == 1 and cfg.hqq_config[0].target_modules is None: + if not isinstance(cfg.quantization.hqq_config, list): + cfg.quantization.hqq_config = [cfg.quantization.hqq_config] + + if ( + len(cfg.quantization.hqq_config) == 1 + and cfg.quantization.hqq_config[0].target_modules is None + ): + + nbits = ( + cfg.quantization.hqq_config[0].nbits + if cfg.quantization.hqq_config[0].nbits is not None + else cfg.quantization.bits + ) + return { - "nbits": cfg.hqq_config[0].nbits, - "group_size": cfg.hqq_config[0].group_size, + "nbits": nbits, + "group_size": cfg.quantization.hqq_config[0].group_size, } hqq_quant_config_kwargs = {"dynamic_config": {}} - for hqq_config in cfg.hqq_config: + for hqq_config in cfg.quantization.hqq_config: + nbits = ( + hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits + ) + target_modules = hqq_config.target_modules if not isinstance(target_modules, list): target_modules = [target_modules] for module in target_modules: hqq_quant_config_kwargs["dynamic_config"][module] = { - "nbits": hqq_config.nbits, + "nbits": nbits, "group_size": hqq_config.group_size, }