Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 (#3107)

This commit is contained in:
salman
2025-09-12 10:55:50 +01:00
committed by GitHub
parent 0401a15888
commit 58d67bf98d
16 changed files with 554 additions and 339 deletions

View File

@@ -44,7 +44,7 @@ jobs:
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]

View File

@@ -304,7 +304,7 @@ jobs:
pytorch: 2.8.0
num_gpus: 1
gpu_type: "B200"
axolotl_extras:
axolotl_extras: fbgemm-gpu
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -51,3 +51,11 @@ axolotl quantize qat.yml
```
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
::: {.callout-note}
If you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it,
e.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w`
:::

View File

@@ -0,0 +1,64 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/dataset_prepared
sequence_len: 8192
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_checkpointing: true
gradient_accumulation_steps: 1
micro_batch_size: 64
num_epochs: 1
optimizer: adamw_torch_fused
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -15,20 +15,18 @@ liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/qat_out/dataset_prepared
sample_packing: true
sequence_len: 512
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
sample_packing: false
sequence_len: 8192
flash_attention: true
qat:
activation_dtype: int8
@@ -67,7 +65,7 @@ fsdp:
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_cpu_ram_efficient_loading: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
@@ -76,6 +74,6 @@ fsdp_config:
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|end_of_text|>
pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,7 +64,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.12.0
torchao==0.13.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6

View File

@@ -162,6 +162,7 @@ extras_require = {
"llmcompressor": [
"llmcompressor==0.5.1",
],
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require

View File

@@ -115,6 +115,7 @@ class QuantizeCliArgs:
quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None)
hub_model_id: Optional[str] = field(default=None)
@dataclass

View File

@@ -5,12 +5,17 @@ CLI to post-training quantize a model using torchao
from pathlib import Path
from typing import Union
from transformers import AutoModelForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
from axolotl.utils.quantization import (
TorchAOQuantDType,
get_quantization_config,
quantization_config_to_str,
quantize_model,
)
LOG = get_logger(__name__)
@@ -43,13 +48,13 @@ def do_quantize(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("model_path") or cfg.output_dir
model_path = cli_args.get("base_model") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype]
weight_dtype = TorchAOQuantDType.from_string(weight_dtype)
else:
weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchIntDType[activation_dtype]
activation_dtype = TorchAOQuantDType.from_string(activation_dtype)
else:
activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size
@@ -57,10 +62,15 @@ def do_quantize(
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
)
output_dir = cli_args.get("output_dir") or cfg.output_dir
hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id
LOG.info(f"Loading model from {model_path}...")
LOG.info(f"Loading model from {model_path}.")
tokenizer = load_tokenizer(cfg)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", torch_dtype=torch_dtype
)
LOG.info(
f"Quantizing model with configuration: \n"
@@ -70,11 +80,21 @@ def do_quantize(
f"\tquantize_embedding: {quantize_embedding}"
)
quantize_model_for_ptq(
quantize_model(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
quantization_config = get_quantization_config(
weight_dtype, activation_dtype, group_size
)
ao_config = TorchAoConfig(
quant_type=quantization_config,
include_input_output_embeddings=quantize_embedding,
)
model.config.quantization_config = ao_config
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
@@ -86,4 +106,14 @@ def do_quantize(
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
if hub_model_id:
hub_model_id = (
hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}"
)
model.push_to_hub(hub_model_id, safe_serialization=False)
tokenizer.push_to_hub(hub_model_id)
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")

View File

@@ -30,11 +30,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders import (
ModelLoader,
load_processor,
load_tokenizer,
)
from axolotl.loaders import ModelLoader, load_processor, load_tokenizer
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
@@ -234,16 +230,15 @@ def save_trained_model(
# handle QAT
if cfg.qat:
from axolotl.utils.quantization import convert_qat_model_for_ptq
from axolotl.utils.quantization import convert_qat_model
LOG.info("Processing QAT model for saving...")
convert_qat_model_for_ptq(
convert_qat_model(
model,
quantize_embedding=cfg.qat.quantize_embedding,
)
LOG.info(
"QAT modules have been converted for PTQ. Please ensure you quantize "
"your model weights with `axolotl quantize`."
"QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`"
" with the same config which you used for training."
)
# Handle ReLoRA early return case
if cfg.relora:
@@ -337,9 +332,7 @@ def save_trained_model(
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
from axolotl.integrations.llm_compressor.utils import save_compressed_model
save_compressed_model(
model=model,

View File

@@ -3,30 +3,47 @@ Utilities for quantization including QAT and PTQ using torchao.
"""
import torch
from torch import nn
from packaging import version
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.quantization.qat import (
FakeQuantizeConfig,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
QATConfig,
)
from torchao.quantization.quant_api import (
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
_is_linear,
)
from axolotl.utils.schemas.enums import TorchIntDType
from axolotl.utils.schemas.enums import TorchAOQuantDType
quantization_config_to_str = {
Int8DynamicActivationInt4WeightConfig: "int8int4",
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
Float8DynamicActivationInt4WeightConfig: "fp8int4",
}
if version.parse(torch.__version__) >= version.parse("2.8.0"):
try:
from torchao.prototype.mx_formats import NVFP4InferenceConfig
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
except:
pass
# int4 weight config imports will fail on machines with fbgemm-gpu installed
# without a CUDA runtime available so we do this safely
try:
from torchao.quantization.quant_api import Int4WeightOnlyConfig
quantization_config_to_str[Int4WeightOnlyConfig] = "int4"
except:
pass
def get_ptq_config(
weight_dtype: TorchIntDType,
activation_dtype: TorchIntDType | None = None,
def get_quantization_config(
weight_dtype: TorchAOQuantDType,
activation_dtype: TorchAOQuantDType | None = None,
group_size: int | None = None,
) -> AOBaseConfig:
"""
@@ -45,44 +62,101 @@ def get_ptq_config(
or if the group size is not specified for int8 or int4 weight only quantization.
"""
if activation_dtype is None:
if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr]
return UIntXWeightOnlyConfig(
dtype=weight_dtype.value,
group_size=group_size,
set_inductor_config=False,
)
if weight_dtype == TorchIntDType.int8:
if group_size is None:
raise ValueError(
"group_size must be specified for int8 weight only quantization"
)
return Int8WeightOnlyConfig(
group_size=group_size,
)
if weight_dtype == TorchIntDType.int4:
if group_size is None:
raise ValueError(
"group_size must be specified for int4 weight only quantization"
)
return Int4WeightOnlyConfig(
group_size=group_size,
)
if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4:
return Int4DynamicActivationInt4WeightConfig()
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8:
return Int8DynamicActivationInt8WeightConfig()
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4:
return Int8DynamicActivationInt4WeightConfig()
if weight_dtype == TorchAOQuantDType.int8:
raise ValueError("Int8WeightOnlyConfig is not supported by torchao QAT.")
if weight_dtype == TorchAOQuantDType.int4:
from torchao.quantization.quant_api import Int4WeightOnlyConfig
if group_size is not None:
return Int4WeightOnlyConfig(group_size=group_size, version=2)
else:
return Int4WeightOnlyConfig(version=2)
if (
activation_dtype == TorchAOQuantDType.int4
and weight_dtype == TorchAOQuantDType.int4
):
raise ValueError(
"Int4DynamicActivationInt4WeightConfig is not supported by torchao QAT."
)
if (
activation_dtype == TorchAOQuantDType.int8
and weight_dtype == TorchAOQuantDType.int8
):
raise ValueError(
"Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT."
)
if (
activation_dtype == TorchAOQuantDType.int8
and weight_dtype == TorchAOQuantDType.int4
):
if group_size is not None:
return Int8DynamicActivationInt4WeightConfig(group_size=group_size)
else:
return Int8DynamicActivationInt4WeightConfig()
if (
activation_dtype == TorchAOQuantDType.float8_e4m3fn
and weight_dtype == TorchAOQuantDType.float8_e4m3fn
):
return Float8DynamicActivationFloat8WeightConfig()
if (
activation_dtype == TorchAOQuantDType.float8_e4m3fn
and weight_dtype == TorchAOQuantDType.int4
):
return Float8DynamicActivationInt4WeightConfig()
if weight_dtype == TorchAOQuantDType.nvfp4:
from torchao.prototype.mx_formats import NVFP4InferenceConfig
if group_size is not None and group_size != 16:
raise ValueError("NVFP4 quantization must use a group_size of 16")
return NVFP4InferenceConfig()
raise ValueError(
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
)
def quantize_model(
model,
weight_dtype: TorchAOQuantDType,
group_size: int | None = None,
activation_dtype: TorchAOQuantDType | None = None,
quantize_embedding: bool | None = None,
):
"""
This function is used to quantize a model.
Args:
model: The model to quantize.
weight_dtype: The dtype to use for weight quantization.
group_size: The group size to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
linear_ptq_config = get_quantization_config(
weight_dtype=weight_dtype,
activation_dtype=activation_dtype,
group_size=group_size,
)
quantize_(model, linear_ptq_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
embedding_quantize_config = get_quantization_config(
weight_dtype=weight_dtype,
activation_dtype=None,
group_size=group_size,
)
quantize_(
model,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def prepare_model_for_qat(
model,
weight_dtype: TorchIntDType,
group_size: int,
activation_dtype: TorchIntDType | None = None,
weight_dtype: TorchAOQuantDType,
group_size: int | None = None,
activation_dtype: TorchAOQuantDType | None = None,
quantize_embedding: bool = False,
):
"""
@@ -100,86 +174,40 @@ def prepare_model_for_qat(
Raises:
ValueError: If the activation/weight dtype combination is invalid.
"""
if activation_dtype:
activation_config = FakeQuantizeConfig(
dtype=activation_dtype.value, granularity="per_token", is_symmetric=False
)
weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size)
linear_quantize_config = IntXQuantizationAwareTrainingConfig(
activation_config=None if activation_dtype is None else activation_config,
weight_config=weight_config,
)
quantize_(model, linear_quantize_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
embedding_quantize_config = IntXQuantizationAwareTrainingConfig(
activation_config=None,
weight_config=weight_config,
)
quantize_(
model,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def quantize_model_for_ptq(
model,
weight_dtype: TorchIntDType,
group_size: int | None = None,
activation_dtype: TorchIntDType | None = None,
quantize_embedding: bool | None = None,
):
"""
This function is used to quantize a model for post-training quantization.
It swaps the model's linear layers with fake quantized linear layers.
If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights.
Args:
model: The model to quantize.
weight_dtype: The dtype to use for weight quantization.
group_size: The group size to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
linear_ptq_config = get_ptq_config(
base_config = get_quantization_config(
weight_dtype=weight_dtype,
activation_dtype=activation_dtype,
group_size=group_size,
)
quantize_(model, linear_ptq_config)
qat_config = QATConfig(base_config)
quantize_(model, qat_config)
if quantize_embedding:
embedding_quantize_config = get_ptq_config(
# activation fake quantization is not supported for embedding layers
embedding_base_config = get_quantization_config(
weight_dtype=weight_dtype,
activation_dtype=None,
group_size=group_size,
)
embedding_qat_config = QATConfig(embedding_base_config)
quantize_(
model,
embedding_quantize_config,
embedding_qat_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def convert_qat_model_for_ptq(
def convert_qat_model(
model,
*,
quantize_embedding: bool | None = None,
quantize_embedding: bool = False,
):
"""
This function is used to convert a swap fake-quantized modules in a model
which has been trained with QAT back to the original modules, ready for PTQ.
Args:
model: The model to convert.
quantize_embedding: Whether to quantize the model's embedding weights.
This function converts a QAT model which has fake quantized layers back to the original model.
"""
config = QATConfig(step="convert")
quantize_(model, config)
if quantize_embedding:
def filter_fn(m, _):
return isinstance(m, nn.Embedding) or _is_linear(m)
else:
filter_fn = _is_linear
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)
quantize_(
model,
config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)

View File

@@ -5,18 +5,21 @@ from enum import Enum
import torch
class TorchIntDType(Enum):
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
class TorchAOQuantDType(Enum):
int4 = torch.int4
int8 = torch.int8
float8_e4m3fn = torch.float8_e4m3fn
nvfp4 = "nvfp4"
uint1 = getattr(torch, "uint1", None)
uint2 = getattr(torch, "uint2", None)
uint3 = getattr(torch, "uint3", None)
uint4 = getattr(torch, "uint4", None)
uint5 = getattr(torch, "uint5", None)
uint6 = getattr(torch, "uint6", None)
uint7 = getattr(torch, "uint7", None)
int4 = getattr(torch, "int4", None)
int8 = getattr(torch, "int8", None)
def from_string(str):
if str == "int4":
return TorchAOQuantDType.int4
if str == "int8":
return TorchAOQuantDType.int8
if str in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if str == "nvfp4":
return TorchAOQuantDType.nvfp4
class RLType(str, Enum):

View File

@@ -6,7 +6,23 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator
from axolotl.utils.schemas.enums import TorchIntDType
from axolotl.utils.schemas.enums import TorchAOQuantDType
def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
if v is None:
return None
if v == "int4":
return TorchAOQuantDType.int4
if v == "int8":
return TorchAOQuantDType.int8
if v in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if v == "nvfp4":
return TorchAOQuantDType.nvfp4
raise ValueError(
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
)
class QATConfig(BaseModel):
@@ -14,13 +30,13 @@ class QATConfig(BaseModel):
QAT Config Schema
"""
activation_dtype: TorchIntDType | None = Field(
activation_dtype: TorchAOQuantDType | None = Field(
default=None,
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
description="Fake quantization layout to use for activation quantization.",
)
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8,
description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"',
weight_dtype: TorchAOQuantDType = Field(
default=TorchAOQuantDType.int8,
description="Fake quantization layout to use for weight quantization.",
)
quantize_embedding: bool | None = Field(
default=False, description="Quantize embedding"
@@ -35,12 +51,8 @@ class QATConfig(BaseModel):
@field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
if v == "int4":
return TorchIntDType.int4
if v == "int8":
return TorchIntDType.int8
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:
return validate_ao_dtype(v)
class PTQConfig(BaseModel):
@@ -48,13 +60,13 @@ class PTQConfig(BaseModel):
PTQ Config Schema
"""
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8,
description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8",
weight_dtype: TorchAOQuantDType = Field(
default=TorchAOQuantDType.int8,
description="Fake quantization layout to use for weight quantization.",
)
activation_dtype: TorchIntDType | None = Field(
activation_dtype: TorchAOQuantDType | None = Field(
default=None,
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
description="Fake quantization layout to use for activation quantization.",
)
quantize_embedding: bool | None = Field(
default=None, description="Whether to quantize the embedding layer."
@@ -66,9 +78,5 @@ class PTQConfig(BaseModel):
@field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
if v == "int4":
return TorchIntDType.int4
if v == "int8":
return TorchIntDType.int8
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:
return validate_ao_dtype(v)

View File

@@ -43,7 +43,7 @@ class TestQATLlama:
"qat": {
"quantize_embedding": True,
"activation_dtype": "int8",
"weight_dtype": "int8",
"weight_dtype": "int4",
"group_size": 8,
},
"num_epochs": 1,
@@ -111,7 +111,7 @@ class TestQATLlama:
"qat": {
"quantize_embedding": True,
"activation_dtype": "int8",
"weight_dtype": "int8",
"weight_dtype": "int4",
"group_size": 8,
},
"save_first_step": False,

View File

@@ -5,41 +5,40 @@ Tests for axolotl.utils.quantization
import pytest
import torch
from torch import nn
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
)
from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
from torchao.quantization.quant_api import (
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
from transformers import AutoModelForCausalLM
from transformers.trainer_callback import TrainerState
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.quantization import (
convert_qat_model_for_ptq,
get_ptq_config,
convert_qat_model,
get_quantization_config,
prepare_model_for_qat,
quantize_model_for_ptq,
quantize_model,
)
from axolotl.utils.schemas.enums import TorchIntDType
from axolotl.utils.schemas.enums import TorchAOQuantDType
from axolotl.utils.schemas.quantization import QATConfig
from tests.e2e.utils import require_torch_2_6_0
from tests.e2e.utils import (
require_torch_2_8_0,
requires_cuda_ge_8_9,
requires_sm_ge_100,
)
@pytest.fixture()
def model():
dummy_model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM2-135M",
device_map="cuda",
"Qwen/Qwen2-0.5B",
device_map="auto",
torch_dtype=torch.bfloat16,
)
with torch.device(dummy_model.device):
@@ -48,45 +47,56 @@ def model():
dummy_model.model.embed_tokens.weight.shape[1],
dtype=dummy_model.model.embed_tokens.weight.dtype,
)
return dummy_model
yield dummy_model
del dummy_model
ptq_config_test_cases = [
# weight_dtype, activation_dtype, group_size, expected_type, expected_params
# weight_dtype, activation_dtype, group_size, expected_type
(
TorchIntDType.uint4,
TorchAOQuantDType.int4,
TorchAOQuantDType.int8,
None,
None,
UIntXWeightOnlyConfig,
{"dtype": torch.uint4, "group_size": None},
),
(TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}),
(TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}),
(
TorchIntDType.int4,
TorchIntDType.int4,
None,
Int4DynamicActivationInt4WeightConfig,
{},
Int8DynamicActivationInt4WeightConfig,
),
(
TorchIntDType.int8,
TorchIntDType.int8,
TorchAOQuantDType.float8_e4m3fn,
TorchAOQuantDType.float8_e4m3fn,
None,
Int8DynamicActivationInt8WeightConfig,
{},
Float8DynamicActivationFloat8WeightConfig,
),
(
TorchAOQuantDType.int4,
TorchAOQuantDType.float8_e4m3fn,
None,
Float8DynamicActivationInt4WeightConfig,
),
]
ptq_test_cases = [
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception
(TorchIntDType.int8, None, 8, False, None),
(TorchIntDType.int4, None, 4, True, None),
(TorchIntDType.uint4, None, 8, False, None),
(TorchIntDType.int4, TorchIntDType.int4, 8, False, None),
(TorchIntDType.int8, TorchIntDType.int8, 8, True, None),
(TorchIntDType.int8, None, None, False, ValueError),
(TorchIntDType.int4, None, None, False, ValueError),
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception, expected_tensor_class
(TorchAOQuantDType.int4, None, 4, True, None, Int4Tensor),
(
TorchAOQuantDType.int4,
TorchAOQuantDType.int8,
8,
False,
None,
LinearActivationQuantizedTensor,
),
# (
# TorchAOQuantDType.int4,
# TorchAOQuantDType.float8_e4m3fn,
# None,
# False,
# None,
# Int4Tensor,
# ),
(TorchAOQuantDType.int4, None, None, False, None, Int4Tensor),
# Deprecated configs
(TorchAOQuantDType.int8, None, 8, False, ValueError, None),
(TorchAOQuantDType.int4, TorchAOQuantDType.int4, 8, False, ValueError, None),
(TorchAOQuantDType.int8, TorchAOQuantDType.int8, 8, True, ValueError, None),
]
@@ -96,44 +106,132 @@ class TestQuantization:
"""
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,expected_type,expected_params",
"weight_dtype,activation_dtype,group_size,expected_type",
ptq_config_test_cases,
)
@require_torch_2_6_0
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_get_ptq_config(
self, weight_dtype, activation_dtype, group_size, expected_type, expected_params
self, weight_dtype, activation_dtype, group_size, expected_type
):
config = get_ptq_config(weight_dtype, activation_dtype, group_size)
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
assert isinstance(config, expected_type)
for param_name, param_value in expected_params.items():
if isinstance(param_value, (PerAxis, PerGroup)):
if isinstance(param_value, PerAxis):
assert isinstance(getattr(config, param_name), PerAxis)
assert getattr(config, param_name).axis == param_value.axis
else:
assert isinstance(getattr(config, param_name), PerGroup)
assert (
getattr(config, param_name).group_size == param_value.group_size
)
else:
assert getattr(config, param_name) == param_value
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_get_ptq_config_int4_weight_only(self):
from torchao.quantization.quant_api import Int4WeightOnlyConfig
config = get_quantization_config(TorchAOQuantDType.int4, None, 4)
assert isinstance(config, Int4WeightOnlyConfig)
@pytest.mark.parametrize(
"weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4]
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception,expected_tensor_class",
ptq_test_cases,
)
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_quantize_model_for_ptq(
self,
model,
weight_dtype,
activation_dtype,
group_size,
quantize_embedding,
expected_exception,
expected_tensor_class,
):
if expected_exception:
with pytest.raises(expected_exception):
quantize_model(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
else:
quantize_model(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(
model.model.embed_tokens.weight, expected_tensor_class
), "Embedding weight should be quantized"
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, expected_tensor_class)
@require_torch_2_8_0
@requires_sm_ge_100
def test_quantize_model_for_ptq_fp8(
self,
model,
):
from torchao.quantization.quantize_.workflows.float8.float8_tensor import (
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
)
quantize_model(
model,
TorchAOQuantDType.float8_e4m3fn,
None,
TorchAOQuantDType.float8_e4m3fn,
)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, Float8Tensor)
assert child.weight.act_quant_kwargs is not None and isinstance(
child.weight.act_quant_kwargs, QuantizeTensorToFloat8Kwargs
)
@require_torch_2_8_0
@requires_sm_ge_100
def test_quantize_model_for_ptq_nvfp4(
self,
model,
):
from torchao.prototype.mx_formats.nvfp4_tensor import (
NVFP4Tensor,
QuantizeTensorToNVFP4Kwargs,
)
quantize_model(model, TorchAOQuantDType.nvfp4, 16, TorchAOQuantDType.nvfp4)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, NVFP4Tensor)
assert child.weight.act_quant_kwargs is not None and isinstance(
child.weight.act_quant_kwargs, QuantizeTensorToNVFP4Kwargs
)
@pytest.mark.parametrize(
"activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8]
"weight_dtype,activation_dtype,group_size,quantize_embedding",
[
(TorchAOQuantDType.int4, None, 8, False),
(TorchAOQuantDType.int4, None, 16, True),
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 8, False),
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 16, True),
(
TorchAOQuantDType.float8_e4m3fn,
TorchAOQuantDType.float8_e4m3fn,
None,
False,
),
(TorchAOQuantDType.int4, TorchAOQuantDType.float8_e4m3fn, None, True),
],
)
@pytest.mark.parametrize("group_size", [4, 8])
@pytest.mark.parametrize("quantize_embedding", [False, True])
@require_torch_2_6_0
@require_torch_2_8_0
@requires_cuda_ge_8_9
def test_prepare_model_for_qat(
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
):
prepare_model_for_qat(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
@@ -142,17 +240,19 @@ class TestQuantization:
model.model.embed_tokens.weight_fake_quantizer.config.dtype
== weight_dtype.value
)
assert (
model.model.embed_tokens.weight_fake_quantizer.config.group_size
== group_size
)
if group_size:
assert (
model.model.embed_tokens.weight_fake_quantizer.config.group_size
== group_size
)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer")
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
assert child.weight_fake_quantizer.config.group_size == group_size
if group_size:
assert child.weight_fake_quantizer.config.group_size == group_size
if activation_dtype:
assert hasattr(child, "activation_fake_quantizer")
assert (
@@ -162,49 +262,40 @@ class TestQuantization:
else:
assert child.activation_fake_quantizer is None
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception",
ptq_test_cases,
)
@require_torch_2_6_0
def test_quantize_model_for_ptq(
self,
model,
weight_dtype,
activation_dtype,
group_size,
quantize_embedding,
expected_exception,
):
if expected_exception:
with pytest.raises(expected_exception):
quantize_model_for_ptq(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
else:
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(
model.model.embed_tokens.weight, AffineQuantizedTensor
), "Embedding weight should be quantized"
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
if activation_dtype:
assert isinstance(
child.weight, LinearActivationQuantizedTensor
), (
"Linear weight should be quantized with activation quantization"
)
else:
assert isinstance(child.weight, AffineQuantizedTensor), (
"Linear weight should be quantized without activation quantization"
)
@require_torch_2_8_0
@requires_cuda_ge_8_9
def test_convert_qat_model(self, model):
config = QATConfig(
weight_dtype="int4",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
)
# quantize model for qat
prepare_model_for_qat(
model,
config.weight_dtype,
config.group_size,
config.activation_dtype,
config.quantize_embedding,
)
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert isinstance(model.lm_head, FakeQuantizedLinear)
# apply conversion
convert_qat_model(
model,
config.quantize_embedding,
)
# ensure modules have been swapped out
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert not isinstance(model.lm_head, FakeQuantizedLinear)
# ensure weights have been quantized
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
assert isinstance(model.lm_head.weight, nn.Parameter)
class TestQuantizationCallback:
@@ -218,10 +309,10 @@ class TestQuantizationCallback:
global_step=0,
)
@require_torch_2_6_0
@require_torch_2_8_0
def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state):
cfg = QATConfig(
weight_dtype="int8",
weight_dtype="int4",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
@@ -268,10 +359,10 @@ class TestQuantizationCallback:
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
@require_torch_2_6_0
@require_torch_2_8_0
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
cfg = QATConfig(
weight_dtype="int8",
weight_dtype="int4",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
@@ -304,43 +395,3 @@ class TestQuantizationCallback:
# quantization should be enabled from the get-go
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
class TestConvertQATModelForPTQ:
"""
Test convert_qat_model_for_ptq
"""
@require_torch_2_6_0
def test_convert_qat_model_for_ptq(self, model):
config = QATConfig(
weight_dtype="int8",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
)
# quantize model for qat
prepare_model_for_qat(
model,
config.weight_dtype,
config.group_size,
config.activation_dtype,
config.quantize_embedding,
)
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert isinstance(model.lm_head, FakeQuantizedLinear)
# apply conversion
convert_qat_model_for_ptq(
model,
quantize_embedding=config.quantize_embedding,
)
# ensure modules have been swapped out
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert not isinstance(model.lm_head, FakeQuantizedLinear)
# ensure weights have been quantized
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
assert isinstance(model.lm_head.weight, nn.Parameter)

View File

@@ -90,6 +90,18 @@ def require_torch_2_7_0(test_case):
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
def require_torch_2_8_0(test_case):
"""
Decorator marking a test that requires torch >= 2.7.0
"""
def is_min_2_8_0():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.8.0")
return unittest.skipUnless(is_min_2_8_0(), "test requires torch>=2.8.0")(test_case)
def require_torch_lt_2_6_0(test_case):
"""
Decorator marking a test that requires torch < 2.6.0
@@ -128,6 +140,24 @@ def require_llmcompressor(test_case):
)(test_case)
def requires_sm_ge_100(test_case):
is_sm_ge_100 = (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (10, 0)
)
return unittest.skipUnless(is_sm_ge_100, "test requires sm>=100")(test_case)
def requires_cuda_ge_8_9(test_case):
is_cuda_ge_8_9 = (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (8, 9)
)
return unittest.skipUnless(is_cuda_ge_8_9, "test requires cuda>=8.9")(test_case)
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)