MX QAT patch (#3553)
* qat patch * tests fixes * fixup per PR code review * use state dict hooks to handle dequant for saving safetensors from transformers * use transformers torch ao quantizer hooks to save mx quantized model --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -5,15 +5,20 @@ Tests for axolotl.utils.quantization
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||
from torchao.quantization import LinearActivationQuantizedTensor
|
||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||
from torchao.quantization.quant_api import (
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
Int8DynamicActivationInt4WeightConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
|
||||
except ImportError:
|
||||
from torchao.quantization.quant_api import (
|
||||
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
|
||||
)
|
||||
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.trainer_callback import TrainerState
|
||||
@@ -129,8 +134,11 @@ class TestQuantization:
|
||||
@require_torch_2_8_0
|
||||
@requires_sm_ge_100
|
||||
def test_get_ptq_config_mxfp4(self):
|
||||
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
|
||||
|
||||
config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)
|
||||
assert isinstance(config, MXFakeQuantizeConfig)
|
||||
assert isinstance(config, MXDynamicActivationMXWeightConfig)
|
||||
assert config.weight_dtype == torch.float4_e2m1fn_x2
|
||||
assert config.block_size == 32
|
||||
|
||||
@require_torch_2_8_0
|
||||
@@ -298,7 +306,6 @@ class TestQuantization:
|
||||
"weight_dtype,activation_dtype,group_size,quantize_embedding",
|
||||
[
|
||||
(TorchAOQuantDType.mxfp4, None, 32, False),
|
||||
(TorchAOQuantDType.mxfp4, None, 32, True),
|
||||
],
|
||||
)
|
||||
@require_torch_2_8_0
|
||||
@@ -314,14 +321,16 @@ class TestQuantization:
|
||||
quantize_embedding,
|
||||
)
|
||||
|
||||
from torchao.prototype.qat import MXFakeQuantizedLinear
|
||||
|
||||
if quantize_embedding:
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
||||
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child, FakeQuantizedLinear)
|
||||
assert hasattr(child, "weight_fake_quantizer")
|
||||
assert isinstance(child, MXFakeQuantizedLinear)
|
||||
assert hasattr(child, "weight_config")
|
||||
|
||||
@require_torch_2_8_0
|
||||
@requires_cuda_ge_8_9
|
||||
|
||||
Reference in New Issue
Block a user