Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 (#3107)

This commit is contained in:
salman
2025-09-12 10:55:50 +01:00
committed by GitHub
parent 0401a15888
commit 58d67bf98d
16 changed files with 554 additions and 339 deletions

View File

@@ -5,41 +5,40 @@ Tests for axolotl.utils.quantization
import pytest
import torch
from torch import nn
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
)
from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
from torchao.quantization.quant_api import (
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
from transformers import AutoModelForCausalLM
from transformers.trainer_callback import TrainerState
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.quantization import (
convert_qat_model_for_ptq,
get_ptq_config,
convert_qat_model,
get_quantization_config,
prepare_model_for_qat,
quantize_model_for_ptq,
quantize_model,
)
from axolotl.utils.schemas.enums import TorchIntDType
from axolotl.utils.schemas.enums import TorchAOQuantDType
from axolotl.utils.schemas.quantization import QATConfig
from tests.e2e.utils import require_torch_2_6_0
from tests.e2e.utils import (
require_torch_2_8_0,
requires_cuda_ge_8_9,
requires_sm_ge_100,
)
@pytest.fixture()
def model():
dummy_model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM2-135M",
device_map="cuda",
"Qwen/Qwen2-0.5B",
device_map="auto",
torch_dtype=torch.bfloat16,
)
with torch.device(dummy_model.device):
@@ -48,45 +47,56 @@ def model():
dummy_model.model.embed_tokens.weight.shape[1],
dtype=dummy_model.model.embed_tokens.weight.dtype,
)
return dummy_model
yield dummy_model
del dummy_model
ptq_config_test_cases = [
# weight_dtype, activation_dtype, group_size, expected_type, expected_params
# weight_dtype, activation_dtype, group_size, expected_type
(
TorchIntDType.uint4,
TorchAOQuantDType.int4,
TorchAOQuantDType.int8,
None,
None,
UIntXWeightOnlyConfig,
{"dtype": torch.uint4, "group_size": None},
),
(TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}),
(TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}),
(
TorchIntDType.int4,
TorchIntDType.int4,
None,
Int4DynamicActivationInt4WeightConfig,
{},
Int8DynamicActivationInt4WeightConfig,
),
(
TorchIntDType.int8,
TorchIntDType.int8,
TorchAOQuantDType.float8_e4m3fn,
TorchAOQuantDType.float8_e4m3fn,
None,
Int8DynamicActivationInt8WeightConfig,
{},
Float8DynamicActivationFloat8WeightConfig,
),
(
TorchAOQuantDType.int4,
TorchAOQuantDType.float8_e4m3fn,
None,
Float8DynamicActivationInt4WeightConfig,
),
]
ptq_test_cases = [
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception
(TorchIntDType.int8, None, 8, False, None),
(TorchIntDType.int4, None, 4, True, None),
(TorchIntDType.uint4, None, 8, False, None),
(TorchIntDType.int4, TorchIntDType.int4, 8, False, None),
(TorchIntDType.int8, TorchIntDType.int8, 8, True, None),
(TorchIntDType.int8, None, None, False, ValueError),
(TorchIntDType.int4, None, None, False, ValueError),
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception, expected_tensor_class
(TorchAOQuantDType.int4, None, 4, True, None, Int4Tensor),
(
TorchAOQuantDType.int4,
TorchAOQuantDType.int8,
8,
False,
None,
LinearActivationQuantizedTensor,
),
# (
# TorchAOQuantDType.int4,
# TorchAOQuantDType.float8_e4m3fn,
# None,
# False,
# None,
# Int4Tensor,
# ),
(TorchAOQuantDType.int4, None, None, False, None, Int4Tensor),
# Deprecated configs
(TorchAOQuantDType.int8, None, 8, False, ValueError, None),
(TorchAOQuantDType.int4, TorchAOQuantDType.int4, 8, False, ValueError, None),
(TorchAOQuantDType.int8, TorchAOQuantDType.int8, 8, True, ValueError, None),
]
@@ -96,44 +106,132 @@ class TestQuantization:
"""
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,expected_type,expected_params",
"weight_dtype,activation_dtype,group_size,expected_type",
ptq_config_test_cases,
)
@require_torch_2_6_0
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_get_ptq_config(
self, weight_dtype, activation_dtype, group_size, expected_type, expected_params
self, weight_dtype, activation_dtype, group_size, expected_type
):
config = get_ptq_config(weight_dtype, activation_dtype, group_size)
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
assert isinstance(config, expected_type)
for param_name, param_value in expected_params.items():
if isinstance(param_value, (PerAxis, PerGroup)):
if isinstance(param_value, PerAxis):
assert isinstance(getattr(config, param_name), PerAxis)
assert getattr(config, param_name).axis == param_value.axis
else:
assert isinstance(getattr(config, param_name), PerGroup)
assert (
getattr(config, param_name).group_size == param_value.group_size
)
else:
assert getattr(config, param_name) == param_value
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_get_ptq_config_int4_weight_only(self):
from torchao.quantization.quant_api import Int4WeightOnlyConfig
config = get_quantization_config(TorchAOQuantDType.int4, None, 4)
assert isinstance(config, Int4WeightOnlyConfig)
@pytest.mark.parametrize(
"weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4]
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception,expected_tensor_class",
ptq_test_cases,
)
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_quantize_model_for_ptq(
self,
model,
weight_dtype,
activation_dtype,
group_size,
quantize_embedding,
expected_exception,
expected_tensor_class,
):
if expected_exception:
with pytest.raises(expected_exception):
quantize_model(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
else:
quantize_model(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(
model.model.embed_tokens.weight, expected_tensor_class
), "Embedding weight should be quantized"
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, expected_tensor_class)
@require_torch_2_8_0
@requires_sm_ge_100
def test_quantize_model_for_ptq_fp8(
self,
model,
):
from torchao.quantization.quantize_.workflows.float8.float8_tensor import (
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
)
quantize_model(
model,
TorchAOQuantDType.float8_e4m3fn,
None,
TorchAOQuantDType.float8_e4m3fn,
)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, Float8Tensor)
assert child.weight.act_quant_kwargs is not None and isinstance(
child.weight.act_quant_kwargs, QuantizeTensorToFloat8Kwargs
)
@require_torch_2_8_0
@requires_sm_ge_100
def test_quantize_model_for_ptq_nvfp4(
self,
model,
):
from torchao.prototype.mx_formats.nvfp4_tensor import (
NVFP4Tensor,
QuantizeTensorToNVFP4Kwargs,
)
quantize_model(model, TorchAOQuantDType.nvfp4, 16, TorchAOQuantDType.nvfp4)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, NVFP4Tensor)
assert child.weight.act_quant_kwargs is not None and isinstance(
child.weight.act_quant_kwargs, QuantizeTensorToNVFP4Kwargs
)
@pytest.mark.parametrize(
"activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8]
"weight_dtype,activation_dtype,group_size,quantize_embedding",
[
(TorchAOQuantDType.int4, None, 8, False),
(TorchAOQuantDType.int4, None, 16, True),
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 8, False),
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 16, True),
(
TorchAOQuantDType.float8_e4m3fn,
TorchAOQuantDType.float8_e4m3fn,
None,
False,
),
(TorchAOQuantDType.int4, TorchAOQuantDType.float8_e4m3fn, None, True),
],
)
@pytest.mark.parametrize("group_size", [4, 8])
@pytest.mark.parametrize("quantize_embedding", [False, True])
@require_torch_2_6_0
@require_torch_2_8_0
@requires_cuda_ge_8_9
def test_prepare_model_for_qat(
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
):
prepare_model_for_qat(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
@@ -142,17 +240,19 @@ class TestQuantization:
model.model.embed_tokens.weight_fake_quantizer.config.dtype
== weight_dtype.value
)
assert (
model.model.embed_tokens.weight_fake_quantizer.config.group_size
== group_size
)
if group_size:
assert (
model.model.embed_tokens.weight_fake_quantizer.config.group_size
== group_size
)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer")
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
assert child.weight_fake_quantizer.config.group_size == group_size
if group_size:
assert child.weight_fake_quantizer.config.group_size == group_size
if activation_dtype:
assert hasattr(child, "activation_fake_quantizer")
assert (
@@ -162,49 +262,40 @@ class TestQuantization:
else:
assert child.activation_fake_quantizer is None
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception",
ptq_test_cases,
)
@require_torch_2_6_0
def test_quantize_model_for_ptq(
self,
model,
weight_dtype,
activation_dtype,
group_size,
quantize_embedding,
expected_exception,
):
if expected_exception:
with pytest.raises(expected_exception):
quantize_model_for_ptq(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
else:
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(
model.model.embed_tokens.weight, AffineQuantizedTensor
), "Embedding weight should be quantized"
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
if activation_dtype:
assert isinstance(
child.weight, LinearActivationQuantizedTensor
), (
"Linear weight should be quantized with activation quantization"
)
else:
assert isinstance(child.weight, AffineQuantizedTensor), (
"Linear weight should be quantized without activation quantization"
)
@require_torch_2_8_0
@requires_cuda_ge_8_9
def test_convert_qat_model(self, model):
config = QATConfig(
weight_dtype="int4",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
)
# quantize model for qat
prepare_model_for_qat(
model,
config.weight_dtype,
config.group_size,
config.activation_dtype,
config.quantize_embedding,
)
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert isinstance(model.lm_head, FakeQuantizedLinear)
# apply conversion
convert_qat_model(
model,
config.quantize_embedding,
)
# ensure modules have been swapped out
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert not isinstance(model.lm_head, FakeQuantizedLinear)
# ensure weights have been quantized
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
assert isinstance(model.lm_head.weight, nn.Parameter)
class TestQuantizationCallback:
@@ -218,10 +309,10 @@ class TestQuantizationCallback:
global_step=0,
)
@require_torch_2_6_0
@require_torch_2_8_0
def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state):
cfg = QATConfig(
weight_dtype="int8",
weight_dtype="int4",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
@@ -268,10 +359,10 @@ class TestQuantizationCallback:
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
@require_torch_2_6_0
@require_torch_2_8_0
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
cfg = QATConfig(
weight_dtype="int8",
weight_dtype="int4",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
@@ -304,43 +395,3 @@ class TestQuantizationCallback:
# quantization should be enabled from the get-go
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
class TestConvertQATModelForPTQ:
"""
Test convert_qat_model_for_ptq
"""
@require_torch_2_6_0
def test_convert_qat_model_for_ptq(self, model):
config = QATConfig(
weight_dtype="int8",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
)
# quantize model for qat
prepare_model_for_qat(
model,
config.weight_dtype,
config.group_size,
config.activation_dtype,
config.quantize_embedding,
)
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert isinstance(model.lm_head, FakeQuantizedLinear)
# apply conversion
convert_qat_model_for_ptq(
model,
quantize_embedding=config.quantize_embedding,
)
# ensure modules have been swapped out
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert not isinstance(model.lm_head, FakeQuantizedLinear)
# ensure weights have been quantized
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
assert isinstance(model.lm_head.weight, nn.Parameter)