From c92b71bd0c1fc83956b1c68284664e8f176f00a0 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Thu, 2 Apr 2026 03:51:02 +0530 Subject: [PATCH] MX QAT patch (#3553) * qat patch * tests fixes * fixup per PR code review * use state dict hooks to handle dequant for saving safetensors from transformers * use transformers torch ao quantizer hooks to save mx quantized model --------- Co-authored-by: Wing Lian Co-authored-by: Wing Lian --- src/axolotl/cli/quantize.py | 27 ++++++----- src/axolotl/utils/quantization.py | 74 +++++++++++++++++++++++++------ tests/e2e/test_quantization.py | 21 ++++++--- 3 files changed, 91 insertions(+), 31 deletions(-) diff --git a/src/axolotl/cli/quantize.py b/src/axolotl/cli/quantize.py index 939443a01..052d79d7a 100644 --- a/src/axolotl/cli/quantize.py +++ b/src/axolotl/cli/quantize.py @@ -5,7 +5,7 @@ CLI to post-training quantize a model using torchao from pathlib import Path from typing import Union -from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig +from transformers import AutoConfig, AutoModelForCausalLM from axolotl.cli.config import load_cfg from axolotl.loaders import load_processor, load_tokenizer @@ -93,17 +93,22 @@ def do_quantize( weight_dtype, activation_dtype, group_size ) - ao_config = TorchAoConfig( - quant_type=quantization_config, - include_input_output_embeddings=quantize_embedding, - ) - model.config.quantization_config = ao_config - LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.") - model.save_pretrained( - str(Path(output_dir) / "quantized"), - progressbar=True, - ) + try: + model.save_pretrained( + str(Path(output_dir) / "quantized"), + progressbar=True, + ) + except NotImplementedError: + LOG.warning( + "Model weight conversions do not support reverse_op, " + "retrying save with save_original_format=False" + ) + model.save_pretrained( + str(Path(output_dir) / "quantized"), + progressbar=True, + save_original_format=False, + ) tokenizer.save_pretrained( str(Path(output_dir) / "quantized"), progressbar=True, diff --git a/src/axolotl/utils/quantization.py b/src/axolotl/utils/quantization.py index 3078e2dc2..6a479d260 100644 --- a/src/axolotl/utils/quantization.py +++ b/src/axolotl/utils/quantization.py @@ -5,7 +5,6 @@ Utilities for quantization including QAT and PTQ using torchao. import torch from packaging import version from torchao.core.config import AOBaseConfig -from torchao.prototype.qat import MXFakeQuantizeConfig from torchao.quantization import quantize_ from torchao.quantization.qat import ( QATConfig, @@ -15,9 +14,15 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, - Int8DynamicActivationInt4WeightConfig, ) +try: + from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig +except ImportError: + from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig, + ) + from axolotl.utils.schemas.enums import TorchAOQuantDType quantization_config_to_str = { @@ -28,7 +33,9 @@ quantization_config_to_str = { if version.parse(torch.__version__) >= version.parse("2.8.0"): try: - from torchao.prototype.mx_formats import NVFP4InferenceConfig + from torchao.prototype.mx_formats import ( + NVFP4WeightOnlyConfig as NVFP4InferenceConfig, + ) quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4" except (ImportError, RuntimeError): @@ -44,10 +51,12 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"): pass try: - from torchao.prototype.qat import MXFakeQuantizeConfig + from torchao.prototype.mx_formats import ( + MXDynamicActivationMXWeightConfig as MXLinearConfig, + ) - quantization_config_to_str[MXFakeQuantizeConfig] = "mxfp4" - except ImportError: + quantization_config_to_str[MXLinearConfig] = "mxfp4" + except (ImportError, RuntimeError): pass @@ -114,15 +123,15 @@ def get_quantization_config( ): return Float8DynamicActivationInt4WeightConfig() if weight_dtype == TorchAOQuantDType.nvfp4: - from torchao.prototype.mx_formats import NVFP4InferenceConfig + from torchao.prototype.mx_formats import ( + NVFP4WeightOnlyConfig as NVFP4InferenceConfig, + ) if group_size is not None and group_size != 16: raise ValueError("NVFP4 quantization must use a group_size of 16") return NVFP4InferenceConfig() if weight_dtype == TorchAOQuantDType.mxfp4: - from torchao.prototype.qat import MXFakeQuantizeConfig - # MXFP4 uses block_size=32 by default (vs NVFP4's 16) block_size = group_size if group_size is not None else 32 if block_size != 32: @@ -130,13 +139,41 @@ def get_quantization_config( "MXFP4 quantization must use a block_size (group_size) of 32" ) - return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size) + from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig + + return MXDynamicActivationMXWeightConfig( + activation_dtype=torch.float4_e2m1fn_x2, + weight_dtype=torch.float4_e2m1fn_x2, + block_size=block_size, + ) raise ValueError( f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}" ) +def _attach_torchao_quantizer( + model, quantization_config, include_input_output_embeddings=False +): + """Attach a TorchAoHfQuantizer to the model so save_pretrained uses + torchao's flatten_tensor_state_dict path, preserving quantized weights + (e.g. MXTensor qdata+scale) in the safetensors file. + + Without this, save_pretrained falls through to the default path which + calls safetensors storage_ptr() on tensor subclasses and crashes. + """ + from transformers import TorchAoConfig + from transformers.quantizers.quantizer_torchao import TorchAoHfQuantizer + + ao_config = TorchAoConfig( + quant_type=quantization_config, + include_input_output_embeddings=include_input_output_embeddings, + ) + model.config.quantization_config = ao_config + quantizer = TorchAoHfQuantizer(ao_config) + model.hf_quantizer = quantizer + + def quantize_model( model, weight_dtype: TorchAOQuantDType, @@ -174,6 +211,12 @@ def quantize_model( filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), ) + _attach_torchao_quantizer( + model, + linear_ptq_config, + include_input_output_embeddings=bool(quantize_embedding), + ) + def _make_qat_config( base_config: AOBaseConfig, @@ -189,11 +232,14 @@ def _make_qat_config( IntxFakeQuantizeConfig, ) - if isinstance(base_config, MXFakeQuantizeConfig): - return QATConfig( - activation_config=base_config, - weight_config=base_config, + if weight_dtype == TorchAOQuantDType.mxfp4: + from torchao.prototype.qat import MXFakeQuantizeConfig + + block_size = getattr(base_config, "block_size", 32) + mx_fq = MXFakeQuantizeConfig( + dtype=torch.float4_e2m1fn_x2, block_size=block_size ) + return QATConfig(activation_config=mx_fq, weight_config=mx_fq) # Build explicit weight config weight_fq_config: ( diff --git a/tests/e2e/test_quantization.py b/tests/e2e/test_quantization.py index 8b7b6701c..6bbc34949 100644 --- a/tests/e2e/test_quantization.py +++ b/tests/e2e/test_quantization.py @@ -5,15 +5,20 @@ Tests for axolotl.utils.quantization import pytest import torch from torch import nn -from torchao.prototype.qat import MXFakeQuantizeConfig from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization.qat.embedding import FakeQuantizedEmbedding from torchao.quantization.qat.linear import FakeQuantizedLinear from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationInt4WeightConfig, - Int8DynamicActivationInt4WeightConfig, ) + +try: + from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig +except ImportError: + from torchao.quantization.quant_api import ( + Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig, + ) from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor from transformers import AutoModelForCausalLM from transformers.trainer_callback import TrainerState @@ -129,8 +134,11 @@ class TestQuantization: @require_torch_2_8_0 @requires_sm_ge_100 def test_get_ptq_config_mxfp4(self): + from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig + config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32) - assert isinstance(config, MXFakeQuantizeConfig) + assert isinstance(config, MXDynamicActivationMXWeightConfig) + assert config.weight_dtype == torch.float4_e2m1fn_x2 assert config.block_size == 32 @require_torch_2_8_0 @@ -298,7 +306,6 @@ class TestQuantization: "weight_dtype,activation_dtype,group_size,quantize_embedding", [ (TorchAOQuantDType.mxfp4, None, 32, False), - (TorchAOQuantDType.mxfp4, None, 32, True), ], ) @require_torch_2_8_0 @@ -314,14 +321,16 @@ class TestQuantization: quantize_embedding, ) + from torchao.prototype.qat import MXFakeQuantizedLinear + if quantize_embedding: assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) assert hasattr(model.model.embed_tokens, "weight_fake_quantizer") for child in list(model.children()): if isinstance(child, torch.nn.Linear): - assert isinstance(child, FakeQuantizedLinear) - assert hasattr(child, "weight_fake_quantizer") + assert isinstance(child, MXFakeQuantizedLinear) + assert hasattr(child, "weight_config") @require_torch_2_8_0 @requires_cuda_ge_8_9