add: support mxfp4 axo (#3375)
* mxfp4 axo * import lint * test for qat mxfp4 * config for mxfp4 * add qat: * pass base config * MXFakeQuantizeConfig * lint * tune config so it fits in 32GB VRAM --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
65
examples/llama-3/3b-qat-mxfp4.yaml
Normal file
65
examples/llama-3/3b-qat-mxfp4.yaml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
base_model: meta-llama/Llama-3.2-3B
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: yahma/alpaca-cleaned
|
||||||
|
type: alpaca
|
||||||
|
split: train[:95%]
|
||||||
|
|
||||||
|
output_dir: ./outputs/qat_out/
|
||||||
|
dataset_prepared_path: ./outputs/dataset_prepared
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
qat:
|
||||||
|
activation_dtype: mxfp4
|
||||||
|
weight_dtype: mxfp4
|
||||||
|
group_size: 32
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: true
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
|
||||||
|
cosine_constant_lr_ratio: 0
|
||||||
|
cosine_min_lr_ratio: 1.0
|
||||||
|
learning_rate: 2e-5
|
||||||
|
save_only_model: true
|
||||||
|
bf16: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|finetune_right_pad_id|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -5,6 +5,7 @@ Utilities for quantization including QAT and PTQ using torchao.
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torchao.core.config import AOBaseConfig
|
from torchao.core.config import AOBaseConfig
|
||||||
|
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||||
from torchao.quantization import quantize_
|
from torchao.quantization import quantize_
|
||||||
from torchao.quantization.qat import (
|
from torchao.quantization.qat import (
|
||||||
QATConfig,
|
QATConfig,
|
||||||
@@ -40,6 +41,13 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||||
|
|
||||||
|
quantization_config_to_str[MXFakeQuantizeConfig] = "mxfp4"
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_quantization_config(
|
def get_quantization_config(
|
||||||
weight_dtype: TorchAOQuantDType,
|
weight_dtype: TorchAOQuantDType,
|
||||||
@@ -109,6 +117,19 @@ def get_quantization_config(
|
|||||||
if group_size is not None and group_size != 16:
|
if group_size is not None and group_size != 16:
|
||||||
raise ValueError("NVFP4 quantization must use a group_size of 16")
|
raise ValueError("NVFP4 quantization must use a group_size of 16")
|
||||||
return NVFP4InferenceConfig()
|
return NVFP4InferenceConfig()
|
||||||
|
|
||||||
|
if weight_dtype == TorchAOQuantDType.mxfp4:
|
||||||
|
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||||
|
|
||||||
|
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)
|
||||||
|
block_size = group_size if group_size is not None else 32
|
||||||
|
if block_size != 32:
|
||||||
|
raise ValueError(
|
||||||
|
"MXFP4 quantization must use a block_size (group_size) of 32"
|
||||||
|
)
|
||||||
|
|
||||||
|
return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
|
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
|
||||||
)
|
)
|
||||||
@@ -179,7 +200,13 @@ def prepare_model_for_qat(
|
|||||||
activation_dtype=activation_dtype,
|
activation_dtype=activation_dtype,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
qat_config = QATConfig(base_config)
|
if isinstance(base_config, MXFakeQuantizeConfig):
|
||||||
|
qat_config = QATConfig(
|
||||||
|
activation_config=base_config,
|
||||||
|
weight_config=base_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qat_config = QATConfig(base_config)
|
||||||
quantize_(model, qat_config)
|
quantize_(model, qat_config)
|
||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
# activation fake quantization is not supported for embedding layers
|
# activation fake quantization is not supported for embedding layers
|
||||||
@@ -188,7 +215,12 @@ def prepare_model_for_qat(
|
|||||||
activation_dtype=None,
|
activation_dtype=None,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
embedding_qat_config = QATConfig(embedding_base_config)
|
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
|
||||||
|
embedding_qat_config = QATConfig(
|
||||||
|
weight_config=embedding_base_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embedding_qat_config = QATConfig(embedding_base_config)
|
||||||
quantize_(
|
quantize_(
|
||||||
model,
|
model,
|
||||||
embedding_qat_config,
|
embedding_qat_config,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ class TorchAOQuantDType(Enum):
|
|||||||
int8 = torch.int8
|
int8 = torch.int8
|
||||||
float8_e4m3fn = torch.float8_e4m3fn
|
float8_e4m3fn = torch.float8_e4m3fn
|
||||||
nvfp4 = "nvfp4"
|
nvfp4 = "nvfp4"
|
||||||
|
mxfp4 = "mxfp4"
|
||||||
|
|
||||||
def from_string(str):
|
def from_string(str):
|
||||||
if str == "int4":
|
if str == "int4":
|
||||||
@@ -20,6 +21,8 @@ class TorchAOQuantDType(Enum):
|
|||||||
return TorchAOQuantDType.float8_e4m3fn
|
return TorchAOQuantDType.float8_e4m3fn
|
||||||
if str == "nvfp4":
|
if str == "nvfp4":
|
||||||
return TorchAOQuantDType.nvfp4
|
return TorchAOQuantDType.nvfp4
|
||||||
|
if str == "mxfp4":
|
||||||
|
return TorchAOQuantDType.mxfp4
|
||||||
|
|
||||||
|
|
||||||
class RLType(str, Enum):
|
class RLType(str, Enum):
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
|
|||||||
return TorchAOQuantDType.float8_e4m3fn
|
return TorchAOQuantDType.float8_e4m3fn
|
||||||
if v == "nvfp4":
|
if v == "nvfp4":
|
||||||
return TorchAOQuantDType.nvfp4
|
return TorchAOQuantDType.nvfp4
|
||||||
|
if v == "mxfp4":
|
||||||
|
return TorchAOQuantDType.mxfp4
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
|
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||||
|
from axolotl.utils.schemas.quantization import QATConfig, validate_ao_dtype
|
||||||
|
|
||||||
from .utils import check_model_output_exists, check_tensorboard
|
from .utils import check_model_output_exists, check_tensorboard
|
||||||
|
|
||||||
@@ -130,3 +132,32 @@ class TestQATLlama:
|
|||||||
loss_threshold,
|
loss_threshold,
|
||||||
"Train Loss (%s) is too high",
|
"Train Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMXFP4Schema:
|
||||||
|
"""Test MXFP4 schema validation"""
|
||||||
|
|
||||||
|
def test_validate_mxfp4_dtype(self):
|
||||||
|
result = validate_ao_dtype("mxfp4")
|
||||||
|
assert result == TorchAOQuantDType.mxfp4
|
||||||
|
|
||||||
|
def test_qat_config_with_mxfp4(self):
|
||||||
|
"""Test QATConfig accepts mxfp4 weight_dtype"""
|
||||||
|
config = QATConfig(
|
||||||
|
weight_dtype="mxfp4",
|
||||||
|
group_size=32,
|
||||||
|
quantize_embedding=False,
|
||||||
|
)
|
||||||
|
assert config.weight_dtype == TorchAOQuantDType.mxfp4
|
||||||
|
assert config.group_size == 32
|
||||||
|
|
||||||
|
def test_qat_config_mxfp4_invalid_group_size(self):
|
||||||
|
"""Test that invalid group_size raises appropriate error during quantization"""
|
||||||
|
# Note: Schema validation doesn't check group_size compatibility,
|
||||||
|
# that happens in get_quantization_config
|
||||||
|
config = QATConfig(
|
||||||
|
weight_dtype="mxfp4",
|
||||||
|
group_size=16, # Invalid for mxfp4, but schema allows it
|
||||||
|
)
|
||||||
|
assert config.group_size == 16 # Schema accepts it
|
||||||
|
# Actual validation happens at runtime in get_quantization_config
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Tests for axolotl.utils.quantization
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||||
from torchao.quantization import LinearActivationQuantizedTensor
|
from torchao.quantization import LinearActivationQuantizedTensor
|
||||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||||
@@ -117,6 +118,21 @@ class TestQuantization:
|
|||||||
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
|
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
|
||||||
assert isinstance(config, expected_type)
|
assert isinstance(config, expected_type)
|
||||||
|
|
||||||
|
@require_torch_2_8_0
|
||||||
|
@requires_sm_ge_100
|
||||||
|
def test_get_ptq_config_mxfp4(self):
|
||||||
|
config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)
|
||||||
|
assert isinstance(config, MXFakeQuantizeConfig)
|
||||||
|
assert config.block_size == 32
|
||||||
|
|
||||||
|
@require_torch_2_8_0
|
||||||
|
@requires_sm_ge_100
|
||||||
|
def test_get_ptq_config_mxfp4_invalid_group_size(self):
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="MXFP4 quantization must use a block_size"
|
||||||
|
):
|
||||||
|
get_quantization_config(TorchAOQuantDType.mxfp4, None, 16)
|
||||||
|
|
||||||
@requires_cuda_ge_8_9
|
@requires_cuda_ge_8_9
|
||||||
@require_torch_2_8_0
|
@require_torch_2_8_0
|
||||||
def test_get_ptq_config_int4_weight_only(self):
|
def test_get_ptq_config_int4_weight_only(self):
|
||||||
@@ -262,6 +278,35 @@ class TestQuantization:
|
|||||||
else:
|
else:
|
||||||
assert child.activation_fake_quantizer is None
|
assert child.activation_fake_quantizer is None
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"weight_dtype,activation_dtype,group_size,quantize_embedding",
|
||||||
|
[
|
||||||
|
(TorchAOQuantDType.mxfp4, None, 32, False),
|
||||||
|
(TorchAOQuantDType.mxfp4, None, 32, True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@require_torch_2_8_0
|
||||||
|
@requires_sm_ge_100
|
||||||
|
def test_prepare_model_for_qat_mxfp4(
|
||||||
|
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
|
||||||
|
):
|
||||||
|
prepare_model_for_qat(
|
||||||
|
model,
|
||||||
|
weight_dtype,
|
||||||
|
group_size,
|
||||||
|
activation_dtype,
|
||||||
|
quantize_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
if quantize_embedding:
|
||||||
|
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||||
|
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
||||||
|
|
||||||
|
for child in list(model.children()):
|
||||||
|
if isinstance(child, torch.nn.Linear):
|
||||||
|
assert isinstance(child, FakeQuantizedLinear)
|
||||||
|
assert hasattr(child, "weight_fake_quantizer")
|
||||||
|
|
||||||
@require_torch_2_8_0
|
@require_torch_2_8_0
|
||||||
@requires_cuda_ge_8_9
|
@requires_cuda_ge_8_9
|
||||||
def test_convert_qat_model(self, model):
|
def test_convert_qat_model(self, model):
|
||||||
|
|||||||
Reference in New Issue
Block a user