Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 (#3107)
This commit is contained in:
@@ -5,41 +5,40 @@ Tests for axolotl.utils.quantization
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
|
||||
from torchao.quantization.granularity import PerAxis, PerGroup
|
||||
from torchao.quantization.linear_activation_quantized_tensor import (
|
||||
LinearActivationQuantizedTensor,
|
||||
)
|
||||
from torchao.quantization import LinearActivationQuantizedTensor
|
||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||
from torchao.quantization.quant_api import (
|
||||
Int4DynamicActivationInt4WeightConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
UIntXWeightOnlyConfig,
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
Int8DynamicActivationInt4WeightConfig,
|
||||
)
|
||||
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.trainer_callback import TrainerState
|
||||
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.quantization import (
|
||||
convert_qat_model_for_ptq,
|
||||
get_ptq_config,
|
||||
convert_qat_model,
|
||||
get_quantization_config,
|
||||
prepare_model_for_qat,
|
||||
quantize_model_for_ptq,
|
||||
quantize_model,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import TorchIntDType
|
||||
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||
from axolotl.utils.schemas.quantization import QATConfig
|
||||
|
||||
from tests.e2e.utils import require_torch_2_6_0
|
||||
from tests.e2e.utils import (
|
||||
require_torch_2_8_0,
|
||||
requires_cuda_ge_8_9,
|
||||
requires_sm_ge_100,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def model():
|
||||
dummy_model = AutoModelForCausalLM.from_pretrained(
|
||||
"HuggingFaceTB/SmolLM2-135M",
|
||||
device_map="cuda",
|
||||
"Qwen/Qwen2-0.5B",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
with torch.device(dummy_model.device):
|
||||
@@ -48,45 +47,56 @@ def model():
|
||||
dummy_model.model.embed_tokens.weight.shape[1],
|
||||
dtype=dummy_model.model.embed_tokens.weight.dtype,
|
||||
)
|
||||
return dummy_model
|
||||
yield dummy_model
|
||||
del dummy_model
|
||||
|
||||
|
||||
ptq_config_test_cases = [
|
||||
# weight_dtype, activation_dtype, group_size, expected_type, expected_params
|
||||
# weight_dtype, activation_dtype, group_size, expected_type
|
||||
(
|
||||
TorchIntDType.uint4,
|
||||
TorchAOQuantDType.int4,
|
||||
TorchAOQuantDType.int8,
|
||||
None,
|
||||
None,
|
||||
UIntXWeightOnlyConfig,
|
||||
{"dtype": torch.uint4, "group_size": None},
|
||||
),
|
||||
(TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}),
|
||||
(TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}),
|
||||
(
|
||||
TorchIntDType.int4,
|
||||
TorchIntDType.int4,
|
||||
None,
|
||||
Int4DynamicActivationInt4WeightConfig,
|
||||
{},
|
||||
Int8DynamicActivationInt4WeightConfig,
|
||||
),
|
||||
(
|
||||
TorchIntDType.int8,
|
||||
TorchIntDType.int8,
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
None,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
{},
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
),
|
||||
(
|
||||
TorchAOQuantDType.int4,
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
None,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
),
|
||||
]
|
||||
|
||||
ptq_test_cases = [
|
||||
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception
|
||||
(TorchIntDType.int8, None, 8, False, None),
|
||||
(TorchIntDType.int4, None, 4, True, None),
|
||||
(TorchIntDType.uint4, None, 8, False, None),
|
||||
(TorchIntDType.int4, TorchIntDType.int4, 8, False, None),
|
||||
(TorchIntDType.int8, TorchIntDType.int8, 8, True, None),
|
||||
(TorchIntDType.int8, None, None, False, ValueError),
|
||||
(TorchIntDType.int4, None, None, False, ValueError),
|
||||
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception, expected_tensor_class
|
||||
(TorchAOQuantDType.int4, None, 4, True, None, Int4Tensor),
|
||||
(
|
||||
TorchAOQuantDType.int4,
|
||||
TorchAOQuantDType.int8,
|
||||
8,
|
||||
False,
|
||||
None,
|
||||
LinearActivationQuantizedTensor,
|
||||
),
|
||||
# (
|
||||
# TorchAOQuantDType.int4,
|
||||
# TorchAOQuantDType.float8_e4m3fn,
|
||||
# None,
|
||||
# False,
|
||||
# None,
|
||||
# Int4Tensor,
|
||||
# ),
|
||||
(TorchAOQuantDType.int4, None, None, False, None, Int4Tensor),
|
||||
# Deprecated configs
|
||||
(TorchAOQuantDType.int8, None, 8, False, ValueError, None),
|
||||
(TorchAOQuantDType.int4, TorchAOQuantDType.int4, 8, False, ValueError, None),
|
||||
(TorchAOQuantDType.int8, TorchAOQuantDType.int8, 8, True, ValueError, None),
|
||||
]
|
||||
|
||||
|
||||
@@ -96,44 +106,132 @@ class TestQuantization:
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype,activation_dtype,group_size,expected_type,expected_params",
|
||||
"weight_dtype,activation_dtype,group_size,expected_type",
|
||||
ptq_config_test_cases,
|
||||
)
|
||||
@require_torch_2_6_0
|
||||
@requires_cuda_ge_8_9
|
||||
@require_torch_2_8_0
|
||||
def test_get_ptq_config(
|
||||
self, weight_dtype, activation_dtype, group_size, expected_type, expected_params
|
||||
self, weight_dtype, activation_dtype, group_size, expected_type
|
||||
):
|
||||
config = get_ptq_config(weight_dtype, activation_dtype, group_size)
|
||||
|
||||
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
|
||||
assert isinstance(config, expected_type)
|
||||
|
||||
for param_name, param_value in expected_params.items():
|
||||
if isinstance(param_value, (PerAxis, PerGroup)):
|
||||
if isinstance(param_value, PerAxis):
|
||||
assert isinstance(getattr(config, param_name), PerAxis)
|
||||
assert getattr(config, param_name).axis == param_value.axis
|
||||
else:
|
||||
assert isinstance(getattr(config, param_name), PerGroup)
|
||||
assert (
|
||||
getattr(config, param_name).group_size == param_value.group_size
|
||||
)
|
||||
else:
|
||||
assert getattr(config, param_name) == param_value
|
||||
@requires_cuda_ge_8_9
|
||||
@require_torch_2_8_0
|
||||
def test_get_ptq_config_int4_weight_only(self):
|
||||
from torchao.quantization.quant_api import Int4WeightOnlyConfig
|
||||
|
||||
config = get_quantization_config(TorchAOQuantDType.int4, None, 4)
|
||||
assert isinstance(config, Int4WeightOnlyConfig)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4]
|
||||
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception,expected_tensor_class",
|
||||
ptq_test_cases,
|
||||
)
|
||||
@requires_cuda_ge_8_9
|
||||
@require_torch_2_8_0
|
||||
def test_quantize_model_for_ptq(
|
||||
self,
|
||||
model,
|
||||
weight_dtype,
|
||||
activation_dtype,
|
||||
group_size,
|
||||
quantize_embedding,
|
||||
expected_exception,
|
||||
expected_tensor_class,
|
||||
):
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception):
|
||||
quantize_model(
|
||||
model,
|
||||
weight_dtype,
|
||||
group_size,
|
||||
activation_dtype,
|
||||
quantize_embedding,
|
||||
)
|
||||
else:
|
||||
quantize_model(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
if quantize_embedding:
|
||||
assert isinstance(
|
||||
model.model.embed_tokens.weight, expected_tensor_class
|
||||
), "Embedding weight should be quantized"
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child.weight, expected_tensor_class)
|
||||
|
||||
@require_torch_2_8_0
|
||||
@requires_sm_ge_100
|
||||
def test_quantize_model_for_ptq_fp8(
|
||||
self,
|
||||
model,
|
||||
):
|
||||
from torchao.quantization.quantize_.workflows.float8.float8_tensor import (
|
||||
Float8Tensor,
|
||||
QuantizeTensorToFloat8Kwargs,
|
||||
)
|
||||
|
||||
quantize_model(
|
||||
model,
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
None,
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
)
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child.weight, Float8Tensor)
|
||||
assert child.weight.act_quant_kwargs is not None and isinstance(
|
||||
child.weight.act_quant_kwargs, QuantizeTensorToFloat8Kwargs
|
||||
)
|
||||
|
||||
@require_torch_2_8_0
|
||||
@requires_sm_ge_100
|
||||
def test_quantize_model_for_ptq_nvfp4(
|
||||
self,
|
||||
model,
|
||||
):
|
||||
from torchao.prototype.mx_formats.nvfp4_tensor import (
|
||||
NVFP4Tensor,
|
||||
QuantizeTensorToNVFP4Kwargs,
|
||||
)
|
||||
|
||||
quantize_model(model, TorchAOQuantDType.nvfp4, 16, TorchAOQuantDType.nvfp4)
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child.weight, NVFP4Tensor)
|
||||
assert child.weight.act_quant_kwargs is not None and isinstance(
|
||||
child.weight.act_quant_kwargs, QuantizeTensorToNVFP4Kwargs
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8]
|
||||
"weight_dtype,activation_dtype,group_size,quantize_embedding",
|
||||
[
|
||||
(TorchAOQuantDType.int4, None, 8, False),
|
||||
(TorchAOQuantDType.int4, None, 16, True),
|
||||
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 8, False),
|
||||
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 16, True),
|
||||
(
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
None,
|
||||
False,
|
||||
),
|
||||
(TorchAOQuantDType.int4, TorchAOQuantDType.float8_e4m3fn, None, True),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("group_size", [4, 8])
|
||||
@pytest.mark.parametrize("quantize_embedding", [False, True])
|
||||
@require_torch_2_6_0
|
||||
@require_torch_2_8_0
|
||||
@requires_cuda_ge_8_9
|
||||
def test_prepare_model_for_qat(
|
||||
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
|
||||
):
|
||||
prepare_model_for_qat(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
model,
|
||||
weight_dtype,
|
||||
group_size,
|
||||
activation_dtype,
|
||||
quantize_embedding,
|
||||
)
|
||||
if quantize_embedding:
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
@@ -142,17 +240,19 @@ class TestQuantization:
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.dtype
|
||||
== weight_dtype.value
|
||||
)
|
||||
assert (
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.group_size
|
||||
== group_size
|
||||
)
|
||||
if group_size:
|
||||
assert (
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.group_size
|
||||
== group_size
|
||||
)
|
||||
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child, FakeQuantizedLinear)
|
||||
assert hasattr(child, "weight_fake_quantizer")
|
||||
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
|
||||
assert child.weight_fake_quantizer.config.group_size == group_size
|
||||
if group_size:
|
||||
assert child.weight_fake_quantizer.config.group_size == group_size
|
||||
if activation_dtype:
|
||||
assert hasattr(child, "activation_fake_quantizer")
|
||||
assert (
|
||||
@@ -162,49 +262,40 @@ class TestQuantization:
|
||||
else:
|
||||
assert child.activation_fake_quantizer is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception",
|
||||
ptq_test_cases,
|
||||
)
|
||||
@require_torch_2_6_0
|
||||
def test_quantize_model_for_ptq(
|
||||
self,
|
||||
model,
|
||||
weight_dtype,
|
||||
activation_dtype,
|
||||
group_size,
|
||||
quantize_embedding,
|
||||
expected_exception,
|
||||
):
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception):
|
||||
quantize_model_for_ptq(
|
||||
model,
|
||||
weight_dtype,
|
||||
group_size,
|
||||
activation_dtype,
|
||||
quantize_embedding,
|
||||
)
|
||||
else:
|
||||
quantize_model_for_ptq(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
if quantize_embedding:
|
||||
assert isinstance(
|
||||
model.model.embed_tokens.weight, AffineQuantizedTensor
|
||||
), "Embedding weight should be quantized"
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
if activation_dtype:
|
||||
assert isinstance(
|
||||
child.weight, LinearActivationQuantizedTensor
|
||||
), (
|
||||
"Linear weight should be quantized with activation quantization"
|
||||
)
|
||||
else:
|
||||
assert isinstance(child.weight, AffineQuantizedTensor), (
|
||||
"Linear weight should be quantized without activation quantization"
|
||||
)
|
||||
@require_torch_2_8_0
|
||||
@requires_cuda_ge_8_9
|
||||
def test_convert_qat_model(self, model):
|
||||
config = QATConfig(
|
||||
weight_dtype="int4",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
)
|
||||
|
||||
# quantize model for qat
|
||||
prepare_model_for_qat(
|
||||
model,
|
||||
config.weight_dtype,
|
||||
config.group_size,
|
||||
config.activation_dtype,
|
||||
config.quantize_embedding,
|
||||
)
|
||||
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
|
||||
# apply conversion
|
||||
convert_qat_model(
|
||||
model,
|
||||
config.quantize_embedding,
|
||||
)
|
||||
# ensure modules have been swapped out
|
||||
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert not isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
|
||||
# ensure weights have been quantized
|
||||
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
|
||||
assert isinstance(model.lm_head.weight, nn.Parameter)
|
||||
|
||||
|
||||
class TestQuantizationCallback:
|
||||
@@ -218,10 +309,10 @@ class TestQuantizationCallback:
|
||||
global_step=0,
|
||||
)
|
||||
|
||||
@require_torch_2_6_0
|
||||
@require_torch_2_8_0
|
||||
def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state):
|
||||
cfg = QATConfig(
|
||||
weight_dtype="int8",
|
||||
weight_dtype="int4",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
@@ -268,10 +359,10 @@ class TestQuantizationCallback:
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
@require_torch_2_6_0
|
||||
@require_torch_2_8_0
|
||||
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
|
||||
cfg = QATConfig(
|
||||
weight_dtype="int8",
|
||||
weight_dtype="int4",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
@@ -304,43 +395,3 @@ class TestQuantizationCallback:
|
||||
# quantization should be enabled from the get-go
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
|
||||
class TestConvertQATModelForPTQ:
|
||||
"""
|
||||
Test convert_qat_model_for_ptq
|
||||
"""
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_convert_qat_model_for_ptq(self, model):
|
||||
config = QATConfig(
|
||||
weight_dtype="int8",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
)
|
||||
|
||||
# quantize model for qat
|
||||
prepare_model_for_qat(
|
||||
model,
|
||||
config.weight_dtype,
|
||||
config.group_size,
|
||||
config.activation_dtype,
|
||||
config.quantize_embedding,
|
||||
)
|
||||
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
|
||||
# apply conversion
|
||||
convert_qat_model_for_ptq(
|
||||
model,
|
||||
quantize_embedding=config.quantize_embedding,
|
||||
)
|
||||
# ensure modules have been swapped out
|
||||
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert not isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
|
||||
# ensure weights have been quantized
|
||||
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
|
||||
assert isinstance(model.lm_head.weight, nn.Parameter)
|
||||
|
||||
Reference in New Issue
Block a user