From 5fca214108af6ae51f2dfc2cc035e8accde94e9b Mon Sep 17 00:00:00 2001 From: salman Date: Wed, 28 May 2025 12:35:47 +0100 Subject: [PATCH] QAT (#2590) QAT and quantization w/torchao --- _quarto.yml | 6 +- docs/cli.qmd | 10 + docs/config.qmd | 14 + docs/qat.qmd | 32 ++ docs/quantize.qmd | 53 +++ examples/llama-3/3b-qat-fsdp2.yaml | 79 +++++ examples/qwen3/8b-qat-fsdp2.yml | 78 +++++ requirements.txt | 2 +- src/axolotl/cli/args.py | 12 + src/axolotl/cli/main.py | 11 + src/axolotl/cli/quantize.py | 90 +++++ src/axolotl/core/trainer_builder.py | 4 + src/axolotl/loaders/model.py | 14 + src/axolotl/loaders/patch_manager.py | 4 +- src/axolotl/monkeypatch/accelerate/fsdp2.py | 159 ++++++++- src/axolotl/train.py | 18 +- src/axolotl/utils/callbacks/qat.py | 50 +++ src/axolotl/utils/distributed.py | 7 +- src/axolotl/utils/quantization.py | 189 +++++++++++ src/axolotl/utils/schemas/config.py | 44 ++- src/axolotl/utils/schemas/enums.py | 16 + src/axolotl/utils/schemas/quantization.py | 64 ++++ tests/e2e/multigpu/solo/test_grpo.py | 6 +- tests/e2e/test_qat.py | 71 ++++ tests/e2e/test_quantization.py | 350 ++++++++++++++++++++ tests/e2e/utils.py | 2 - 26 files changed, 1372 insertions(+), 13 deletions(-) create mode 100644 docs/qat.qmd create mode 100644 docs/quantize.qmd create mode 100644 examples/llama-3/3b-qat-fsdp2.yaml create mode 100644 examples/qwen3/8b-qat-fsdp2.yml create mode 100644 src/axolotl/cli/quantize.py create mode 100644 src/axolotl/utils/callbacks/qat.py create mode 100644 src/axolotl/utils/quantization.py create mode 100644 src/axolotl/utils/schemas/quantization.py create mode 100644 tests/e2e/test_qat.py create mode 100644 tests/e2e/test_quantization.py diff --git a/_quarto.yml b/_quarto.yml index 696cedd51..e05a1c35c 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -43,6 +43,7 @@ quartodoc: - cli.vllm_serve - cli.cloud.base - cli.cloud.modal_ + - cli.quantize - title: Trainers desc: Training implementations contents: @@ -147,6 +148,7 @@ quartodoc: - utils.optimizers.adopt - utils.data.pretraining - utils.data.sft + - utils.quantization - title: Schemas desc: Pydantic data models for Axolotl config contents: @@ -196,7 +198,7 @@ quartodoc: - utils.callbacks.lisa - utils.callbacks.mlflow_ - utils.callbacks.comet_ - + - utils.callbacks.qat website: title: "Axolotl" description: "We make fine-tuning accessible, scalable, and fun" @@ -256,6 +258,8 @@ website: - docs/lr_groups.qmd - docs/lora_optims.qmd - docs/dataset_loading.qmd + - docs/qat.qmd + - docs/quantize.qmd - section: "Core Concepts" contents: diff --git a/docs/cli.qmd b/docs/cli.qmd index 1003a210c..f6f9b3481 100644 --- a/docs/cli.qmd +++ b/docs/cli.qmd @@ -209,6 +209,16 @@ axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing. +### quantize + +Quantizes a model using the quantization configuration specified in your YAML file. + +```bash +axolotl quantize config.yml +``` + +See [Quantization](./quantize.qmd) for more details. + ## Legacy CLI Usage diff --git a/docs/config.qmd b/docs/config.qmd index 369d3db43..5a36ca8aa 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -65,6 +65,20 @@ bnb_config_kwargs: bnb_4bit_quant_type: nf4 bnb_4bit_use_double_quant: true +# quantization aware training +qat: + activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8" + weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8" + group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization + fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after + +# post-training quantization +quantization: + weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8 + activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8" + group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization + quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer. + # Whether you are training a 4-bit GPTQ quantized model gptq: true diff --git a/docs/qat.qmd b/docs/qat.qmd new file mode 100644 index 000000000..0531388de --- /dev/null +++ b/docs/qat.qmd @@ -0,0 +1,32 @@ +--- +title: "Quantization Aware Training (QAT)" +back-to-top-navigation: true +toc: true +toc-expand: 2 +toc-depth: 4 +--- + +## Overview + +[Quantization Aware Training](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/#quantization-aware-training) (QAT) is a technique for improving the accuracy of models which are quantized +by applying "fake" quantizations to the model's weights (and optionally, activations) during training. This fake +quantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually +quantized, the accuracy loss is minimized. We use the quantization techniques implemented in [torchao](https://github.com/pytorch/ao) to provide +support for QAT and post-training quantization (PTQ) in axolotl. + +We recommend reviewing the excellent QAT tutorial in the [torchtune library](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#quantizing-the-qat-model), +and the QAT documentation in the [torchao library](https://github.com/pytorch/ao/tree/main/torchao/quantization/qat), for more details. + +## Configuring QAT in Axolotl + +To enable QAT in axolotl, add the following to your configuration file: + +```yaml +qat: + activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8" + weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8" + group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization + fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after +``` + +Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize` command](./quantize.md) to do this. diff --git a/docs/quantize.qmd b/docs/quantize.qmd new file mode 100644 index 000000000..294efda8b --- /dev/null +++ b/docs/quantize.qmd @@ -0,0 +1,53 @@ +--- +title: "Quantization with torchao" +back-to-top-navigation: true +toc: true +toc-expand: 2 +toc-depth: 4 +--- + +Quantization is a technique to lower the memory footprint of your model, potentially at the cost of accuracy or model performance. We support quantizing your model using the [torchao](https://github.com/pytorch/ao) library. Quantization is supported for both post-training quantization (PTQ) and quantization-aware training (QAT). + + +::: {.callout-note} + +We do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment. + +::: + +## Configuring Quantization in Axolotl + +Quantization is configured using the `quantization` key in your configuration file. + +```yaml +base_model: # The path to the model to quantize. +quantization: + weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8 + activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8" + group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization + quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer. + +output_dir: # The path to the output directory. +``` + +Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory. + +You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which +you used to train the model: + +```yaml +# qat.yml +qat: + activation_dtype: int8 + weight_dtype: int8 + group_size: 256 + quantize_embedding: true + +output_dir: # The path to the output directory used during training where the final checkpoint has been saved. +``` + +```bash +axolotl quantize qat.yml +``` + +This ensures that an identical quantization configuration is used to quantize the model as was used to train it. diff --git a/examples/llama-3/3b-qat-fsdp2.yaml b/examples/llama-3/3b-qat-fsdp2.yaml new file mode 100644 index 000000000..5d979c96c --- /dev/null +++ b/examples/llama-3/3b-qat-fsdp2.yaml @@ -0,0 +1,79 @@ +base_model: meta-llama/Llama-3.2-3B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true + +datasets: + - path: yahma/alpaca-cleaned + type: alpaca + +output_dir: ./outputs/qat_out/ + +sample_packing: true +pad_to_sequence_len: true +sequence_len: 512 + +flex_attention: true +flex_attn_compile_kwargs: + dynamic: false + mode: max-autotune-no-cudagraphs + +qat: + activation_dtype: int8 + weight_dtype: int4 + group_size: 32 + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 +num_epochs: 1 +optimizer: adamw_torch_fused + +cosine_constant_lr_ratio: 0 +cosine_min_lr_ratio: 1.0 +learning_rate: 2e-5 +save_only_model: true +bf16: true + +resume_from_checkpoint: +logging_steps: 1 + +evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_steps: 10 +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap + +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD + fsdp_reshard_after_forward: true + fsdp_activation_checkpointing: true + +special_tokens: + pad_token: <|end_of_text|> diff --git a/examples/qwen3/8b-qat-fsdp2.yml b/examples/qwen3/8b-qat-fsdp2.yml new file mode 100644 index 000000000..6832b6af7 --- /dev/null +++ b/examples/qwen3/8b-qat-fsdp2.yml @@ -0,0 +1,78 @@ +base_model: Qwen/Qwen3-8B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true + +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/qat_out/ + +sequence_len: 2048 +sample_packing: true +flex_attention: true +pad_to_sequence_len: true + +flex_attn_compile_kwargs: + dynamic: false + mode: max-autotune-no-cudagraphs + +qat: + activation_dtype: int8 + weight_dtype: int4 + group_size: 256 + fake_quant_after_n_steps: 1000 + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 2 +max_steps: 2000 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_steps: 10 +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap + +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD + fsdp_reshard_after_forward: true + fsdp_activation_checkpointing: true + +special_tokens: diff --git a/requirements.txt b/requirements.txt index 4ae82dd49..4e632b0f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,7 +63,7 @@ langdetect==1.0.9 immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 -torchao==0.9.0 +torchao==0.10.0 schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.6 diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index 088e337e4..4be3704ac 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -90,6 +90,18 @@ class VllmServeCliArgs: ) +@dataclass +class QuantizeCliArgs: + """Dataclass with CLI arguments for `axolotl quantize` command.""" + + base_model: Optional[str] = field(default=None) + weight_dtype: Optional[str] = field(default=None) + activation_dtype: Optional[str] = field(default=None) + quantize_embedding: Optional[bool] = field(default=None) + group_size: Optional[int] = field(default=None) + output_dir: Optional[str] = field(default=None) + + @dataclass class EvaluateCliArgs: """Dataclass with CLI arguments for `axolotl evaluate` command.""" diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 601add709..e61dad5d6 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -17,6 +17,7 @@ import axolotl from axolotl.cli.args import ( EvaluateCliArgs, PreprocessCliArgs, + QuantizeCliArgs, TrainerCliArgs, VllmServeCliArgs, ) @@ -333,6 +334,16 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs): do_vllm_serve(config, cli_args) +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@add_options_from_dataclass(QuantizeCliArgs) +@filter_none_kwargs +def quantize(config: str, **cli_args: QuantizeCliArgs): + from axolotl.cli.quantize import do_quantize + + do_quantize(config, cli_args) + + @cli.command() @click.argument("model", type=click.Path(exists=True, path_type=str)) @click.argument("output", type=click.Path(exists=False, path_type=str)) diff --git a/src/axolotl/cli/quantize.py b/src/axolotl/cli/quantize.py new file mode 100644 index 000000000..2036fddea --- /dev/null +++ b/src/axolotl/cli/quantize.py @@ -0,0 +1,90 @@ +""" +CLI to post-training quantize a model using torchao +""" + +import logging +from pathlib import Path +from typing import Union + +from transformers import AutoModelForCausalLM + +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.config import load_cfg +from axolotl.loaders import load_tokenizer +from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq + +LOG = logging.getLogger(__name__) + + +def do_quantize( + config: Union[Path, str], + cli_args: dict, +): + """ + Quantizes a model's model's weights + + Args: + config (Union[Path, str]): The path to the config file + cli_args (dict): Additional command-line arguments + """ + print_axolotl_text_art() + + cfg = load_cfg(config) + + if cfg.qat and cfg.quantization: + raise ValueError( + "QAT and quantization cannot be used together. Please specify only one of qat or quantization in your config file." + ) + + if cfg.qat: + quantize_cfg = cfg.qat + elif cfg.quantization: + quantize_cfg = cfg.quantization + else: + raise ValueError( + "No quantization configuration found. Please specify either qat or quantization in your config file." + ) + + model_path = cli_args.get("model_path") or cfg.output_dir + if weight_dtype := cli_args.get("weight_dtype"): + weight_dtype = TorchIntDType[weight_dtype] + else: + weight_dtype = quantize_cfg.weight_dtype + if activation_dtype := cli_args.get("activation_dtype"): + activation_dtype = TorchIntDType[activation_dtype] + else: + activation_dtype = quantize_cfg.activation_dtype + group_size = cli_args.get("group_size") or quantize_cfg.group_size + quantize_embedding = ( + cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding + ) + output_dir = cli_args.get("output_dir") or cfg.output_dir + + LOG.info(f"Loading model from {model_path}...") + tokenizer = load_tokenizer(cfg) + model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") + + LOG.info( + f"Quantizing model with configuration: \n" + f"\tweight_dtype: {weight_dtype}\n" + f"\tactivation_dtype: {activation_dtype}\n" + f"\tgroup_size: {group_size}\n" + f"\tquantize_embedding: {quantize_embedding}" + ) + + quantize_model_for_ptq( + model, weight_dtype, group_size, activation_dtype, quantize_embedding + ) + + LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...") + model.save_pretrained( + str(Path(output_dir) / "quantized"), + safe_serialization=False, + progressbar=True, + ) + tokenizer.save_pretrained( + str(Path(output_dir) / "quantized"), + safe_serialization=False, + progressbar=True, + ) + LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...") diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 9709f0fd4..08759d9f9 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -79,6 +79,7 @@ from axolotl.utils.callbacks import ( ) from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.profiler import PytorchProfilerCallback +from axolotl.utils.callbacks.qat import QATCallback from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, @@ -254,6 +255,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) + if self.cfg.qat: + callbacks.append(QATCallback(self.cfg.qat)) + return callbacks def get_post_trainer_create_callbacks(self, trainer): diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index d7ac84a6d..8d8f927a7 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -191,6 +191,7 @@ class ModelLoader: self._adjust_model_config() self._log_memory_usage() self._configure_embedding_dtypes() + self._configure_qat() def _resize_token_embeddings(self): """Resize token embeddings if needed.""" @@ -305,6 +306,19 @@ class ModelLoader: before_kbit_train_or_finetune=False, ) + def _configure_qat(self): + """Configure QAT.""" + if self.cfg.qat: + from axolotl.utils.quantization import prepare_model_for_qat + + prepare_model_for_qat( + self.model, + self.cfg.qat.weight_dtype, + self.cfg.qat.group_size, + self.cfg.qat.activation_dtype, + self.cfg.qat.quantize_embedding, + ) + def _load_adapters(self) -> PeftConfig | None: """Load LoRA or other adapters.""" # Load LoRA or adapter diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index f251f958d..36813bafd 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -80,9 +80,9 @@ class PatchManager: def _apply_fsdp_patches(self): """Apply patches for FSDP configurations.""" if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": - from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils + from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2 - patch_accelerate_fsdp_utils() + patch_accelerate_fsdp2() def _apply_adapter_patches(self): """Apply patches for adapter configurations.""" diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index d8ec00c69..ffde17aeb 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -1,5 +1,5 @@ """ -monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation +monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts """ import logging @@ -52,7 +52,146 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic model.load_state_dict(sharded_sd, assign=True) -def patch_accelerate_fsdp_utils(): +def set_state_dict_type(self, state_dict_type=None): + """ + Set the state dict config based on the `StateDictType`. + """ + import os + + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullOptimStateDictConfig, + FullStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + StateDictType, + ) + + # Override the state_dict_type if provided, typical use case: + # user trains with sharded, but final save is with full + if state_dict_type is not None: + self.state_dict_type = state_dict_type + + if self.state_dict_type is None: + self.state_dict_type = os.environ.get( + "FSDP_STATE_DICT_TYPE", + "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT", + ) + if isinstance(self.state_dict_type, str): + if self.state_dict_type.isdigit(): + self.state_dict_type = StateDictType(int(self.state_dict_type)) + else: + self.state_dict_type = StateDictType[self.state_dict_type.upper()] + + if self.state_dict_type == StateDictType.FULL_STATE_DICT: + if self.state_dict_config is None: + self.state_dict_config = FullStateDictConfig( + offload_to_cpu=True, rank0_only=True + ) + if self.optim_state_dict_config is None: + self.optim_state_dict_config = FullOptimStateDictConfig( + offload_to_cpu=True, rank0_only=True + ) + elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT: + if self.state_dict_config is None: + self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) + if self.optim_state_dict_config is None: + self.optim_state_dict_config = ShardedOptimStateDictConfig( + offload_to_cpu=True + ) + + +def get_state_dict(self, model, unwrap=True): + """ + Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full + precision. + + Args: + model (`torch.nn.Module`): + A PyTorch model sent through [`Accelerator.prepare`] + unwrap (`bool`, *optional*, defaults to `True`): + Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict + + Returns: + `dict`: The state dictionary of the model potentially without full precision. + + Example: + + ```python + >>> import torch + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> net = torch.nn.Linear(2, 2) + >>> net = accelerator.prepare(net) + >>> state_dict = accelerator.get_state_dict(net) + ``` + """ + from accelerate import DistributedType + from accelerate.utils import compare_versions + + if self.distributed_type == DistributedType.DEEPSPEED: + zero3_sharding = self.deepspeed_config["zero_optimization"]["stage"] == 3 + tp_sharding = ( + self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1 + ) + if zero3_sharding or tp_sharding: + if model.zero_gather_16bit_weights_on_model_save(): + if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"): + raise ImportError( + "Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`." + ) + state_dict = ( + model._consolidated_16bit_state_dict() # pylint: disable=protected-access + if tp_sharding + else model._zero3_consolidated_16bit_state_dict() # pylint: disable=protected-access + ) + else: + raise ValueError( + "Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. " + "To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or " + "set `zero3_save_16bit_model` to True when using `accelerate config`. " + "To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights." + ) + else: + from deepspeed.checkpoint.utils import clone_tensors_for_torch_save + + state_dict = clone_tensors_for_torch_save( + self.unwrap_model(model).state_dict() + ) + elif self.is_fsdp2: + # https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465 + state_dict = {} + sharded_state_dict = model.state_dict() + for param_name, param in sharded_state_dict.items(): + if param.is_cpu: + param = param.to(torch.device("cuda")) + + param = param.full_tensor() + if torch.distributed.get_rank() == 0: + state_dict[param_name] = param.cpu() + torch.distributed.barrier() + elif self.distributed_type == DistributedType.FSDP: + from torch.distributed.fsdp import FullStateDictConfig + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import StateDictType + + full_state_dict_config = FullStateDictConfig( + offload_to_cpu=True, rank0_only=True + ) + with FSDP.state_dict_type( + model, StateDictType.FULL_STATE_DICT, full_state_dict_config + ): + state_dict = model.state_dict() + else: + if unwrap: + model = self.unwrap_model(model) + state_dict = model.state_dict() + + return state_dict + + +def patch_accelerate_fsdp2(): + import accelerate from accelerate.utils import fsdp_utils fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict @@ -61,3 +200,19 @@ def patch_accelerate_fsdp_utils(): "fsdp2_load_full_state_dict", fsdp2_load_full_state_dict, ) + + accelerate.Accelerator.get_state_dict = get_state_dict + setattr( + sys.modules["accelerate"], + "Accelerator.get_state_dict", + get_state_dict, + ) + + accelerate.utils.dataclasses.FullyShardedDataParallelPlugin.set_state_dict_type = ( + set_state_dict_type + ) + setattr( + sys.modules["accelerate.utils.dataclasses"], + "FullyShardedDataParallelPlugin.set_state_dict_type", + set_state_dict_type, + ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 52ec8f22b..8a4c0040d 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -238,13 +238,27 @@ def save_trained_model( model: The trained model to save. safe_serialization: Whether to use safe serialization. """ - LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.") + LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.") # Post training module hooks for name, module in model.named_modules(): if hasattr(module, "_post_training"): module._post_training(model, name) # pylint: disable=protected-access + # handle QAT + if cfg.qat: + from axolotl.utils.quantization import convert_qat_model_for_ptq + + LOG.info("Processing QAT model for saving...") + convert_qat_model_for_ptq( + model, + quantize_embedding=cfg.qat.quantize_embedding, + ) + LOG.info( + "QAT modules have been converted for PTQ. Please ensure you quantize " + "your model weights with `axolotl quantize`." + ) + # Handle FSDP state dict type state_dict_type = "FULL_STATE_DICT" if trainer.is_fsdp_enabled and str(cfg.fsdp_config.fsdp_version) != "2": @@ -321,6 +335,8 @@ def save_trained_model( save_compressed=cfg.llmcompressor.save_compressed, ) + LOG.info(f"Model successfully saved to {cfg.output_dir}") + def create_model_card(cfg: DictDefault, trainer: Trainer): """ diff --git a/src/axolotl/utils/callbacks/qat.py b/src/axolotl/utils/callbacks/qat.py new file mode 100644 index 000000000..da4f2612b --- /dev/null +++ b/src/axolotl/utils/callbacks/qat.py @@ -0,0 +1,50 @@ +"""QAT Callback for HF Causal Trainer""" + +import logging +from functools import partial + +from torch import nn +from torchao.quantization.qat.embedding import FakeQuantizedEmbedding +from torchao.quantization.qat.linear import FakeQuantizedLinear +from transformers import TrainerCallback + +from axolotl.utils.schemas.quantization import QATConfig + +LOG = logging.getLogger(__name__) + + +def toggle_fake_quant(mod: nn.Module, enable: bool): + """ + Toggle fake quantization for any fake quantized linear or embedding layers in the model. + + Args: + mod: The module to toggle fake quantization for. + enable: Whether to enable or disable fake quantization. + """ + if isinstance(mod, (FakeQuantizedLinear, FakeQuantizedEmbedding)): + if ( + isinstance(mod, FakeQuantizedLinear) + and mod.activation_fake_quantizer is not None + ): + mod.activation_fake_quantizer.enabled = enable + mod.weight_fake_quantizer.enabled = enable + + +class QATCallback(TrainerCallback): + """ + Callback to toggle fake quantization for the model. + """ + + def __init__(self, cfg: QATConfig): + self.cfg = cfg + + def on_step_begin( + self, args, state, control, model, **kwargs + ): # pylint: disable=unused-argument + if self.cfg.fake_quant_after_n_steps is not None: + if state.global_step == 0: + LOG.info(f"Disabling fake quantization at step {state.global_step}") + model.apply(partial(toggle_fake_quant, enable=False)) + elif state.global_step == self.cfg.fake_quant_after_n_steps: + LOG.info(f"Enabling fake quantization at step {state.global_step}") + model.apply(partial(toggle_fake_quant, enable=True)) diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 8c52102c8..0673c6e95 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -103,7 +103,12 @@ def cleanup_distributed(): termination or when training successfully completes. """ # Ensure that all operations are completed before destroying the process group - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() + + if torch.xpu.is_available(): + torch.xpu.synchronize() + # Destroy the process group if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() diff --git a/src/axolotl/utils/quantization.py b/src/axolotl/utils/quantization.py new file mode 100644 index 000000000..612b1d44e --- /dev/null +++ b/src/axolotl/utils/quantization.py @@ -0,0 +1,189 @@ +""" +Utilities for quantization including QAT and PTQ using torchao. +""" + +import logging + +import torch +from torch import nn +from torchao.core.config import AOBaseConfig +from torchao.quantization import quantize_ +from torchao.quantization.qat import ( + FakeQuantizeConfig, + FromIntXQuantizationAwareTrainingConfig, + IntXQuantizationAwareTrainingConfig, +) +from torchao.quantization.quant_api import ( + Int4DynamicActivationInt4WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + UIntXWeightOnlyConfig, + _is_linear, +) + +from axolotl.utils.schemas.enums import TorchIntDType + +LOG = logging.getLogger(__name__) + + +def get_ptq_config( + weight_dtype: TorchIntDType, + activation_dtype: TorchIntDType | None = None, + group_size: int | None = None, +) -> AOBaseConfig: + """ + This function is used to build a post-training quantization config. + + Args: + weight_dtype: The dtype to use for weight quantization. + activation_dtype: The dtype to use for activation quantization. + group_size: The group size to use for weight quantization. + + Returns: + The post-training quantization config. + + Raises: + ValueError: If the activation dtype is not specified and the weight dtype is not int8 or int4, + or if the group size is not specified for int8 or int4 weight only quantization. + """ + if activation_dtype is None: + if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr] + return UIntXWeightOnlyConfig( + dtype=weight_dtype.value, + group_size=group_size, + set_inductor_config=False, + ) + if weight_dtype == TorchIntDType.int8: + if group_size is None: + raise ValueError( + "group_size must be specified for int8 weight only quantization" + ) + return Int8WeightOnlyConfig( + group_size=group_size, + ) + if weight_dtype == TorchIntDType.int4: + if group_size is None: + raise ValueError( + "group_size must be specified for int4 weight only quantization" + ) + return Int4WeightOnlyConfig( + group_size=group_size, + ) + if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4: + return Int4DynamicActivationInt4WeightConfig() + if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8: + return Int8DynamicActivationInt8WeightConfig() + if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4: + return Int8DynamicActivationInt4WeightConfig() + raise ValueError( + f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}" + ) + + +def prepare_model_for_qat( + model, + weight_dtype: TorchIntDType, + group_size: int, + activation_dtype: TorchIntDType | None = None, + quantize_embedding: bool = False, +): + """ + This function is used to prepare a model for QAT by swapping the model's linear + layers with fake quantized linear layers, and optionally the embedding weights with + fake quantized embedding weights. + + Args: + model: The model to quantize. + weight_dtype: The dtype to use for weight quantization. + group_size: The group size to use for weight quantization. + activation_dtype: The dtype to use for activation quantization. + quantize_embedding: Whether to quantize the model's embedding weights. + + Raises: + ValueError: If the activation/weight dtype combination is invalid. + """ + if activation_dtype: + activation_config = FakeQuantizeConfig( + dtype=activation_dtype.value, granularity="per_token", is_symmetric=False + ) + weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size) + linear_quantize_config = IntXQuantizationAwareTrainingConfig( + activation_config=None if activation_dtype is None else activation_config, + weight_config=weight_config, + ) + quantize_(model, linear_quantize_config) + if quantize_embedding: + # activation fake quantization is not supported for embedding layers + embedding_quantize_config = IntXQuantizationAwareTrainingConfig( + activation_config=None, + weight_config=weight_config, + ) + quantize_( + model, + embedding_quantize_config, + filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), + ) + + +def quantize_model_for_ptq( + model, + weight_dtype: TorchIntDType, + group_size: int | None = None, + activation_dtype: TorchIntDType | None = None, + quantize_embedding: bool | None = None, +): + """ + This function is used to quantize a model for post-training quantization. + It swaps the model's linear layers with fake quantized linear layers. + If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights. + + Args: + model: The model to quantize. + weight_dtype: The dtype to use for weight quantization. + group_size: The group size to use for weight quantization. + activation_dtype: The dtype to use for activation quantization. + quantize_embedding: Whether to quantize the model's embedding weights. + + """ + linear_ptq_config = get_ptq_config( + weight_dtype=weight_dtype, + activation_dtype=activation_dtype, + group_size=group_size, + ) + quantize_(model, linear_ptq_config) + if quantize_embedding: + embedding_quantize_config = get_ptq_config( + weight_dtype=weight_dtype, + activation_dtype=None, + group_size=group_size, + ) + quantize_( + model, + embedding_quantize_config, + filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), + ) + + +def convert_qat_model_for_ptq( + model, + *, + quantize_embedding: bool | None = None, +): + """ + This function is used to convert a swap fake-quantized modules in a model + which has been trained with QAT back to the original modules, ready for PTQ. + + Args: + model: The model to convert. + quantize_embedding: Whether to quantize the model's embedding weights. + """ + if quantize_embedding: + + def filter_fn(m, _): + return isinstance(m, nn.Embedding) or _is_linear(m) + + else: + filter_fn = _is_linear + quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e68185323..8a4d6d63f 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -44,6 +44,7 @@ from axolotl.utils.schemas.model import ( ) from axolotl.utils.schemas.multimodal import MultiModalConfig from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig +from axolotl.utils.schemas.quantization import PTQConfig, QATConfig from axolotl.utils.schemas.training import HyperparametersConfig from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.vllm import VllmConfig @@ -91,6 +92,8 @@ class AxolotlInputConfig( vllm: VllmConfig | None = Field( default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda ) + qat: QATConfig | None = None + quantization: PTQConfig | None = None reward_model: bool | None = None process_reward_model: bool | None = None num_labels: int | None = None @@ -126,7 +129,7 @@ class AxolotlInputConfig( default=None, json_schema_extra={"description": "streaming dataset to use for pretraining"}, ) - dataset_processes: int | None = Field(default=min(32, os.cpu_count())) # type: ignore[type-var] + dataset_processes: int | None = Field(default=min(32, os.cpu_count() or 1)) dataset_exact_deduplication: bool | None = None dataset_keep_in_memory: bool | None = None dataloader_pin_memory: bool | None = None @@ -1481,3 +1484,42 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return self + + @model_validator(mode="before") + @classmethod + def check_qat_config(cls, data): + qat_cfg = data.get("qat", {}) + if not qat_cfg: + return data + + if data.get("peft"): + raise ValueError("QAT and PEFT cannot be used together.") + + if data.get("load_in_8bit"): + raise ValueError("QAT and load_in_8bit cannot be used together.") + + if data.get("load_in_4bit"): + raise ValueError("QAT and load_in_4bit cannot be used together.") + + env_capabilities = data.get("env_capabilities", {}) + torch_version = env_capabilities.get("torch_version") + + if torch_version is None: + import torch + + torch_version = str(torch.__version__).split("+", maxsplit=1)[0] + + if ( + data.get("fsdp") + and data.get("fsdp_config") + and str(data["fsdp_config"].get("fsdp_version")) == "2" + ): + if version.parse(torch_version) < version.parse("2.7.0"): + raise ValueError( + "FSDP2 and QAT are not supported on torch version < 2.7.0" + ) + + if version.parse(torch_version) < version.parse("2.6.0"): + raise ValueError("QAT is not supported on torch version < 2.6.0") + + return data diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 526872412..91fdce161 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -2,6 +2,22 @@ from enum import Enum +import torch + + +class TorchIntDType(Enum): + """Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4""" + + uint1 = getattr(torch, "uint1", None) # pylint: disable=invalid-name + uint2 = getattr(torch, "uint2", None) # pylint: disable=invalid-name + uint3 = getattr(torch, "uint3", None) # pylint: disable=invalid-name + uint4 = getattr(torch, "uint4", None) # pylint: disable=invalid-name + uint5 = getattr(torch, "uint5", None) # pylint: disable=invalid-name + uint6 = getattr(torch, "uint6", None) # pylint: disable=invalid-name + uint7 = getattr(torch, "uint7", None) # pylint: disable=invalid-name + int4 = getattr(torch, "int4", None) # pylint: disable=invalid-name + int8 = getattr(torch, "int8", None) # pylint: disable=invalid-name + class RLType(str, Enum): """RL trainer type configuration subset""" diff --git a/src/axolotl/utils/schemas/quantization.py b/src/axolotl/utils/schemas/quantization.py new file mode 100644 index 000000000..fe2cdb1fe --- /dev/null +++ b/src/axolotl/utils/schemas/quantization.py @@ -0,0 +1,64 @@ +""" +QAT Config Schema +""" + +from typing import Any + +from pydantic import BaseModel, Field, field_validator + +from axolotl.utils.schemas.enums import TorchIntDType + + +class QATConfig(BaseModel): + """ + QAT Config Schema + """ + + activation_dtype: TorchIntDType | None = Field( + default=None, description="Activation dtype" + ) + weight_dtype: TorchIntDType = Field( + default=TorchIntDType.int8, description="Weight dtype" + ) + quantize_embedding: bool | None = Field( + default=False, description="Quantize embedding" + ) + group_size: int | None = Field(default=32, description="Group size") + fake_quant_after_n_steps: int | None = Field( + default=None, description="Fake quant after n steps" + ) + + @field_validator("activation_dtype", "weight_dtype", mode="before") + @classmethod + def validate_dtype(cls, v: Any) -> TorchIntDType | None: + if v == "int4": + return TorchIntDType.int4 + if v == "int8": + return TorchIntDType.int8 + raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']") + + +class PTQConfig(BaseModel): + """ + PTQ Config Schema + """ + + weight_dtype: TorchIntDType = Field( + default=TorchIntDType.int8, description="Weight dtype" + ) + activation_dtype: TorchIntDType | None = Field( + default=None, description="Activation dtype" + ) + quantize_embedding: bool | None = Field( + default=None, description="Quantize embedding" + ) + group_size: int | None = Field(default=32, description="Group size") + + @field_validator("activation_dtype", "weight_dtype", mode="before") + @classmethod + def validate_dtype(cls, v: Any) -> TorchIntDType | None: + if v == "int4": + return TorchIntDType.int4 + if v == "int8": + return TorchIntDType.int8 + raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']") diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index 575b7a620..6c7a9b2e4 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -4,7 +4,6 @@ GRPO test suite import os import random -import shutil import subprocess # nosec B404 import sys import tempfile @@ -118,7 +117,10 @@ def start_vllm( recursive_kill(process) with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file: print(log_file.read()) - shutil.rmtree("/tmp/vllm.log") + try: + os.remove("/tmp/vllm.log") + except FileNotFoundError: + pass raise RuntimeError(f"VLLM server process did not start within {wait} seconds.") # return the process diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py new file mode 100644 index 000000000..f9e7993be --- /dev/null +++ b/tests/e2e/test_qat.py @@ -0,0 +1,71 @@ +""" +E2E tests for QAT +""" + +import unittest +from pathlib import Path + +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists, with_temp_dir + + +class TestQATLlama(unittest.TestCase): + """ + Test case for QAT Llama models + """ + + @with_temp_dir + def test_qat_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "field_messages": "conversations", + "message_property_mappings": { + "role": "from", + "content": "value", + }, + "drop_system_message": True, + "split": "train[:1%]", + }, + ], + "chat_template": "chatml", + "qat": { + "quantize_embedding": True, + "activation_dtype": "int8", + "weight_dtype": "int8", + "group_size": 8, + }, + "num_epochs": 1, + "micro_batch_size": 1, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_safetensors": True, + "bf16": True, + } + ) + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg) diff --git a/tests/e2e/test_quantization.py b/tests/e2e/test_quantization.py new file mode 100644 index 000000000..500b7e556 --- /dev/null +++ b/tests/e2e/test_quantization.py @@ -0,0 +1,350 @@ +""" +Tests for axolotl.utils.quantization +""" + +import pytest +import torch +from torch import nn +from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.linear_activation_quantized_tensor import ( + LinearActivationQuantizedTensor, +) +from torchao.quantization.qat.embedding import FakeQuantizedEmbedding +from torchao.quantization.qat.linear import FakeQuantizedLinear +from torchao.quantization.quant_api import ( + Int4DynamicActivationInt4WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + UIntXWeightOnlyConfig, +) +from transformers import AutoModelForCausalLM +from transformers.trainer_callback import TrainerState + +from axolotl.utils.callbacks.qat import QATCallback +from axolotl.utils.quantization import ( + convert_qat_model_for_ptq, + get_ptq_config, + prepare_model_for_qat, + quantize_model_for_ptq, +) +from axolotl.utils.schemas.enums import TorchIntDType +from axolotl.utils.schemas.quantization import QATConfig + +from tests.e2e.utils import require_torch_2_6_0 + + +@pytest.fixture() +def model(): + dummy_model = AutoModelForCausalLM.from_pretrained( + "HuggingFaceTB/SmolLM2-135M", + device_map="cuda", + torch_dtype=torch.bfloat16, + ) + with torch.device(dummy_model.device): + dummy_model.model.embed_tokens = torch.nn.Embedding( + dummy_model.model.embed_tokens.weight.shape[0], + dummy_model.model.embed_tokens.weight.shape[1], + dtype=dummy_model.model.embed_tokens.weight.dtype, + ) + return dummy_model + + +ptq_config_test_cases = [ + # weight_dtype, activation_dtype, group_size, expected_type, expected_params + ( + TorchIntDType.uint4, + None, + None, + UIntXWeightOnlyConfig, + {"dtype": torch.uint4, "group_size": None}, + ), + (TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}), + (TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}), + ( + TorchIntDType.int4, + TorchIntDType.int4, + None, + Int4DynamicActivationInt4WeightConfig, + {}, + ), + ( + TorchIntDType.int8, + TorchIntDType.int8, + None, + Int8DynamicActivationInt8WeightConfig, + {}, + ), +] + +ptq_test_cases = [ + # weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception + (TorchIntDType.int8, None, 8, False, None), + (TorchIntDType.int4, None, 4, True, None), + (TorchIntDType.uint4, None, 8, False, None), + (TorchIntDType.int4, TorchIntDType.int4, 8, False, None), + (TorchIntDType.int8, TorchIntDType.int8, 8, True, None), + (TorchIntDType.int8, None, None, False, ValueError), + (TorchIntDType.int4, None, None, False, ValueError), +] + + +class TestQuantization: + """ + Test quantization utilities + """ + + @pytest.mark.parametrize( + "weight_dtype,activation_dtype,group_size,expected_type,expected_params", + ptq_config_test_cases, + ) + @require_torch_2_6_0 + def test_get_ptq_config( + self, weight_dtype, activation_dtype, group_size, expected_type, expected_params + ): + config = get_ptq_config(weight_dtype, activation_dtype, group_size) + + assert isinstance(config, expected_type) + + for param_name, param_value in expected_params.items(): + if isinstance(param_value, (PerAxis, PerGroup)): + if isinstance(param_value, PerAxis): + assert isinstance(getattr(config, param_name), PerAxis) + assert getattr(config, param_name).axis == param_value.axis + else: + assert isinstance(getattr(config, param_name), PerGroup) + assert ( + getattr(config, param_name).group_size == param_value.group_size + ) + else: + assert getattr(config, param_name) == param_value + + @pytest.mark.parametrize( + "weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4] + ) + @pytest.mark.parametrize( + "activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8] + ) + @pytest.mark.parametrize("group_size", [4, 8]) + @pytest.mark.parametrize("quantize_embedding", [False, True]) + @require_torch_2_6_0 + def test_prepare_model_for_qat( + self, model, weight_dtype, activation_dtype, group_size, quantize_embedding + ): # pylint: disable=redefined-outer-name + prepare_model_for_qat( + model, weight_dtype, group_size, activation_dtype, quantize_embedding + ) + if quantize_embedding: + assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) + assert hasattr(model.model.embed_tokens, "weight_fake_quantizer") + assert ( + model.model.embed_tokens.weight_fake_quantizer.config.dtype + == weight_dtype.value + ) + assert ( + model.model.embed_tokens.weight_fake_quantizer.config.group_size + == group_size + ) + + for child in list(model.children()): + if isinstance(child, torch.nn.Linear): + assert isinstance(child, FakeQuantizedLinear) + assert hasattr(child, "weight_fake_quantizer") + assert child.weight_fake_quantizer.config.dtype == weight_dtype.value + assert child.weight_fake_quantizer.config.group_size == group_size + if activation_dtype: + assert hasattr(child, "activation_fake_quantizer") + assert ( + child.activation_fake_quantizer.config.dtype + == activation_dtype.value + ) + else: + assert child.activation_fake_quantizer is None + + @pytest.mark.parametrize( + "weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception", + ptq_test_cases, + ) + @require_torch_2_6_0 + def test_quantize_model_for_ptq( + self, + model, + weight_dtype, + activation_dtype, + group_size, + quantize_embedding, + expected_exception, + ): # pylint: disable=redefined-outer-name + if expected_exception: + with pytest.raises(expected_exception): + quantize_model_for_ptq( + model, + weight_dtype, + group_size, + activation_dtype, + quantize_embedding, + ) + else: + quantize_model_for_ptq( + model, weight_dtype, group_size, activation_dtype, quantize_embedding + ) + if quantize_embedding: + assert isinstance( + model.model.embed_tokens.weight, AffineQuantizedTensor + ), "Embedding weight should be quantized" + for child in list(model.children()): + if isinstance(child, torch.nn.Linear): + if activation_dtype: + assert isinstance( + child.weight, LinearActivationQuantizedTensor + ), "Linear weight should be quantized with activation quantization" + else: + assert isinstance( + child.weight, AffineQuantizedTensor + ), "Linear weight should be quantized without activation quantization" + + +class TestQuantizationCallback: + """ + Test QATCallback + """ + + @pytest.fixture() + def trainer_state(self): + return TrainerState( + global_step=0, + ) + + @require_torch_2_6_0 + def test_qat_callback_fake_quant_after_n_steps( + self, model, trainer_state + ): # pylint: disable=redefined-outer-name + cfg = QATConfig( + weight_dtype="int8", + activation_dtype="int8", + group_size=8, + quantize_embedding=True, + fake_quant_after_n_steps=100, + ) + + prepare_model_for_qat( + model, + cfg.weight_dtype, + cfg.group_size, + cfg.activation_dtype, + cfg.quantize_embedding, + ) + + # ensure model has been quantized + assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) + assert model.model.embed_tokens.weight_fake_quantizer.enabled + assert isinstance(model.lm_head, FakeQuantizedLinear) + assert model.lm_head.weight_fake_quantizer.enabled + + qat_callback = QATCallback(cfg) + + # simulate first training step + qat_callback.on_step_begin( + args=None, + state=trainer_state, + control=None, + model=model, + ) + + # quantization should have been disabled + assert not model.model.embed_tokens.weight_fake_quantizer.enabled + assert not model.lm_head.weight_fake_quantizer.enabled + + trainer_state.global_step = 100 + qat_callback.on_step_begin( + args=None, + state=trainer_state, + control=None, + model=model, + ) + + # quantization should have been enabled + assert model.model.embed_tokens.weight_fake_quantizer.enabled + assert model.lm_head.weight_fake_quantizer.enabled + + @require_torch_2_6_0 + def test_qat_callback_fake_quant_after_n_steps_is_none( + self, model, trainer_state + ): # pylint: disable=redefined-outer-name + cfg = QATConfig( + weight_dtype="int8", + activation_dtype="int8", + group_size=8, + quantize_embedding=True, + fake_quant_after_n_steps=None, + ) + + prepare_model_for_qat( + model, + cfg.weight_dtype, + cfg.group_size, + cfg.activation_dtype, + cfg.quantize_embedding, + ) + + # ensure model has been quantized + assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) + assert model.model.embed_tokens.weight_fake_quantizer.enabled + assert isinstance(model.lm_head, FakeQuantizedLinear) + assert model.lm_head.weight_fake_quantizer.enabled + + qat_callback = QATCallback(cfg) + # simulate first training step + qat_callback.on_step_begin( + args=None, + state=trainer_state, + control=None, + model=model, + ) + + # quantization should be enabled from the get-go + assert model.model.embed_tokens.weight_fake_quantizer.enabled + assert model.lm_head.weight_fake_quantizer.enabled + + +class TestConvertQATModelForPTQ: + """ + Test convert_qat_model_for_ptq + """ + + @require_torch_2_6_0 + def test_convert_qat_model_for_ptq( + self, model + ): # pylint: disable=redefined-outer-name + config = QATConfig( + weight_dtype="int8", + activation_dtype="int8", + group_size=8, + quantize_embedding=True, + ) + + # quantize model for qat + prepare_model_for_qat( + model, + config.weight_dtype, + config.group_size, + config.activation_dtype, + config.quantize_embedding, + ) + + assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) + assert isinstance(model.lm_head, FakeQuantizedLinear) + + # apply conversion + convert_qat_model_for_ptq( + model, + quantize_embedding=config.quantize_embedding, + ) + # ensure modules have been swapped out + assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding) + assert not isinstance(model.lm_head, FakeQuantizedLinear) + + # ensure weights have been quantized + assert isinstance(model.model.embed_tokens.weight, nn.Parameter) + assert isinstance(model.lm_head.weight, nn.Parameter) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 61df1d8fe..65069eb16 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -10,8 +10,6 @@ from functools import wraps from pathlib import Path import torch - -# from importlib.metadata import version from packaging import version from tbparse import SummaryReader