Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 (#3107)

This commit is contained in:
salman
2025-09-12 10:55:50 +01:00
committed by GitHub
parent 0401a15888
commit 58d67bf98d
16 changed files with 554 additions and 339 deletions

View File

@@ -115,6 +115,7 @@ class QuantizeCliArgs:
quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None)
hub_model_id: Optional[str] = field(default=None)
@dataclass

View File

@@ -5,12 +5,17 @@ CLI to post-training quantize a model using torchao
from pathlib import Path
from typing import Union
from transformers import AutoModelForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
from axolotl.utils.quantization import (
TorchAOQuantDType,
get_quantization_config,
quantization_config_to_str,
quantize_model,
)
LOG = get_logger(__name__)
@@ -43,13 +48,13 @@ def do_quantize(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("model_path") or cfg.output_dir
model_path = cli_args.get("base_model") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype]
weight_dtype = TorchAOQuantDType.from_string(weight_dtype)
else:
weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchIntDType[activation_dtype]
activation_dtype = TorchAOQuantDType.from_string(activation_dtype)
else:
activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size
@@ -57,10 +62,15 @@ def do_quantize(
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
)
output_dir = cli_args.get("output_dir") or cfg.output_dir
hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id
LOG.info(f"Loading model from {model_path}...")
LOG.info(f"Loading model from {model_path}.")
tokenizer = load_tokenizer(cfg)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", torch_dtype=torch_dtype
)
LOG.info(
f"Quantizing model with configuration: \n"
@@ -70,11 +80,21 @@ def do_quantize(
f"\tquantize_embedding: {quantize_embedding}"
)
quantize_model_for_ptq(
quantize_model(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
quantization_config = get_quantization_config(
weight_dtype, activation_dtype, group_size
)
ao_config = TorchAoConfig(
quant_type=quantization_config,
include_input_output_embeddings=quantize_embedding,
)
model.config.quantization_config = ao_config
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
@@ -86,4 +106,14 @@ def do_quantize(
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
if hub_model_id:
hub_model_id = (
hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}"
)
model.push_to_hub(hub_model_id, safe_serialization=False)
tokenizer.push_to_hub(hub_model_id)
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")

View File

@@ -30,11 +30,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders import (
ModelLoader,
load_processor,
load_tokenizer,
)
from axolotl.loaders import ModelLoader, load_processor, load_tokenizer
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
@@ -234,16 +230,15 @@ def save_trained_model(
# handle QAT
if cfg.qat:
from axolotl.utils.quantization import convert_qat_model_for_ptq
from axolotl.utils.quantization import convert_qat_model
LOG.info("Processing QAT model for saving...")
convert_qat_model_for_ptq(
convert_qat_model(
model,
quantize_embedding=cfg.qat.quantize_embedding,
)
LOG.info(
"QAT modules have been converted for PTQ. Please ensure you quantize "
"your model weights with `axolotl quantize`."
"QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`"
" with the same config which you used for training."
)
# Handle ReLoRA early return case
if cfg.relora:
@@ -337,9 +332,7 @@ def save_trained_model(
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
from axolotl.integrations.llm_compressor.utils import save_compressed_model
save_compressed_model(
model=model,

View File

@@ -3,30 +3,47 @@ Utilities for quantization including QAT and PTQ using torchao.
"""
import torch
from torch import nn
from packaging import version
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.quantization.qat import (
FakeQuantizeConfig,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
QATConfig,
)
from torchao.quantization.quant_api import (
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
_is_linear,
)
from axolotl.utils.schemas.enums import TorchIntDType
from axolotl.utils.schemas.enums import TorchAOQuantDType
quantization_config_to_str = {
Int8DynamicActivationInt4WeightConfig: "int8int4",
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
Float8DynamicActivationInt4WeightConfig: "fp8int4",
}
if version.parse(torch.__version__) >= version.parse("2.8.0"):
try:
from torchao.prototype.mx_formats import NVFP4InferenceConfig
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
except:
pass
# int4 weight config imports will fail on machines with fbgemm-gpu installed
# without a CUDA runtime available so we do this safely
try:
from torchao.quantization.quant_api import Int4WeightOnlyConfig
quantization_config_to_str[Int4WeightOnlyConfig] = "int4"
except:
pass
def get_ptq_config(
weight_dtype: TorchIntDType,
activation_dtype: TorchIntDType | None = None,
def get_quantization_config(
weight_dtype: TorchAOQuantDType,
activation_dtype: TorchAOQuantDType | None = None,
group_size: int | None = None,
) -> AOBaseConfig:
"""
@@ -45,44 +62,101 @@ def get_ptq_config(
or if the group size is not specified for int8 or int4 weight only quantization.
"""
if activation_dtype is None:
if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr]
return UIntXWeightOnlyConfig(
dtype=weight_dtype.value,
group_size=group_size,
set_inductor_config=False,
)
if weight_dtype == TorchIntDType.int8:
if group_size is None:
raise ValueError(
"group_size must be specified for int8 weight only quantization"
)
return Int8WeightOnlyConfig(
group_size=group_size,
)
if weight_dtype == TorchIntDType.int4:
if group_size is None:
raise ValueError(
"group_size must be specified for int4 weight only quantization"
)
return Int4WeightOnlyConfig(
group_size=group_size,
)
if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4:
return Int4DynamicActivationInt4WeightConfig()
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8:
return Int8DynamicActivationInt8WeightConfig()
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4:
return Int8DynamicActivationInt4WeightConfig()
if weight_dtype == TorchAOQuantDType.int8:
raise ValueError("Int8WeightOnlyConfig is not supported by torchao QAT.")
if weight_dtype == TorchAOQuantDType.int4:
from torchao.quantization.quant_api import Int4WeightOnlyConfig
if group_size is not None:
return Int4WeightOnlyConfig(group_size=group_size, version=2)
else:
return Int4WeightOnlyConfig(version=2)
if (
activation_dtype == TorchAOQuantDType.int4
and weight_dtype == TorchAOQuantDType.int4
):
raise ValueError(
"Int4DynamicActivationInt4WeightConfig is not supported by torchao QAT."
)
if (
activation_dtype == TorchAOQuantDType.int8
and weight_dtype == TorchAOQuantDType.int8
):
raise ValueError(
"Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT."
)
if (
activation_dtype == TorchAOQuantDType.int8
and weight_dtype == TorchAOQuantDType.int4
):
if group_size is not None:
return Int8DynamicActivationInt4WeightConfig(group_size=group_size)
else:
return Int8DynamicActivationInt4WeightConfig()
if (
activation_dtype == TorchAOQuantDType.float8_e4m3fn
and weight_dtype == TorchAOQuantDType.float8_e4m3fn
):
return Float8DynamicActivationFloat8WeightConfig()
if (
activation_dtype == TorchAOQuantDType.float8_e4m3fn
and weight_dtype == TorchAOQuantDType.int4
):
return Float8DynamicActivationInt4WeightConfig()
if weight_dtype == TorchAOQuantDType.nvfp4:
from torchao.prototype.mx_formats import NVFP4InferenceConfig
if group_size is not None and group_size != 16:
raise ValueError("NVFP4 quantization must use a group_size of 16")
return NVFP4InferenceConfig()
raise ValueError(
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
)
def quantize_model(
model,
weight_dtype: TorchAOQuantDType,
group_size: int | None = None,
activation_dtype: TorchAOQuantDType | None = None,
quantize_embedding: bool | None = None,
):
"""
This function is used to quantize a model.
Args:
model: The model to quantize.
weight_dtype: The dtype to use for weight quantization.
group_size: The group size to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
linear_ptq_config = get_quantization_config(
weight_dtype=weight_dtype,
activation_dtype=activation_dtype,
group_size=group_size,
)
quantize_(model, linear_ptq_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
embedding_quantize_config = get_quantization_config(
weight_dtype=weight_dtype,
activation_dtype=None,
group_size=group_size,
)
quantize_(
model,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def prepare_model_for_qat(
model,
weight_dtype: TorchIntDType,
group_size: int,
activation_dtype: TorchIntDType | None = None,
weight_dtype: TorchAOQuantDType,
group_size: int | None = None,
activation_dtype: TorchAOQuantDType | None = None,
quantize_embedding: bool = False,
):
"""
@@ -100,86 +174,40 @@ def prepare_model_for_qat(
Raises:
ValueError: If the activation/weight dtype combination is invalid.
"""
if activation_dtype:
activation_config = FakeQuantizeConfig(
dtype=activation_dtype.value, granularity="per_token", is_symmetric=False
)
weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size)
linear_quantize_config = IntXQuantizationAwareTrainingConfig(
activation_config=None if activation_dtype is None else activation_config,
weight_config=weight_config,
)
quantize_(model, linear_quantize_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
embedding_quantize_config = IntXQuantizationAwareTrainingConfig(
activation_config=None,
weight_config=weight_config,
)
quantize_(
model,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def quantize_model_for_ptq(
model,
weight_dtype: TorchIntDType,
group_size: int | None = None,
activation_dtype: TorchIntDType | None = None,
quantize_embedding: bool | None = None,
):
"""
This function is used to quantize a model for post-training quantization.
It swaps the model's linear layers with fake quantized linear layers.
If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights.
Args:
model: The model to quantize.
weight_dtype: The dtype to use for weight quantization.
group_size: The group size to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
linear_ptq_config = get_ptq_config(
base_config = get_quantization_config(
weight_dtype=weight_dtype,
activation_dtype=activation_dtype,
group_size=group_size,
)
quantize_(model, linear_ptq_config)
qat_config = QATConfig(base_config)
quantize_(model, qat_config)
if quantize_embedding:
embedding_quantize_config = get_ptq_config(
# activation fake quantization is not supported for embedding layers
embedding_base_config = get_quantization_config(
weight_dtype=weight_dtype,
activation_dtype=None,
group_size=group_size,
)
embedding_qat_config = QATConfig(embedding_base_config)
quantize_(
model,
embedding_quantize_config,
embedding_qat_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def convert_qat_model_for_ptq(
def convert_qat_model(
model,
*,
quantize_embedding: bool | None = None,
quantize_embedding: bool = False,
):
"""
This function is used to convert a swap fake-quantized modules in a model
which has been trained with QAT back to the original modules, ready for PTQ.
Args:
model: The model to convert.
quantize_embedding: Whether to quantize the model's embedding weights.
This function converts a QAT model which has fake quantized layers back to the original model.
"""
config = QATConfig(step="convert")
quantize_(model, config)
if quantize_embedding:
def filter_fn(m, _):
return isinstance(m, nn.Embedding) or _is_linear(m)
else:
filter_fn = _is_linear
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)
quantize_(
model,
config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)

View File

@@ -5,18 +5,21 @@ from enum import Enum
import torch
class TorchIntDType(Enum):
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
class TorchAOQuantDType(Enum):
int4 = torch.int4
int8 = torch.int8
float8_e4m3fn = torch.float8_e4m3fn
nvfp4 = "nvfp4"
uint1 = getattr(torch, "uint1", None)
uint2 = getattr(torch, "uint2", None)
uint3 = getattr(torch, "uint3", None)
uint4 = getattr(torch, "uint4", None)
uint5 = getattr(torch, "uint5", None)
uint6 = getattr(torch, "uint6", None)
uint7 = getattr(torch, "uint7", None)
int4 = getattr(torch, "int4", None)
int8 = getattr(torch, "int8", None)
def from_string(str):
if str == "int4":
return TorchAOQuantDType.int4
if str == "int8":
return TorchAOQuantDType.int8
if str in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if str == "nvfp4":
return TorchAOQuantDType.nvfp4
class RLType(str, Enum):

View File

@@ -6,7 +6,23 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator
from axolotl.utils.schemas.enums import TorchIntDType
from axolotl.utils.schemas.enums import TorchAOQuantDType
def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
if v is None:
return None
if v == "int4":
return TorchAOQuantDType.int4
if v == "int8":
return TorchAOQuantDType.int8
if v in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if v == "nvfp4":
return TorchAOQuantDType.nvfp4
raise ValueError(
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
)
class QATConfig(BaseModel):
@@ -14,13 +30,13 @@ class QATConfig(BaseModel):
QAT Config Schema
"""
activation_dtype: TorchIntDType | None = Field(
activation_dtype: TorchAOQuantDType | None = Field(
default=None,
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
description="Fake quantization layout to use for activation quantization.",
)
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8,
description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"',
weight_dtype: TorchAOQuantDType = Field(
default=TorchAOQuantDType.int8,
description="Fake quantization layout to use for weight quantization.",
)
quantize_embedding: bool | None = Field(
default=False, description="Quantize embedding"
@@ -35,12 +51,8 @@ class QATConfig(BaseModel):
@field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
if v == "int4":
return TorchIntDType.int4
if v == "int8":
return TorchIntDType.int8
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:
return validate_ao_dtype(v)
class PTQConfig(BaseModel):
@@ -48,13 +60,13 @@ class PTQConfig(BaseModel):
PTQ Config Schema
"""
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8,
description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8",
weight_dtype: TorchAOQuantDType = Field(
default=TorchAOQuantDType.int8,
description="Fake quantization layout to use for weight quantization.",
)
activation_dtype: TorchIntDType | None = Field(
activation_dtype: TorchAOQuantDType | None = Field(
default=None,
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
description="Fake quantization layout to use for activation quantization.",
)
quantize_embedding: bool | None = Field(
default=None, description="Whether to quantize the embedding layer."
@@ -66,9 +78,5 @@ class PTQConfig(BaseModel):
@field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
if v == "int4":
return TorchIntDType.int4
if v == "int8":
return TorchIntDType.int8
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:
return validate_ao_dtype(v)