diff --git a/examples/llama-3/3b-qat-mxfp4.yaml b/examples/llama-3/3b-qat-mxfp4.yaml new file mode 100644 index 000000000..7ae941e9e --- /dev/null +++ b/examples/llama-3/3b-qat-mxfp4.yaml @@ -0,0 +1,65 @@ +base_model: meta-llama/Llama-3.2-3B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true + +datasets: + - path: yahma/alpaca-cleaned + type: alpaca + split: train[:95%] + +output_dir: ./outputs/qat_out/ +dataset_prepared_path: ./outputs/dataset_prepared + +sequence_len: 2048 +flash_attention: true + +qat: + activation_dtype: mxfp4 + weight_dtype: mxfp4 + group_size: 32 + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_checkpointing: true +activation_offloading: true +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_8bit + +cosine_constant_lr_ratio: 0 +cosine_min_lr_ratio: 1.0 +learning_rate: 2e-5 +save_only_model: true +bf16: true + +resume_from_checkpoint: +logging_steps: 1 + +evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 + +special_tokens: + pad_token: <|finetune_right_pad_id|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/utils/quantization.py b/src/axolotl/utils/quantization.py index 6c29a5442..43af858b1 100644 --- a/src/axolotl/utils/quantization.py +++ b/src/axolotl/utils/quantization.py @@ -5,6 +5,7 @@ Utilities for quantization including QAT and PTQ using torchao. import torch from packaging import version from torchao.core.config import AOBaseConfig +from torchao.prototype.qat import MXFakeQuantizeConfig from torchao.quantization import quantize_ from torchao.quantization.qat import ( QATConfig, @@ -40,6 +41,13 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"): except: pass + try: + from torchao.prototype.qat import MXFakeQuantizeConfig + + quantization_config_to_str[MXFakeQuantizeConfig] = "mxfp4" + except ImportError: + pass + def get_quantization_config( weight_dtype: TorchAOQuantDType, @@ -109,6 +117,19 @@ def get_quantization_config( if group_size is not None and group_size != 16: raise ValueError("NVFP4 quantization must use a group_size of 16") return NVFP4InferenceConfig() + + if weight_dtype == TorchAOQuantDType.mxfp4: + from torchao.prototype.qat import MXFakeQuantizeConfig + + # MXFP4 uses block_size=32 by default (vs NVFP4's 16) + block_size = group_size if group_size is not None else 32 + if block_size != 32: + raise ValueError( + "MXFP4 quantization must use a block_size (group_size) of 32" + ) + + return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size) + raise ValueError( f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}" ) @@ -179,7 +200,13 @@ def prepare_model_for_qat( activation_dtype=activation_dtype, group_size=group_size, ) - qat_config = QATConfig(base_config) + if isinstance(base_config, MXFakeQuantizeConfig): + qat_config = QATConfig( + activation_config=base_config, + weight_config=base_config, + ) + else: + qat_config = QATConfig(base_config) quantize_(model, qat_config) if quantize_embedding: # activation fake quantization is not supported for embedding layers @@ -188,7 +215,12 @@ def prepare_model_for_qat( activation_dtype=None, group_size=group_size, ) - embedding_qat_config = QATConfig(embedding_base_config) + if isinstance(embedding_base_config, MXFakeQuantizeConfig): + embedding_qat_config = QATConfig( + weight_config=embedding_base_config, + ) + else: + embedding_qat_config = QATConfig(embedding_base_config) quantize_( model, embedding_qat_config, diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index b67888e0f..893f23288 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -10,6 +10,7 @@ class TorchAOQuantDType(Enum): int8 = torch.int8 float8_e4m3fn = torch.float8_e4m3fn nvfp4 = "nvfp4" + mxfp4 = "mxfp4" def from_string(str): if str == "int4": @@ -20,6 +21,8 @@ class TorchAOQuantDType(Enum): return TorchAOQuantDType.float8_e4m3fn if str == "nvfp4": return TorchAOQuantDType.nvfp4 + if str == "mxfp4": + return TorchAOQuantDType.mxfp4 class RLType(str, Enum): diff --git a/src/axolotl/utils/schemas/quantization.py b/src/axolotl/utils/schemas/quantization.py index a7c130574..b15e5d225 100644 --- a/src/axolotl/utils/schemas/quantization.py +++ b/src/axolotl/utils/schemas/quantization.py @@ -20,6 +20,9 @@ def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None: return TorchAOQuantDType.float8_e4m3fn if v == "nvfp4": return TorchAOQuantDType.nvfp4 + if v == "mxfp4": + return TorchAOQuantDType.mxfp4 + raise ValueError( f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}" ) diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py index 5cfbc8553..251d5b17b 100644 --- a/tests/e2e/test_qat.py +++ b/tests/e2e/test_qat.py @@ -8,6 +8,8 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.schemas.enums import TorchAOQuantDType +from axolotl.utils.schemas.quantization import QATConfig, validate_ao_dtype from .utils import check_model_output_exists, check_tensorboard @@ -130,3 +132,32 @@ class TestQATLlama: loss_threshold, "Train Loss (%s) is too high", ) + + +class TestMXFP4Schema: + """Test MXFP4 schema validation""" + + def test_validate_mxfp4_dtype(self): + result = validate_ao_dtype("mxfp4") + assert result == TorchAOQuantDType.mxfp4 + + def test_qat_config_with_mxfp4(self): + """Test QATConfig accepts mxfp4 weight_dtype""" + config = QATConfig( + weight_dtype="mxfp4", + group_size=32, + quantize_embedding=False, + ) + assert config.weight_dtype == TorchAOQuantDType.mxfp4 + assert config.group_size == 32 + + def test_qat_config_mxfp4_invalid_group_size(self): + """Test that invalid group_size raises appropriate error during quantization""" + # Note: Schema validation doesn't check group_size compatibility, + # that happens in get_quantization_config + config = QATConfig( + weight_dtype="mxfp4", + group_size=16, # Invalid for mxfp4, but schema allows it + ) + assert config.group_size == 16 # Schema accepts it + # Actual validation happens at runtime in get_quantization_config diff --git a/tests/e2e/test_quantization.py b/tests/e2e/test_quantization.py index 706279c6c..371ffb659 100644 --- a/tests/e2e/test_quantization.py +++ b/tests/e2e/test_quantization.py @@ -5,6 +5,7 @@ Tests for axolotl.utils.quantization import pytest import torch from torch import nn +from torchao.prototype.qat import MXFakeQuantizeConfig from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization.qat.embedding import FakeQuantizedEmbedding from torchao.quantization.qat.linear import FakeQuantizedLinear @@ -117,6 +118,21 @@ class TestQuantization: config = get_quantization_config(weight_dtype, activation_dtype, group_size) assert isinstance(config, expected_type) + @require_torch_2_8_0 + @requires_sm_ge_100 + def test_get_ptq_config_mxfp4(self): + config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32) + assert isinstance(config, MXFakeQuantizeConfig) + assert config.block_size == 32 + + @require_torch_2_8_0 + @requires_sm_ge_100 + def test_get_ptq_config_mxfp4_invalid_group_size(self): + with pytest.raises( + ValueError, match="MXFP4 quantization must use a block_size" + ): + get_quantization_config(TorchAOQuantDType.mxfp4, None, 16) + @requires_cuda_ge_8_9 @require_torch_2_8_0 def test_get_ptq_config_int4_weight_only(self): @@ -262,6 +278,35 @@ class TestQuantization: else: assert child.activation_fake_quantizer is None + @pytest.mark.parametrize( + "weight_dtype,activation_dtype,group_size,quantize_embedding", + [ + (TorchAOQuantDType.mxfp4, None, 32, False), + (TorchAOQuantDType.mxfp4, None, 32, True), + ], + ) + @require_torch_2_8_0 + @requires_sm_ge_100 + def test_prepare_model_for_qat_mxfp4( + self, model, weight_dtype, activation_dtype, group_size, quantize_embedding + ): + prepare_model_for_qat( + model, + weight_dtype, + group_size, + activation_dtype, + quantize_embedding, + ) + + if quantize_embedding: + assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) + assert hasattr(model.model.embed_tokens, "weight_fake_quantizer") + + for child in list(model.children()): + if isinstance(child, torch.nn.Linear): + assert isinstance(child, FakeQuantizedLinear) + assert hasattr(child, "weight_fake_quantizer") + @require_torch_2_8_0 @requires_cuda_ge_8_9 def test_convert_qat_model(self, model):