add: support mxfp4 axo (#3375)

* mxfp4 axo

* import lint

* test for qat mxfp4

* config for mxfp4

* add qat:

* pass base config

* MXFakeQuantizeConfig

* lint

* tune config so it fits in 32GB VRAM

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
VED
2026-03-06 00:10:45 +05:30
committed by GitHub
parent 4b8bc52424
commit 1eaf4d7418
6 changed files with 181 additions and 2 deletions

View File

@@ -0,0 +1,65 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/dataset_prepared
sequence_len: 2048
flash_attention: true
qat:
activation_dtype: mxfp4
weight_dtype: mxfp4
group_size: 32
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_checkpointing: true
activation_offloading: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -5,6 +5,7 @@ Utilities for quantization including QAT and PTQ using torchao.
import torch import torch
from packaging import version from packaging import version
from torchao.core.config import AOBaseConfig from torchao.core.config import AOBaseConfig
from torchao.prototype.qat import MXFakeQuantizeConfig
from torchao.quantization import quantize_ from torchao.quantization import quantize_
from torchao.quantization.qat import ( from torchao.quantization.qat import (
QATConfig, QATConfig,
@@ -40,6 +41,13 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
except: except:
pass pass
try:
from torchao.prototype.qat import MXFakeQuantizeConfig
quantization_config_to_str[MXFakeQuantizeConfig] = "mxfp4"
except ImportError:
pass
def get_quantization_config( def get_quantization_config(
weight_dtype: TorchAOQuantDType, weight_dtype: TorchAOQuantDType,
@@ -109,6 +117,19 @@ def get_quantization_config(
if group_size is not None and group_size != 16: if group_size is not None and group_size != 16:
raise ValueError("NVFP4 quantization must use a group_size of 16") raise ValueError("NVFP4 quantization must use a group_size of 16")
return NVFP4InferenceConfig() return NVFP4InferenceConfig()
if weight_dtype == TorchAOQuantDType.mxfp4:
from torchao.prototype.qat import MXFakeQuantizeConfig
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)
block_size = group_size if group_size is not None else 32
if block_size != 32:
raise ValueError(
"MXFP4 quantization must use a block_size (group_size) of 32"
)
return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size)
raise ValueError( raise ValueError(
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}" f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
) )
@@ -179,7 +200,13 @@ def prepare_model_for_qat(
activation_dtype=activation_dtype, activation_dtype=activation_dtype,
group_size=group_size, group_size=group_size,
) )
qat_config = QATConfig(base_config) if isinstance(base_config, MXFakeQuantizeConfig):
qat_config = QATConfig(
activation_config=base_config,
weight_config=base_config,
)
else:
qat_config = QATConfig(base_config)
quantize_(model, qat_config) quantize_(model, qat_config)
if quantize_embedding: if quantize_embedding:
# activation fake quantization is not supported for embedding layers # activation fake quantization is not supported for embedding layers
@@ -188,7 +215,12 @@ def prepare_model_for_qat(
activation_dtype=None, activation_dtype=None,
group_size=group_size, group_size=group_size,
) )
embedding_qat_config = QATConfig(embedding_base_config) if isinstance(embedding_base_config, MXFakeQuantizeConfig):
embedding_qat_config = QATConfig(
weight_config=embedding_base_config,
)
else:
embedding_qat_config = QATConfig(embedding_base_config)
quantize_( quantize_(
model, model,
embedding_qat_config, embedding_qat_config,

View File

@@ -10,6 +10,7 @@ class TorchAOQuantDType(Enum):
int8 = torch.int8 int8 = torch.int8
float8_e4m3fn = torch.float8_e4m3fn float8_e4m3fn = torch.float8_e4m3fn
nvfp4 = "nvfp4" nvfp4 = "nvfp4"
mxfp4 = "mxfp4"
def from_string(str): def from_string(str):
if str == "int4": if str == "int4":
@@ -20,6 +21,8 @@ class TorchAOQuantDType(Enum):
return TorchAOQuantDType.float8_e4m3fn return TorchAOQuantDType.float8_e4m3fn
if str == "nvfp4": if str == "nvfp4":
return TorchAOQuantDType.nvfp4 return TorchAOQuantDType.nvfp4
if str == "mxfp4":
return TorchAOQuantDType.mxfp4
class RLType(str, Enum): class RLType(str, Enum):

View File

@@ -20,6 +20,9 @@ def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
return TorchAOQuantDType.float8_e4m3fn return TorchAOQuantDType.float8_e4m3fn
if v == "nvfp4": if v == "nvfp4":
return TorchAOQuantDType.nvfp4 return TorchAOQuantDType.nvfp4
if v == "mxfp4":
return TorchAOQuantDType.mxfp4
raise ValueError( raise ValueError(
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}" f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
) )

View File

@@ -8,6 +8,8 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import TorchAOQuantDType
from axolotl.utils.schemas.quantization import QATConfig, validate_ao_dtype
from .utils import check_model_output_exists, check_tensorboard from .utils import check_model_output_exists, check_tensorboard
@@ -130,3 +132,32 @@ class TestQATLlama:
loss_threshold, loss_threshold,
"Train Loss (%s) is too high", "Train Loss (%s) is too high",
) )
class TestMXFP4Schema:
"""Test MXFP4 schema validation"""
def test_validate_mxfp4_dtype(self):
result = validate_ao_dtype("mxfp4")
assert result == TorchAOQuantDType.mxfp4
def test_qat_config_with_mxfp4(self):
"""Test QATConfig accepts mxfp4 weight_dtype"""
config = QATConfig(
weight_dtype="mxfp4",
group_size=32,
quantize_embedding=False,
)
assert config.weight_dtype == TorchAOQuantDType.mxfp4
assert config.group_size == 32
def test_qat_config_mxfp4_invalid_group_size(self):
"""Test that invalid group_size raises appropriate error during quantization"""
# Note: Schema validation doesn't check group_size compatibility,
# that happens in get_quantization_config
config = QATConfig(
weight_dtype="mxfp4",
group_size=16, # Invalid for mxfp4, but schema allows it
)
assert config.group_size == 16 # Schema accepts it
# Actual validation happens at runtime in get_quantization_config

View File

@@ -5,6 +5,7 @@ Tests for axolotl.utils.quantization
import pytest import pytest
import torch import torch
from torch import nn from torch import nn
from torchao.prototype.qat import MXFakeQuantizeConfig
from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear from torchao.quantization.qat.linear import FakeQuantizedLinear
@@ -117,6 +118,21 @@ class TestQuantization:
config = get_quantization_config(weight_dtype, activation_dtype, group_size) config = get_quantization_config(weight_dtype, activation_dtype, group_size)
assert isinstance(config, expected_type) assert isinstance(config, expected_type)
@require_torch_2_8_0
@requires_sm_ge_100
def test_get_ptq_config_mxfp4(self):
config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)
assert isinstance(config, MXFakeQuantizeConfig)
assert config.block_size == 32
@require_torch_2_8_0
@requires_sm_ge_100
def test_get_ptq_config_mxfp4_invalid_group_size(self):
with pytest.raises(
ValueError, match="MXFP4 quantization must use a block_size"
):
get_quantization_config(TorchAOQuantDType.mxfp4, None, 16)
@requires_cuda_ge_8_9 @requires_cuda_ge_8_9
@require_torch_2_8_0 @require_torch_2_8_0
def test_get_ptq_config_int4_weight_only(self): def test_get_ptq_config_int4_weight_only(self):
@@ -262,6 +278,35 @@ class TestQuantization:
else: else:
assert child.activation_fake_quantizer is None assert child.activation_fake_quantizer is None
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding",
[
(TorchAOQuantDType.mxfp4, None, 32, False),
(TorchAOQuantDType.mxfp4, None, 32, True),
],
)
@require_torch_2_8_0
@requires_sm_ge_100
def test_prepare_model_for_qat_mxfp4(
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
):
prepare_model_for_qat(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer")
@require_torch_2_8_0 @require_torch_2_8_0
@requires_cuda_ge_8_9 @requires_cuda_ge_8_9
def test_convert_qat_model(self, model): def test_convert_qat_model(self, model):