add: support mxfp4 axo (#3375)
* mxfp4 axo * import lint * test for qat mxfp4 * config for mxfp4 * add qat: * pass base config * MXFakeQuantizeConfig * lint * tune config so it fits in 32GB VRAM --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -5,6 +5,7 @@ Utilities for quantization including QAT and PTQ using torchao.
|
||||
import torch
|
||||
from packaging import version
|
||||
from torchao.core.config import AOBaseConfig
|
||||
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||
from torchao.quantization import quantize_
|
||||
from torchao.quantization.qat import (
|
||||
QATConfig,
|
||||
@@ -40,6 +41,13 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||
|
||||
quantization_config_to_str[MXFakeQuantizeConfig] = "mxfp4"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def get_quantization_config(
|
||||
weight_dtype: TorchAOQuantDType,
|
||||
@@ -109,6 +117,19 @@ def get_quantization_config(
|
||||
if group_size is not None and group_size != 16:
|
||||
raise ValueError("NVFP4 quantization must use a group_size of 16")
|
||||
return NVFP4InferenceConfig()
|
||||
|
||||
if weight_dtype == TorchAOQuantDType.mxfp4:
|
||||
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||
|
||||
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)
|
||||
block_size = group_size if group_size is not None else 32
|
||||
if block_size != 32:
|
||||
raise ValueError(
|
||||
"MXFP4 quantization must use a block_size (group_size) of 32"
|
||||
)
|
||||
|
||||
return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
|
||||
)
|
||||
@@ -179,7 +200,13 @@ def prepare_model_for_qat(
|
||||
activation_dtype=activation_dtype,
|
||||
group_size=group_size,
|
||||
)
|
||||
qat_config = QATConfig(base_config)
|
||||
if isinstance(base_config, MXFakeQuantizeConfig):
|
||||
qat_config = QATConfig(
|
||||
activation_config=base_config,
|
||||
weight_config=base_config,
|
||||
)
|
||||
else:
|
||||
qat_config = QATConfig(base_config)
|
||||
quantize_(model, qat_config)
|
||||
if quantize_embedding:
|
||||
# activation fake quantization is not supported for embedding layers
|
||||
@@ -188,7 +215,12 @@ def prepare_model_for_qat(
|
||||
activation_dtype=None,
|
||||
group_size=group_size,
|
||||
)
|
||||
embedding_qat_config = QATConfig(embedding_base_config)
|
||||
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
|
||||
embedding_qat_config = QATConfig(
|
||||
weight_config=embedding_base_config,
|
||||
)
|
||||
else:
|
||||
embedding_qat_config = QATConfig(embedding_base_config)
|
||||
quantize_(
|
||||
model,
|
||||
embedding_qat_config,
|
||||
|
||||
@@ -10,6 +10,7 @@ class TorchAOQuantDType(Enum):
|
||||
int8 = torch.int8
|
||||
float8_e4m3fn = torch.float8_e4m3fn
|
||||
nvfp4 = "nvfp4"
|
||||
mxfp4 = "mxfp4"
|
||||
|
||||
def from_string(str):
|
||||
if str == "int4":
|
||||
@@ -20,6 +21,8 @@ class TorchAOQuantDType(Enum):
|
||||
return TorchAOQuantDType.float8_e4m3fn
|
||||
if str == "nvfp4":
|
||||
return TorchAOQuantDType.nvfp4
|
||||
if str == "mxfp4":
|
||||
return TorchAOQuantDType.mxfp4
|
||||
|
||||
|
||||
class RLType(str, Enum):
|
||||
|
||||
@@ -20,6 +20,9 @@ def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
|
||||
return TorchAOQuantDType.float8_e4m3fn
|
||||
if v == "nvfp4":
|
||||
return TorchAOQuantDType.nvfp4
|
||||
if v == "mxfp4":
|
||||
return TorchAOQuantDType.mxfp4
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user