Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 (#3107)

This commit is contained in:
salman
2025-09-12 10:55:50 +01:00
committed by GitHub
parent 0401a15888
commit 58d67bf98d
16 changed files with 554 additions and 339 deletions

View File

@@ -44,7 +44,7 @@ jobs:
cuda_version: 12.8.1 cuda_version: 12.8.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.8.0 pytorch: 2.8.0
axolotl_extras: axolotl_extras: fbgemm-gpu
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]

View File

@@ -304,7 +304,7 @@ jobs:
pytorch: 2.8.0 pytorch: 2.8.0
num_gpus: 1 num_gpus: 1
gpu_type: "B200" gpu_type: "B200"
axolotl_extras: axolotl_extras: fbgemm-gpu
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@@ -51,3 +51,11 @@ axolotl quantize qat.yml
``` ```
This ensures that an identical quantization configuration is used to quantize the model as was used to train it. This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
::: {.callout-note}
If you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it,
e.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w`
:::

View File

@@ -0,0 +1,64 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/dataset_prepared
sequence_len: 8192
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_checkpointing: true
gradient_accumulation_steps: 1
micro_batch_size: 64
num_epochs: 1
optimizer: adamw_torch_fused
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -15,20 +15,18 @@ liger_glu_activation: true
liger_layer_norm: true liger_layer_norm: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
datasets: datasets:
- path: yahma/alpaca-cleaned - path: yahma/alpaca-cleaned
type: alpaca type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/ output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/qat_out/dataset_prepared
sample_packing: true sample_packing: false
sequence_len: 8192
sequence_len: 512 flash_attention: true
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat: qat:
activation_dtype: int8 activation_dtype: int8
@@ -67,7 +65,7 @@ fsdp:
fsdp_config: fsdp_config:
fsdp_version: 2 fsdp_version: 2
fsdp_offload_params: false fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true fsdp_cpu_ram_efficient_loading: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
@@ -76,6 +74,6 @@ fsdp_config:
fsdp_activation_checkpointing: true fsdp_activation_checkpointing: true
special_tokens: special_tokens:
pad_token: <|end_of_text|> pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config # save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,7 +64,7 @@ langdetect==1.0.9
immutabledict==4.2.0 immutabledict==4.2.0
antlr4-python3-runtime==4.13.2 antlr4-python3-runtime==4.13.2
torchao==0.12.0 torchao==0.13.0
schedulefree==1.4.1 schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6 axolotl-contribs-lgpl==0.0.6

View File

@@ -162,6 +162,7 @@ extras_require = {
"llmcompressor": [ "llmcompressor": [
"llmcompressor==0.5.1", "llmcompressor==0.5.1",
], ],
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
} }
install_requires, dependency_links, extras_require_build = parse_requirements( install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require extras_require

View File

@@ -115,6 +115,7 @@ class QuantizeCliArgs:
quantize_embedding: Optional[bool] = field(default=None) quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None) group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None) output_dir: Optional[str] = field(default=None)
hub_model_id: Optional[str] = field(default=None)
@dataclass @dataclass

View File

@@ -5,12 +5,17 @@ CLI to post-training quantize a model using torchao
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from transformers import AutoModelForCausalLM from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq from axolotl.utils.quantization import (
TorchAOQuantDType,
get_quantization_config,
quantization_config_to_str,
quantize_model,
)
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -43,13 +48,13 @@ def do_quantize(
"No quantization configuration found. Please specify either qat or quantization in your config file." "No quantization configuration found. Please specify either qat or quantization in your config file."
) )
model_path = cli_args.get("model_path") or cfg.output_dir model_path = cli_args.get("base_model") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"): if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype] weight_dtype = TorchAOQuantDType.from_string(weight_dtype)
else: else:
weight_dtype = quantize_cfg.weight_dtype weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"): if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchIntDType[activation_dtype] activation_dtype = TorchAOQuantDType.from_string(activation_dtype)
else: else:
activation_dtype = quantize_cfg.activation_dtype activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size group_size = cli_args.get("group_size") or quantize_cfg.group_size
@@ -57,10 +62,15 @@ def do_quantize(
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
) )
output_dir = cli_args.get("output_dir") or cfg.output_dir output_dir = cli_args.get("output_dir") or cfg.output_dir
hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id
LOG.info(f"Loading model from {model_path}...") LOG.info(f"Loading model from {model_path}.")
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", torch_dtype=torch_dtype
)
LOG.info( LOG.info(
f"Quantizing model with configuration: \n" f"Quantizing model with configuration: \n"
@@ -70,11 +80,21 @@ def do_quantize(
f"\tquantize_embedding: {quantize_embedding}" f"\tquantize_embedding: {quantize_embedding}"
) )
quantize_model_for_ptq( quantize_model(
model, weight_dtype, group_size, activation_dtype, quantize_embedding model, weight_dtype, group_size, activation_dtype, quantize_embedding
) )
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...") quantization_config = get_quantization_config(
weight_dtype, activation_dtype, group_size
)
ao_config = TorchAoConfig(
quant_type=quantization_config,
include_input_output_embeddings=quantize_embedding,
)
model.config.quantization_config = ao_config
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained( model.save_pretrained(
str(Path(output_dir) / "quantized"), str(Path(output_dir) / "quantized"),
safe_serialization=False, safe_serialization=False,
@@ -86,4 +106,14 @@ def do_quantize(
progressbar=True, progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files, save_jinja_files=cfg.tokenizer_save_jinja_files,
) )
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
if hub_model_id:
hub_model_id = (
hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}"
)
model.push_to_hub(hub_model_id, safe_serialization=False)
tokenizer.push_to_hub(hub_model_id)
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")

View File

@@ -30,11 +30,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.loaders import ( from axolotl.loaders import ModelLoader, load_processor, load_tokenizer
ModelLoader,
load_processor,
load_tokenizer,
)
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.distributed import cleanup_distributed
@@ -234,16 +230,15 @@ def save_trained_model(
# handle QAT # handle QAT
if cfg.qat: if cfg.qat:
from axolotl.utils.quantization import convert_qat_model_for_ptq from axolotl.utils.quantization import convert_qat_model
LOG.info("Processing QAT model for saving...") convert_qat_model(
convert_qat_model_for_ptq(
model, model,
quantize_embedding=cfg.qat.quantize_embedding, quantize_embedding=cfg.qat.quantize_embedding,
) )
LOG.info( LOG.info(
"QAT modules have been converted for PTQ. Please ensure you quantize " "QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`"
"your model weights with `axolotl quantize`." " with the same config which you used for training."
) )
# Handle ReLoRA early return case # Handle ReLoRA early return case
if cfg.relora: if cfg.relora:
@@ -337,9 +332,7 @@ def save_trained_model(
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor: if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin # TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import ( from axolotl.integrations.llm_compressor.utils import save_compressed_model
save_compressed_model,
)
save_compressed_model( save_compressed_model(
model=model, model=model,

View File

@@ -3,30 +3,47 @@ Utilities for quantization including QAT and PTQ using torchao.
""" """
import torch import torch
from torch import nn from packaging import version
from torchao.core.config import AOBaseConfig from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_ from torchao.quantization import quantize_
from torchao.quantization.qat import ( from torchao.quantization.qat import (
FakeQuantizeConfig, QATConfig,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
) )
from torchao.quantization.quant_api import ( from torchao.quantization.quant_api import (
Int4DynamicActivationInt4WeightConfig, Float8DynamicActivationFloat8WeightConfig,
Int4WeightOnlyConfig, Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
_is_linear,
) )
from axolotl.utils.schemas.enums import TorchIntDType from axolotl.utils.schemas.enums import TorchAOQuantDType
quantization_config_to_str = {
Int8DynamicActivationInt4WeightConfig: "int8int4",
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
Float8DynamicActivationInt4WeightConfig: "fp8int4",
}
if version.parse(torch.__version__) >= version.parse("2.8.0"):
try:
from torchao.prototype.mx_formats import NVFP4InferenceConfig
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
except:
pass
# int4 weight config imports will fail on machines with fbgemm-gpu installed
# without a CUDA runtime available so we do this safely
try:
from torchao.quantization.quant_api import Int4WeightOnlyConfig
quantization_config_to_str[Int4WeightOnlyConfig] = "int4"
except:
pass
def get_ptq_config( def get_quantization_config(
weight_dtype: TorchIntDType, weight_dtype: TorchAOQuantDType,
activation_dtype: TorchIntDType | None = None, activation_dtype: TorchAOQuantDType | None = None,
group_size: int | None = None, group_size: int | None = None,
) -> AOBaseConfig: ) -> AOBaseConfig:
""" """
@@ -45,44 +62,101 @@ def get_ptq_config(
or if the group size is not specified for int8 or int4 weight only quantization. or if the group size is not specified for int8 or int4 weight only quantization.
""" """
if activation_dtype is None: if activation_dtype is None:
if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr] if weight_dtype == TorchAOQuantDType.int8:
return UIntXWeightOnlyConfig( raise ValueError("Int8WeightOnlyConfig is not supported by torchao QAT.")
dtype=weight_dtype.value, if weight_dtype == TorchAOQuantDType.int4:
group_size=group_size, from torchao.quantization.quant_api import Int4WeightOnlyConfig
set_inductor_config=False,
) if group_size is not None:
if weight_dtype == TorchIntDType.int8: return Int4WeightOnlyConfig(group_size=group_size, version=2)
if group_size is None: else:
raise ValueError( return Int4WeightOnlyConfig(version=2)
"group_size must be specified for int8 weight only quantization" if (
) activation_dtype == TorchAOQuantDType.int4
return Int8WeightOnlyConfig( and weight_dtype == TorchAOQuantDType.int4
group_size=group_size, ):
) raise ValueError(
if weight_dtype == TorchIntDType.int4: "Int4DynamicActivationInt4WeightConfig is not supported by torchao QAT."
if group_size is None: )
raise ValueError( if (
"group_size must be specified for int4 weight only quantization" activation_dtype == TorchAOQuantDType.int8
) and weight_dtype == TorchAOQuantDType.int8
return Int4WeightOnlyConfig( ):
group_size=group_size, raise ValueError(
) "Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT."
if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4: )
return Int4DynamicActivationInt4WeightConfig() if (
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8: activation_dtype == TorchAOQuantDType.int8
return Int8DynamicActivationInt8WeightConfig() and weight_dtype == TorchAOQuantDType.int4
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4: ):
return Int8DynamicActivationInt4WeightConfig() if group_size is not None:
return Int8DynamicActivationInt4WeightConfig(group_size=group_size)
else:
return Int8DynamicActivationInt4WeightConfig()
if (
activation_dtype == TorchAOQuantDType.float8_e4m3fn
and weight_dtype == TorchAOQuantDType.float8_e4m3fn
):
return Float8DynamicActivationFloat8WeightConfig()
if (
activation_dtype == TorchAOQuantDType.float8_e4m3fn
and weight_dtype == TorchAOQuantDType.int4
):
return Float8DynamicActivationInt4WeightConfig()
if weight_dtype == TorchAOQuantDType.nvfp4:
from torchao.prototype.mx_formats import NVFP4InferenceConfig
if group_size is not None and group_size != 16:
raise ValueError("NVFP4 quantization must use a group_size of 16")
return NVFP4InferenceConfig()
raise ValueError( raise ValueError(
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}" f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
) )
def quantize_model(
model,
weight_dtype: TorchAOQuantDType,
group_size: int | None = None,
activation_dtype: TorchAOQuantDType | None = None,
quantize_embedding: bool | None = None,
):
"""
This function is used to quantize a model.
Args:
model: The model to quantize.
weight_dtype: The dtype to use for weight quantization.
group_size: The group size to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
linear_ptq_config = get_quantization_config(
weight_dtype=weight_dtype,
activation_dtype=activation_dtype,
group_size=group_size,
)
quantize_(model, linear_ptq_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
embedding_quantize_config = get_quantization_config(
weight_dtype=weight_dtype,
activation_dtype=None,
group_size=group_size,
)
quantize_(
model,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def prepare_model_for_qat( def prepare_model_for_qat(
model, model,
weight_dtype: TorchIntDType, weight_dtype: TorchAOQuantDType,
group_size: int, group_size: int | None = None,
activation_dtype: TorchIntDType | None = None, activation_dtype: TorchAOQuantDType | None = None,
quantize_embedding: bool = False, quantize_embedding: bool = False,
): ):
""" """
@@ -100,86 +174,40 @@ def prepare_model_for_qat(
Raises: Raises:
ValueError: If the activation/weight dtype combination is invalid. ValueError: If the activation/weight dtype combination is invalid.
""" """
if activation_dtype: base_config = get_quantization_config(
activation_config = FakeQuantizeConfig(
dtype=activation_dtype.value, granularity="per_token", is_symmetric=False
)
weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size)
linear_quantize_config = IntXQuantizationAwareTrainingConfig(
activation_config=None if activation_dtype is None else activation_config,
weight_config=weight_config,
)
quantize_(model, linear_quantize_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
embedding_quantize_config = IntXQuantizationAwareTrainingConfig(
activation_config=None,
weight_config=weight_config,
)
quantize_(
model,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def quantize_model_for_ptq(
model,
weight_dtype: TorchIntDType,
group_size: int | None = None,
activation_dtype: TorchIntDType | None = None,
quantize_embedding: bool | None = None,
):
"""
This function is used to quantize a model for post-training quantization.
It swaps the model's linear layers with fake quantized linear layers.
If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights.
Args:
model: The model to quantize.
weight_dtype: The dtype to use for weight quantization.
group_size: The group size to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
linear_ptq_config = get_ptq_config(
weight_dtype=weight_dtype, weight_dtype=weight_dtype,
activation_dtype=activation_dtype, activation_dtype=activation_dtype,
group_size=group_size, group_size=group_size,
) )
quantize_(model, linear_ptq_config) qat_config = QATConfig(base_config)
quantize_(model, qat_config)
if quantize_embedding: if quantize_embedding:
embedding_quantize_config = get_ptq_config( # activation fake quantization is not supported for embedding layers
embedding_base_config = get_quantization_config(
weight_dtype=weight_dtype, weight_dtype=weight_dtype,
activation_dtype=None, activation_dtype=None,
group_size=group_size, group_size=group_size,
) )
embedding_qat_config = QATConfig(embedding_base_config)
quantize_( quantize_(
model, model,
embedding_quantize_config, embedding_qat_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
) )
def convert_qat_model_for_ptq( def convert_qat_model(
model, model,
*, quantize_embedding: bool = False,
quantize_embedding: bool | None = None,
): ):
""" """
This function is used to convert a swap fake-quantized modules in a model This function converts a QAT model which has fake quantized layers back to the original model.
which has been trained with QAT back to the original modules, ready for PTQ.
Args:
model: The model to convert.
quantize_embedding: Whether to quantize the model's embedding weights.
""" """
config = QATConfig(step="convert")
quantize_(model, config)
if quantize_embedding: if quantize_embedding:
quantize_(
def filter_fn(m, _): model,
return isinstance(m, nn.Embedding) or _is_linear(m) config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
else: )
filter_fn = _is_linear
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)

View File

@@ -5,18 +5,21 @@ from enum import Enum
import torch import torch
class TorchIntDType(Enum): class TorchAOQuantDType(Enum):
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4""" int4 = torch.int4
int8 = torch.int8
float8_e4m3fn = torch.float8_e4m3fn
nvfp4 = "nvfp4"
uint1 = getattr(torch, "uint1", None) def from_string(str):
uint2 = getattr(torch, "uint2", None) if str == "int4":
uint3 = getattr(torch, "uint3", None) return TorchAOQuantDType.int4
uint4 = getattr(torch, "uint4", None) if str == "int8":
uint5 = getattr(torch, "uint5", None) return TorchAOQuantDType.int8
uint6 = getattr(torch, "uint6", None) if str in ["float8_e4m3fn", "fp8", "float8"]:
uint7 = getattr(torch, "uint7", None) return TorchAOQuantDType.float8_e4m3fn
int4 = getattr(torch, "int4", None) if str == "nvfp4":
int8 = getattr(torch, "int8", None) return TorchAOQuantDType.nvfp4
class RLType(str, Enum): class RLType(str, Enum):

View File

@@ -6,7 +6,23 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from axolotl.utils.schemas.enums import TorchIntDType from axolotl.utils.schemas.enums import TorchAOQuantDType
def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
if v is None:
return None
if v == "int4":
return TorchAOQuantDType.int4
if v == "int8":
return TorchAOQuantDType.int8
if v in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if v == "nvfp4":
return TorchAOQuantDType.nvfp4
raise ValueError(
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
)
class QATConfig(BaseModel): class QATConfig(BaseModel):
@@ -14,13 +30,13 @@ class QATConfig(BaseModel):
QAT Config Schema QAT Config Schema
""" """
activation_dtype: TorchIntDType | None = Field( activation_dtype: TorchAOQuantDType | None = Field(
default=None, default=None,
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"', description="Fake quantization layout to use for activation quantization.",
) )
weight_dtype: TorchIntDType = Field( weight_dtype: TorchAOQuantDType = Field(
default=TorchIntDType.int8, default=TorchAOQuantDType.int8,
description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"', description="Fake quantization layout to use for weight quantization.",
) )
quantize_embedding: bool | None = Field( quantize_embedding: bool | None = Field(
default=False, description="Quantize embedding" default=False, description="Quantize embedding"
@@ -35,12 +51,8 @@ class QATConfig(BaseModel):
@field_validator("activation_dtype", "weight_dtype", mode="before") @field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod @classmethod
def validate_dtype(cls, v: Any) -> TorchIntDType | None: def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:
if v == "int4": return validate_ao_dtype(v)
return TorchIntDType.int4
if v == "int8":
return TorchIntDType.int8
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
class PTQConfig(BaseModel): class PTQConfig(BaseModel):
@@ -48,13 +60,13 @@ class PTQConfig(BaseModel):
PTQ Config Schema PTQ Config Schema
""" """
weight_dtype: TorchIntDType = Field( weight_dtype: TorchAOQuantDType = Field(
default=TorchIntDType.int8, default=TorchAOQuantDType.int8,
description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8", description="Fake quantization layout to use for weight quantization.",
) )
activation_dtype: TorchIntDType | None = Field( activation_dtype: TorchAOQuantDType | None = Field(
default=None, default=None,
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"', description="Fake quantization layout to use for activation quantization.",
) )
quantize_embedding: bool | None = Field( quantize_embedding: bool | None = Field(
default=None, description="Whether to quantize the embedding layer." default=None, description="Whether to quantize the embedding layer."
@@ -66,9 +78,5 @@ class PTQConfig(BaseModel):
@field_validator("activation_dtype", "weight_dtype", mode="before") @field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod @classmethod
def validate_dtype(cls, v: Any) -> TorchIntDType | None: def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:
if v == "int4": return validate_ao_dtype(v)
return TorchIntDType.int4
if v == "int8":
return TorchIntDType.int8
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")

View File

@@ -43,7 +43,7 @@ class TestQATLlama:
"qat": { "qat": {
"quantize_embedding": True, "quantize_embedding": True,
"activation_dtype": "int8", "activation_dtype": "int8",
"weight_dtype": "int8", "weight_dtype": "int4",
"group_size": 8, "group_size": 8,
}, },
"num_epochs": 1, "num_epochs": 1,
@@ -111,7 +111,7 @@ class TestQATLlama:
"qat": { "qat": {
"quantize_embedding": True, "quantize_embedding": True,
"activation_dtype": "int8", "activation_dtype": "int8",
"weight_dtype": "int8", "weight_dtype": "int4",
"group_size": 8, "group_size": 8,
}, },
"save_first_step": False, "save_first_step": False,

View File

@@ -5,41 +5,40 @@ Tests for axolotl.utils.quantization
import pytest import pytest
import torch import torch
from torch import nn from torch import nn
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
)
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear from torchao.quantization.qat.linear import FakeQuantizedLinear
from torchao.quantization.quant_api import ( from torchao.quantization.quant_api import (
Int4DynamicActivationInt4WeightConfig, Float8DynamicActivationFloat8WeightConfig,
Int4WeightOnlyConfig, Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig, Int8DynamicActivationInt4WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
) )
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from transformers.trainer_callback import TrainerState from transformers.trainer_callback import TrainerState
from axolotl.utils.callbacks.qat import QATCallback from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.quantization import ( from axolotl.utils.quantization import (
convert_qat_model_for_ptq, convert_qat_model,
get_ptq_config, get_quantization_config,
prepare_model_for_qat, prepare_model_for_qat,
quantize_model_for_ptq, quantize_model,
) )
from axolotl.utils.schemas.enums import TorchIntDType from axolotl.utils.schemas.enums import TorchAOQuantDType
from axolotl.utils.schemas.quantization import QATConfig from axolotl.utils.schemas.quantization import QATConfig
from tests.e2e.utils import require_torch_2_6_0 from tests.e2e.utils import (
require_torch_2_8_0,
requires_cuda_ge_8_9,
requires_sm_ge_100,
)
@pytest.fixture() @pytest.fixture()
def model(): def model():
dummy_model = AutoModelForCausalLM.from_pretrained( dummy_model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM2-135M", "Qwen/Qwen2-0.5B",
device_map="cuda", device_map="auto",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
) )
with torch.device(dummy_model.device): with torch.device(dummy_model.device):
@@ -48,45 +47,56 @@ def model():
dummy_model.model.embed_tokens.weight.shape[1], dummy_model.model.embed_tokens.weight.shape[1],
dtype=dummy_model.model.embed_tokens.weight.dtype, dtype=dummy_model.model.embed_tokens.weight.dtype,
) )
return dummy_model yield dummy_model
del dummy_model
ptq_config_test_cases = [ ptq_config_test_cases = [
# weight_dtype, activation_dtype, group_size, expected_type, expected_params # weight_dtype, activation_dtype, group_size, expected_type
( (
TorchIntDType.uint4, TorchAOQuantDType.int4,
TorchAOQuantDType.int8,
None, None,
None, Int8DynamicActivationInt4WeightConfig,
UIntXWeightOnlyConfig,
{"dtype": torch.uint4, "group_size": None},
),
(TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}),
(TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}),
(
TorchIntDType.int4,
TorchIntDType.int4,
None,
Int4DynamicActivationInt4WeightConfig,
{},
), ),
( (
TorchIntDType.int8, TorchAOQuantDType.float8_e4m3fn,
TorchIntDType.int8, TorchAOQuantDType.float8_e4m3fn,
None, None,
Int8DynamicActivationInt8WeightConfig, Float8DynamicActivationFloat8WeightConfig,
{}, ),
(
TorchAOQuantDType.int4,
TorchAOQuantDType.float8_e4m3fn,
None,
Float8DynamicActivationInt4WeightConfig,
), ),
] ]
ptq_test_cases = [ ptq_test_cases = [
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception # weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception, expected_tensor_class
(TorchIntDType.int8, None, 8, False, None), (TorchAOQuantDType.int4, None, 4, True, None, Int4Tensor),
(TorchIntDType.int4, None, 4, True, None), (
(TorchIntDType.uint4, None, 8, False, None), TorchAOQuantDType.int4,
(TorchIntDType.int4, TorchIntDType.int4, 8, False, None), TorchAOQuantDType.int8,
(TorchIntDType.int8, TorchIntDType.int8, 8, True, None), 8,
(TorchIntDType.int8, None, None, False, ValueError), False,
(TorchIntDType.int4, None, None, False, ValueError), None,
LinearActivationQuantizedTensor,
),
# (
# TorchAOQuantDType.int4,
# TorchAOQuantDType.float8_e4m3fn,
# None,
# False,
# None,
# Int4Tensor,
# ),
(TorchAOQuantDType.int4, None, None, False, None, Int4Tensor),
# Deprecated configs
(TorchAOQuantDType.int8, None, 8, False, ValueError, None),
(TorchAOQuantDType.int4, TorchAOQuantDType.int4, 8, False, ValueError, None),
(TorchAOQuantDType.int8, TorchAOQuantDType.int8, 8, True, ValueError, None),
] ]
@@ -96,44 +106,132 @@ class TestQuantization:
""" """
@pytest.mark.parametrize( @pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,expected_type,expected_params", "weight_dtype,activation_dtype,group_size,expected_type",
ptq_config_test_cases, ptq_config_test_cases,
) )
@require_torch_2_6_0 @requires_cuda_ge_8_9
@require_torch_2_8_0
def test_get_ptq_config( def test_get_ptq_config(
self, weight_dtype, activation_dtype, group_size, expected_type, expected_params self, weight_dtype, activation_dtype, group_size, expected_type
): ):
config = get_ptq_config(weight_dtype, activation_dtype, group_size) config = get_quantization_config(weight_dtype, activation_dtype, group_size)
assert isinstance(config, expected_type) assert isinstance(config, expected_type)
for param_name, param_value in expected_params.items(): @requires_cuda_ge_8_9
if isinstance(param_value, (PerAxis, PerGroup)): @require_torch_2_8_0
if isinstance(param_value, PerAxis): def test_get_ptq_config_int4_weight_only(self):
assert isinstance(getattr(config, param_name), PerAxis) from torchao.quantization.quant_api import Int4WeightOnlyConfig
assert getattr(config, param_name).axis == param_value.axis
else: config = get_quantization_config(TorchAOQuantDType.int4, None, 4)
assert isinstance(getattr(config, param_name), PerGroup) assert isinstance(config, Int4WeightOnlyConfig)
assert (
getattr(config, param_name).group_size == param_value.group_size
)
else:
assert getattr(config, param_name) == param_value
@pytest.mark.parametrize( @pytest.mark.parametrize(
"weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4] "weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception,expected_tensor_class",
ptq_test_cases,
) )
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_quantize_model_for_ptq(
self,
model,
weight_dtype,
activation_dtype,
group_size,
quantize_embedding,
expected_exception,
expected_tensor_class,
):
if expected_exception:
with pytest.raises(expected_exception):
quantize_model(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
else:
quantize_model(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(
model.model.embed_tokens.weight, expected_tensor_class
), "Embedding weight should be quantized"
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, expected_tensor_class)
@require_torch_2_8_0
@requires_sm_ge_100
def test_quantize_model_for_ptq_fp8(
self,
model,
):
from torchao.quantization.quantize_.workflows.float8.float8_tensor import (
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
)
quantize_model(
model,
TorchAOQuantDType.float8_e4m3fn,
None,
TorchAOQuantDType.float8_e4m3fn,
)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, Float8Tensor)
assert child.weight.act_quant_kwargs is not None and isinstance(
child.weight.act_quant_kwargs, QuantizeTensorToFloat8Kwargs
)
@require_torch_2_8_0
@requires_sm_ge_100
def test_quantize_model_for_ptq_nvfp4(
self,
model,
):
from torchao.prototype.mx_formats.nvfp4_tensor import (
NVFP4Tensor,
QuantizeTensorToNVFP4Kwargs,
)
quantize_model(model, TorchAOQuantDType.nvfp4, 16, TorchAOQuantDType.nvfp4)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child.weight, NVFP4Tensor)
assert child.weight.act_quant_kwargs is not None and isinstance(
child.weight.act_quant_kwargs, QuantizeTensorToNVFP4Kwargs
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8] "weight_dtype,activation_dtype,group_size,quantize_embedding",
[
(TorchAOQuantDType.int4, None, 8, False),
(TorchAOQuantDType.int4, None, 16, True),
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 8, False),
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 16, True),
(
TorchAOQuantDType.float8_e4m3fn,
TorchAOQuantDType.float8_e4m3fn,
None,
False,
),
(TorchAOQuantDType.int4, TorchAOQuantDType.float8_e4m3fn, None, True),
],
) )
@pytest.mark.parametrize("group_size", [4, 8]) @require_torch_2_8_0
@pytest.mark.parametrize("quantize_embedding", [False, True]) @requires_cuda_ge_8_9
@require_torch_2_6_0
def test_prepare_model_for_qat( def test_prepare_model_for_qat(
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
): ):
prepare_model_for_qat( prepare_model_for_qat(
model, weight_dtype, group_size, activation_dtype, quantize_embedding model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
) )
if quantize_embedding: if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
@@ -142,17 +240,19 @@ class TestQuantization:
model.model.embed_tokens.weight_fake_quantizer.config.dtype model.model.embed_tokens.weight_fake_quantizer.config.dtype
== weight_dtype.value == weight_dtype.value
) )
assert ( if group_size:
model.model.embed_tokens.weight_fake_quantizer.config.group_size assert (
== group_size model.model.embed_tokens.weight_fake_quantizer.config.group_size
) == group_size
)
for child in list(model.children()): for child in list(model.children()):
if isinstance(child, torch.nn.Linear): if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear) assert isinstance(child, FakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer") assert hasattr(child, "weight_fake_quantizer")
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
assert child.weight_fake_quantizer.config.group_size == group_size if group_size:
assert child.weight_fake_quantizer.config.group_size == group_size
if activation_dtype: if activation_dtype:
assert hasattr(child, "activation_fake_quantizer") assert hasattr(child, "activation_fake_quantizer")
assert ( assert (
@@ -162,49 +262,40 @@ class TestQuantization:
else: else:
assert child.activation_fake_quantizer is None assert child.activation_fake_quantizer is None
@pytest.mark.parametrize( @require_torch_2_8_0
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception", @requires_cuda_ge_8_9
ptq_test_cases, def test_convert_qat_model(self, model):
) config = QATConfig(
@require_torch_2_6_0 weight_dtype="int4",
def test_quantize_model_for_ptq( activation_dtype="int8",
self, group_size=8,
model, quantize_embedding=True,
weight_dtype, )
activation_dtype,
group_size, # quantize model for qat
quantize_embedding, prepare_model_for_qat(
expected_exception, model,
): config.weight_dtype,
if expected_exception: config.group_size,
with pytest.raises(expected_exception): config.activation_dtype,
quantize_model_for_ptq( config.quantize_embedding,
model, )
weight_dtype,
group_size, assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
activation_dtype, assert isinstance(model.lm_head, FakeQuantizedLinear)
quantize_embedding,
) # apply conversion
else: convert_qat_model(
quantize_model_for_ptq( model,
model, weight_dtype, group_size, activation_dtype, quantize_embedding config.quantize_embedding,
) )
if quantize_embedding: # ensure modules have been swapped out
assert isinstance( assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
model.model.embed_tokens.weight, AffineQuantizedTensor assert not isinstance(model.lm_head, FakeQuantizedLinear)
), "Embedding weight should be quantized"
for child in list(model.children()): # ensure weights have been quantized
if isinstance(child, torch.nn.Linear): assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
if activation_dtype: assert isinstance(model.lm_head.weight, nn.Parameter)
assert isinstance(
child.weight, LinearActivationQuantizedTensor
), (
"Linear weight should be quantized with activation quantization"
)
else:
assert isinstance(child.weight, AffineQuantizedTensor), (
"Linear weight should be quantized without activation quantization"
)
class TestQuantizationCallback: class TestQuantizationCallback:
@@ -218,10 +309,10 @@ class TestQuantizationCallback:
global_step=0, global_step=0,
) )
@require_torch_2_6_0 @require_torch_2_8_0
def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state): def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state):
cfg = QATConfig( cfg = QATConfig(
weight_dtype="int8", weight_dtype="int4",
activation_dtype="int8", activation_dtype="int8",
group_size=8, group_size=8,
quantize_embedding=True, quantize_embedding=True,
@@ -268,10 +359,10 @@ class TestQuantizationCallback:
assert model.model.embed_tokens.weight_fake_quantizer.enabled assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled assert model.lm_head.weight_fake_quantizer.enabled
@require_torch_2_6_0 @require_torch_2_8_0
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state): def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
cfg = QATConfig( cfg = QATConfig(
weight_dtype="int8", weight_dtype="int4",
activation_dtype="int8", activation_dtype="int8",
group_size=8, group_size=8,
quantize_embedding=True, quantize_embedding=True,
@@ -304,43 +395,3 @@ class TestQuantizationCallback:
# quantization should be enabled from the get-go # quantization should be enabled from the get-go
assert model.model.embed_tokens.weight_fake_quantizer.enabled assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled assert model.lm_head.weight_fake_quantizer.enabled
class TestConvertQATModelForPTQ:
"""
Test convert_qat_model_for_ptq
"""
@require_torch_2_6_0
def test_convert_qat_model_for_ptq(self, model):
config = QATConfig(
weight_dtype="int8",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
)
# quantize model for qat
prepare_model_for_qat(
model,
config.weight_dtype,
config.group_size,
config.activation_dtype,
config.quantize_embedding,
)
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert isinstance(model.lm_head, FakeQuantizedLinear)
# apply conversion
convert_qat_model_for_ptq(
model,
quantize_embedding=config.quantize_embedding,
)
# ensure modules have been swapped out
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert not isinstance(model.lm_head, FakeQuantizedLinear)
# ensure weights have been quantized
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
assert isinstance(model.lm_head.weight, nn.Parameter)

View File

@@ -90,6 +90,18 @@ def require_torch_2_7_0(test_case):
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case) return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
def require_torch_2_8_0(test_case):
"""
Decorator marking a test that requires torch >= 2.7.0
"""
def is_min_2_8_0():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.8.0")
return unittest.skipUnless(is_min_2_8_0(), "test requires torch>=2.8.0")(test_case)
def require_torch_lt_2_6_0(test_case): def require_torch_lt_2_6_0(test_case):
""" """
Decorator marking a test that requires torch < 2.6.0 Decorator marking a test that requires torch < 2.6.0
@@ -128,6 +140,24 @@ def require_llmcompressor(test_case):
)(test_case) )(test_case)
def requires_sm_ge_100(test_case):
is_sm_ge_100 = (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (10, 0)
)
return unittest.skipUnless(is_sm_ge_100, "test requires sm>=100")(test_case)
def requires_cuda_ge_8_9(test_case):
is_cuda_ge_8_9 = (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (8, 9)
)
return unittest.skipUnless(is_cuda_ge_8_9, "test requires cuda>=8.9")(test_case)
def is_hopper(): def is_hopper():
compute_capability = torch.cuda.get_device_capability() compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0) return compute_capability == (9, 0)