more comprehensive hqq config options
This commit is contained in:
committed by
Sung Ching Liu
parent
f55b3c805b
commit
0c40d12a18
@@ -904,7 +904,7 @@ class ModelLoader:
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
elif self.cfg.hqq_nbits:
|
||||
elif self.cfg.use_hqq:
|
||||
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
|
||||
|
||||
self.model_kwargs["quantization_config"] = HqqConfig(
|
||||
@@ -1044,6 +1044,9 @@ class ModelLoader:
|
||||
config=self.model_config,
|
||||
)
|
||||
else:
|
||||
if self.cfg.use_hqq:
|
||||
# if using hqq, we need to set device_map to gpu otherwise the loading get stuck
|
||||
self.model_kwargs["device_map"] = "auto"
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
@@ -1198,7 +1201,7 @@ class ModelLoader:
|
||||
if (
|
||||
not skip_prepare_model_for_kbit_training
|
||||
and self.cfg.adapter in ["lora", "qlora"]
|
||||
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq_nbits)
|
||||
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.use_hqq)
|
||||
):
|
||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
self.model = prepare_model_for_kbit_training(
|
||||
|
||||
@@ -2,50 +2,69 @@
|
||||
Takes care of quantization configuration
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from annotated_types import MinLen
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class HQQConfig(BaseModel):
|
||||
"""HQQ configuration subset"""
|
||||
|
||||
hqq_nbits: Literal[8, 4, 3, 2, 1] | None = None
|
||||
hqq_group_size: int | None = None
|
||||
hqq_target_module: list[str] | None = None
|
||||
n_bits: int | None = Field(default=None)
|
||||
group_size: int | None = Field(default=None)
|
||||
target_modules: list[str] | str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Target modules for HQQ quantization. If not specified, the whole model will be quantized."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class QuantizationConfig(BaseModel):
|
||||
"""Over all Quantization configuration subset"""
|
||||
|
||||
# We will use this class as base future refactoring of all quantization configs
|
||||
use_hqq: bool = False
|
||||
hqq_config: Annotated[list[HQQConfig], MinLen(1)] | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_hqq_config_fields(cls, data):
|
||||
fields = ("hqq_nbits", "hqq_group_size")
|
||||
non_empty_count = sum(1 for field in fields if data.get(field))
|
||||
if non_empty_count == 1 or (
|
||||
data.get("'hqq_target_module") and non_empty_count < 2
|
||||
):
|
||||
def check_hqq_config(cls, data):
|
||||
if data.get("use_hqq") and not data.get("hqq_config"):
|
||||
raise ValueError(
|
||||
"If using HQQ, must set both `hqq_nbits` and `hqq_group_size`"
|
||||
"If using HQQ, must set `hqq_config` to a list of HQQConfig objects"
|
||||
)
|
||||
|
||||
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
|
||||
for hqq_config in data.get("hqq_config"):
|
||||
if hqq_config.get("target_modules") is None:
|
||||
raise ValueError(
|
||||
"If using HQQ, `target_modules` must be specified for each HQQConfig object"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_hqq_quant_config_kwargs(cfg):
|
||||
|
||||
# If no target module is specified, then target the whole model
|
||||
if cfg.hqq_target_module is None:
|
||||
if len(cfg.hqq_config) == 1 and cfg.hqq_config[0].target_modules is None:
|
||||
return {
|
||||
"nbits": cfg.hqq_nbits,
|
||||
"group_size": cfg.hqq_group_size,
|
||||
"nbits": cfg.hqq_config[0].n_bits,
|
||||
"group_size": cfg.hqq_config[0].group_size,
|
||||
}
|
||||
|
||||
hqq_target_module = cfg.hqq_target_module
|
||||
if not isinstance(cfg.hqq_target_module, list):
|
||||
hqq_target_module = [hqq_target_module]
|
||||
|
||||
hqq_quant_config_kwargs = {"dynamic_config": {}}
|
||||
for module in hqq_target_module:
|
||||
hqq_quant_config_kwargs["dynamic_config"][module] = {
|
||||
"nbits": cfg.hqq_nbits,
|
||||
"group_size": cfg.hqq_group_size,
|
||||
}
|
||||
for hqq_config in cfg.hqq_config:
|
||||
target_modules = hqq_config.target_modules
|
||||
if not isinstance(target_modules, list):
|
||||
target_modules = [target_modules]
|
||||
|
||||
for module in target_modules:
|
||||
hqq_quant_config_kwargs["dynamic_config"][module] = {
|
||||
"nbits": hqq_config.hqq_nbits,
|
||||
"group_size": hqq_config.hqq_group_size,
|
||||
}
|
||||
|
||||
return hqq_quant_config_kwargs
|
||||
|
||||
Reference in New Issue
Block a user