Files
axolotl/tests/e2e/test_quantization.py
Wing Lian ac77da96da use smaller pretrained models for ci (#3620) [skip ci]
* use smaller pretrained models for ci

* more steps for loss check

* fix tests

* more train steps

* fix losses
2026-04-27 13:22:56 -04:00

473 lines
16 KiB
Python

"""
Tests for axolotl.utils.quantization
"""
import pytest
import torch
from torch import nn
from torchao.quantization import IntxUnpackedToInt8Tensor
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationIntxWeightConfig,
)
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
from transformers import AutoModelForCausalLM
from transformers.trainer_callback import TrainerState
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.quantization import (
convert_qat_model,
get_quantization_config,
prepare_model_for_qat,
quantize_model,
)
from axolotl.utils.schemas.enums import TorchAOQuantDType
from axolotl.utils.schemas.quantization import QATConfig
from tests.e2e.utils import (
require_torch_2_8_0,
requires_cuda_ge_8_9,
requires_sm_ge_100,
)
def _get_fake_quant_config_dtype(config):
"""Get the weight dtype from a fake quantize config, handling different config types."""
if hasattr(config, "dtype"):
return config.dtype
# Int4WeightFakeQuantizeConfig doesn't have .dtype — weight is always int4
return torch.int4
@pytest.fixture()
def model():
dummy_model = AutoModelForCausalLM.from_pretrained(
"axolotl-ai-co/tiny-qwen2-129m",
device_map="auto",
dtype=torch.bfloat16,
)
with torch.device(dummy_model.device):
dummy_model.model.embed_tokens = torch.nn.Embedding(
dummy_model.model.embed_tokens.weight.shape[0],
dummy_model.model.embed_tokens.weight.shape[1],
dtype=dummy_model.model.embed_tokens.weight.dtype,
)
yield dummy_model
del dummy_model
ptq_config_test_cases = [
# weight_dtype, activation_dtype, group_size, expected_type
(
TorchAOQuantDType.int4,
TorchAOQuantDType.int8,
None,
Int8DynamicActivationIntxWeightConfig,
),
(
TorchAOQuantDType.float8_e4m3fn,
TorchAOQuantDType.float8_e4m3fn,
None,
Float8DynamicActivationFloat8WeightConfig,
),
(
TorchAOQuantDType.int4,
TorchAOQuantDType.float8_e4m3fn,
None,
Float8DynamicActivationInt4WeightConfig,
),
]
ptq_test_cases = [
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception, expected_tensor_class
(TorchAOQuantDType.int4, None, 4, True, None, Int4Tensor),
(
TorchAOQuantDType.int4,
TorchAOQuantDType.int8,
8,
False,
None,
IntxUnpackedToInt8Tensor,
),
# (
# TorchAOQuantDType.int4,
# TorchAOQuantDType.float8_e4m3fn,
# None,
# False,
# None,
# Int4Tensor,
# ),
(TorchAOQuantDType.int4, None, None, False, None, Int4Tensor),
# Deprecated configs
(TorchAOQuantDType.int8, None, 8, False, ValueError, None),
(TorchAOQuantDType.int4, TorchAOQuantDType.int4, 8, False, ValueError, None),
(TorchAOQuantDType.int8, TorchAOQuantDType.int8, 8, True, ValueError, None),
]
class TestQuantization:
"""
Test quantization utilities
"""
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,expected_type",
ptq_config_test_cases,
)
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_get_ptq_config(
self, weight_dtype, activation_dtype, group_size, expected_type
):
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
assert isinstance(config, expected_type)
@require_torch_2_8_0
@requires_sm_ge_100
def test_get_ptq_config_mxfp4(self):
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)
assert isinstance(config, MXDynamicActivationMXWeightConfig)
assert config.weight_dtype == torch.float4_e2m1fn_x2
assert config.block_size == 32
@require_torch_2_8_0
@requires_sm_ge_100
def test_get_ptq_config_mxfp4_invalid_group_size(self):
with pytest.raises(
ValueError, match="MXFP4 quantization must use a block_size"
):
get_quantization_config(TorchAOQuantDType.mxfp4, None, 16)
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_get_ptq_config_int4_weight_only(self):
from torchao.quantization.quant_api import Int4WeightOnlyConfig
config = get_quantization_config(TorchAOQuantDType.int4, None, 4)
assert isinstance(config, Int4WeightOnlyConfig)
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception,expected_tensor_class",
ptq_test_cases,
)
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_quantize_model_for_ptq(
self,
model,
weight_dtype,
activation_dtype,
group_size,
quantize_embedding,
expected_exception,
expected_tensor_class,
):
# TODO: add mslk-cuda as a CI dependency once pytorch 2.10.x is available
# (see https://pypi.org/project/mslk-cuda/)
if expected_tensor_class is Int4Tensor and activation_dtype is None:
try:
from torchao.quantization.quantize_.workflows.int4.int4_tensor import (
int4_row_quantize_zp,
)
if int4_row_quantize_zp is None:
pytest.skip("Int4Tensor requires mslk >= 1.0.0")
except ImportError:
pytest.skip("Int4Tensor requires mslk >= 1.0.0")
if expected_exception:
with pytest.raises(expected_exception):
quantize_model(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
else:
quantize_model(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(
model.model.embed_tokens.weight, expected_tensor_class
), "Embedding weight should be quantized"
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, expected_tensor_class)
@require_torch_2_8_0
@requires_sm_ge_100
def test_quantize_model_for_ptq_fp8(
self,
model,
):
from torchao.quantization.quantize_.workflows.float8.float8_tensor import (
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
)
quantize_model(
model,
TorchAOQuantDType.float8_e4m3fn,
None,
TorchAOQuantDType.float8_e4m3fn,
)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, Float8Tensor)
assert child.weight.act_quant_kwargs is not None and isinstance(
child.weight.act_quant_kwargs, QuantizeTensorToFloat8Kwargs
)
@require_torch_2_8_0
@requires_sm_ge_100
def test_quantize_model_for_ptq_nvfp4(
self,
model,
):
from torchao.prototype.mx_formats.nvfp4_tensor import (
NVFP4Tensor,
QuantizeTensorToNVFP4Kwargs,
)
quantize_model(model, TorchAOQuantDType.nvfp4, 16, TorchAOQuantDType.nvfp4)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, NVFP4Tensor)
assert child.weight.act_quant_kwargs is not None and isinstance(
child.weight.act_quant_kwargs, QuantizeTensorToNVFP4Kwargs
)
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding",
[
(TorchAOQuantDType.int4, None, 8, False),
(TorchAOQuantDType.int4, None, 16, True),
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 8, False),
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 16, True),
(
TorchAOQuantDType.float8_e4m3fn,
TorchAOQuantDType.float8_e4m3fn,
None,
False,
),
(TorchAOQuantDType.int4, TorchAOQuantDType.float8_e4m3fn, None, True),
],
)
@require_torch_2_8_0
@requires_cuda_ge_8_9
def test_prepare_model_for_qat(
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
):
prepare_model_for_qat(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
embed_config = model.model.embed_tokens.weight_fake_quantizer.config
assert _get_fake_quant_config_dtype(embed_config) == weight_dtype.value
if group_size:
assert embed_config.group_size == group_size
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer")
w_config = child.weight_fake_quantizer.config
assert _get_fake_quant_config_dtype(w_config) == weight_dtype.value
if group_size:
assert w_config.group_size == group_size
if activation_dtype:
assert hasattr(child, "activation_fake_quantizer")
a_config = child.activation_fake_quantizer.config
assert (
_get_fake_quant_config_dtype(a_config) == activation_dtype.value
)
else:
assert child.activation_fake_quantizer is None
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding",
[
(TorchAOQuantDType.mxfp4, None, 32, False),
],
)
@require_torch_2_8_0
@requires_sm_ge_100
def test_prepare_model_for_qat_mxfp4(
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
):
prepare_model_for_qat(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
from torchao.prototype.qat import MXFakeQuantizedLinear
if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child, MXFakeQuantizedLinear)
assert hasattr(child, "weight_config")
@require_torch_2_8_0
@requires_cuda_ge_8_9
def test_convert_qat_model(self, model):
config = QATConfig(
weight_dtype="int4",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
)
# quantize model for qat
prepare_model_for_qat(
model,
config.weight_dtype,
config.group_size,
config.activation_dtype,
config.quantize_embedding,
)
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert isinstance(model.lm_head, FakeQuantizedLinear)
# apply conversion
convert_qat_model(
model,
config.quantize_embedding,
)
# ensure modules have been swapped out
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert not isinstance(model.lm_head, FakeQuantizedLinear)
# ensure weights have been quantized
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
assert isinstance(model.lm_head.weight, nn.Parameter)
class TestQuantizationCallback:
"""
Test QATCallback
"""
@pytest.fixture()
def trainer_state(self):
return TrainerState(
global_step=0,
)
@require_torch_2_8_0
def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state):
cfg = QATConfig(
weight_dtype="int4",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
fake_quant_after_n_steps=100,
)
prepare_model_for_qat(
model,
cfg.weight_dtype,
cfg.group_size,
cfg.activation_dtype,
cfg.quantize_embedding,
)
# ensure model has been quantized
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert isinstance(model.lm_head, FakeQuantizedLinear)
# Only test enable/disable toggling if the fake quantizer supports it
# (Int4WeightFakeQuantizer does not have an 'enabled' attribute)
supports_toggle = hasattr(
model.model.embed_tokens.weight_fake_quantizer, "enabled"
)
if supports_toggle:
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
qat_callback = QATCallback(cfg)
# simulate first training step
qat_callback.on_step_begin(
args=None,
state=trainer_state,
control=None,
model=model,
)
if supports_toggle:
# quantization should have been disabled
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
assert not model.lm_head.weight_fake_quantizer.enabled
trainer_state.global_step = 100
qat_callback.on_step_begin(
args=None,
state=trainer_state,
control=None,
model=model,
)
if supports_toggle:
# quantization should have been enabled
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
@require_torch_2_8_0
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
cfg = QATConfig(
weight_dtype="int4",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
fake_quant_after_n_steps=None,
)
prepare_model_for_qat(
model,
cfg.weight_dtype,
cfg.group_size,
cfg.activation_dtype,
cfg.quantize_embedding,
)
# ensure model has been quantized
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert isinstance(model.lm_head, FakeQuantizedLinear)
if hasattr(model.model.embed_tokens.weight_fake_quantizer, "enabled"):
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
qat_callback = QATCallback(cfg)
# simulate first training step
qat_callback.on_step_begin(
args=None,
state=trainer_state,
control=None,
model=model,
)
# quantization should be enabled from the get-go
if hasattr(model.model.embed_tokens.weight_fake_quantizer, "enabled"):
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled