WIP quant config refactor
This commit is contained in:
committed by
Sung Ching Liu
parent
f6f5f89c6d
commit
aba484de97
@@ -60,7 +60,6 @@ class AxolotlInputConfig(
|
||||
ModelOutputConfig,
|
||||
LoraConfig,
|
||||
ReLoRAConfig,
|
||||
QuantizationConfig,
|
||||
HyperparametersConfig,
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
@@ -85,6 +84,8 @@ class AxolotlInputConfig(
|
||||
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||
shrink_embeddings: bool | None = None
|
||||
|
||||
quantization: QuantizationConfig | None = None
|
||||
|
||||
rl: RLType | None = None
|
||||
trl: TRLConfig | None = Field(
|
||||
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Takes care of quantization configuration
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from annotated_types import MinLen
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
@@ -11,8 +11,8 @@ from pydantic import BaseModel, Field, model_validator
|
||||
class HQQConfig(BaseModel):
|
||||
"""HQQ configuration subset"""
|
||||
|
||||
nbits: int | None = Field(default=None)
|
||||
group_size: int | None = Field(default=None)
|
||||
nbits: Literal[8, 4, 3, 2, 1]
|
||||
group_size: int = Field(default=64)
|
||||
target_modules: list[str] | str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -25,7 +25,9 @@ class QuantizationConfig(BaseModel):
|
||||
"""Over all Quantization configuration subset"""
|
||||
|
||||
# We will use this class as base future refactoring of all quantization configs
|
||||
use_hqq: bool = False
|
||||
backend: Literal["bnb", "hqq", "gptq"] | None = None
|
||||
bits: int | None = None
|
||||
bnb_config: dict[str, Any] | None = None
|
||||
hqq_config: Annotated[list[HQQConfig], MinLen(1)] | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
Reference in New Issue
Block a user