QAT and quantization w/torchao
This commit is contained in:
salman
2025-05-28 12:35:47 +01:00
committed by GitHub
parent 20fda75917
commit 5fca214108
26 changed files with 1372 additions and 13 deletions

View File

@@ -90,6 +90,18 @@ class VllmServeCliArgs:
)
@dataclass
class QuantizeCliArgs:
"""Dataclass with CLI arguments for `axolotl quantize` command."""
base_model: Optional[str] = field(default=None)
weight_dtype: Optional[str] = field(default=None)
activation_dtype: Optional[str] = field(default=None)
quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None)
@dataclass
class EvaluateCliArgs:
"""Dataclass with CLI arguments for `axolotl evaluate` command."""

View File

@@ -17,6 +17,7 @@ import axolotl
from axolotl.cli.args import (
EvaluateCliArgs,
PreprocessCliArgs,
QuantizeCliArgs,
TrainerCliArgs,
VllmServeCliArgs,
)
@@ -333,6 +334,16 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(QuantizeCliArgs)
@filter_none_kwargs
def quantize(config: str, **cli_args: QuantizeCliArgs):
from axolotl.cli.quantize import do_quantize
do_quantize(config, cli_args)
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))

View File

@@ -0,0 +1,90 @@
"""
CLI to post-training quantize a model using torchao
"""
import logging
from pathlib import Path
from typing import Union
from transformers import AutoModelForCausalLM
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
LOG = logging.getLogger(__name__)
def do_quantize(
config: Union[Path, str],
cli_args: dict,
):
"""
Quantizes a model's model's weights
Args:
config (Union[Path, str]): The path to the config file
cli_args (dict): Additional command-line arguments
"""
print_axolotl_text_art()
cfg = load_cfg(config)
if cfg.qat and cfg.quantization:
raise ValueError(
"QAT and quantization cannot be used together. Please specify only one of qat or quantization in your config file."
)
if cfg.qat:
quantize_cfg = cfg.qat
elif cfg.quantization:
quantize_cfg = cfg.quantization
else:
raise ValueError(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("model_path") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype]
else:
weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchIntDType[activation_dtype]
else:
activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size
quantize_embedding = (
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
)
output_dir = cli_args.get("output_dir") or cfg.output_dir
LOG.info(f"Loading model from {model_path}...")
tokenizer = load_tokenizer(cfg)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
LOG.info(
f"Quantizing model with configuration: \n"
f"\tweight_dtype: {weight_dtype}\n"
f"\tactivation_dtype: {activation_dtype}\n"
f"\tgroup_size: {group_size}\n"
f"\tquantize_embedding: {quantize_embedding}"
)
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -79,6 +79,7 @@ from axolotl.utils.callbacks import (
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
@@ -254,6 +255,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):

View File

@@ -191,6 +191,7 @@ class ModelLoader:
self._adjust_model_config()
self._log_memory_usage()
self._configure_embedding_dtypes()
self._configure_qat()
def _resize_token_embeddings(self):
"""Resize token embeddings if needed."""
@@ -305,6 +306,19 @@ class ModelLoader:
before_kbit_train_or_finetune=False,
)
def _configure_qat(self):
"""Configure QAT."""
if self.cfg.qat:
from axolotl.utils.quantization import prepare_model_for_qat
prepare_model_for_qat(
self.model,
self.cfg.qat.weight_dtype,
self.cfg.qat.group_size,
self.cfg.qat.activation_dtype,
self.cfg.qat.quantize_embedding,
)
def _load_adapters(self) -> PeftConfig | None:
"""Load LoRA or other adapters."""
# Load LoRA or adapter

View File

@@ -80,9 +80,9 @@ class PatchManager:
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
patch_accelerate_fsdp_utils()
patch_accelerate_fsdp2()
def _apply_adapter_patches(self):
"""Apply patches for adapter configurations."""

View File

@@ -1,5 +1,5 @@
"""
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
"""
import logging
@@ -52,7 +52,146 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
model.load_state_dict(sharded_sd, assign=True)
def patch_accelerate_fsdp_utils():
def set_state_dict_type(self, state_dict_type=None):
"""
Set the state dict config based on the `StateDictType`.
"""
import os
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
FullStateDictConfig,
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
StateDictType,
)
# Override the state_dict_type if provided, typical use case:
# user trains with sharded, but final save is with full
if state_dict_type is not None:
self.state_dict_type = state_dict_type
if self.state_dict_type is None:
self.state_dict_type = os.environ.get(
"FSDP_STATE_DICT_TYPE",
"FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT",
)
if isinstance(self.state_dict_type, str):
if self.state_dict_type.isdigit():
self.state_dict_type = StateDictType(int(self.state_dict_type))
else:
self.state_dict_type = StateDictType[self.state_dict_type.upper()]
if self.state_dict_type == StateDictType.FULL_STATE_DICT:
if self.state_dict_config is None:
self.state_dict_config = FullStateDictConfig(
offload_to_cpu=True, rank0_only=True
)
if self.optim_state_dict_config is None:
self.optim_state_dict_config = FullOptimStateDictConfig(
offload_to_cpu=True, rank0_only=True
)
elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT:
if self.state_dict_config is None:
self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
if self.optim_state_dict_config is None:
self.optim_state_dict_config = ShardedOptimStateDictConfig(
offload_to_cpu=True
)
def get_state_dict(self, model, unwrap=True):
"""
Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full
precision.
Args:
model (`torch.nn.Module`):
A PyTorch model sent through [`Accelerator.prepare`]
unwrap (`bool`, *optional*, defaults to `True`):
Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict
Returns:
`dict`: The state dictionary of the model potentially without full precision.
Example:
```python
>>> import torch
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> net = torch.nn.Linear(2, 2)
>>> net = accelerator.prepare(net)
>>> state_dict = accelerator.get_state_dict(net)
```
"""
from accelerate import DistributedType
from accelerate.utils import compare_versions
if self.distributed_type == DistributedType.DEEPSPEED:
zero3_sharding = self.deepspeed_config["zero_optimization"]["stage"] == 3
tp_sharding = (
self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1
)
if zero3_sharding or tp_sharding:
if model.zero_gather_16bit_weights_on_model_save():
if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"):
raise ImportError(
"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`."
)
state_dict = (
model._consolidated_16bit_state_dict() # pylint: disable=protected-access
if tp_sharding
else model._zero3_consolidated_16bit_state_dict() # pylint: disable=protected-access
)
else:
raise ValueError(
"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
"set `zero3_save_16bit_model` to True when using `accelerate config`. "
"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
)
else:
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
state_dict = clone_tensors_for_torch_save(
self.unwrap_model(model).state_dict()
)
elif self.is_fsdp2:
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
state_dict = {}
sharded_state_dict = model.state_dict()
for param_name, param in sharded_state_dict.items():
if param.is_cpu:
param = param.to(torch.device("cuda"))
param = param.full_tensor()
if torch.distributed.get_rank() == 0:
state_dict[param_name] = param.cpu()
torch.distributed.barrier()
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
full_state_dict_config = FullStateDictConfig(
offload_to_cpu=True, rank0_only=True
)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, full_state_dict_config
):
state_dict = model.state_dict()
else:
if unwrap:
model = self.unwrap_model(model)
state_dict = model.state_dict()
return state_dict
def patch_accelerate_fsdp2():
import accelerate
from accelerate.utils import fsdp_utils
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
@@ -61,3 +200,19 @@ def patch_accelerate_fsdp_utils():
"fsdp2_load_full_state_dict",
fsdp2_load_full_state_dict,
)
accelerate.Accelerator.get_state_dict = get_state_dict
setattr(
sys.modules["accelerate"],
"Accelerator.get_state_dict",
get_state_dict,
)
accelerate.utils.dataclasses.FullyShardedDataParallelPlugin.set_state_dict_type = (
set_state_dict_type
)
setattr(
sys.modules["accelerate.utils.dataclasses"],
"FullyShardedDataParallelPlugin.set_state_dict_type",
set_state_dict_type,
)

View File

@@ -238,13 +238,27 @@ def save_trained_model(
model: The trained model to save.
safe_serialization: Whether to use safe serialization.
"""
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
# Post training module hooks
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
# handle QAT
if cfg.qat:
from axolotl.utils.quantization import convert_qat_model_for_ptq
LOG.info("Processing QAT model for saving...")
convert_qat_model_for_ptq(
model,
quantize_embedding=cfg.qat.quantize_embedding,
)
LOG.info(
"QAT modules have been converted for PTQ. Please ensure you quantize "
"your model weights with `axolotl quantize`."
)
# Handle FSDP state dict type
state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled and str(cfg.fsdp_config.fsdp_version) != "2":
@@ -321,6 +335,8 @@ def save_trained_model(
save_compressed=cfg.llmcompressor.save_compressed,
)
LOG.info(f"Model successfully saved to {cfg.output_dir}")
def create_model_card(cfg: DictDefault, trainer: Trainer):
"""

View File

@@ -0,0 +1,50 @@
"""QAT Callback for HF Causal Trainer"""
import logging
from functools import partial
from torch import nn
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
from transformers import TrainerCallback
from axolotl.utils.schemas.quantization import QATConfig
LOG = logging.getLogger(__name__)
def toggle_fake_quant(mod: nn.Module, enable: bool):
"""
Toggle fake quantization for any fake quantized linear or embedding layers in the model.
Args:
mod: The module to toggle fake quantization for.
enable: Whether to enable or disable fake quantization.
"""
if isinstance(mod, (FakeQuantizedLinear, FakeQuantizedEmbedding)):
if (
isinstance(mod, FakeQuantizedLinear)
and mod.activation_fake_quantizer is not None
):
mod.activation_fake_quantizer.enabled = enable
mod.weight_fake_quantizer.enabled = enable
class QATCallback(TrainerCallback):
"""
Callback to toggle fake quantization for the model.
"""
def __init__(self, cfg: QATConfig):
self.cfg = cfg
def on_step_begin(
self, args, state, control, model, **kwargs
): # pylint: disable=unused-argument
if self.cfg.fake_quant_after_n_steps is not None:
if state.global_step == 0:
LOG.info(f"Disabling fake quantization at step {state.global_step}")
model.apply(partial(toggle_fake_quant, enable=False))
elif state.global_step == self.cfg.fake_quant_after_n_steps:
LOG.info(f"Enabling fake quantization at step {state.global_step}")
model.apply(partial(toggle_fake_quant, enable=True))

View File

@@ -103,7 +103,12 @@ def cleanup_distributed():
termination or when training successfully completes.
"""
# Ensure that all operations are completed before destroying the process group
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
if torch.xpu.is_available():
torch.xpu.synchronize()
# Destroy the process group
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()

View File

@@ -0,0 +1,189 @@
"""
Utilities for quantization including QAT and PTQ using torchao.
"""
import logging
import torch
from torch import nn
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.quantization.qat import (
FakeQuantizeConfig,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.quant_api import (
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
_is_linear,
)
from axolotl.utils.schemas.enums import TorchIntDType
LOG = logging.getLogger(__name__)
def get_ptq_config(
weight_dtype: TorchIntDType,
activation_dtype: TorchIntDType | None = None,
group_size: int | None = None,
) -> AOBaseConfig:
"""
This function is used to build a post-training quantization config.
Args:
weight_dtype: The dtype to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
group_size: The group size to use for weight quantization.
Returns:
The post-training quantization config.
Raises:
ValueError: If the activation dtype is not specified and the weight dtype is not int8 or int4,
or if the group size is not specified for int8 or int4 weight only quantization.
"""
if activation_dtype is None:
if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr]
return UIntXWeightOnlyConfig(
dtype=weight_dtype.value,
group_size=group_size,
set_inductor_config=False,
)
if weight_dtype == TorchIntDType.int8:
if group_size is None:
raise ValueError(
"group_size must be specified for int8 weight only quantization"
)
return Int8WeightOnlyConfig(
group_size=group_size,
)
if weight_dtype == TorchIntDType.int4:
if group_size is None:
raise ValueError(
"group_size must be specified for int4 weight only quantization"
)
return Int4WeightOnlyConfig(
group_size=group_size,
)
if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4:
return Int4DynamicActivationInt4WeightConfig()
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8:
return Int8DynamicActivationInt8WeightConfig()
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4:
return Int8DynamicActivationInt4WeightConfig()
raise ValueError(
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
)
def prepare_model_for_qat(
model,
weight_dtype: TorchIntDType,
group_size: int,
activation_dtype: TorchIntDType | None = None,
quantize_embedding: bool = False,
):
"""
This function is used to prepare a model for QAT by swapping the model's linear
layers with fake quantized linear layers, and optionally the embedding weights with
fake quantized embedding weights.
Args:
model: The model to quantize.
weight_dtype: The dtype to use for weight quantization.
group_size: The group size to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
quantize_embedding: Whether to quantize the model's embedding weights.
Raises:
ValueError: If the activation/weight dtype combination is invalid.
"""
if activation_dtype:
activation_config = FakeQuantizeConfig(
dtype=activation_dtype.value, granularity="per_token", is_symmetric=False
)
weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size)
linear_quantize_config = IntXQuantizationAwareTrainingConfig(
activation_config=None if activation_dtype is None else activation_config,
weight_config=weight_config,
)
quantize_(model, linear_quantize_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
embedding_quantize_config = IntXQuantizationAwareTrainingConfig(
activation_config=None,
weight_config=weight_config,
)
quantize_(
model,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def quantize_model_for_ptq(
model,
weight_dtype: TorchIntDType,
group_size: int | None = None,
activation_dtype: TorchIntDType | None = None,
quantize_embedding: bool | None = None,
):
"""
This function is used to quantize a model for post-training quantization.
It swaps the model's linear layers with fake quantized linear layers.
If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights.
Args:
model: The model to quantize.
weight_dtype: The dtype to use for weight quantization.
group_size: The group size to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
linear_ptq_config = get_ptq_config(
weight_dtype=weight_dtype,
activation_dtype=activation_dtype,
group_size=group_size,
)
quantize_(model, linear_ptq_config)
if quantize_embedding:
embedding_quantize_config = get_ptq_config(
weight_dtype=weight_dtype,
activation_dtype=None,
group_size=group_size,
)
quantize_(
model,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def convert_qat_model_for_ptq(
model,
*,
quantize_embedding: bool | None = None,
):
"""
This function is used to convert a swap fake-quantized modules in a model
which has been trained with QAT back to the original modules, ready for PTQ.
Args:
model: The model to convert.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
if quantize_embedding:
def filter_fn(m, _):
return isinstance(m, nn.Embedding) or _is_linear(m)
else:
filter_fn = _is_linear
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)

View File

@@ -44,6 +44,7 @@ from axolotl.utils.schemas.model import (
)
from axolotl.utils.schemas.multimodal import MultiModalConfig
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
from axolotl.utils.schemas.quantization import PTQConfig, QATConfig
from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig
@@ -91,6 +92,8 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
)
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = None
process_reward_model: bool | None = None
num_labels: int | None = None
@@ -126,7 +129,7 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
dataset_processes: int | None = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
dataset_processes: int | None = Field(default=min(32, os.cpu_count() or 1))
dataset_exact_deduplication: bool | None = None
dataset_keep_in_memory: bool | None = None
dataloader_pin_memory: bool | None = None
@@ -1481,3 +1484,42 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return self
@model_validator(mode="before")
@classmethod
def check_qat_config(cls, data):
qat_cfg = data.get("qat", {})
if not qat_cfg:
return data
if data.get("peft"):
raise ValueError("QAT and PEFT cannot be used together.")
if data.get("load_in_8bit"):
raise ValueError("QAT and load_in_8bit cannot be used together.")
if data.get("load_in_4bit"):
raise ValueError("QAT and load_in_4bit cannot be used together.")
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if (
data.get("fsdp")
and data.get("fsdp_config")
and str(data["fsdp_config"].get("fsdp_version")) == "2"
):
if version.parse(torch_version) < version.parse("2.7.0"):
raise ValueError(
"FSDP2 and QAT are not supported on torch version < 2.7.0"
)
if version.parse(torch_version) < version.parse("2.6.0"):
raise ValueError("QAT is not supported on torch version < 2.6.0")
return data

View File

@@ -2,6 +2,22 @@
from enum import Enum
import torch
class TorchIntDType(Enum):
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
uint1 = getattr(torch, "uint1", None) # pylint: disable=invalid-name
uint2 = getattr(torch, "uint2", None) # pylint: disable=invalid-name
uint3 = getattr(torch, "uint3", None) # pylint: disable=invalid-name
uint4 = getattr(torch, "uint4", None) # pylint: disable=invalid-name
uint5 = getattr(torch, "uint5", None) # pylint: disable=invalid-name
uint6 = getattr(torch, "uint6", None) # pylint: disable=invalid-name
uint7 = getattr(torch, "uint7", None) # pylint: disable=invalid-name
int4 = getattr(torch, "int4", None) # pylint: disable=invalid-name
int8 = getattr(torch, "int8", None) # pylint: disable=invalid-name
class RLType(str, Enum):
"""RL trainer type configuration subset"""

View File

@@ -0,0 +1,64 @@
"""
QAT Config Schema
"""
from typing import Any
from pydantic import BaseModel, Field, field_validator
from axolotl.utils.schemas.enums import TorchIntDType
class QATConfig(BaseModel):
"""
QAT Config Schema
"""
activation_dtype: TorchIntDType | None = Field(
default=None, description="Activation dtype"
)
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8, description="Weight dtype"
)
quantize_embedding: bool | None = Field(
default=False, description="Quantize embedding"
)
group_size: int | None = Field(default=32, description="Group size")
fake_quant_after_n_steps: int | None = Field(
default=None, description="Fake quant after n steps"
)
@field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
if v == "int4":
return TorchIntDType.int4
if v == "int8":
return TorchIntDType.int8
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
class PTQConfig(BaseModel):
"""
PTQ Config Schema
"""
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8, description="Weight dtype"
)
activation_dtype: TorchIntDType | None = Field(
default=None, description="Activation dtype"
)
quantize_embedding: bool | None = Field(
default=None, description="Quantize embedding"
)
group_size: int | None = Field(default=32, description="Group size")
@field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
if v == "int4":
return TorchIntDType.int4
if v == "int8":
return TorchIntDType.int8
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")