WIP quant config refactor

This commit is contained in:
Sunny Liu
2025-04-19 01:32:36 -04:00
committed by Sung Ching Liu
parent f6f5f89c6d
commit aba484de97
2 changed files with 8 additions and 5 deletions

View File

@@ -60,7 +60,6 @@ class AxolotlInputConfig(
ModelOutputConfig,
LoraConfig,
ReLoRAConfig,
QuantizationConfig,
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
@@ -85,6 +84,8 @@ class AxolotlInputConfig(
# optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: bool | None = None
quantization: QuantizationConfig | None = None
rl: RLType | None = None
trl: TRLConfig | None = Field(
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda

View File

@@ -2,7 +2,7 @@
Takes care of quantization configuration
"""
from typing import Annotated
from typing import Annotated, Any, Literal
from annotated_types import MinLen
from pydantic import BaseModel, Field, model_validator
@@ -11,8 +11,8 @@ from pydantic import BaseModel, Field, model_validator
class HQQConfig(BaseModel):
"""HQQ configuration subset"""
nbits: int | None = Field(default=None)
group_size: int | None = Field(default=None)
nbits: Literal[8, 4, 3, 2, 1]
group_size: int = Field(default=64)
target_modules: list[str] | str | None = Field(
default=None,
json_schema_extra={
@@ -25,7 +25,9 @@ class QuantizationConfig(BaseModel):
"""Over all Quantization configuration subset"""
# We will use this class as base future refactoring of all quantization configs
use_hqq: bool = False
backend: Literal["bnb", "hqq", "gptq"] | None = None
bits: int | None = None
bnb_config: dict[str, Any] | None = None
hqq_config: Annotated[list[HQQConfig], MinLen(1)] | None = None
@model_validator(mode="before")