add: support mxfp4 axo (#3375)

* mxfp4 axo

* import lint

* test for qat mxfp4

* config for mxfp4

* add qat:

* pass base config

* MXFakeQuantizeConfig

* lint

* tune config so it fits in 32GB VRAM

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
VED
2026-03-06 00:10:45 +05:30
committed by GitHub
parent 4b8bc52424
commit 1eaf4d7418
6 changed files with 181 additions and 2 deletions

View File

@@ -5,6 +5,7 @@ Utilities for quantization including QAT and PTQ using torchao.
import torch
from packaging import version
from torchao.core.config import AOBaseConfig
from torchao.prototype.qat import MXFakeQuantizeConfig
from torchao.quantization import quantize_
from torchao.quantization.qat import (
QATConfig,
@@ -40,6 +41,13 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
except:
pass
try:
from torchao.prototype.qat import MXFakeQuantizeConfig
quantization_config_to_str[MXFakeQuantizeConfig] = "mxfp4"
except ImportError:
pass
def get_quantization_config(
weight_dtype: TorchAOQuantDType,
@@ -109,6 +117,19 @@ def get_quantization_config(
if group_size is not None and group_size != 16:
raise ValueError("NVFP4 quantization must use a group_size of 16")
return NVFP4InferenceConfig()
if weight_dtype == TorchAOQuantDType.mxfp4:
from torchao.prototype.qat import MXFakeQuantizeConfig
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)
block_size = group_size if group_size is not None else 32
if block_size != 32:
raise ValueError(
"MXFP4 quantization must use a block_size (group_size) of 32"
)
return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size)
raise ValueError(
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
)
@@ -179,7 +200,13 @@ def prepare_model_for_qat(
activation_dtype=activation_dtype,
group_size=group_size,
)
qat_config = QATConfig(base_config)
if isinstance(base_config, MXFakeQuantizeConfig):
qat_config = QATConfig(
activation_config=base_config,
weight_config=base_config,
)
else:
qat_config = QATConfig(base_config)
quantize_(model, qat_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
@@ -188,7 +215,12 @@ def prepare_model_for_qat(
activation_dtype=None,
group_size=group_size,
)
embedding_qat_config = QATConfig(embedding_base_config)
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
embedding_qat_config = QATConfig(
weight_config=embedding_base_config,
)
else:
embedding_qat_config = QATConfig(embedding_base_config)
quantize_(
model,
embedding_qat_config,

View File

@@ -10,6 +10,7 @@ class TorchAOQuantDType(Enum):
int8 = torch.int8
float8_e4m3fn = torch.float8_e4m3fn
nvfp4 = "nvfp4"
mxfp4 = "mxfp4"
def from_string(str):
if str == "int4":
@@ -20,6 +21,8 @@ class TorchAOQuantDType(Enum):
return TorchAOQuantDType.float8_e4m3fn
if str == "nvfp4":
return TorchAOQuantDType.nvfp4
if str == "mxfp4":
return TorchAOQuantDType.mxfp4
class RLType(str, Enum):

View File

@@ -20,6 +20,9 @@ def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
return TorchAOQuantDType.float8_e4m3fn
if v == "nvfp4":
return TorchAOQuantDType.nvfp4
if v == "mxfp4":
return TorchAOQuantDType.mxfp4
raise ValueError(
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
)