add: support mxfp4 axo (#3375)
* mxfp4 axo * import lint * test for qat mxfp4 * config for mxfp4 * add qat: * pass base config * MXFakeQuantizeConfig * lint * tune config so it fits in 32GB VRAM --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -5,6 +5,7 @@ Tests for axolotl.utils.quantization
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||
from torchao.quantization import LinearActivationQuantizedTensor
|
||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||
@@ -117,6 +118,21 @@ class TestQuantization:
|
||||
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
|
||||
assert isinstance(config, expected_type)
|
||||
|
||||
@require_torch_2_8_0
|
||||
@requires_sm_ge_100
|
||||
def test_get_ptq_config_mxfp4(self):
|
||||
config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)
|
||||
assert isinstance(config, MXFakeQuantizeConfig)
|
||||
assert config.block_size == 32
|
||||
|
||||
@require_torch_2_8_0
|
||||
@requires_sm_ge_100
|
||||
def test_get_ptq_config_mxfp4_invalid_group_size(self):
|
||||
with pytest.raises(
|
||||
ValueError, match="MXFP4 quantization must use a block_size"
|
||||
):
|
||||
get_quantization_config(TorchAOQuantDType.mxfp4, None, 16)
|
||||
|
||||
@requires_cuda_ge_8_9
|
||||
@require_torch_2_8_0
|
||||
def test_get_ptq_config_int4_weight_only(self):
|
||||
@@ -262,6 +278,35 @@ class TestQuantization:
|
||||
else:
|
||||
assert child.activation_fake_quantizer is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype,activation_dtype,group_size,quantize_embedding",
|
||||
[
|
||||
(TorchAOQuantDType.mxfp4, None, 32, False),
|
||||
(TorchAOQuantDType.mxfp4, None, 32, True),
|
||||
],
|
||||
)
|
||||
@require_torch_2_8_0
|
||||
@requires_sm_ge_100
|
||||
def test_prepare_model_for_qat_mxfp4(
|
||||
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
|
||||
):
|
||||
prepare_model_for_qat(
|
||||
model,
|
||||
weight_dtype,
|
||||
group_size,
|
||||
activation_dtype,
|
||||
quantize_embedding,
|
||||
)
|
||||
|
||||
if quantize_embedding:
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
||||
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child, FakeQuantizedLinear)
|
||||
assert hasattr(child, "weight_fake_quantizer")
|
||||
|
||||
@require_torch_2_8_0
|
||||
@requires_cuda_ge_8_9
|
||||
def test_convert_qat_model(self, model):
|
||||
|
||||
Reference in New Issue
Block a user