From 58d67bf98ddca63cb082374a04f8b2250ffc2c59 Mon Sep 17 00:00:00 2001 From: salman Date: Fri, 12 Sep 2025 10:55:50 +0100 Subject: [PATCH] Migrate QAT API; fix `axolotl quantize` for QAT-ed models; add NVFP4 (#3107) --- .github/workflows/multi-gpu-e2e.yml | 2 +- .github/workflows/tests.yml | 2 +- docs/quantize.qmd | 8 + examples/llama-3/3b-qat-fsdp2-nvfp4.yaml | 64 ++++ examples/llama-3/3b-qat-fsdp2.yaml | 18 +- requirements.txt | 2 +- setup.py | 1 + src/axolotl/cli/args.py | 1 + src/axolotl/cli/quantize.py | 50 ++- src/axolotl/train.py | 19 +- src/axolotl/utils/quantization.py | 244 +++++++------- src/axolotl/utils/schemas/enums.py | 25 +- src/axolotl/utils/schemas/quantization.py | 54 ++-- tests/e2e/test_qat.py | 4 +- tests/e2e/test_quantization.py | 369 ++++++++++++---------- tests/e2e/utils.py | 30 ++ 16 files changed, 554 insertions(+), 339 deletions(-) create mode 100644 examples/llama-3/3b-qat-fsdp2-nvfp4.yaml diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index 6492e5d3e..05f9e0761 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -44,7 +44,7 @@ jobs: cuda_version: 12.8.1 python_version: "3.11" pytorch: 2.8.0 - axolotl_extras: + axolotl_extras: fbgemm-gpu num_gpus: 2 nightly_build: "true" runs-on: [self-hosted, modal] diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 337230d4a..cfd2c715d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -304,7 +304,7 @@ jobs: pytorch: 2.8.0 num_gpus: 1 gpu_type: "B200" - axolotl_extras: + axolotl_extras: fbgemm-gpu steps: - name: Checkout uses: actions/checkout@v4 diff --git a/docs/quantize.qmd b/docs/quantize.qmd index 113fcafbe..43c817a5b 100644 --- a/docs/quantize.qmd +++ b/docs/quantize.qmd @@ -51,3 +51,11 @@ axolotl quantize qat.yml ``` This ensures that an identical quantization configuration is used to quantize the model as was used to train it. + + +::: {.callout-note} + +If you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it, +e.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w` + +::: diff --git a/examples/llama-3/3b-qat-fsdp2-nvfp4.yaml b/examples/llama-3/3b-qat-fsdp2-nvfp4.yaml new file mode 100644 index 000000000..1ec809bbe --- /dev/null +++ b/examples/llama-3/3b-qat-fsdp2-nvfp4.yaml @@ -0,0 +1,64 @@ +base_model: meta-llama/Llama-3.2-3B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true + +datasets: + - path: yahma/alpaca-cleaned + type: alpaca + split: train[:95%] + +output_dir: ./outputs/qat_out/ +dataset_prepared_path: ./outputs/dataset_prepared + +sequence_len: 8192 +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_checkpointing: true +gradient_accumulation_steps: 1 +micro_batch_size: 64 +num_epochs: 1 +optimizer: adamw_torch_fused + +cosine_constant_lr_ratio: 0 +cosine_min_lr_ratio: 1.0 +learning_rate: 2e-5 +save_only_model: true +bf16: true + +resume_from_checkpoint: +logging_steps: 1 + +evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 + +special_tokens: + pad_token: <|finetune_right_pad_id|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/3b-qat-fsdp2.yaml b/examples/llama-3/3b-qat-fsdp2.yaml index 35e3461e2..0c5a87891 100644 --- a/examples/llama-3/3b-qat-fsdp2.yaml +++ b/examples/llama-3/3b-qat-fsdp2.yaml @@ -15,20 +15,18 @@ liger_glu_activation: true liger_layer_norm: true liger_fused_linear_cross_entropy: true + datasets: - path: yahma/alpaca-cleaned type: alpaca + split: train[:95%] output_dir: ./outputs/qat_out/ +dataset_prepared_path: ./outputs/qat_out/dataset_prepared -sample_packing: true - -sequence_len: 512 - -flex_attention: true -flex_attn_compile_kwargs: - dynamic: false - mode: max-autotune-no-cudagraphs +sample_packing: false +sequence_len: 8192 +flash_attention: true qat: activation_dtype: int8 @@ -67,7 +65,7 @@ fsdp: fsdp_config: fsdp_version: 2 fsdp_offload_params: false - fsdp_cpu_ram_efficient_loading: true + fsdp_cpu_ram_efficient_loading: false fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_state_dict_type: FULL_STATE_DICT @@ -76,6 +74,6 @@ fsdp_config: fsdp_activation_checkpointing: true special_tokens: - pad_token: <|end_of_text|> + pad_token: <|finetune_right_pad_id|> # save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/requirements.txt b/requirements.txt index 1292a179a..6138707af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -64,7 +64,7 @@ langdetect==1.0.9 immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 -torchao==0.12.0 +torchao==0.13.0 schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.6 diff --git a/setup.py b/setup.py index 4cbc562e0..3a44f0ae9 100644 --- a/setup.py +++ b/setup.py @@ -162,6 +162,7 @@ extras_require = { "llmcompressor": [ "llmcompressor==0.5.1", ], + "fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"], } install_requires, dependency_links, extras_require_build = parse_requirements( extras_require diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index 396e9a8af..14dafa43f 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -115,6 +115,7 @@ class QuantizeCliArgs: quantize_embedding: Optional[bool] = field(default=None) group_size: Optional[int] = field(default=None) output_dir: Optional[str] = field(default=None) + hub_model_id: Optional[str] = field(default=None) @dataclass diff --git a/src/axolotl/cli/quantize.py b/src/axolotl/cli/quantize.py index b8a8de781..6838f47d8 100644 --- a/src/axolotl/cli/quantize.py +++ b/src/axolotl/cli/quantize.py @@ -5,12 +5,17 @@ CLI to post-training quantize a model using torchao from pathlib import Path from typing import Union -from transformers import AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig from axolotl.cli.config import load_cfg from axolotl.loaders import load_tokenizer from axolotl.utils.logging import get_logger -from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq +from axolotl.utils.quantization import ( + TorchAOQuantDType, + get_quantization_config, + quantization_config_to_str, + quantize_model, +) LOG = get_logger(__name__) @@ -43,13 +48,13 @@ def do_quantize( "No quantization configuration found. Please specify either qat or quantization in your config file." ) - model_path = cli_args.get("model_path") or cfg.output_dir + model_path = cli_args.get("base_model") or cfg.output_dir if weight_dtype := cli_args.get("weight_dtype"): - weight_dtype = TorchIntDType[weight_dtype] + weight_dtype = TorchAOQuantDType.from_string(weight_dtype) else: weight_dtype = quantize_cfg.weight_dtype if activation_dtype := cli_args.get("activation_dtype"): - activation_dtype = TorchIntDType[activation_dtype] + activation_dtype = TorchAOQuantDType.from_string(activation_dtype) else: activation_dtype = quantize_cfg.activation_dtype group_size = cli_args.get("group_size") or quantize_cfg.group_size @@ -57,10 +62,15 @@ def do_quantize( cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding ) output_dir = cli_args.get("output_dir") or cfg.output_dir + hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id - LOG.info(f"Loading model from {model_path}...") + LOG.info(f"Loading model from {model_path}.") tokenizer = load_tokenizer(cfg) - model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") + config = AutoConfig.from_pretrained(model_path) + torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None + model = AutoModelForCausalLM.from_pretrained( + model_path, device_map="auto", torch_dtype=torch_dtype + ) LOG.info( f"Quantizing model with configuration: \n" @@ -70,11 +80,21 @@ def do_quantize( f"\tquantize_embedding: {quantize_embedding}" ) - quantize_model_for_ptq( + quantize_model( model, weight_dtype, group_size, activation_dtype, quantize_embedding ) - LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...") + quantization_config = get_quantization_config( + weight_dtype, activation_dtype, group_size + ) + + ao_config = TorchAoConfig( + quant_type=quantization_config, + include_input_output_embeddings=quantize_embedding, + ) + model.config.quantization_config = ao_config + + LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.") model.save_pretrained( str(Path(output_dir) / "quantized"), safe_serialization=False, @@ -86,4 +106,14 @@ def do_quantize( progressbar=True, save_jinja_files=cfg.tokenizer_save_jinja_files, ) - LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...") + + if hub_model_id: + hub_model_id = ( + hub_model_id.rstrip("-") + + f"-{quantization_config_to_str[type(quantization_config)]}" + ) + model.push_to_hub(hub_model_id, safe_serialization=False) + tokenizer.push_to_hub(hub_model_id) + LOG.info(f"Quantized model pushed to: {hub_model_id}.") + + LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.") diff --git a/src/axolotl/train.py b/src/axolotl/train.py index e8e314579..b0482bb1e 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -30,11 +30,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) from axolotl.integrations.base import PluginManager -from axolotl.loaders import ( - ModelLoader, - load_processor, - load_tokenizer, -) +from axolotl.loaders import ModelLoader, load_processor, load_tokenizer from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed @@ -234,16 +230,15 @@ def save_trained_model( # handle QAT if cfg.qat: - from axolotl.utils.quantization import convert_qat_model_for_ptq + from axolotl.utils.quantization import convert_qat_model - LOG.info("Processing QAT model for saving...") - convert_qat_model_for_ptq( + convert_qat_model( model, quantize_embedding=cfg.qat.quantize_embedding, ) LOG.info( - "QAT modules have been converted for PTQ. Please ensure you quantize " - "your model weights with `axolotl quantize`." + "QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`" + " with the same config which you used for training." ) # Handle ReLoRA early return case if cfg.relora: @@ -337,9 +332,7 @@ def save_trained_model( if hasattr(cfg, "llmcompressor") and cfg.llmcompressor: # TODO: add integration support so this can be implemented completely within the plugin - from axolotl.integrations.llm_compressor.utils import ( - save_compressed_model, - ) + from axolotl.integrations.llm_compressor.utils import save_compressed_model save_compressed_model( model=model, diff --git a/src/axolotl/utils/quantization.py b/src/axolotl/utils/quantization.py index f9a30b660..6c29a5442 100644 --- a/src/axolotl/utils/quantization.py +++ b/src/axolotl/utils/quantization.py @@ -3,30 +3,47 @@ Utilities for quantization including QAT and PTQ using torchao. """ import torch -from torch import nn +from packaging import version from torchao.core.config import AOBaseConfig from torchao.quantization import quantize_ from torchao.quantization.qat import ( - FakeQuantizeConfig, - FromIntXQuantizationAwareTrainingConfig, - IntXQuantizationAwareTrainingConfig, + QATConfig, ) from torchao.quantization.quant_api import ( - Int4DynamicActivationInt4WeightConfig, - Int4WeightOnlyConfig, + Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt4WeightConfig, - Int8DynamicActivationInt8WeightConfig, - Int8WeightOnlyConfig, - UIntXWeightOnlyConfig, - _is_linear, ) -from axolotl.utils.schemas.enums import TorchIntDType +from axolotl.utils.schemas.enums import TorchAOQuantDType + +quantization_config_to_str = { + Int8DynamicActivationInt4WeightConfig: "int8int4", + Float8DynamicActivationFloat8WeightConfig: "fp8fp8", + Float8DynamicActivationInt4WeightConfig: "fp8int4", +} + +if version.parse(torch.__version__) >= version.parse("2.8.0"): + try: + from torchao.prototype.mx_formats import NVFP4InferenceConfig + + quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4" + except: + pass + + # int4 weight config imports will fail on machines with fbgemm-gpu installed + # without a CUDA runtime available so we do this safely + try: + from torchao.quantization.quant_api import Int4WeightOnlyConfig + + quantization_config_to_str[Int4WeightOnlyConfig] = "int4" + except: + pass -def get_ptq_config( - weight_dtype: TorchIntDType, - activation_dtype: TorchIntDType | None = None, +def get_quantization_config( + weight_dtype: TorchAOQuantDType, + activation_dtype: TorchAOQuantDType | None = None, group_size: int | None = None, ) -> AOBaseConfig: """ @@ -45,44 +62,101 @@ def get_ptq_config( or if the group size is not specified for int8 or int4 weight only quantization. """ if activation_dtype is None: - if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr] - return UIntXWeightOnlyConfig( - dtype=weight_dtype.value, - group_size=group_size, - set_inductor_config=False, - ) - if weight_dtype == TorchIntDType.int8: - if group_size is None: - raise ValueError( - "group_size must be specified for int8 weight only quantization" - ) - return Int8WeightOnlyConfig( - group_size=group_size, - ) - if weight_dtype == TorchIntDType.int4: - if group_size is None: - raise ValueError( - "group_size must be specified for int4 weight only quantization" - ) - return Int4WeightOnlyConfig( - group_size=group_size, - ) - if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4: - return Int4DynamicActivationInt4WeightConfig() - if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8: - return Int8DynamicActivationInt8WeightConfig() - if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4: - return Int8DynamicActivationInt4WeightConfig() + if weight_dtype == TorchAOQuantDType.int8: + raise ValueError("Int8WeightOnlyConfig is not supported by torchao QAT.") + if weight_dtype == TorchAOQuantDType.int4: + from torchao.quantization.quant_api import Int4WeightOnlyConfig + + if group_size is not None: + return Int4WeightOnlyConfig(group_size=group_size, version=2) + else: + return Int4WeightOnlyConfig(version=2) + if ( + activation_dtype == TorchAOQuantDType.int4 + and weight_dtype == TorchAOQuantDType.int4 + ): + raise ValueError( + "Int4DynamicActivationInt4WeightConfig is not supported by torchao QAT." + ) + if ( + activation_dtype == TorchAOQuantDType.int8 + and weight_dtype == TorchAOQuantDType.int8 + ): + raise ValueError( + "Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT." + ) + if ( + activation_dtype == TorchAOQuantDType.int8 + and weight_dtype == TorchAOQuantDType.int4 + ): + if group_size is not None: + return Int8DynamicActivationInt4WeightConfig(group_size=group_size) + else: + return Int8DynamicActivationInt4WeightConfig() + if ( + activation_dtype == TorchAOQuantDType.float8_e4m3fn + and weight_dtype == TorchAOQuantDType.float8_e4m3fn + ): + return Float8DynamicActivationFloat8WeightConfig() + if ( + activation_dtype == TorchAOQuantDType.float8_e4m3fn + and weight_dtype == TorchAOQuantDType.int4 + ): + return Float8DynamicActivationInt4WeightConfig() + if weight_dtype == TorchAOQuantDType.nvfp4: + from torchao.prototype.mx_formats import NVFP4InferenceConfig + + if group_size is not None and group_size != 16: + raise ValueError("NVFP4 quantization must use a group_size of 16") + return NVFP4InferenceConfig() raise ValueError( f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}" ) +def quantize_model( + model, + weight_dtype: TorchAOQuantDType, + group_size: int | None = None, + activation_dtype: TorchAOQuantDType | None = None, + quantize_embedding: bool | None = None, +): + """ + This function is used to quantize a model. + + Args: + model: The model to quantize. + weight_dtype: The dtype to use for weight quantization. + group_size: The group size to use for weight quantization. + activation_dtype: The dtype to use for activation quantization. + quantize_embedding: Whether to quantize the model's embedding weights. + + """ + linear_ptq_config = get_quantization_config( + weight_dtype=weight_dtype, + activation_dtype=activation_dtype, + group_size=group_size, + ) + quantize_(model, linear_ptq_config) + if quantize_embedding: + # activation fake quantization is not supported for embedding layers + embedding_quantize_config = get_quantization_config( + weight_dtype=weight_dtype, + activation_dtype=None, + group_size=group_size, + ) + quantize_( + model, + embedding_quantize_config, + filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), + ) + + def prepare_model_for_qat( model, - weight_dtype: TorchIntDType, - group_size: int, - activation_dtype: TorchIntDType | None = None, + weight_dtype: TorchAOQuantDType, + group_size: int | None = None, + activation_dtype: TorchAOQuantDType | None = None, quantize_embedding: bool = False, ): """ @@ -100,86 +174,40 @@ def prepare_model_for_qat( Raises: ValueError: If the activation/weight dtype combination is invalid. """ - if activation_dtype: - activation_config = FakeQuantizeConfig( - dtype=activation_dtype.value, granularity="per_token", is_symmetric=False - ) - weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size) - linear_quantize_config = IntXQuantizationAwareTrainingConfig( - activation_config=None if activation_dtype is None else activation_config, - weight_config=weight_config, - ) - quantize_(model, linear_quantize_config) - if quantize_embedding: - # activation fake quantization is not supported for embedding layers - embedding_quantize_config = IntXQuantizationAwareTrainingConfig( - activation_config=None, - weight_config=weight_config, - ) - quantize_( - model, - embedding_quantize_config, - filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), - ) - - -def quantize_model_for_ptq( - model, - weight_dtype: TorchIntDType, - group_size: int | None = None, - activation_dtype: TorchIntDType | None = None, - quantize_embedding: bool | None = None, -): - """ - This function is used to quantize a model for post-training quantization. - It swaps the model's linear layers with fake quantized linear layers. - If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights. - - Args: - model: The model to quantize. - weight_dtype: The dtype to use for weight quantization. - group_size: The group size to use for weight quantization. - activation_dtype: The dtype to use for activation quantization. - quantize_embedding: Whether to quantize the model's embedding weights. - - """ - linear_ptq_config = get_ptq_config( + base_config = get_quantization_config( weight_dtype=weight_dtype, activation_dtype=activation_dtype, group_size=group_size, ) - quantize_(model, linear_ptq_config) + qat_config = QATConfig(base_config) + quantize_(model, qat_config) if quantize_embedding: - embedding_quantize_config = get_ptq_config( + # activation fake quantization is not supported for embedding layers + embedding_base_config = get_quantization_config( weight_dtype=weight_dtype, activation_dtype=None, group_size=group_size, ) + embedding_qat_config = QATConfig(embedding_base_config) quantize_( model, - embedding_quantize_config, + embedding_qat_config, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), ) -def convert_qat_model_for_ptq( +def convert_qat_model( model, - *, - quantize_embedding: bool | None = None, + quantize_embedding: bool = False, ): """ - This function is used to convert a swap fake-quantized modules in a model - which has been trained with QAT back to the original modules, ready for PTQ. - - Args: - model: The model to convert. - quantize_embedding: Whether to quantize the model's embedding weights. + This function converts a QAT model which has fake quantized layers back to the original model. """ + config = QATConfig(step="convert") + quantize_(model, config) if quantize_embedding: - - def filter_fn(m, _): - return isinstance(m, nn.Embedding) or _is_linear(m) - - else: - filter_fn = _is_linear - quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn) + quantize_( + model, + config, + filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), + ) diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 8f4718aa9..bcd03e1a2 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -5,18 +5,21 @@ from enum import Enum import torch -class TorchIntDType(Enum): - """Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4""" +class TorchAOQuantDType(Enum): + int4 = torch.int4 + int8 = torch.int8 + float8_e4m3fn = torch.float8_e4m3fn + nvfp4 = "nvfp4" - uint1 = getattr(torch, "uint1", None) - uint2 = getattr(torch, "uint2", None) - uint3 = getattr(torch, "uint3", None) - uint4 = getattr(torch, "uint4", None) - uint5 = getattr(torch, "uint5", None) - uint6 = getattr(torch, "uint6", None) - uint7 = getattr(torch, "uint7", None) - int4 = getattr(torch, "int4", None) - int8 = getattr(torch, "int8", None) + def from_string(str): + if str == "int4": + return TorchAOQuantDType.int4 + if str == "int8": + return TorchAOQuantDType.int8 + if str in ["float8_e4m3fn", "fp8", "float8"]: + return TorchAOQuantDType.float8_e4m3fn + if str == "nvfp4": + return TorchAOQuantDType.nvfp4 class RLType(str, Enum): diff --git a/src/axolotl/utils/schemas/quantization.py b/src/axolotl/utils/schemas/quantization.py index 090640c7b..a7c130574 100644 --- a/src/axolotl/utils/schemas/quantization.py +++ b/src/axolotl/utils/schemas/quantization.py @@ -6,7 +6,23 @@ from typing import Any from pydantic import BaseModel, Field, field_validator -from axolotl.utils.schemas.enums import TorchIntDType +from axolotl.utils.schemas.enums import TorchAOQuantDType + + +def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None: + if v is None: + return None + if v == "int4": + return TorchAOQuantDType.int4 + if v == "int8": + return TorchAOQuantDType.int8 + if v in ["float8_e4m3fn", "fp8", "float8"]: + return TorchAOQuantDType.float8_e4m3fn + if v == "nvfp4": + return TorchAOQuantDType.nvfp4 + raise ValueError( + f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}" + ) class QATConfig(BaseModel): @@ -14,13 +30,13 @@ class QATConfig(BaseModel): QAT Config Schema """ - activation_dtype: TorchIntDType | None = Field( + activation_dtype: TorchAOQuantDType | None = Field( default=None, - description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"', + description="Fake quantization layout to use for activation quantization.", ) - weight_dtype: TorchIntDType = Field( - default=TorchIntDType.int8, - description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"', + weight_dtype: TorchAOQuantDType = Field( + default=TorchAOQuantDType.int8, + description="Fake quantization layout to use for weight quantization.", ) quantize_embedding: bool | None = Field( default=False, description="Quantize embedding" @@ -35,12 +51,8 @@ class QATConfig(BaseModel): @field_validator("activation_dtype", "weight_dtype", mode="before") @classmethod - def validate_dtype(cls, v: Any) -> TorchIntDType | None: - if v == "int4": - return TorchIntDType.int4 - if v == "int8": - return TorchIntDType.int8 - raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']") + def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None: + return validate_ao_dtype(v) class PTQConfig(BaseModel): @@ -48,13 +60,13 @@ class PTQConfig(BaseModel): PTQ Config Schema """ - weight_dtype: TorchIntDType = Field( - default=TorchIntDType.int8, - description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8", + weight_dtype: TorchAOQuantDType = Field( + default=TorchAOQuantDType.int8, + description="Fake quantization layout to use for weight quantization.", ) - activation_dtype: TorchIntDType | None = Field( + activation_dtype: TorchAOQuantDType | None = Field( default=None, - description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"', + description="Fake quantization layout to use for activation quantization.", ) quantize_embedding: bool | None = Field( default=None, description="Whether to quantize the embedding layer." @@ -66,9 +78,5 @@ class PTQConfig(BaseModel): @field_validator("activation_dtype", "weight_dtype", mode="before") @classmethod - def validate_dtype(cls, v: Any) -> TorchIntDType | None: - if v == "int4": - return TorchIntDType.int4 - if v == "int8": - return TorchIntDType.int8 - raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']") + def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None: + return validate_ao_dtype(v) diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py index 7d41dfb50..2f8398ef7 100644 --- a/tests/e2e/test_qat.py +++ b/tests/e2e/test_qat.py @@ -43,7 +43,7 @@ class TestQATLlama: "qat": { "quantize_embedding": True, "activation_dtype": "int8", - "weight_dtype": "int8", + "weight_dtype": "int4", "group_size": 8, }, "num_epochs": 1, @@ -111,7 +111,7 @@ class TestQATLlama: "qat": { "quantize_embedding": True, "activation_dtype": "int8", - "weight_dtype": "int8", + "weight_dtype": "int4", "group_size": 8, }, "save_first_step": False, diff --git a/tests/e2e/test_quantization.py b/tests/e2e/test_quantization.py index cfbdfec38..b64aef51a 100644 --- a/tests/e2e/test_quantization.py +++ b/tests/e2e/test_quantization.py @@ -5,41 +5,40 @@ Tests for axolotl.utils.quantization import pytest import torch from torch import nn -from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor -from torchao.quantization.granularity import PerAxis, PerGroup -from torchao.quantization.linear_activation_quantized_tensor import ( - LinearActivationQuantizedTensor, -) +from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization.qat.embedding import FakeQuantizedEmbedding from torchao.quantization.qat.linear import FakeQuantizedLinear from torchao.quantization.quant_api import ( - Int4DynamicActivationInt4WeightConfig, - Int4WeightOnlyConfig, - Int8DynamicActivationInt8WeightConfig, - Int8WeightOnlyConfig, - UIntXWeightOnlyConfig, + Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt4WeightConfig, ) +from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor from transformers import AutoModelForCausalLM from transformers.trainer_callback import TrainerState from axolotl.utils.callbacks.qat import QATCallback from axolotl.utils.quantization import ( - convert_qat_model_for_ptq, - get_ptq_config, + convert_qat_model, + get_quantization_config, prepare_model_for_qat, - quantize_model_for_ptq, + quantize_model, ) -from axolotl.utils.schemas.enums import TorchIntDType +from axolotl.utils.schemas.enums import TorchAOQuantDType from axolotl.utils.schemas.quantization import QATConfig -from tests.e2e.utils import require_torch_2_6_0 +from tests.e2e.utils import ( + require_torch_2_8_0, + requires_cuda_ge_8_9, + requires_sm_ge_100, +) @pytest.fixture() def model(): dummy_model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceTB/SmolLM2-135M", - device_map="cuda", + "Qwen/Qwen2-0.5B", + device_map="auto", torch_dtype=torch.bfloat16, ) with torch.device(dummy_model.device): @@ -48,45 +47,56 @@ def model(): dummy_model.model.embed_tokens.weight.shape[1], dtype=dummy_model.model.embed_tokens.weight.dtype, ) - return dummy_model + yield dummy_model + del dummy_model ptq_config_test_cases = [ - # weight_dtype, activation_dtype, group_size, expected_type, expected_params + # weight_dtype, activation_dtype, group_size, expected_type ( - TorchIntDType.uint4, + TorchAOQuantDType.int4, + TorchAOQuantDType.int8, None, - None, - UIntXWeightOnlyConfig, - {"dtype": torch.uint4, "group_size": None}, - ), - (TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}), - (TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}), - ( - TorchIntDType.int4, - TorchIntDType.int4, - None, - Int4DynamicActivationInt4WeightConfig, - {}, + Int8DynamicActivationInt4WeightConfig, ), ( - TorchIntDType.int8, - TorchIntDType.int8, + TorchAOQuantDType.float8_e4m3fn, + TorchAOQuantDType.float8_e4m3fn, None, - Int8DynamicActivationInt8WeightConfig, - {}, + Float8DynamicActivationFloat8WeightConfig, + ), + ( + TorchAOQuantDType.int4, + TorchAOQuantDType.float8_e4m3fn, + None, + Float8DynamicActivationInt4WeightConfig, ), ] ptq_test_cases = [ - # weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception - (TorchIntDType.int8, None, 8, False, None), - (TorchIntDType.int4, None, 4, True, None), - (TorchIntDType.uint4, None, 8, False, None), - (TorchIntDType.int4, TorchIntDType.int4, 8, False, None), - (TorchIntDType.int8, TorchIntDType.int8, 8, True, None), - (TorchIntDType.int8, None, None, False, ValueError), - (TorchIntDType.int4, None, None, False, ValueError), + # weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception, expected_tensor_class + (TorchAOQuantDType.int4, None, 4, True, None, Int4Tensor), + ( + TorchAOQuantDType.int4, + TorchAOQuantDType.int8, + 8, + False, + None, + LinearActivationQuantizedTensor, + ), + # ( + # TorchAOQuantDType.int4, + # TorchAOQuantDType.float8_e4m3fn, + # None, + # False, + # None, + # Int4Tensor, + # ), + (TorchAOQuantDType.int4, None, None, False, None, Int4Tensor), + # Deprecated configs + (TorchAOQuantDType.int8, None, 8, False, ValueError, None), + (TorchAOQuantDType.int4, TorchAOQuantDType.int4, 8, False, ValueError, None), + (TorchAOQuantDType.int8, TorchAOQuantDType.int8, 8, True, ValueError, None), ] @@ -96,44 +106,132 @@ class TestQuantization: """ @pytest.mark.parametrize( - "weight_dtype,activation_dtype,group_size,expected_type,expected_params", + "weight_dtype,activation_dtype,group_size,expected_type", ptq_config_test_cases, ) - @require_torch_2_6_0 + @requires_cuda_ge_8_9 + @require_torch_2_8_0 def test_get_ptq_config( - self, weight_dtype, activation_dtype, group_size, expected_type, expected_params + self, weight_dtype, activation_dtype, group_size, expected_type ): - config = get_ptq_config(weight_dtype, activation_dtype, group_size) - + config = get_quantization_config(weight_dtype, activation_dtype, group_size) assert isinstance(config, expected_type) - for param_name, param_value in expected_params.items(): - if isinstance(param_value, (PerAxis, PerGroup)): - if isinstance(param_value, PerAxis): - assert isinstance(getattr(config, param_name), PerAxis) - assert getattr(config, param_name).axis == param_value.axis - else: - assert isinstance(getattr(config, param_name), PerGroup) - assert ( - getattr(config, param_name).group_size == param_value.group_size - ) - else: - assert getattr(config, param_name) == param_value + @requires_cuda_ge_8_9 + @require_torch_2_8_0 + def test_get_ptq_config_int4_weight_only(self): + from torchao.quantization.quant_api import Int4WeightOnlyConfig + + config = get_quantization_config(TorchAOQuantDType.int4, None, 4) + assert isinstance(config, Int4WeightOnlyConfig) @pytest.mark.parametrize( - "weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4] + "weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception,expected_tensor_class", + ptq_test_cases, ) + @requires_cuda_ge_8_9 + @require_torch_2_8_0 + def test_quantize_model_for_ptq( + self, + model, + weight_dtype, + activation_dtype, + group_size, + quantize_embedding, + expected_exception, + expected_tensor_class, + ): + if expected_exception: + with pytest.raises(expected_exception): + quantize_model( + model, + weight_dtype, + group_size, + activation_dtype, + quantize_embedding, + ) + else: + quantize_model( + model, weight_dtype, group_size, activation_dtype, quantize_embedding + ) + if quantize_embedding: + assert isinstance( + model.model.embed_tokens.weight, expected_tensor_class + ), "Embedding weight should be quantized" + for child in list(model.children()): + if isinstance(child, torch.nn.Linear): + assert isinstance(child.weight, expected_tensor_class) + + @require_torch_2_8_0 + @requires_sm_ge_100 + def test_quantize_model_for_ptq_fp8( + self, + model, + ): + from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( + Float8Tensor, + QuantizeTensorToFloat8Kwargs, + ) + + quantize_model( + model, + TorchAOQuantDType.float8_e4m3fn, + None, + TorchAOQuantDType.float8_e4m3fn, + ) + for child in list(model.children()): + if isinstance(child, torch.nn.Linear): + assert isinstance(child.weight, Float8Tensor) + assert child.weight.act_quant_kwargs is not None and isinstance( + child.weight.act_quant_kwargs, QuantizeTensorToFloat8Kwargs + ) + + @require_torch_2_8_0 + @requires_sm_ge_100 + def test_quantize_model_for_ptq_nvfp4( + self, + model, + ): + from torchao.prototype.mx_formats.nvfp4_tensor import ( + NVFP4Tensor, + QuantizeTensorToNVFP4Kwargs, + ) + + quantize_model(model, TorchAOQuantDType.nvfp4, 16, TorchAOQuantDType.nvfp4) + for child in list(model.children()): + if isinstance(child, torch.nn.Linear): + assert isinstance(child.weight, NVFP4Tensor) + assert child.weight.act_quant_kwargs is not None and isinstance( + child.weight.act_quant_kwargs, QuantizeTensorToNVFP4Kwargs + ) + @pytest.mark.parametrize( - "activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8] + "weight_dtype,activation_dtype,group_size,quantize_embedding", + [ + (TorchAOQuantDType.int4, None, 8, False), + (TorchAOQuantDType.int4, None, 16, True), + (TorchAOQuantDType.int4, TorchAOQuantDType.int8, 8, False), + (TorchAOQuantDType.int4, TorchAOQuantDType.int8, 16, True), + ( + TorchAOQuantDType.float8_e4m3fn, + TorchAOQuantDType.float8_e4m3fn, + None, + False, + ), + (TorchAOQuantDType.int4, TorchAOQuantDType.float8_e4m3fn, None, True), + ], ) - @pytest.mark.parametrize("group_size", [4, 8]) - @pytest.mark.parametrize("quantize_embedding", [False, True]) - @require_torch_2_6_0 + @require_torch_2_8_0 + @requires_cuda_ge_8_9 def test_prepare_model_for_qat( self, model, weight_dtype, activation_dtype, group_size, quantize_embedding ): prepare_model_for_qat( - model, weight_dtype, group_size, activation_dtype, quantize_embedding + model, + weight_dtype, + group_size, + activation_dtype, + quantize_embedding, ) if quantize_embedding: assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) @@ -142,17 +240,19 @@ class TestQuantization: model.model.embed_tokens.weight_fake_quantizer.config.dtype == weight_dtype.value ) - assert ( - model.model.embed_tokens.weight_fake_quantizer.config.group_size - == group_size - ) + if group_size: + assert ( + model.model.embed_tokens.weight_fake_quantizer.config.group_size + == group_size + ) for child in list(model.children()): if isinstance(child, torch.nn.Linear): assert isinstance(child, FakeQuantizedLinear) assert hasattr(child, "weight_fake_quantizer") assert child.weight_fake_quantizer.config.dtype == weight_dtype.value - assert child.weight_fake_quantizer.config.group_size == group_size + if group_size: + assert child.weight_fake_quantizer.config.group_size == group_size if activation_dtype: assert hasattr(child, "activation_fake_quantizer") assert ( @@ -162,49 +262,40 @@ class TestQuantization: else: assert child.activation_fake_quantizer is None - @pytest.mark.parametrize( - "weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception", - ptq_test_cases, - ) - @require_torch_2_6_0 - def test_quantize_model_for_ptq( - self, - model, - weight_dtype, - activation_dtype, - group_size, - quantize_embedding, - expected_exception, - ): - if expected_exception: - with pytest.raises(expected_exception): - quantize_model_for_ptq( - model, - weight_dtype, - group_size, - activation_dtype, - quantize_embedding, - ) - else: - quantize_model_for_ptq( - model, weight_dtype, group_size, activation_dtype, quantize_embedding - ) - if quantize_embedding: - assert isinstance( - model.model.embed_tokens.weight, AffineQuantizedTensor - ), "Embedding weight should be quantized" - for child in list(model.children()): - if isinstance(child, torch.nn.Linear): - if activation_dtype: - assert isinstance( - child.weight, LinearActivationQuantizedTensor - ), ( - "Linear weight should be quantized with activation quantization" - ) - else: - assert isinstance(child.weight, AffineQuantizedTensor), ( - "Linear weight should be quantized without activation quantization" - ) + @require_torch_2_8_0 + @requires_cuda_ge_8_9 + def test_convert_qat_model(self, model): + config = QATConfig( + weight_dtype="int4", + activation_dtype="int8", + group_size=8, + quantize_embedding=True, + ) + + # quantize model for qat + prepare_model_for_qat( + model, + config.weight_dtype, + config.group_size, + config.activation_dtype, + config.quantize_embedding, + ) + + assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) + assert isinstance(model.lm_head, FakeQuantizedLinear) + + # apply conversion + convert_qat_model( + model, + config.quantize_embedding, + ) + # ensure modules have been swapped out + assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) + assert not isinstance(model.lm_head, FakeQuantizedLinear) + + # ensure weights have been quantized + assert isinstance(model.model.embed_tokens.weight, nn.Parameter) + assert isinstance(model.lm_head.weight, nn.Parameter) class TestQuantizationCallback: @@ -218,10 +309,10 @@ class TestQuantizationCallback: global_step=0, ) - @require_torch_2_6_0 + @require_torch_2_8_0 def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state): cfg = QATConfig( - weight_dtype="int8", + weight_dtype="int4", activation_dtype="int8", group_size=8, quantize_embedding=True, @@ -268,10 +359,10 @@ class TestQuantizationCallback: assert model.model.embed_tokens.weight_fake_quantizer.enabled assert model.lm_head.weight_fake_quantizer.enabled - @require_torch_2_6_0 + @require_torch_2_8_0 def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state): cfg = QATConfig( - weight_dtype="int8", + weight_dtype="int4", activation_dtype="int8", group_size=8, quantize_embedding=True, @@ -304,43 +395,3 @@ class TestQuantizationCallback: # quantization should be enabled from the get-go assert model.model.embed_tokens.weight_fake_quantizer.enabled assert model.lm_head.weight_fake_quantizer.enabled - - -class TestConvertQATModelForPTQ: - """ - Test convert_qat_model_for_ptq - """ - - @require_torch_2_6_0 - def test_convert_qat_model_for_ptq(self, model): - config = QATConfig( - weight_dtype="int8", - activation_dtype="int8", - group_size=8, - quantize_embedding=True, - ) - - # quantize model for qat - prepare_model_for_qat( - model, - config.weight_dtype, - config.group_size, - config.activation_dtype, - config.quantize_embedding, - ) - - assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) - assert isinstance(model.lm_head, FakeQuantizedLinear) - - # apply conversion - convert_qat_model_for_ptq( - model, - quantize_embedding=config.quantize_embedding, - ) - # ensure modules have been swapped out - assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) - assert not isinstance(model.lm_head, FakeQuantizedLinear) - - # ensure weights have been quantized - assert isinstance(model.model.embed_tokens.weight, nn.Parameter) - assert isinstance(model.lm_head.weight, nn.Parameter) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 7db6cf74e..a2dd8bc5e 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -90,6 +90,18 @@ def require_torch_2_7_0(test_case): return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case) +def require_torch_2_8_0(test_case): + """ + Decorator marking a test that requires torch >= 2.7.0 + """ + + def is_min_2_8_0(): + torch_version = version.parse(torch.__version__) + return torch_version >= version.parse("2.8.0") + + return unittest.skipUnless(is_min_2_8_0(), "test requires torch>=2.8.0")(test_case) + + def require_torch_lt_2_6_0(test_case): """ Decorator marking a test that requires torch < 2.6.0 @@ -128,6 +140,24 @@ def require_llmcompressor(test_case): )(test_case) +def requires_sm_ge_100(test_case): + is_sm_ge_100 = ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (10, 0) + ) + return unittest.skipUnless(is_sm_ge_100, "test requires sm>=100")(test_case) + + +def requires_cuda_ge_8_9(test_case): + is_cuda_ge_8_9 = ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (8, 9) + ) + return unittest.skipUnless(is_cuda_ge_8_9, "test requires cuda>=8.9")(test_case) + + def is_hopper(): compute_capability = torch.cuda.get_device_capability() return compute_capability == (9, 0)