WIP quant config refactor
This commit is contained in:
committed by
Sung Ching Liu
parent
f6f5f89c6d
commit
aba484de97
@@ -60,7 +60,6 @@ class AxolotlInputConfig(
|
|||||||
ModelOutputConfig,
|
ModelOutputConfig,
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
ReLoRAConfig,
|
ReLoRAConfig,
|
||||||
QuantizationConfig,
|
|
||||||
HyperparametersConfig,
|
HyperparametersConfig,
|
||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
@@ -85,6 +84,8 @@ class AxolotlInputConfig(
|
|||||||
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||||
shrink_embeddings: bool | None = None
|
shrink_embeddings: bool | None = None
|
||||||
|
|
||||||
|
quantization: QuantizationConfig | None = None
|
||||||
|
|
||||||
rl: RLType | None = None
|
rl: RLType | None = None
|
||||||
trl: TRLConfig | None = Field(
|
trl: TRLConfig | None = Field(
|
||||||
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
Takes care of quantization configuration
|
Takes care of quantization configuration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from annotated_types import MinLen
|
from annotated_types import MinLen
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
@@ -11,8 +11,8 @@ from pydantic import BaseModel, Field, model_validator
|
|||||||
class HQQConfig(BaseModel):
|
class HQQConfig(BaseModel):
|
||||||
"""HQQ configuration subset"""
|
"""HQQ configuration subset"""
|
||||||
|
|
||||||
nbits: int | None = Field(default=None)
|
nbits: Literal[8, 4, 3, 2, 1]
|
||||||
group_size: int | None = Field(default=None)
|
group_size: int = Field(default=64)
|
||||||
target_modules: list[str] | str | None = Field(
|
target_modules: list[str] | str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -25,7 +25,9 @@ class QuantizationConfig(BaseModel):
|
|||||||
"""Over all Quantization configuration subset"""
|
"""Over all Quantization configuration subset"""
|
||||||
|
|
||||||
# We will use this class as base future refactoring of all quantization configs
|
# We will use this class as base future refactoring of all quantization configs
|
||||||
use_hqq: bool = False
|
backend: Literal["bnb", "hqq", "gptq"] | None = None
|
||||||
|
bits: int | None = None
|
||||||
|
bnb_config: dict[str, Any] | None = None
|
||||||
hqq_config: Annotated[list[HQQConfig], MinLen(1)] | None = None
|
hqq_config: Annotated[list[HQQConfig], MinLen(1)] | None = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
Reference in New Issue
Block a user