WIP quant config refactor

This commit is contained in:
Sunny Liu
2025-04-19 01:32:36 -04:00
committed by Sung Ching Liu
parent f6f5f89c6d
commit aba484de97
2 changed files with 8 additions and 5 deletions

View File

@@ -60,7 +60,6 @@ class AxolotlInputConfig(
ModelOutputConfig, ModelOutputConfig,
LoraConfig, LoraConfig,
ReLoRAConfig, ReLoRAConfig,
QuantizationConfig,
HyperparametersConfig, HyperparametersConfig,
WandbConfig, WandbConfig,
MLFlowConfig, MLFlowConfig,
@@ -85,6 +84,8 @@ class AxolotlInputConfig(
# optionally shrink the embeddings when the tokenizer vocab size is smaller # optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: bool | None = None shrink_embeddings: bool | None = None
quantization: QuantizationConfig | None = None
rl: RLType | None = None rl: RLType | None = None
trl: TRLConfig | None = Field( trl: TRLConfig | None = Field(
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda

View File

@@ -2,7 +2,7 @@
Takes care of quantization configuration Takes care of quantization configuration
""" """
from typing import Annotated from typing import Annotated, Any, Literal
from annotated_types import MinLen from annotated_types import MinLen
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
@@ -11,8 +11,8 @@ from pydantic import BaseModel, Field, model_validator
class HQQConfig(BaseModel): class HQQConfig(BaseModel):
"""HQQ configuration subset""" """HQQ configuration subset"""
nbits: int | None = Field(default=None) nbits: Literal[8, 4, 3, 2, 1]
group_size: int | None = Field(default=None) group_size: int = Field(default=64)
target_modules: list[str] | str | None = Field( target_modules: list[str] | str | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
@@ -25,7 +25,9 @@ class QuantizationConfig(BaseModel):
"""Over all Quantization configuration subset""" """Over all Quantization configuration subset"""
# We will use this class as base future refactoring of all quantization configs # We will use this class as base future refactoring of all quantization configs
use_hqq: bool = False backend: Literal["bnb", "hqq", "gptq"] | None = None
bits: int | None = None
bnb_config: dict[str, Any] | None = None
hqq_config: Annotated[list[HQQConfig], MinLen(1)] | None = None hqq_config: Annotated[list[HQQConfig], MinLen(1)] | None = None
@model_validator(mode="before") @model_validator(mode="before")