quantization config refactoring - better integration

This commit is contained in:
Sunny Liu
2025-04-19 17:24:02 -04:00
committed by Sung Ching Liu
parent 143b2e082c
commit ba8e29c841
5 changed files with 74 additions and 43 deletions

View File

@@ -236,6 +236,23 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.quantization:
if cfg.quantization.backend in ["bnb"]:
if cfg.quantization.bits == 8:
cfg.load_in_8bit = True
elif cfg.quantization.bits == 4:
cfg.load_in_4bit = True
elif cfg.quantization.backend == "gptq":
cfg.gptq = True
elif cfg.quantization.backend == "hqq":
cfg.hqq = True
if cfg.hqq and not cfg.quantization.hqq_config:
raise ValueError(
"If using HQQ, must set `hqq_config` to a list of HQQConfig objects"
)
def normalize_cfg_datasets(cfg): def normalize_cfg_datasets(cfg):
""" """

View File

@@ -887,8 +887,8 @@ class ModelLoader:
# but deepspeed needs this still in bfloat16 # but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32 bnb_config["bnb_4bit_quant_storage"] = torch.float32
if self.cfg.bnb_config_kwargs: if self.cfg.quantization.bnb_config_kwargs:
bnb_config.update(self.cfg.bnb_config_kwargs) bnb_config.update(self.cfg.quantization.bnb_config_kwargs)
self.model_kwargs["quantization_config"] = BitsAndBytesConfig( self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, **bnb_config,
@@ -904,7 +904,7 @@ class ModelLoader:
**bnb_config, **bnb_config,
) )
elif self.cfg.use_hqq: elif self.cfg.hqq:
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
self.model_kwargs["quantization_config"] = HqqConfig( self.model_kwargs["quantization_config"] = HqqConfig(
@@ -1044,7 +1044,7 @@ class ModelLoader:
config=self.model_config, config=self.model_config,
) )
else: else:
if self.cfg.use_hqq: if self.cfg.hqq:
# if using hqq, we need to set device_map to gpu otherwise the loading get stuck # if using hqq, we need to set device_map to gpu otherwise the loading get stuck
self.model_kwargs["device_map"] = "auto" self.model_kwargs["device_map"] = "auto"
self.model = self.auto_model_loader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
@@ -1201,7 +1201,7 @@ class ModelLoader:
if ( if (
not skip_prepare_model_for_kbit_training not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"] and self.cfg.adapter in ["lora", "qlora"]
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.use_hqq) and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq)
): ):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training( self.model = prepare_model_for_kbit_training(

View File

@@ -44,7 +44,6 @@ from axolotl.utils.schemas.model import (
) )
from axolotl.utils.schemas.multimodal import MultiModalConfig from axolotl.utils.schemas.multimodal import MultiModalConfig
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
from axolotl.utils.schemas.quant import QuantizationConfig
from axolotl.utils.schemas.training import HyperparametersConfig from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig from axolotl.utils.schemas.vllm import VllmConfig
@@ -84,8 +83,6 @@ class AxolotlInputConfig(
# optionally shrink the embeddings when the tokenizer vocab size is smaller # optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: bool | None = None shrink_embeddings: bool | None = None
quantization: QuantizationConfig | None = None
rl: RLType | None = None rl: RLType | None = None
trl: TRLConfig | None = Field( trl: TRLConfig | None = Field(
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda

View File

@@ -1,9 +1,9 @@
"""Pydantic models for PEFT-related configuration""" """Pydantic models for PEFT-related configuration"""
from typing import Any
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from axolotl.utils.schemas.quant import QuantizationConfig
class LoftQConfig(BaseModel): class LoftQConfig(BaseModel):
"""LoftQ configuration subset""" """LoftQ configuration subset"""
@@ -23,8 +23,11 @@ class PeftConfig(BaseModel):
class LoraConfig(BaseModel): class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset""" """Peft / LoRA configuration subset"""
load_in_8bit: bool | None = Field(default=False) quantization: QuantizationConfig | None = None
load_in_4bit: bool | None = Field(default=False) load_in_4bit: bool | None = None # for internal use
load_in_8bit: bool | None = None # for internal use
hqq: bool | None = None # for internal use
gptq: bool | None = None # for internal use
adapter: str | None = None adapter: str | None = None
lora_model_dir: str | None = None lora_model_dir: str | None = None
@@ -50,8 +53,6 @@ class LoraConfig(BaseModel):
}, },
) )
lora_on_cpu: bool | None = None lora_on_cpu: bool | None = None
gptq: bool | None = None
bnb_config_kwargs: dict[str, Any] | None = None
loraplus_lr_ratio: float | None = Field( loraplus_lr_ratio: float | None = Field(
default=None, default=None,
@@ -74,11 +75,11 @@ class LoraConfig(BaseModel):
if ( if (
not data.get("adapter") not data.get("adapter")
and not data.get("inference") and not data.get("inference")
and (data.get("load_in_8bit") or data.get("load_in_4bit")) and (data.get("quantization"))
): ):
raise ValueError( raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training." "Quantization is not supported without setting an adapter for training."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit." "If you want to full finetune, please turn off Quantization."
) )
return data return data
@@ -87,24 +88,19 @@ class LoraConfig(BaseModel):
if self.adapter == "qlora": if self.adapter == "qlora":
if self.merge_lora: if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit # can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit: if self.quantization:
raise ValueError("Can't merge qlora if loaded in 8bit") raise ValueError("Can't merge qlora if loaded in quantized model")
if self.gptq: if self.quantization.backend == "gptq":
raise ValueError("Can't merge qlora if gptq") raise ValueError("Can't merge qlora if using gptq")
if self.load_in_4bit and not self.use_hqq:
raise ValueError("Can't merge qlora if loaded in 4bit")
else: else:
if self.load_in_8bit: if self.quantization.bits >= 4:
raise ValueError("Can't load qlora in 8bit") raise ValueError("Can't load qlora in >4 bit")
if self.gptq: if self.quantization.backend == "gptq":
raise ValueError("Can't load qlora if gptq") raise ValueError("Can't load qlora if using gptq")
if not self.load_in_4bit and not self.use_hqq:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self return self
@field_validator("loraplus_lr_embedding") @field_validator("loraplus_lr_embedding")

View File

@@ -11,7 +11,13 @@ from pydantic import BaseModel, Field, model_validator
class HQQConfig(BaseModel): class HQQConfig(BaseModel):
"""HQQ configuration subset""" """HQQ configuration subset"""
nbits: Literal[8, 4, 3, 2, 1] nbits: Literal[8, 4, 3, 2, 1] | None = Field(
default=None,
json_schema_extra={
"description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1."
},
)
group_size: int = Field(default=64) group_size: int = Field(default=64)
target_modules: list[str] | str | None = Field( target_modules: list[str] | str | None = Field(
default=None, default=None,
@@ -26,23 +32,21 @@ class QuantizationConfig(BaseModel):
# We will use this class as base future refactoring of all quantization configs # We will use this class as base future refactoring of all quantization configs
backend: Literal["bnb", "hqq", "gptq"] | None = None backend: Literal["bnb", "hqq", "gptq"] | None = None
bits: int | None = None bits: Literal[8, 4, 3, 2, 1] | None = None
bnb_config: dict[str, Any] | None = None bnb_config_kwargs: dict[str, Any] | None = None
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_hqq_config(cls, data): def check_hqq_config(cls, data):
if data.get("use_hqq") and not data.get("hqq_config"): if data.get("backend") == "hqq" and not data.get("hqq_config"):
raise ValueError( raise ValueError("If using HQQ, must set `group_size` under `hqq_config`")
"If using HQQ, must set `hqq_config` to a list of HQQConfig objects"
)
if data.get("hqq_config") and len(data.get("hqq_config")) > 1: if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
for hqq_config in data.get("hqq_config"): for hqq_config in data.get("hqq_config"):
if hqq_config.get("target_modules") is None: if hqq_config.get("target_modules") is None:
raise ValueError( raise ValueError(
"If using HQQ, `target_modules` must be specified for each HQQConfig object" "For list of hqq configs, `target_modules` must be specified for each"
) )
return data return data
@@ -51,21 +55,38 @@ class QuantizationConfig(BaseModel):
def get_hqq_quant_config_kwargs(cfg): def get_hqq_quant_config_kwargs(cfg):
# If no target module is specified, then target the whole model # If no target module is specified, then target the whole model
if len(cfg.hqq_config) == 1 and cfg.hqq_config[0].target_modules is None: if not isinstance(cfg.quantization.hqq_config, list):
cfg.quantization.hqq_config = [cfg.quantization.hqq_config]
if (
len(cfg.quantization.hqq_config) == 1
and cfg.quantization.hqq_config[0].target_modules is None
):
nbits = (
cfg.quantization.hqq_config[0].nbits
if cfg.quantization.hqq_config[0].nbits is not None
else cfg.quantization.bits
)
return { return {
"nbits": cfg.hqq_config[0].nbits, "nbits": nbits,
"group_size": cfg.hqq_config[0].group_size, "group_size": cfg.quantization.hqq_config[0].group_size,
} }
hqq_quant_config_kwargs = {"dynamic_config": {}} hqq_quant_config_kwargs = {"dynamic_config": {}}
for hqq_config in cfg.hqq_config: for hqq_config in cfg.quantization.hqq_config:
nbits = (
hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits
)
target_modules = hqq_config.target_modules target_modules = hqq_config.target_modules
if not isinstance(target_modules, list): if not isinstance(target_modules, list):
target_modules = [target_modules] target_modules = [target_modules]
for module in target_modules: for module in target_modules:
hqq_quant_config_kwargs["dynamic_config"][module] = { hqq_quant_config_kwargs["dynamic_config"][module] = {
"nbits": hqq_config.nbits, "nbits": nbits,
"group_size": hqq_config.group_size, "group_size": hqq_config.group_size,
} }