MX QAT patch (#3553)

* qat patch

* tests fixes

* fixup per PR code review

* use state dict hooks to handle dequant for saving safetensors from transformers

* use transformers torch ao quantizer hooks to save mx quantized model

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
VED
2026-04-02 03:51:02 +05:30
committed by GitHub
parent 6c92b5c31c
commit c92b71bd0c
3 changed files with 91 additions and 31 deletions

View File

@@ -5,7 +5,7 @@ CLI to post-training quantize a model using torchao
from pathlib import Path
from typing import Union
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
from transformers import AutoConfig, AutoModelForCausalLM
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_processor, load_tokenizer
@@ -93,17 +93,22 @@ def do_quantize(
weight_dtype, activation_dtype, group_size
)
ao_config = TorchAoConfig(
quant_type=quantization_config,
include_input_output_embeddings=quantize_embedding,
)
model.config.quantization_config = ao_config
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
progressbar=True,
)
try:
model.save_pretrained(
str(Path(output_dir) / "quantized"),
progressbar=True,
)
except NotImplementedError:
LOG.warning(
"Model weight conversions do not support reverse_op, "
"retrying save with save_original_format=False"
)
model.save_pretrained(
str(Path(output_dir) / "quantized"),
progressbar=True,
save_original_format=False,
)
tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"),
progressbar=True,

View File

@@ -5,7 +5,6 @@ Utilities for quantization including QAT and PTQ using torchao.
import torch
from packaging import version
from torchao.core.config import AOBaseConfig
from torchao.prototype.qat import MXFakeQuantizeConfig
from torchao.quantization import quantize_
from torchao.quantization.qat import (
QATConfig,
@@ -15,9 +14,15 @@ from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
)
try:
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
except ImportError:
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
)
from axolotl.utils.schemas.enums import TorchAOQuantDType
quantization_config_to_str = {
@@ -28,7 +33,9 @@ quantization_config_to_str = {
if version.parse(torch.__version__) >= version.parse("2.8.0"):
try:
from torchao.prototype.mx_formats import NVFP4InferenceConfig
from torchao.prototype.mx_formats import (
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
)
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
except (ImportError, RuntimeError):
@@ -44,10 +51,12 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
pass
try:
from torchao.prototype.qat import MXFakeQuantizeConfig
from torchao.prototype.mx_formats import (
MXDynamicActivationMXWeightConfig as MXLinearConfig,
)
quantization_config_to_str[MXFakeQuantizeConfig] = "mxfp4"
except ImportError:
quantization_config_to_str[MXLinearConfig] = "mxfp4"
except (ImportError, RuntimeError):
pass
@@ -114,15 +123,15 @@ def get_quantization_config(
):
return Float8DynamicActivationInt4WeightConfig()
if weight_dtype == TorchAOQuantDType.nvfp4:
from torchao.prototype.mx_formats import NVFP4InferenceConfig
from torchao.prototype.mx_formats import (
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
)
if group_size is not None and group_size != 16:
raise ValueError("NVFP4 quantization must use a group_size of 16")
return NVFP4InferenceConfig()
if weight_dtype == TorchAOQuantDType.mxfp4:
from torchao.prototype.qat import MXFakeQuantizeConfig
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)
block_size = group_size if group_size is not None else 32
if block_size != 32:
@@ -130,13 +139,41 @@ def get_quantization_config(
"MXFP4 quantization must use a block_size (group_size) of 32"
)
return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size)
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
return MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float4_e2m1fn_x2,
weight_dtype=torch.float4_e2m1fn_x2,
block_size=block_size,
)
raise ValueError(
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
)
def _attach_torchao_quantizer(
model, quantization_config, include_input_output_embeddings=False
):
"""Attach a TorchAoHfQuantizer to the model so save_pretrained uses
torchao's flatten_tensor_state_dict path, preserving quantized weights
(e.g. MXTensor qdata+scale) in the safetensors file.
Without this, save_pretrained falls through to the default path which
calls safetensors storage_ptr() on tensor subclasses and crashes.
"""
from transformers import TorchAoConfig
from transformers.quantizers.quantizer_torchao import TorchAoHfQuantizer
ao_config = TorchAoConfig(
quant_type=quantization_config,
include_input_output_embeddings=include_input_output_embeddings,
)
model.config.quantization_config = ao_config
quantizer = TorchAoHfQuantizer(ao_config)
model.hf_quantizer = quantizer
def quantize_model(
model,
weight_dtype: TorchAOQuantDType,
@@ -174,6 +211,12 @@ def quantize_model(
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
_attach_torchao_quantizer(
model,
linear_ptq_config,
include_input_output_embeddings=bool(quantize_embedding),
)
def _make_qat_config(
base_config: AOBaseConfig,
@@ -189,11 +232,14 @@ def _make_qat_config(
IntxFakeQuantizeConfig,
)
if isinstance(base_config, MXFakeQuantizeConfig):
return QATConfig(
activation_config=base_config,
weight_config=base_config,
if weight_dtype == TorchAOQuantDType.mxfp4:
from torchao.prototype.qat import MXFakeQuantizeConfig
block_size = getattr(base_config, "block_size", 32)
mx_fq = MXFakeQuantizeConfig(
dtype=torch.float4_e2m1fn_x2, block_size=block_size
)
return QATConfig(activation_config=mx_fq, weight_config=mx_fq)
# Build explicit weight config
weight_fq_config: (

View File

@@ -5,15 +5,20 @@ Tests for axolotl.utils.quantization
import pytest
import torch
from torch import nn
from torchao.prototype.qat import MXFakeQuantizeConfig
from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt4WeightConfig,
)
try:
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
except ImportError:
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
from transformers import AutoModelForCausalLM
from transformers.trainer_callback import TrainerState
@@ -129,8 +134,11 @@ class TestQuantization:
@require_torch_2_8_0
@requires_sm_ge_100
def test_get_ptq_config_mxfp4(self):
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)
assert isinstance(config, MXFakeQuantizeConfig)
assert isinstance(config, MXDynamicActivationMXWeightConfig)
assert config.weight_dtype == torch.float4_e2m1fn_x2
assert config.block_size == 32
@require_torch_2_8_0
@@ -298,7 +306,6 @@ class TestQuantization:
"weight_dtype,activation_dtype,group_size,quantize_embedding",
[
(TorchAOQuantDType.mxfp4, None, 32, False),
(TorchAOQuantDType.mxfp4, None, 32, True),
],
)
@require_torch_2_8_0
@@ -314,14 +321,16 @@ class TestQuantization:
quantize_embedding,
)
from torchao.prototype.qat import MXFakeQuantizedLinear
if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer")
assert isinstance(child, MXFakeQuantizedLinear)
assert hasattr(child, "weight_config")
@require_torch_2_8_0
@requires_cuda_ge_8_9