add: support mxfp4 axo (#3375)

* mxfp4 axo

* import lint

* test for qat mxfp4

* config for mxfp4

* add qat:

* pass base config

* MXFakeQuantizeConfig

* lint

* tune config so it fits in 32GB VRAM

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
VED
2026-03-06 00:10:45 +05:30
committed by GitHub
parent 4b8bc52424
commit 1eaf4d7418
6 changed files with 181 additions and 2 deletions

View File

@@ -5,6 +5,7 @@ Tests for axolotl.utils.quantization
import pytest
import torch
from torch import nn
from torchao.prototype.qat import MXFakeQuantizeConfig
from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
@@ -117,6 +118,21 @@ class TestQuantization:
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
assert isinstance(config, expected_type)
@require_torch_2_8_0
@requires_sm_ge_100
def test_get_ptq_config_mxfp4(self):
config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)
assert isinstance(config, MXFakeQuantizeConfig)
assert config.block_size == 32
@require_torch_2_8_0
@requires_sm_ge_100
def test_get_ptq_config_mxfp4_invalid_group_size(self):
with pytest.raises(
ValueError, match="MXFP4 quantization must use a block_size"
):
get_quantization_config(TorchAOQuantDType.mxfp4, None, 16)
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_get_ptq_config_int4_weight_only(self):
@@ -262,6 +278,35 @@ class TestQuantization:
else:
assert child.activation_fake_quantizer is None
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding",
[
(TorchAOQuantDType.mxfp4, None, 32, False),
(TorchAOQuantDType.mxfp4, None, 32, True),
],
)
@require_torch_2_8_0
@requires_sm_ge_100
def test_prepare_model_for_qat_mxfp4(
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
):
prepare_model_for_qat(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer")
@require_torch_2_8_0
@requires_cuda_ge_8_9
def test_convert_qat_model(self, model):