MX QAT patch (#3553)
* qat patch * tests fixes * fixup per PR code review * use state dict hooks to handle dequant for saving safetensors from transformers * use transformers torch ao quantizer hooks to save mx quantized model --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ CLI to post-training quantize a model using torchao
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.loaders import load_processor, load_tokenizer
|
from axolotl.loaders import load_processor, load_tokenizer
|
||||||
@@ -93,17 +93,22 @@ def do_quantize(
|
|||||||
weight_dtype, activation_dtype, group_size
|
weight_dtype, activation_dtype, group_size
|
||||||
)
|
)
|
||||||
|
|
||||||
ao_config = TorchAoConfig(
|
|
||||||
quant_type=quantization_config,
|
|
||||||
include_input_output_embeddings=quantize_embedding,
|
|
||||||
)
|
|
||||||
model.config.quantization_config = ao_config
|
|
||||||
|
|
||||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
|
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
|
||||||
model.save_pretrained(
|
try:
|
||||||
str(Path(output_dir) / "quantized"),
|
model.save_pretrained(
|
||||||
progressbar=True,
|
str(Path(output_dir) / "quantized"),
|
||||||
)
|
progressbar=True,
|
||||||
|
)
|
||||||
|
except NotImplementedError:
|
||||||
|
LOG.warning(
|
||||||
|
"Model weight conversions do not support reverse_op, "
|
||||||
|
"retrying save with save_original_format=False"
|
||||||
|
)
|
||||||
|
model.save_pretrained(
|
||||||
|
str(Path(output_dir) / "quantized"),
|
||||||
|
progressbar=True,
|
||||||
|
save_original_format=False,
|
||||||
|
)
|
||||||
tokenizer.save_pretrained(
|
tokenizer.save_pretrained(
|
||||||
str(Path(output_dir) / "quantized"),
|
str(Path(output_dir) / "quantized"),
|
||||||
progressbar=True,
|
progressbar=True,
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ Utilities for quantization including QAT and PTQ using torchao.
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torchao.core.config import AOBaseConfig
|
from torchao.core.config import AOBaseConfig
|
||||||
from torchao.prototype.qat import MXFakeQuantizeConfig
|
|
||||||
from torchao.quantization import quantize_
|
from torchao.quantization import quantize_
|
||||||
from torchao.quantization.qat import (
|
from torchao.quantization.qat import (
|
||||||
QATConfig,
|
QATConfig,
|
||||||
@@ -15,9 +14,15 @@ from torchao.quantization.quant_api import (
|
|||||||
Float8DynamicActivationFloat8WeightConfig,
|
Float8DynamicActivationFloat8WeightConfig,
|
||||||
Float8DynamicActivationInt4WeightConfig,
|
Float8DynamicActivationInt4WeightConfig,
|
||||||
Int4WeightOnlyConfig,
|
Int4WeightOnlyConfig,
|
||||||
Int8DynamicActivationInt4WeightConfig,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
|
||||||
|
except ImportError:
|
||||||
|
from torchao.quantization.quant_api import (
|
||||||
|
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||||
|
|
||||||
quantization_config_to_str = {
|
quantization_config_to_str = {
|
||||||
@@ -28,7 +33,9 @@ quantization_config_to_str = {
|
|||||||
|
|
||||||
if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
||||||
try:
|
try:
|
||||||
from torchao.prototype.mx_formats import NVFP4InferenceConfig
|
from torchao.prototype.mx_formats import (
|
||||||
|
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
|
||||||
|
)
|
||||||
|
|
||||||
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
|
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
|
||||||
except (ImportError, RuntimeError):
|
except (ImportError, RuntimeError):
|
||||||
@@ -44,10 +51,12 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torchao.prototype.qat import MXFakeQuantizeConfig
|
from torchao.prototype.mx_formats import (
|
||||||
|
MXDynamicActivationMXWeightConfig as MXLinearConfig,
|
||||||
|
)
|
||||||
|
|
||||||
quantization_config_to_str[MXFakeQuantizeConfig] = "mxfp4"
|
quantization_config_to_str[MXLinearConfig] = "mxfp4"
|
||||||
except ImportError:
|
except (ImportError, RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -114,15 +123,15 @@ def get_quantization_config(
|
|||||||
):
|
):
|
||||||
return Float8DynamicActivationInt4WeightConfig()
|
return Float8DynamicActivationInt4WeightConfig()
|
||||||
if weight_dtype == TorchAOQuantDType.nvfp4:
|
if weight_dtype == TorchAOQuantDType.nvfp4:
|
||||||
from torchao.prototype.mx_formats import NVFP4InferenceConfig
|
from torchao.prototype.mx_formats import (
|
||||||
|
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
|
||||||
|
)
|
||||||
|
|
||||||
if group_size is not None and group_size != 16:
|
if group_size is not None and group_size != 16:
|
||||||
raise ValueError("NVFP4 quantization must use a group_size of 16")
|
raise ValueError("NVFP4 quantization must use a group_size of 16")
|
||||||
return NVFP4InferenceConfig()
|
return NVFP4InferenceConfig()
|
||||||
|
|
||||||
if weight_dtype == TorchAOQuantDType.mxfp4:
|
if weight_dtype == TorchAOQuantDType.mxfp4:
|
||||||
from torchao.prototype.qat import MXFakeQuantizeConfig
|
|
||||||
|
|
||||||
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)
|
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)
|
||||||
block_size = group_size if group_size is not None else 32
|
block_size = group_size if group_size is not None else 32
|
||||||
if block_size != 32:
|
if block_size != 32:
|
||||||
@@ -130,13 +139,41 @@ def get_quantization_config(
|
|||||||
"MXFP4 quantization must use a block_size (group_size) of 32"
|
"MXFP4 quantization must use a block_size (group_size) of 32"
|
||||||
)
|
)
|
||||||
|
|
||||||
return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size)
|
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
|
||||||
|
|
||||||
|
return MXDynamicActivationMXWeightConfig(
|
||||||
|
activation_dtype=torch.float4_e2m1fn_x2,
|
||||||
|
weight_dtype=torch.float4_e2m1fn_x2,
|
||||||
|
block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
|
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _attach_torchao_quantizer(
|
||||||
|
model, quantization_config, include_input_output_embeddings=False
|
||||||
|
):
|
||||||
|
"""Attach a TorchAoHfQuantizer to the model so save_pretrained uses
|
||||||
|
torchao's flatten_tensor_state_dict path, preserving quantized weights
|
||||||
|
(e.g. MXTensor qdata+scale) in the safetensors file.
|
||||||
|
|
||||||
|
Without this, save_pretrained falls through to the default path which
|
||||||
|
calls safetensors storage_ptr() on tensor subclasses and crashes.
|
||||||
|
"""
|
||||||
|
from transformers import TorchAoConfig
|
||||||
|
from transformers.quantizers.quantizer_torchao import TorchAoHfQuantizer
|
||||||
|
|
||||||
|
ao_config = TorchAoConfig(
|
||||||
|
quant_type=quantization_config,
|
||||||
|
include_input_output_embeddings=include_input_output_embeddings,
|
||||||
|
)
|
||||||
|
model.config.quantization_config = ao_config
|
||||||
|
quantizer = TorchAoHfQuantizer(ao_config)
|
||||||
|
model.hf_quantizer = quantizer
|
||||||
|
|
||||||
|
|
||||||
def quantize_model(
|
def quantize_model(
|
||||||
model,
|
model,
|
||||||
weight_dtype: TorchAOQuantDType,
|
weight_dtype: TorchAOQuantDType,
|
||||||
@@ -174,6 +211,12 @@ def quantize_model(
|
|||||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_attach_torchao_quantizer(
|
||||||
|
model,
|
||||||
|
linear_ptq_config,
|
||||||
|
include_input_output_embeddings=bool(quantize_embedding),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_qat_config(
|
def _make_qat_config(
|
||||||
base_config: AOBaseConfig,
|
base_config: AOBaseConfig,
|
||||||
@@ -189,11 +232,14 @@ def _make_qat_config(
|
|||||||
IntxFakeQuantizeConfig,
|
IntxFakeQuantizeConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(base_config, MXFakeQuantizeConfig):
|
if weight_dtype == TorchAOQuantDType.mxfp4:
|
||||||
return QATConfig(
|
from torchao.prototype.qat import MXFakeQuantizeConfig
|
||||||
activation_config=base_config,
|
|
||||||
weight_config=base_config,
|
block_size = getattr(base_config, "block_size", 32)
|
||||||
|
mx_fq = MXFakeQuantizeConfig(
|
||||||
|
dtype=torch.float4_e2m1fn_x2, block_size=block_size
|
||||||
)
|
)
|
||||||
|
return QATConfig(activation_config=mx_fq, weight_config=mx_fq)
|
||||||
|
|
||||||
# Build explicit weight config
|
# Build explicit weight config
|
||||||
weight_fq_config: (
|
weight_fq_config: (
|
||||||
|
|||||||
@@ -5,15 +5,20 @@ Tests for axolotl.utils.quantization
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchao.prototype.qat import MXFakeQuantizeConfig
|
|
||||||
from torchao.quantization import LinearActivationQuantizedTensor
|
from torchao.quantization import LinearActivationQuantizedTensor
|
||||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||||
from torchao.quantization.quant_api import (
|
from torchao.quantization.quant_api import (
|
||||||
Float8DynamicActivationFloat8WeightConfig,
|
Float8DynamicActivationFloat8WeightConfig,
|
||||||
Float8DynamicActivationInt4WeightConfig,
|
Float8DynamicActivationInt4WeightConfig,
|
||||||
Int8DynamicActivationInt4WeightConfig,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
|
||||||
|
except ImportError:
|
||||||
|
from torchao.quantization.quant_api import (
|
||||||
|
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
|
||||||
|
)
|
||||||
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
|
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
@@ -129,8 +134,11 @@ class TestQuantization:
|
|||||||
@require_torch_2_8_0
|
@require_torch_2_8_0
|
||||||
@requires_sm_ge_100
|
@requires_sm_ge_100
|
||||||
def test_get_ptq_config_mxfp4(self):
|
def test_get_ptq_config_mxfp4(self):
|
||||||
|
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
|
||||||
|
|
||||||
config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)
|
config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)
|
||||||
assert isinstance(config, MXFakeQuantizeConfig)
|
assert isinstance(config, MXDynamicActivationMXWeightConfig)
|
||||||
|
assert config.weight_dtype == torch.float4_e2m1fn_x2
|
||||||
assert config.block_size == 32
|
assert config.block_size == 32
|
||||||
|
|
||||||
@require_torch_2_8_0
|
@require_torch_2_8_0
|
||||||
@@ -298,7 +306,6 @@ class TestQuantization:
|
|||||||
"weight_dtype,activation_dtype,group_size,quantize_embedding",
|
"weight_dtype,activation_dtype,group_size,quantize_embedding",
|
||||||
[
|
[
|
||||||
(TorchAOQuantDType.mxfp4, None, 32, False),
|
(TorchAOQuantDType.mxfp4, None, 32, False),
|
||||||
(TorchAOQuantDType.mxfp4, None, 32, True),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@require_torch_2_8_0
|
@require_torch_2_8_0
|
||||||
@@ -314,14 +321,16 @@ class TestQuantization:
|
|||||||
quantize_embedding,
|
quantize_embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from torchao.prototype.qat import MXFakeQuantizedLinear
|
||||||
|
|
||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||||
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
||||||
|
|
||||||
for child in list(model.children()):
|
for child in list(model.children()):
|
||||||
if isinstance(child, torch.nn.Linear):
|
if isinstance(child, torch.nn.Linear):
|
||||||
assert isinstance(child, FakeQuantizedLinear)
|
assert isinstance(child, MXFakeQuantizedLinear)
|
||||||
assert hasattr(child, "weight_fake_quantizer")
|
assert hasattr(child, "weight_config")
|
||||||
|
|
||||||
@require_torch_2_8_0
|
@require_torch_2_8_0
|
||||||
@requires_cuda_ge_8_9
|
@requires_cuda_ge_8_9
|
||||||
|
|||||||
Reference in New Issue
Block a user