MX QAT patch (#3553)

* qat patch

* tests fixes

* fixup per PR code review

* use state dict hooks to handle dequant for saving safetensors from transformers

* use transformers torch ao quantizer hooks to save mx quantized model

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
VED
2026-04-02 03:51:02 +05:30
committed by GitHub
parent 6c92b5c31c
commit c92b71bd0c
3 changed files with 91 additions and 31 deletions

View File

@@ -5,7 +5,7 @@ CLI to post-training quantize a model using torchao
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig from transformers import AutoConfig, AutoModelForCausalLM
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.loaders import load_processor, load_tokenizer from axolotl.loaders import load_processor, load_tokenizer
@@ -93,17 +93,22 @@ def do_quantize(
weight_dtype, activation_dtype, group_size weight_dtype, activation_dtype, group_size
) )
ao_config = TorchAoConfig(
quant_type=quantization_config,
include_input_output_embeddings=quantize_embedding,
)
model.config.quantization_config = ao_config
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.") LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained( try:
str(Path(output_dir) / "quantized"), model.save_pretrained(
progressbar=True, str(Path(output_dir) / "quantized"),
) progressbar=True,
)
except NotImplementedError:
LOG.warning(
"Model weight conversions do not support reverse_op, "
"retrying save with save_original_format=False"
)
model.save_pretrained(
str(Path(output_dir) / "quantized"),
progressbar=True,
save_original_format=False,
)
tokenizer.save_pretrained( tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"), str(Path(output_dir) / "quantized"),
progressbar=True, progressbar=True,

View File

@@ -5,7 +5,6 @@ Utilities for quantization including QAT and PTQ using torchao.
import torch import torch
from packaging import version from packaging import version
from torchao.core.config import AOBaseConfig from torchao.core.config import AOBaseConfig
from torchao.prototype.qat import MXFakeQuantizeConfig
from torchao.quantization import quantize_ from torchao.quantization import quantize_
from torchao.quantization.qat import ( from torchao.quantization.qat import (
QATConfig, QATConfig,
@@ -15,9 +14,15 @@ from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig, Float8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig, Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
) )
try:
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
except ImportError:
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
)
from axolotl.utils.schemas.enums import TorchAOQuantDType from axolotl.utils.schemas.enums import TorchAOQuantDType
quantization_config_to_str = { quantization_config_to_str = {
@@ -28,7 +33,9 @@ quantization_config_to_str = {
if version.parse(torch.__version__) >= version.parse("2.8.0"): if version.parse(torch.__version__) >= version.parse("2.8.0"):
try: try:
from torchao.prototype.mx_formats import NVFP4InferenceConfig from torchao.prototype.mx_formats import (
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
)
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4" quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
except (ImportError, RuntimeError): except (ImportError, RuntimeError):
@@ -44,10 +51,12 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
pass pass
try: try:
from torchao.prototype.qat import MXFakeQuantizeConfig from torchao.prototype.mx_formats import (
MXDynamicActivationMXWeightConfig as MXLinearConfig,
)
quantization_config_to_str[MXFakeQuantizeConfig] = "mxfp4" quantization_config_to_str[MXLinearConfig] = "mxfp4"
except ImportError: except (ImportError, RuntimeError):
pass pass
@@ -114,15 +123,15 @@ def get_quantization_config(
): ):
return Float8DynamicActivationInt4WeightConfig() return Float8DynamicActivationInt4WeightConfig()
if weight_dtype == TorchAOQuantDType.nvfp4: if weight_dtype == TorchAOQuantDType.nvfp4:
from torchao.prototype.mx_formats import NVFP4InferenceConfig from torchao.prototype.mx_formats import (
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
)
if group_size is not None and group_size != 16: if group_size is not None and group_size != 16:
raise ValueError("NVFP4 quantization must use a group_size of 16") raise ValueError("NVFP4 quantization must use a group_size of 16")
return NVFP4InferenceConfig() return NVFP4InferenceConfig()
if weight_dtype == TorchAOQuantDType.mxfp4: if weight_dtype == TorchAOQuantDType.mxfp4:
from torchao.prototype.qat import MXFakeQuantizeConfig
# MXFP4 uses block_size=32 by default (vs NVFP4's 16) # MXFP4 uses block_size=32 by default (vs NVFP4's 16)
block_size = group_size if group_size is not None else 32 block_size = group_size if group_size is not None else 32
if block_size != 32: if block_size != 32:
@@ -130,13 +139,41 @@ def get_quantization_config(
"MXFP4 quantization must use a block_size (group_size) of 32" "MXFP4 quantization must use a block_size (group_size) of 32"
) )
return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size) from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
return MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float4_e2m1fn_x2,
weight_dtype=torch.float4_e2m1fn_x2,
block_size=block_size,
)
raise ValueError( raise ValueError(
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}" f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
) )
def _attach_torchao_quantizer(
model, quantization_config, include_input_output_embeddings=False
):
"""Attach a TorchAoHfQuantizer to the model so save_pretrained uses
torchao's flatten_tensor_state_dict path, preserving quantized weights
(e.g. MXTensor qdata+scale) in the safetensors file.
Without this, save_pretrained falls through to the default path which
calls safetensors storage_ptr() on tensor subclasses and crashes.
"""
from transformers import TorchAoConfig
from transformers.quantizers.quantizer_torchao import TorchAoHfQuantizer
ao_config = TorchAoConfig(
quant_type=quantization_config,
include_input_output_embeddings=include_input_output_embeddings,
)
model.config.quantization_config = ao_config
quantizer = TorchAoHfQuantizer(ao_config)
model.hf_quantizer = quantizer
def quantize_model( def quantize_model(
model, model,
weight_dtype: TorchAOQuantDType, weight_dtype: TorchAOQuantDType,
@@ -174,6 +211,12 @@ def quantize_model(
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
) )
_attach_torchao_quantizer(
model,
linear_ptq_config,
include_input_output_embeddings=bool(quantize_embedding),
)
def _make_qat_config( def _make_qat_config(
base_config: AOBaseConfig, base_config: AOBaseConfig,
@@ -189,11 +232,14 @@ def _make_qat_config(
IntxFakeQuantizeConfig, IntxFakeQuantizeConfig,
) )
if isinstance(base_config, MXFakeQuantizeConfig): if weight_dtype == TorchAOQuantDType.mxfp4:
return QATConfig( from torchao.prototype.qat import MXFakeQuantizeConfig
activation_config=base_config,
weight_config=base_config, block_size = getattr(base_config, "block_size", 32)
mx_fq = MXFakeQuantizeConfig(
dtype=torch.float4_e2m1fn_x2, block_size=block_size
) )
return QATConfig(activation_config=mx_fq, weight_config=mx_fq)
# Build explicit weight config # Build explicit weight config
weight_fq_config: ( weight_fq_config: (

View File

@@ -5,15 +5,20 @@ Tests for axolotl.utils.quantization
import pytest import pytest
import torch import torch
from torch import nn from torch import nn
from torchao.prototype.qat import MXFakeQuantizeConfig
from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear from torchao.quantization.qat.linear import FakeQuantizedLinear
from torchao.quantization.quant_api import ( from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig, Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt4WeightConfig,
) )
try:
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
except ImportError:
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from transformers.trainer_callback import TrainerState from transformers.trainer_callback import TrainerState
@@ -129,8 +134,11 @@ class TestQuantization:
@require_torch_2_8_0 @require_torch_2_8_0
@requires_sm_ge_100 @requires_sm_ge_100
def test_get_ptq_config_mxfp4(self): def test_get_ptq_config_mxfp4(self):
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32) config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)
assert isinstance(config, MXFakeQuantizeConfig) assert isinstance(config, MXDynamicActivationMXWeightConfig)
assert config.weight_dtype == torch.float4_e2m1fn_x2
assert config.block_size == 32 assert config.block_size == 32
@require_torch_2_8_0 @require_torch_2_8_0
@@ -298,7 +306,6 @@ class TestQuantization:
"weight_dtype,activation_dtype,group_size,quantize_embedding", "weight_dtype,activation_dtype,group_size,quantize_embedding",
[ [
(TorchAOQuantDType.mxfp4, None, 32, False), (TorchAOQuantDType.mxfp4, None, 32, False),
(TorchAOQuantDType.mxfp4, None, 32, True),
], ],
) )
@require_torch_2_8_0 @require_torch_2_8_0
@@ -314,14 +321,16 @@ class TestQuantization:
quantize_embedding, quantize_embedding,
) )
from torchao.prototype.qat import MXFakeQuantizedLinear
if quantize_embedding: if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer") assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
for child in list(model.children()): for child in list(model.children()):
if isinstance(child, torch.nn.Linear): if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear) assert isinstance(child, MXFakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer") assert hasattr(child, "weight_config")
@require_torch_2_8_0 @require_torch_2_8_0
@requires_cuda_ge_8_9 @requires_cuda_ge_8_9