diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 641b444e9..e1ccdde97 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -60,7 +60,6 @@ class AxolotlInputConfig( ModelOutputConfig, LoraConfig, ReLoRAConfig, - QuantizationConfig, HyperparametersConfig, WandbConfig, MLFlowConfig, @@ -85,6 +84,8 @@ class AxolotlInputConfig( # optionally shrink the embeddings when the tokenizer vocab size is smaller shrink_embeddings: bool | None = None + quantization: QuantizationConfig | None = None + rl: RLType | None = None trl: TRLConfig | None = Field( default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda diff --git a/src/axolotl/utils/schemas/quant.py b/src/axolotl/utils/schemas/quant.py index 4714d0799..6f8b17b53 100644 --- a/src/axolotl/utils/schemas/quant.py +++ b/src/axolotl/utils/schemas/quant.py @@ -2,7 +2,7 @@ Takes care of quantization configuration """ -from typing import Annotated +from typing import Annotated, Any, Literal from annotated_types import MinLen from pydantic import BaseModel, Field, model_validator @@ -11,8 +11,8 @@ from pydantic import BaseModel, Field, model_validator class HQQConfig(BaseModel): """HQQ configuration subset""" - nbits: int | None = Field(default=None) - group_size: int | None = Field(default=None) + nbits: Literal[8, 4, 3, 2, 1] + group_size: int = Field(default=64) target_modules: list[str] | str | None = Field( default=None, json_schema_extra={ @@ -25,7 +25,9 @@ class QuantizationConfig(BaseModel): """Over all Quantization configuration subset""" # We will use this class as base future refactoring of all quantization configs - use_hqq: bool = False + backend: Literal["bnb", "hqq", "gptq"] | None = None + bits: int | None = None + bnb_config: dict[str, Any] | None = None hqq_config: Annotated[list[HQQConfig], MinLen(1)] | None = None @model_validator(mode="before")