Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 (#3107)
This commit is contained in:
@@ -115,6 +115,7 @@ class QuantizeCliArgs:
|
||||
quantize_embedding: Optional[bool] = field(default=None)
|
||||
group_size: Optional[int] = field(default=None)
|
||||
output_dir: Optional[str] = field(default=None)
|
||||
hub_model_id: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -5,12 +5,17 @@ CLI to post-training quantize a model using torchao
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
|
||||
from axolotl.utils.quantization import (
|
||||
TorchAOQuantDType,
|
||||
get_quantization_config,
|
||||
quantization_config_to_str,
|
||||
quantize_model,
|
||||
)
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -43,13 +48,13 @@ def do_quantize(
|
||||
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
||||
)
|
||||
|
||||
model_path = cli_args.get("model_path") or cfg.output_dir
|
||||
model_path = cli_args.get("base_model") or cfg.output_dir
|
||||
if weight_dtype := cli_args.get("weight_dtype"):
|
||||
weight_dtype = TorchIntDType[weight_dtype]
|
||||
weight_dtype = TorchAOQuantDType.from_string(weight_dtype)
|
||||
else:
|
||||
weight_dtype = quantize_cfg.weight_dtype
|
||||
if activation_dtype := cli_args.get("activation_dtype"):
|
||||
activation_dtype = TorchIntDType[activation_dtype]
|
||||
activation_dtype = TorchAOQuantDType.from_string(activation_dtype)
|
||||
else:
|
||||
activation_dtype = quantize_cfg.activation_dtype
|
||||
group_size = cli_args.get("group_size") or quantize_cfg.group_size
|
||||
@@ -57,10 +62,15 @@ def do_quantize(
|
||||
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
|
||||
)
|
||||
output_dir = cli_args.get("output_dir") or cfg.output_dir
|
||||
hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id
|
||||
|
||||
LOG.info(f"Loading model from {model_path}...")
|
||||
LOG.info(f"Loading model from {model_path}.")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", torch_dtype=torch_dtype
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
f"Quantizing model with configuration: \n"
|
||||
@@ -70,11 +80,21 @@ def do_quantize(
|
||||
f"\tquantize_embedding: {quantize_embedding}"
|
||||
)
|
||||
|
||||
quantize_model_for_ptq(
|
||||
quantize_model(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
|
||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
|
||||
quantization_config = get_quantization_config(
|
||||
weight_dtype, activation_dtype, group_size
|
||||
)
|
||||
|
||||
ao_config = TorchAoConfig(
|
||||
quant_type=quantization_config,
|
||||
include_input_output_embeddings=quantize_embedding,
|
||||
)
|
||||
model.config.quantization_config = ao_config
|
||||
|
||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
|
||||
model.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
@@ -86,4 +106,14 @@ def do_quantize(
|
||||
progressbar=True,
|
||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||
)
|
||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
|
||||
|
||||
if hub_model_id:
|
||||
hub_model_id = (
|
||||
hub_model_id.rstrip("-")
|
||||
+ f"-{quantization_config_to_str[type(quantization_config)]}"
|
||||
)
|
||||
model.push_to_hub(hub_model_id, safe_serialization=False)
|
||||
tokenizer.push_to_hub(hub_model_id)
|
||||
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
|
||||
|
||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
|
||||
|
||||
@@ -30,11 +30,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders import (
|
||||
ModelLoader,
|
||||
load_processor,
|
||||
load_tokenizer,
|
||||
)
|
||||
from axolotl.loaders import ModelLoader, load_processor, load_tokenizer
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
@@ -234,16 +230,15 @@ def save_trained_model(
|
||||
|
||||
# handle QAT
|
||||
if cfg.qat:
|
||||
from axolotl.utils.quantization import convert_qat_model_for_ptq
|
||||
from axolotl.utils.quantization import convert_qat_model
|
||||
|
||||
LOG.info("Processing QAT model for saving...")
|
||||
convert_qat_model_for_ptq(
|
||||
convert_qat_model(
|
||||
model,
|
||||
quantize_embedding=cfg.qat.quantize_embedding,
|
||||
)
|
||||
LOG.info(
|
||||
"QAT modules have been converted for PTQ. Please ensure you quantize "
|
||||
"your model weights with `axolotl quantize`."
|
||||
"QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`"
|
||||
" with the same config which you used for training."
|
||||
)
|
||||
# Handle ReLoRA early return case
|
||||
if cfg.relora:
|
||||
@@ -337,9 +332,7 @@ def save_trained_model(
|
||||
|
||||
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||
# TODO: add integration support so this can be implemented completely within the plugin
|
||||
from axolotl.integrations.llm_compressor.utils import (
|
||||
save_compressed_model,
|
||||
)
|
||||
from axolotl.integrations.llm_compressor.utils import save_compressed_model
|
||||
|
||||
save_compressed_model(
|
||||
model=model,
|
||||
|
||||
@@ -3,30 +3,47 @@ Utilities for quantization including QAT and PTQ using torchao.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from packaging import version
|
||||
from torchao.core.config import AOBaseConfig
|
||||
from torchao.quantization import quantize_
|
||||
from torchao.quantization.qat import (
|
||||
FakeQuantizeConfig,
|
||||
FromIntXQuantizationAwareTrainingConfig,
|
||||
IntXQuantizationAwareTrainingConfig,
|
||||
QATConfig,
|
||||
)
|
||||
from torchao.quantization.quant_api import (
|
||||
Int4DynamicActivationInt4WeightConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
Int8DynamicActivationInt4WeightConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
UIntXWeightOnlyConfig,
|
||||
_is_linear,
|
||||
)
|
||||
|
||||
from axolotl.utils.schemas.enums import TorchIntDType
|
||||
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||
|
||||
quantization_config_to_str = {
|
||||
Int8DynamicActivationInt4WeightConfig: "int8int4",
|
||||
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
|
||||
Float8DynamicActivationInt4WeightConfig: "fp8int4",
|
||||
}
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
||||
try:
|
||||
from torchao.prototype.mx_formats import NVFP4InferenceConfig
|
||||
|
||||
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
|
||||
except:
|
||||
pass
|
||||
|
||||
# int4 weight config imports will fail on machines with fbgemm-gpu installed
|
||||
# without a CUDA runtime available so we do this safely
|
||||
try:
|
||||
from torchao.quantization.quant_api import Int4WeightOnlyConfig
|
||||
|
||||
quantization_config_to_str[Int4WeightOnlyConfig] = "int4"
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def get_ptq_config(
|
||||
weight_dtype: TorchIntDType,
|
||||
activation_dtype: TorchIntDType | None = None,
|
||||
def get_quantization_config(
|
||||
weight_dtype: TorchAOQuantDType,
|
||||
activation_dtype: TorchAOQuantDType | None = None,
|
||||
group_size: int | None = None,
|
||||
) -> AOBaseConfig:
|
||||
"""
|
||||
@@ -45,44 +62,101 @@ def get_ptq_config(
|
||||
or if the group size is not specified for int8 or int4 weight only quantization.
|
||||
"""
|
||||
if activation_dtype is None:
|
||||
if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr]
|
||||
return UIntXWeightOnlyConfig(
|
||||
dtype=weight_dtype.value,
|
||||
group_size=group_size,
|
||||
set_inductor_config=False,
|
||||
)
|
||||
if weight_dtype == TorchIntDType.int8:
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
"group_size must be specified for int8 weight only quantization"
|
||||
)
|
||||
return Int8WeightOnlyConfig(
|
||||
group_size=group_size,
|
||||
)
|
||||
if weight_dtype == TorchIntDType.int4:
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
"group_size must be specified for int4 weight only quantization"
|
||||
)
|
||||
return Int4WeightOnlyConfig(
|
||||
group_size=group_size,
|
||||
)
|
||||
if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4:
|
||||
return Int4DynamicActivationInt4WeightConfig()
|
||||
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8:
|
||||
return Int8DynamicActivationInt8WeightConfig()
|
||||
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4:
|
||||
return Int8DynamicActivationInt4WeightConfig()
|
||||
if weight_dtype == TorchAOQuantDType.int8:
|
||||
raise ValueError("Int8WeightOnlyConfig is not supported by torchao QAT.")
|
||||
if weight_dtype == TorchAOQuantDType.int4:
|
||||
from torchao.quantization.quant_api import Int4WeightOnlyConfig
|
||||
|
||||
if group_size is not None:
|
||||
return Int4WeightOnlyConfig(group_size=group_size, version=2)
|
||||
else:
|
||||
return Int4WeightOnlyConfig(version=2)
|
||||
if (
|
||||
activation_dtype == TorchAOQuantDType.int4
|
||||
and weight_dtype == TorchAOQuantDType.int4
|
||||
):
|
||||
raise ValueError(
|
||||
"Int4DynamicActivationInt4WeightConfig is not supported by torchao QAT."
|
||||
)
|
||||
if (
|
||||
activation_dtype == TorchAOQuantDType.int8
|
||||
and weight_dtype == TorchAOQuantDType.int8
|
||||
):
|
||||
raise ValueError(
|
||||
"Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT."
|
||||
)
|
||||
if (
|
||||
activation_dtype == TorchAOQuantDType.int8
|
||||
and weight_dtype == TorchAOQuantDType.int4
|
||||
):
|
||||
if group_size is not None:
|
||||
return Int8DynamicActivationInt4WeightConfig(group_size=group_size)
|
||||
else:
|
||||
return Int8DynamicActivationInt4WeightConfig()
|
||||
if (
|
||||
activation_dtype == TorchAOQuantDType.float8_e4m3fn
|
||||
and weight_dtype == TorchAOQuantDType.float8_e4m3fn
|
||||
):
|
||||
return Float8DynamicActivationFloat8WeightConfig()
|
||||
if (
|
||||
activation_dtype == TorchAOQuantDType.float8_e4m3fn
|
||||
and weight_dtype == TorchAOQuantDType.int4
|
||||
):
|
||||
return Float8DynamicActivationInt4WeightConfig()
|
||||
if weight_dtype == TorchAOQuantDType.nvfp4:
|
||||
from torchao.prototype.mx_formats import NVFP4InferenceConfig
|
||||
|
||||
if group_size is not None and group_size != 16:
|
||||
raise ValueError("NVFP4 quantization must use a group_size of 16")
|
||||
return NVFP4InferenceConfig()
|
||||
raise ValueError(
|
||||
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
|
||||
)
|
||||
|
||||
|
||||
def quantize_model(
|
||||
model,
|
||||
weight_dtype: TorchAOQuantDType,
|
||||
group_size: int | None = None,
|
||||
activation_dtype: TorchAOQuantDType | None = None,
|
||||
quantize_embedding: bool | None = None,
|
||||
):
|
||||
"""
|
||||
This function is used to quantize a model.
|
||||
|
||||
Args:
|
||||
model: The model to quantize.
|
||||
weight_dtype: The dtype to use for weight quantization.
|
||||
group_size: The group size to use for weight quantization.
|
||||
activation_dtype: The dtype to use for activation quantization.
|
||||
quantize_embedding: Whether to quantize the model's embedding weights.
|
||||
|
||||
"""
|
||||
linear_ptq_config = get_quantization_config(
|
||||
weight_dtype=weight_dtype,
|
||||
activation_dtype=activation_dtype,
|
||||
group_size=group_size,
|
||||
)
|
||||
quantize_(model, linear_ptq_config)
|
||||
if quantize_embedding:
|
||||
# activation fake quantization is not supported for embedding layers
|
||||
embedding_quantize_config = get_quantization_config(
|
||||
weight_dtype=weight_dtype,
|
||||
activation_dtype=None,
|
||||
group_size=group_size,
|
||||
)
|
||||
quantize_(
|
||||
model,
|
||||
embedding_quantize_config,
|
||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||
)
|
||||
|
||||
|
||||
def prepare_model_for_qat(
|
||||
model,
|
||||
weight_dtype: TorchIntDType,
|
||||
group_size: int,
|
||||
activation_dtype: TorchIntDType | None = None,
|
||||
weight_dtype: TorchAOQuantDType,
|
||||
group_size: int | None = None,
|
||||
activation_dtype: TorchAOQuantDType | None = None,
|
||||
quantize_embedding: bool = False,
|
||||
):
|
||||
"""
|
||||
@@ -100,86 +174,40 @@ def prepare_model_for_qat(
|
||||
Raises:
|
||||
ValueError: If the activation/weight dtype combination is invalid.
|
||||
"""
|
||||
if activation_dtype:
|
||||
activation_config = FakeQuantizeConfig(
|
||||
dtype=activation_dtype.value, granularity="per_token", is_symmetric=False
|
||||
)
|
||||
weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size)
|
||||
linear_quantize_config = IntXQuantizationAwareTrainingConfig(
|
||||
activation_config=None if activation_dtype is None else activation_config,
|
||||
weight_config=weight_config,
|
||||
)
|
||||
quantize_(model, linear_quantize_config)
|
||||
if quantize_embedding:
|
||||
# activation fake quantization is not supported for embedding layers
|
||||
embedding_quantize_config = IntXQuantizationAwareTrainingConfig(
|
||||
activation_config=None,
|
||||
weight_config=weight_config,
|
||||
)
|
||||
quantize_(
|
||||
model,
|
||||
embedding_quantize_config,
|
||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||
)
|
||||
|
||||
|
||||
def quantize_model_for_ptq(
|
||||
model,
|
||||
weight_dtype: TorchIntDType,
|
||||
group_size: int | None = None,
|
||||
activation_dtype: TorchIntDType | None = None,
|
||||
quantize_embedding: bool | None = None,
|
||||
):
|
||||
"""
|
||||
This function is used to quantize a model for post-training quantization.
|
||||
It swaps the model's linear layers with fake quantized linear layers.
|
||||
If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights.
|
||||
|
||||
Args:
|
||||
model: The model to quantize.
|
||||
weight_dtype: The dtype to use for weight quantization.
|
||||
group_size: The group size to use for weight quantization.
|
||||
activation_dtype: The dtype to use for activation quantization.
|
||||
quantize_embedding: Whether to quantize the model's embedding weights.
|
||||
|
||||
"""
|
||||
linear_ptq_config = get_ptq_config(
|
||||
base_config = get_quantization_config(
|
||||
weight_dtype=weight_dtype,
|
||||
activation_dtype=activation_dtype,
|
||||
group_size=group_size,
|
||||
)
|
||||
quantize_(model, linear_ptq_config)
|
||||
qat_config = QATConfig(base_config)
|
||||
quantize_(model, qat_config)
|
||||
if quantize_embedding:
|
||||
embedding_quantize_config = get_ptq_config(
|
||||
# activation fake quantization is not supported for embedding layers
|
||||
embedding_base_config = get_quantization_config(
|
||||
weight_dtype=weight_dtype,
|
||||
activation_dtype=None,
|
||||
group_size=group_size,
|
||||
)
|
||||
embedding_qat_config = QATConfig(embedding_base_config)
|
||||
quantize_(
|
||||
model,
|
||||
embedding_quantize_config,
|
||||
embedding_qat_config,
|
||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||
)
|
||||
|
||||
|
||||
def convert_qat_model_for_ptq(
|
||||
def convert_qat_model(
|
||||
model,
|
||||
*,
|
||||
quantize_embedding: bool | None = None,
|
||||
quantize_embedding: bool = False,
|
||||
):
|
||||
"""
|
||||
This function is used to convert a swap fake-quantized modules in a model
|
||||
which has been trained with QAT back to the original modules, ready for PTQ.
|
||||
|
||||
Args:
|
||||
model: The model to convert.
|
||||
quantize_embedding: Whether to quantize the model's embedding weights.
|
||||
This function converts a QAT model which has fake quantized layers back to the original model.
|
||||
"""
|
||||
config = QATConfig(step="convert")
|
||||
quantize_(model, config)
|
||||
if quantize_embedding:
|
||||
|
||||
def filter_fn(m, _):
|
||||
return isinstance(m, nn.Embedding) or _is_linear(m)
|
||||
|
||||
else:
|
||||
filter_fn = _is_linear
|
||||
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)
|
||||
quantize_(
|
||||
model,
|
||||
config,
|
||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||
)
|
||||
|
||||
@@ -5,18 +5,21 @@ from enum import Enum
|
||||
import torch
|
||||
|
||||
|
||||
class TorchIntDType(Enum):
|
||||
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
|
||||
class TorchAOQuantDType(Enum):
|
||||
int4 = torch.int4
|
||||
int8 = torch.int8
|
||||
float8_e4m3fn = torch.float8_e4m3fn
|
||||
nvfp4 = "nvfp4"
|
||||
|
||||
uint1 = getattr(torch, "uint1", None)
|
||||
uint2 = getattr(torch, "uint2", None)
|
||||
uint3 = getattr(torch, "uint3", None)
|
||||
uint4 = getattr(torch, "uint4", None)
|
||||
uint5 = getattr(torch, "uint5", None)
|
||||
uint6 = getattr(torch, "uint6", None)
|
||||
uint7 = getattr(torch, "uint7", None)
|
||||
int4 = getattr(torch, "int4", None)
|
||||
int8 = getattr(torch, "int8", None)
|
||||
def from_string(str):
|
||||
if str == "int4":
|
||||
return TorchAOQuantDType.int4
|
||||
if str == "int8":
|
||||
return TorchAOQuantDType.int8
|
||||
if str in ["float8_e4m3fn", "fp8", "float8"]:
|
||||
return TorchAOQuantDType.float8_e4m3fn
|
||||
if str == "nvfp4":
|
||||
return TorchAOQuantDType.nvfp4
|
||||
|
||||
|
||||
class RLType(str, Enum):
|
||||
|
||||
@@ -6,7 +6,23 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from axolotl.utils.schemas.enums import TorchIntDType
|
||||
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||
|
||||
|
||||
def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
|
||||
if v is None:
|
||||
return None
|
||||
if v == "int4":
|
||||
return TorchAOQuantDType.int4
|
||||
if v == "int8":
|
||||
return TorchAOQuantDType.int8
|
||||
if v in ["float8_e4m3fn", "fp8", "float8"]:
|
||||
return TorchAOQuantDType.float8_e4m3fn
|
||||
if v == "nvfp4":
|
||||
return TorchAOQuantDType.nvfp4
|
||||
raise ValueError(
|
||||
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
|
||||
)
|
||||
|
||||
|
||||
class QATConfig(BaseModel):
|
||||
@@ -14,13 +30,13 @@ class QATConfig(BaseModel):
|
||||
QAT Config Schema
|
||||
"""
|
||||
|
||||
activation_dtype: TorchIntDType | None = Field(
|
||||
activation_dtype: TorchAOQuantDType | None = Field(
|
||||
default=None,
|
||||
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
|
||||
description="Fake quantization layout to use for activation quantization.",
|
||||
)
|
||||
weight_dtype: TorchIntDType = Field(
|
||||
default=TorchIntDType.int8,
|
||||
description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"',
|
||||
weight_dtype: TorchAOQuantDType = Field(
|
||||
default=TorchAOQuantDType.int8,
|
||||
description="Fake quantization layout to use for weight quantization.",
|
||||
)
|
||||
quantize_embedding: bool | None = Field(
|
||||
default=False, description="Quantize embedding"
|
||||
@@ -35,12 +51,8 @@ class QATConfig(BaseModel):
|
||||
|
||||
@field_validator("activation_dtype", "weight_dtype", mode="before")
|
||||
@classmethod
|
||||
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
|
||||
if v == "int4":
|
||||
return TorchIntDType.int4
|
||||
if v == "int8":
|
||||
return TorchIntDType.int8
|
||||
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
|
||||
def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:
|
||||
return validate_ao_dtype(v)
|
||||
|
||||
|
||||
class PTQConfig(BaseModel):
|
||||
@@ -48,13 +60,13 @@ class PTQConfig(BaseModel):
|
||||
PTQ Config Schema
|
||||
"""
|
||||
|
||||
weight_dtype: TorchIntDType = Field(
|
||||
default=TorchIntDType.int8,
|
||||
description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8",
|
||||
weight_dtype: TorchAOQuantDType = Field(
|
||||
default=TorchAOQuantDType.int8,
|
||||
description="Fake quantization layout to use for weight quantization.",
|
||||
)
|
||||
activation_dtype: TorchIntDType | None = Field(
|
||||
activation_dtype: TorchAOQuantDType | None = Field(
|
||||
default=None,
|
||||
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
|
||||
description="Fake quantization layout to use for activation quantization.",
|
||||
)
|
||||
quantize_embedding: bool | None = Field(
|
||||
default=None, description="Whether to quantize the embedding layer."
|
||||
@@ -66,9 +78,5 @@ class PTQConfig(BaseModel):
|
||||
|
||||
@field_validator("activation_dtype", "weight_dtype", mode="before")
|
||||
@classmethod
|
||||
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
|
||||
if v == "int4":
|
||||
return TorchIntDType.int4
|
||||
if v == "int8":
|
||||
return TorchIntDType.int8
|
||||
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
|
||||
def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:
|
||||
return validate_ao_dtype(v)
|
||||
|
||||
Reference in New Issue
Block a user