add: support mxfp4 axo (#3375)
* mxfp4 axo * import lint * test for qat mxfp4 * config for mxfp4 * add qat: * pass base config * MXFakeQuantizeConfig * lint * tune config so it fits in 32GB VRAM --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -8,6 +8,8 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||
from axolotl.utils.schemas.quantization import QATConfig, validate_ao_dtype
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
@@ -130,3 +132,32 @@ class TestQATLlama:
|
||||
loss_threshold,
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
|
||||
|
||||
class TestMXFP4Schema:
|
||||
"""Test MXFP4 schema validation"""
|
||||
|
||||
def test_validate_mxfp4_dtype(self):
|
||||
result = validate_ao_dtype("mxfp4")
|
||||
assert result == TorchAOQuantDType.mxfp4
|
||||
|
||||
def test_qat_config_with_mxfp4(self):
|
||||
"""Test QATConfig accepts mxfp4 weight_dtype"""
|
||||
config = QATConfig(
|
||||
weight_dtype="mxfp4",
|
||||
group_size=32,
|
||||
quantize_embedding=False,
|
||||
)
|
||||
assert config.weight_dtype == TorchAOQuantDType.mxfp4
|
||||
assert config.group_size == 32
|
||||
|
||||
def test_qat_config_mxfp4_invalid_group_size(self):
|
||||
"""Test that invalid group_size raises appropriate error during quantization"""
|
||||
# Note: Schema validation doesn't check group_size compatibility,
|
||||
# that happens in get_quantization_config
|
||||
config = QATConfig(
|
||||
weight_dtype="mxfp4",
|
||||
group_size=16, # Invalid for mxfp4, but schema allows it
|
||||
)
|
||||
assert config.group_size == 16 # Schema accepts it
|
||||
# Actual validation happens at runtime in get_quantization_config
|
||||
|
||||
Reference in New Issue
Block a user