quantization config refactoring - better integration
This commit is contained in:
committed by
Sung Ching Liu
parent
143b2e082c
commit
ba8e29c841
@@ -236,6 +236,23 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
if cfg.quantization:
|
||||||
|
if cfg.quantization.backend in ["bnb"]:
|
||||||
|
if cfg.quantization.bits == 8:
|
||||||
|
cfg.load_in_8bit = True
|
||||||
|
elif cfg.quantization.bits == 4:
|
||||||
|
cfg.load_in_4bit = True
|
||||||
|
|
||||||
|
elif cfg.quantization.backend == "gptq":
|
||||||
|
cfg.gptq = True
|
||||||
|
elif cfg.quantization.backend == "hqq":
|
||||||
|
cfg.hqq = True
|
||||||
|
|
||||||
|
if cfg.hqq and not cfg.quantization.hqq_config:
|
||||||
|
raise ValueError(
|
||||||
|
"If using HQQ, must set `hqq_config` to a list of HQQConfig objects"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def normalize_cfg_datasets(cfg):
|
def normalize_cfg_datasets(cfg):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -887,8 +887,8 @@ class ModelLoader:
|
|||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||||
|
|
||||||
if self.cfg.bnb_config_kwargs:
|
if self.cfg.quantization.bnb_config_kwargs:
|
||||||
bnb_config.update(self.cfg.bnb_config_kwargs)
|
bnb_config.update(self.cfg.quantization.bnb_config_kwargs)
|
||||||
|
|
||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
@@ -904,7 +904,7 @@ class ModelLoader:
|
|||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.cfg.use_hqq:
|
elif self.cfg.hqq:
|
||||||
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
|
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
|
||||||
|
|
||||||
self.model_kwargs["quantization_config"] = HqqConfig(
|
self.model_kwargs["quantization_config"] = HqqConfig(
|
||||||
@@ -1044,7 +1044,7 @@ class ModelLoader:
|
|||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.cfg.use_hqq:
|
if self.cfg.hqq:
|
||||||
# if using hqq, we need to set device_map to gpu otherwise the loading get stuck
|
# if using hqq, we need to set device_map to gpu otherwise the loading get stuck
|
||||||
self.model_kwargs["device_map"] = "auto"
|
self.model_kwargs["device_map"] = "auto"
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
@@ -1201,7 +1201,7 @@ class ModelLoader:
|
|||||||
if (
|
if (
|
||||||
not skip_prepare_model_for_kbit_training
|
not skip_prepare_model_for_kbit_training
|
||||||
and self.cfg.adapter in ["lora", "qlora"]
|
and self.cfg.adapter in ["lora", "qlora"]
|
||||||
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.use_hqq)
|
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq)
|
||||||
):
|
):
|
||||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
self.model = prepare_model_for_kbit_training(
|
self.model = prepare_model_for_kbit_training(
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ from axolotl.utils.schemas.model import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||||
from axolotl.utils.schemas.quant import QuantizationConfig
|
|
||||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||||
from axolotl.utils.schemas.trl import TRLConfig
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
from axolotl.utils.schemas.vllm import VllmConfig
|
from axolotl.utils.schemas.vllm import VllmConfig
|
||||||
@@ -84,8 +83,6 @@ class AxolotlInputConfig(
|
|||||||
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||||
shrink_embeddings: bool | None = None
|
shrink_embeddings: bool | None = None
|
||||||
|
|
||||||
quantization: QuantizationConfig | None = None
|
|
||||||
|
|
||||||
rl: RLType | None = None
|
rl: RLType | None = None
|
||||||
trl: TRLConfig | None = Field(
|
trl: TRLConfig | None = Field(
|
||||||
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""Pydantic models for PEFT-related configuration"""
|
"""Pydantic models for PEFT-related configuration"""
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.quant import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
"""LoftQ configuration subset"""
|
"""LoftQ configuration subset"""
|
||||||
@@ -23,8 +23,11 @@ class PeftConfig(BaseModel):
|
|||||||
class LoraConfig(BaseModel):
|
class LoraConfig(BaseModel):
|
||||||
"""Peft / LoRA configuration subset"""
|
"""Peft / LoRA configuration subset"""
|
||||||
|
|
||||||
load_in_8bit: bool | None = Field(default=False)
|
quantization: QuantizationConfig | None = None
|
||||||
load_in_4bit: bool | None = Field(default=False)
|
load_in_4bit: bool | None = None # for internal use
|
||||||
|
load_in_8bit: bool | None = None # for internal use
|
||||||
|
hqq: bool | None = None # for internal use
|
||||||
|
gptq: bool | None = None # for internal use
|
||||||
|
|
||||||
adapter: str | None = None
|
adapter: str | None = None
|
||||||
lora_model_dir: str | None = None
|
lora_model_dir: str | None = None
|
||||||
@@ -50,8 +53,6 @@ class LoraConfig(BaseModel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
lora_on_cpu: bool | None = None
|
lora_on_cpu: bool | None = None
|
||||||
gptq: bool | None = None
|
|
||||||
bnb_config_kwargs: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
loraplus_lr_ratio: float | None = Field(
|
loraplus_lr_ratio: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -74,11 +75,11 @@ class LoraConfig(BaseModel):
|
|||||||
if (
|
if (
|
||||||
not data.get("adapter")
|
not data.get("adapter")
|
||||||
and not data.get("inference")
|
and not data.get("inference")
|
||||||
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
and (data.get("quantization"))
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
"Quantization is not supported without setting an adapter for training."
|
||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
"If you want to full finetune, please turn off Quantization."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -87,24 +88,19 @@ class LoraConfig(BaseModel):
|
|||||||
if self.adapter == "qlora":
|
if self.adapter == "qlora":
|
||||||
if self.merge_lora:
|
if self.merge_lora:
|
||||||
# can't merge qlora if loaded in 8bit or 4bit
|
# can't merge qlora if loaded in 8bit or 4bit
|
||||||
if self.load_in_8bit:
|
if self.quantization:
|
||||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
raise ValueError("Can't merge qlora if loaded in quantized model")
|
||||||
|
|
||||||
if self.gptq:
|
if self.quantization.backend == "gptq":
|
||||||
raise ValueError("Can't merge qlora if gptq")
|
raise ValueError("Can't merge qlora if using gptq")
|
||||||
|
|
||||||
if self.load_in_4bit and not self.use_hqq:
|
|
||||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.load_in_8bit:
|
if self.quantization.bits >= 4:
|
||||||
raise ValueError("Can't load qlora in 8bit")
|
raise ValueError("Can't load qlora in >4 bit")
|
||||||
|
|
||||||
if self.gptq:
|
if self.quantization.backend == "gptq":
|
||||||
raise ValueError("Can't load qlora if gptq")
|
raise ValueError("Can't load qlora if using gptq")
|
||||||
|
|
||||||
if not self.load_in_4bit and not self.use_hqq:
|
|
||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@field_validator("loraplus_lr_embedding")
|
@field_validator("loraplus_lr_embedding")
|
||||||
|
|||||||
@@ -11,7 +11,13 @@ from pydantic import BaseModel, Field, model_validator
|
|||||||
class HQQConfig(BaseModel):
|
class HQQConfig(BaseModel):
|
||||||
"""HQQ configuration subset"""
|
"""HQQ configuration subset"""
|
||||||
|
|
||||||
nbits: Literal[8, 4, 3, 2, 1]
|
nbits: Literal[8, 4, 3, 2, 1] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
group_size: int = Field(default=64)
|
group_size: int = Field(default=64)
|
||||||
target_modules: list[str] | str | None = Field(
|
target_modules: list[str] | str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -26,23 +32,21 @@ class QuantizationConfig(BaseModel):
|
|||||||
|
|
||||||
# We will use this class as base future refactoring of all quantization configs
|
# We will use this class as base future refactoring of all quantization configs
|
||||||
backend: Literal["bnb", "hqq", "gptq"] | None = None
|
backend: Literal["bnb", "hqq", "gptq"] | None = None
|
||||||
bits: int | None = None
|
bits: Literal[8, 4, 3, 2, 1] | None = None
|
||||||
bnb_config: dict[str, Any] | None = None
|
bnb_config_kwargs: dict[str, Any] | None = None
|
||||||
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
|
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_hqq_config(cls, data):
|
def check_hqq_config(cls, data):
|
||||||
if data.get("use_hqq") and not data.get("hqq_config"):
|
if data.get("backend") == "hqq" and not data.get("hqq_config"):
|
||||||
raise ValueError(
|
raise ValueError("If using HQQ, must set `group_size` under `hqq_config`")
|
||||||
"If using HQQ, must set `hqq_config` to a list of HQQConfig objects"
|
|
||||||
)
|
|
||||||
|
|
||||||
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
|
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
|
||||||
for hqq_config in data.get("hqq_config"):
|
for hqq_config in data.get("hqq_config"):
|
||||||
if hqq_config.get("target_modules") is None:
|
if hqq_config.get("target_modules") is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If using HQQ, `target_modules` must be specified for each HQQConfig object"
|
"For list of hqq configs, `target_modules` must be specified for each"
|
||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
@@ -51,21 +55,38 @@ class QuantizationConfig(BaseModel):
|
|||||||
def get_hqq_quant_config_kwargs(cfg):
|
def get_hqq_quant_config_kwargs(cfg):
|
||||||
|
|
||||||
# If no target module is specified, then target the whole model
|
# If no target module is specified, then target the whole model
|
||||||
if len(cfg.hqq_config) == 1 and cfg.hqq_config[0].target_modules is None:
|
if not isinstance(cfg.quantization.hqq_config, list):
|
||||||
|
cfg.quantization.hqq_config = [cfg.quantization.hqq_config]
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(cfg.quantization.hqq_config) == 1
|
||||||
|
and cfg.quantization.hqq_config[0].target_modules is None
|
||||||
|
):
|
||||||
|
|
||||||
|
nbits = (
|
||||||
|
cfg.quantization.hqq_config[0].nbits
|
||||||
|
if cfg.quantization.hqq_config[0].nbits is not None
|
||||||
|
else cfg.quantization.bits
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"nbits": cfg.hqq_config[0].nbits,
|
"nbits": nbits,
|
||||||
"group_size": cfg.hqq_config[0].group_size,
|
"group_size": cfg.quantization.hqq_config[0].group_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
hqq_quant_config_kwargs = {"dynamic_config": {}}
|
hqq_quant_config_kwargs = {"dynamic_config": {}}
|
||||||
for hqq_config in cfg.hqq_config:
|
for hqq_config in cfg.quantization.hqq_config:
|
||||||
|
nbits = (
|
||||||
|
hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits
|
||||||
|
)
|
||||||
|
|
||||||
target_modules = hqq_config.target_modules
|
target_modules = hqq_config.target_modules
|
||||||
if not isinstance(target_modules, list):
|
if not isinstance(target_modules, list):
|
||||||
target_modules = [target_modules]
|
target_modules = [target_modules]
|
||||||
|
|
||||||
for module in target_modules:
|
for module in target_modules:
|
||||||
hqq_quant_config_kwargs["dynamic_config"][module] = {
|
hqq_quant_config_kwargs["dynamic_config"][module] = {
|
||||||
"nbits": hqq_config.nbits,
|
"nbits": nbits,
|
||||||
"group_size": hqq_config.group_size,
|
"group_size": hqq_config.group_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user