QAT and quantization w/torchao
This commit is contained in:
salman
2025-05-28 12:35:47 +01:00
committed by GitHub
parent 20fda75917
commit 5fca214108
26 changed files with 1372 additions and 13 deletions

View File

@@ -43,6 +43,7 @@ quartodoc:
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- cli.quantize
- title: Trainers
desc: Training implementations
contents:
@@ -147,6 +148,7 @@ quartodoc:
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.quantization
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:
@@ -196,7 +198,7 @@ quartodoc:
- utils.callbacks.lisa
- utils.callbacks.mlflow_
- utils.callbacks.comet_
- utils.callbacks.qat
website:
title: "Axolotl"
description: "We make fine-tuning accessible, scalable, and fun"
@@ -256,6 +258,8 @@ website:
- docs/lr_groups.qmd
- docs/lora_optims.qmd
- docs/dataset_loading.qmd
- docs/qat.qmd
- docs/quantize.qmd
- section: "Core Concepts"
contents:

View File

@@ -209,6 +209,16 @@ axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
### quantize
Quantizes a model using the quantization configuration specified in your YAML file.
```bash
axolotl quantize config.yml
```
See [Quantization](./quantize.qmd) for more details.
## Legacy CLI Usage

View File

@@ -65,6 +65,20 @@ bnb_config_kwargs:
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# quantization aware training
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
# post-training quantization
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
# Whether you are training a 4-bit GPTQ quantized model
gptq: true

32
docs/qat.qmd Normal file
View File

@@ -0,0 +1,32 @@
---
title: "Quantization Aware Training (QAT)"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
## Overview
[Quantization Aware Training](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/#quantization-aware-training) (QAT) is a technique for improving the accuracy of models which are quantized
by applying "fake" quantizations to the model's weights (and optionally, activations) during training. This fake
quantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually
quantized, the accuracy loss is minimized. We use the quantization techniques implemented in [torchao](https://github.com/pytorch/ao) to provide
support for QAT and post-training quantization (PTQ) in axolotl.
We recommend reviewing the excellent QAT tutorial in the [torchtune library](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#quantizing-the-qat-model),
and the QAT documentation in the [torchao library](https://github.com/pytorch/ao/tree/main/torchao/quantization/qat), for more details.
## Configuring QAT in Axolotl
To enable QAT in axolotl, add the following to your configuration file:
```yaml
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
```
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize` command](./quantize.md) to do this.

53
docs/quantize.qmd Normal file
View File

@@ -0,0 +1,53 @@
---
title: "Quantization with torchao"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
Quantization is a technique to lower the memory footprint of your model, potentially at the cost of accuracy or model performance. We support quantizing your model using the [torchao](https://github.com/pytorch/ao) library. Quantization is supported for both post-training quantization (PTQ) and quantization-aware training (QAT).
::: {.callout-note}
We do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment.
:::
## Configuring Quantization in Axolotl
Quantization is configured using the `quantization` key in your configuration file.
```yaml
base_model: # The path to the model to quantize.
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
output_dir: # The path to the output directory.
```
Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which
you used to train the model:
```yaml
# qat.yml
qat:
activation_dtype: int8
weight_dtype: int8
group_size: 256
quantize_embedding: true
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
```
```bash
axolotl quantize qat.yml
```
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.

View File

@@ -0,0 +1,79 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
output_dir: ./outputs/qat_out/
sample_packing: true
pad_to_sequence_len: true
sequence_len: 512
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat:
activation_dtype: int8
weight_dtype: int4
group_size: 32
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_steps: 10
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,78 @@
base_model: Qwen/Qwen3-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/qat_out/
sequence_len: 2048
sample_packing: true
flex_attention: true
pad_to_sequence_len: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
qat:
activation_dtype: int8
weight_dtype: int4
group_size: 256
fake_quant_after_n_steps: 1000
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
max_steps: 2000
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_steps: 10
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:

View File

@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.9.0
torchao==0.10.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6

View File

@@ -90,6 +90,18 @@ class VllmServeCliArgs:
)
@dataclass
class QuantizeCliArgs:
"""Dataclass with CLI arguments for `axolotl quantize` command."""
base_model: Optional[str] = field(default=None)
weight_dtype: Optional[str] = field(default=None)
activation_dtype: Optional[str] = field(default=None)
quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None)
@dataclass
class EvaluateCliArgs:
"""Dataclass with CLI arguments for `axolotl evaluate` command."""

View File

@@ -17,6 +17,7 @@ import axolotl
from axolotl.cli.args import (
EvaluateCliArgs,
PreprocessCliArgs,
QuantizeCliArgs,
TrainerCliArgs,
VllmServeCliArgs,
)
@@ -333,6 +334,16 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(QuantizeCliArgs)
@filter_none_kwargs
def quantize(config: str, **cli_args: QuantizeCliArgs):
from axolotl.cli.quantize import do_quantize
do_quantize(config, cli_args)
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))

View File

@@ -0,0 +1,90 @@
"""
CLI to post-training quantize a model using torchao
"""
import logging
from pathlib import Path
from typing import Union
from transformers import AutoModelForCausalLM
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
LOG = logging.getLogger(__name__)
def do_quantize(
config: Union[Path, str],
cli_args: dict,
):
"""
Quantizes a model's model's weights
Args:
config (Union[Path, str]): The path to the config file
cli_args (dict): Additional command-line arguments
"""
print_axolotl_text_art()
cfg = load_cfg(config)
if cfg.qat and cfg.quantization:
raise ValueError(
"QAT and quantization cannot be used together. Please specify only one of qat or quantization in your config file."
)
if cfg.qat:
quantize_cfg = cfg.qat
elif cfg.quantization:
quantize_cfg = cfg.quantization
else:
raise ValueError(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("model_path") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype]
else:
weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchIntDType[activation_dtype]
else:
activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size
quantize_embedding = (
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
)
output_dir = cli_args.get("output_dir") or cfg.output_dir
LOG.info(f"Loading model from {model_path}...")
tokenizer = load_tokenizer(cfg)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
LOG.info(
f"Quantizing model with configuration: \n"
f"\tweight_dtype: {weight_dtype}\n"
f"\tactivation_dtype: {activation_dtype}\n"
f"\tgroup_size: {group_size}\n"
f"\tquantize_embedding: {quantize_embedding}"
)
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -79,6 +79,7 @@ from axolotl.utils.callbacks import (
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
@@ -254,6 +255,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):

View File

@@ -191,6 +191,7 @@ class ModelLoader:
self._adjust_model_config()
self._log_memory_usage()
self._configure_embedding_dtypes()
self._configure_qat()
def _resize_token_embeddings(self):
"""Resize token embeddings if needed."""
@@ -305,6 +306,19 @@ class ModelLoader:
before_kbit_train_or_finetune=False,
)
def _configure_qat(self):
"""Configure QAT."""
if self.cfg.qat:
from axolotl.utils.quantization import prepare_model_for_qat
prepare_model_for_qat(
self.model,
self.cfg.qat.weight_dtype,
self.cfg.qat.group_size,
self.cfg.qat.activation_dtype,
self.cfg.qat.quantize_embedding,
)
def _load_adapters(self) -> PeftConfig | None:
"""Load LoRA or other adapters."""
# Load LoRA or adapter

View File

@@ -80,9 +80,9 @@ class PatchManager:
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
patch_accelerate_fsdp_utils()
patch_accelerate_fsdp2()
def _apply_adapter_patches(self):
"""Apply patches for adapter configurations."""

View File

@@ -1,5 +1,5 @@
"""
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
"""
import logging
@@ -52,7 +52,146 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
model.load_state_dict(sharded_sd, assign=True)
def patch_accelerate_fsdp_utils():
def set_state_dict_type(self, state_dict_type=None):
"""
Set the state dict config based on the `StateDictType`.
"""
import os
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
FullStateDictConfig,
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
StateDictType,
)
# Override the state_dict_type if provided, typical use case:
# user trains with sharded, but final save is with full
if state_dict_type is not None:
self.state_dict_type = state_dict_type
if self.state_dict_type is None:
self.state_dict_type = os.environ.get(
"FSDP_STATE_DICT_TYPE",
"FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT",
)
if isinstance(self.state_dict_type, str):
if self.state_dict_type.isdigit():
self.state_dict_type = StateDictType(int(self.state_dict_type))
else:
self.state_dict_type = StateDictType[self.state_dict_type.upper()]
if self.state_dict_type == StateDictType.FULL_STATE_DICT:
if self.state_dict_config is None:
self.state_dict_config = FullStateDictConfig(
offload_to_cpu=True, rank0_only=True
)
if self.optim_state_dict_config is None:
self.optim_state_dict_config = FullOptimStateDictConfig(
offload_to_cpu=True, rank0_only=True
)
elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT:
if self.state_dict_config is None:
self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
if self.optim_state_dict_config is None:
self.optim_state_dict_config = ShardedOptimStateDictConfig(
offload_to_cpu=True
)
def get_state_dict(self, model, unwrap=True):
"""
Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full
precision.
Args:
model (`torch.nn.Module`):
A PyTorch model sent through [`Accelerator.prepare`]
unwrap (`bool`, *optional*, defaults to `True`):
Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict
Returns:
`dict`: The state dictionary of the model potentially without full precision.
Example:
```python
>>> import torch
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> net = torch.nn.Linear(2, 2)
>>> net = accelerator.prepare(net)
>>> state_dict = accelerator.get_state_dict(net)
```
"""
from accelerate import DistributedType
from accelerate.utils import compare_versions
if self.distributed_type == DistributedType.DEEPSPEED:
zero3_sharding = self.deepspeed_config["zero_optimization"]["stage"] == 3
tp_sharding = (
self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1
)
if zero3_sharding or tp_sharding:
if model.zero_gather_16bit_weights_on_model_save():
if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"):
raise ImportError(
"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`."
)
state_dict = (
model._consolidated_16bit_state_dict() # pylint: disable=protected-access
if tp_sharding
else model._zero3_consolidated_16bit_state_dict() # pylint: disable=protected-access
)
else:
raise ValueError(
"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
"set `zero3_save_16bit_model` to True when using `accelerate config`. "
"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
)
else:
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
state_dict = clone_tensors_for_torch_save(
self.unwrap_model(model).state_dict()
)
elif self.is_fsdp2:
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
state_dict = {}
sharded_state_dict = model.state_dict()
for param_name, param in sharded_state_dict.items():
if param.is_cpu:
param = param.to(torch.device("cuda"))
param = param.full_tensor()
if torch.distributed.get_rank() == 0:
state_dict[param_name] = param.cpu()
torch.distributed.barrier()
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
full_state_dict_config = FullStateDictConfig(
offload_to_cpu=True, rank0_only=True
)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, full_state_dict_config
):
state_dict = model.state_dict()
else:
if unwrap:
model = self.unwrap_model(model)
state_dict = model.state_dict()
return state_dict
def patch_accelerate_fsdp2():
import accelerate
from accelerate.utils import fsdp_utils
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
@@ -61,3 +200,19 @@ def patch_accelerate_fsdp_utils():
"fsdp2_load_full_state_dict",
fsdp2_load_full_state_dict,
)
accelerate.Accelerator.get_state_dict = get_state_dict
setattr(
sys.modules["accelerate"],
"Accelerator.get_state_dict",
get_state_dict,
)
accelerate.utils.dataclasses.FullyShardedDataParallelPlugin.set_state_dict_type = (
set_state_dict_type
)
setattr(
sys.modules["accelerate.utils.dataclasses"],
"FullyShardedDataParallelPlugin.set_state_dict_type",
set_state_dict_type,
)

View File

@@ -238,13 +238,27 @@ def save_trained_model(
model: The trained model to save.
safe_serialization: Whether to use safe serialization.
"""
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
# Post training module hooks
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
# handle QAT
if cfg.qat:
from axolotl.utils.quantization import convert_qat_model_for_ptq
LOG.info("Processing QAT model for saving...")
convert_qat_model_for_ptq(
model,
quantize_embedding=cfg.qat.quantize_embedding,
)
LOG.info(
"QAT modules have been converted for PTQ. Please ensure you quantize "
"your model weights with `axolotl quantize`."
)
# Handle FSDP state dict type
state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled and str(cfg.fsdp_config.fsdp_version) != "2":
@@ -321,6 +335,8 @@ def save_trained_model(
save_compressed=cfg.llmcompressor.save_compressed,
)
LOG.info(f"Model successfully saved to {cfg.output_dir}")
def create_model_card(cfg: DictDefault, trainer: Trainer):
"""

View File

@@ -0,0 +1,50 @@
"""QAT Callback for HF Causal Trainer"""
import logging
from functools import partial
from torch import nn
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
from transformers import TrainerCallback
from axolotl.utils.schemas.quantization import QATConfig
LOG = logging.getLogger(__name__)
def toggle_fake_quant(mod: nn.Module, enable: bool):
"""
Toggle fake quantization for any fake quantized linear or embedding layers in the model.
Args:
mod: The module to toggle fake quantization for.
enable: Whether to enable or disable fake quantization.
"""
if isinstance(mod, (FakeQuantizedLinear, FakeQuantizedEmbedding)):
if (
isinstance(mod, FakeQuantizedLinear)
and mod.activation_fake_quantizer is not None
):
mod.activation_fake_quantizer.enabled = enable
mod.weight_fake_quantizer.enabled = enable
class QATCallback(TrainerCallback):
"""
Callback to toggle fake quantization for the model.
"""
def __init__(self, cfg: QATConfig):
self.cfg = cfg
def on_step_begin(
self, args, state, control, model, **kwargs
): # pylint: disable=unused-argument
if self.cfg.fake_quant_after_n_steps is not None:
if state.global_step == 0:
LOG.info(f"Disabling fake quantization at step {state.global_step}")
model.apply(partial(toggle_fake_quant, enable=False))
elif state.global_step == self.cfg.fake_quant_after_n_steps:
LOG.info(f"Enabling fake quantization at step {state.global_step}")
model.apply(partial(toggle_fake_quant, enable=True))

View File

@@ -103,7 +103,12 @@ def cleanup_distributed():
termination or when training successfully completes.
"""
# Ensure that all operations are completed before destroying the process group
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
if torch.xpu.is_available():
torch.xpu.synchronize()
# Destroy the process group
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()

View File

@@ -0,0 +1,189 @@
"""
Utilities for quantization including QAT and PTQ using torchao.
"""
import logging
import torch
from torch import nn
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.quantization.qat import (
FakeQuantizeConfig,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.quant_api import (
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
_is_linear,
)
from axolotl.utils.schemas.enums import TorchIntDType
LOG = logging.getLogger(__name__)
def get_ptq_config(
weight_dtype: TorchIntDType,
activation_dtype: TorchIntDType | None = None,
group_size: int | None = None,
) -> AOBaseConfig:
"""
This function is used to build a post-training quantization config.
Args:
weight_dtype: The dtype to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
group_size: The group size to use for weight quantization.
Returns:
The post-training quantization config.
Raises:
ValueError: If the activation dtype is not specified and the weight dtype is not int8 or int4,
or if the group size is not specified for int8 or int4 weight only quantization.
"""
if activation_dtype is None:
if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr]
return UIntXWeightOnlyConfig(
dtype=weight_dtype.value,
group_size=group_size,
set_inductor_config=False,
)
if weight_dtype == TorchIntDType.int8:
if group_size is None:
raise ValueError(
"group_size must be specified for int8 weight only quantization"
)
return Int8WeightOnlyConfig(
group_size=group_size,
)
if weight_dtype == TorchIntDType.int4:
if group_size is None:
raise ValueError(
"group_size must be specified for int4 weight only quantization"
)
return Int4WeightOnlyConfig(
group_size=group_size,
)
if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4:
return Int4DynamicActivationInt4WeightConfig()
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8:
return Int8DynamicActivationInt8WeightConfig()
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4:
return Int8DynamicActivationInt4WeightConfig()
raise ValueError(
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
)
def prepare_model_for_qat(
model,
weight_dtype: TorchIntDType,
group_size: int,
activation_dtype: TorchIntDType | None = None,
quantize_embedding: bool = False,
):
"""
This function is used to prepare a model for QAT by swapping the model's linear
layers with fake quantized linear layers, and optionally the embedding weights with
fake quantized embedding weights.
Args:
model: The model to quantize.
weight_dtype: The dtype to use for weight quantization.
group_size: The group size to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
quantize_embedding: Whether to quantize the model's embedding weights.
Raises:
ValueError: If the activation/weight dtype combination is invalid.
"""
if activation_dtype:
activation_config = FakeQuantizeConfig(
dtype=activation_dtype.value, granularity="per_token", is_symmetric=False
)
weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size)
linear_quantize_config = IntXQuantizationAwareTrainingConfig(
activation_config=None if activation_dtype is None else activation_config,
weight_config=weight_config,
)
quantize_(model, linear_quantize_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
embedding_quantize_config = IntXQuantizationAwareTrainingConfig(
activation_config=None,
weight_config=weight_config,
)
quantize_(
model,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def quantize_model_for_ptq(
model,
weight_dtype: TorchIntDType,
group_size: int | None = None,
activation_dtype: TorchIntDType | None = None,
quantize_embedding: bool | None = None,
):
"""
This function is used to quantize a model for post-training quantization.
It swaps the model's linear layers with fake quantized linear layers.
If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights.
Args:
model: The model to quantize.
weight_dtype: The dtype to use for weight quantization.
group_size: The group size to use for weight quantization.
activation_dtype: The dtype to use for activation quantization.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
linear_ptq_config = get_ptq_config(
weight_dtype=weight_dtype,
activation_dtype=activation_dtype,
group_size=group_size,
)
quantize_(model, linear_ptq_config)
if quantize_embedding:
embedding_quantize_config = get_ptq_config(
weight_dtype=weight_dtype,
activation_dtype=None,
group_size=group_size,
)
quantize_(
model,
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
def convert_qat_model_for_ptq(
model,
*,
quantize_embedding: bool | None = None,
):
"""
This function is used to convert a swap fake-quantized modules in a model
which has been trained with QAT back to the original modules, ready for PTQ.
Args:
model: The model to convert.
quantize_embedding: Whether to quantize the model's embedding weights.
"""
if quantize_embedding:
def filter_fn(m, _):
return isinstance(m, nn.Embedding) or _is_linear(m)
else:
filter_fn = _is_linear
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)

View File

@@ -44,6 +44,7 @@ from axolotl.utils.schemas.model import (
)
from axolotl.utils.schemas.multimodal import MultiModalConfig
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
from axolotl.utils.schemas.quantization import PTQConfig, QATConfig
from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig
@@ -91,6 +92,8 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
)
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = None
process_reward_model: bool | None = None
num_labels: int | None = None
@@ -126,7 +129,7 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
dataset_processes: int | None = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
dataset_processes: int | None = Field(default=min(32, os.cpu_count() or 1))
dataset_exact_deduplication: bool | None = None
dataset_keep_in_memory: bool | None = None
dataloader_pin_memory: bool | None = None
@@ -1481,3 +1484,42 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return self
@model_validator(mode="before")
@classmethod
def check_qat_config(cls, data):
qat_cfg = data.get("qat", {})
if not qat_cfg:
return data
if data.get("peft"):
raise ValueError("QAT and PEFT cannot be used together.")
if data.get("load_in_8bit"):
raise ValueError("QAT and load_in_8bit cannot be used together.")
if data.get("load_in_4bit"):
raise ValueError("QAT and load_in_4bit cannot be used together.")
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if (
data.get("fsdp")
and data.get("fsdp_config")
and str(data["fsdp_config"].get("fsdp_version")) == "2"
):
if version.parse(torch_version) < version.parse("2.7.0"):
raise ValueError(
"FSDP2 and QAT are not supported on torch version < 2.7.0"
)
if version.parse(torch_version) < version.parse("2.6.0"):
raise ValueError("QAT is not supported on torch version < 2.6.0")
return data

View File

@@ -2,6 +2,22 @@
from enum import Enum
import torch
class TorchIntDType(Enum):
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
uint1 = getattr(torch, "uint1", None) # pylint: disable=invalid-name
uint2 = getattr(torch, "uint2", None) # pylint: disable=invalid-name
uint3 = getattr(torch, "uint3", None) # pylint: disable=invalid-name
uint4 = getattr(torch, "uint4", None) # pylint: disable=invalid-name
uint5 = getattr(torch, "uint5", None) # pylint: disable=invalid-name
uint6 = getattr(torch, "uint6", None) # pylint: disable=invalid-name
uint7 = getattr(torch, "uint7", None) # pylint: disable=invalid-name
int4 = getattr(torch, "int4", None) # pylint: disable=invalid-name
int8 = getattr(torch, "int8", None) # pylint: disable=invalid-name
class RLType(str, Enum):
"""RL trainer type configuration subset"""

View File

@@ -0,0 +1,64 @@
"""
QAT Config Schema
"""
from typing import Any
from pydantic import BaseModel, Field, field_validator
from axolotl.utils.schemas.enums import TorchIntDType
class QATConfig(BaseModel):
"""
QAT Config Schema
"""
activation_dtype: TorchIntDType | None = Field(
default=None, description="Activation dtype"
)
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8, description="Weight dtype"
)
quantize_embedding: bool | None = Field(
default=False, description="Quantize embedding"
)
group_size: int | None = Field(default=32, description="Group size")
fake_quant_after_n_steps: int | None = Field(
default=None, description="Fake quant after n steps"
)
@field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
if v == "int4":
return TorchIntDType.int4
if v == "int8":
return TorchIntDType.int8
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
class PTQConfig(BaseModel):
"""
PTQ Config Schema
"""
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8, description="Weight dtype"
)
activation_dtype: TorchIntDType | None = Field(
default=None, description="Activation dtype"
)
quantize_embedding: bool | None = Field(
default=None, description="Quantize embedding"
)
group_size: int | None = Field(default=32, description="Group size")
@field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
if v == "int4":
return TorchIntDType.int4
if v == "int8":
return TorchIntDType.int8
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")

View File

@@ -4,7 +4,6 @@ GRPO test suite
import os
import random
import shutil
import subprocess # nosec B404
import sys
import tempfile
@@ -118,7 +117,10 @@ def start_vllm(
recursive_kill(process)
with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file:
print(log_file.read())
shutil.rmtree("/tmp/vllm.log")
try:
os.remove("/tmp/vllm.log")
except FileNotFoundError:
pass
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
# return the process

71
tests/e2e/test_qat.py Normal file
View File

@@ -0,0 +1,71 @@
"""
E2E tests for QAT
"""
import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
class TestQATLlama(unittest.TestCase):
"""
Test case for QAT Llama models
"""
@with_temp_dir
def test_qat_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"field_messages": "conversations",
"message_property_mappings": {
"role": "from",
"content": "value",
},
"drop_system_message": True,
"split": "train[:1%]",
},
],
"chat_template": "chatml",
"qat": {
"quantize_embedding": True,
"activation_dtype": "int8",
"weight_dtype": "int8",
"group_size": 8,
},
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)

View File

@@ -0,0 +1,350 @@
"""
Tests for axolotl.utils.quantization
"""
import pytest
import torch
from torch import nn
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
)
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
from torchao.quantization.quant_api import (
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
)
from transformers import AutoModelForCausalLM
from transformers.trainer_callback import TrainerState
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.quantization import (
convert_qat_model_for_ptq,
get_ptq_config,
prepare_model_for_qat,
quantize_model_for_ptq,
)
from axolotl.utils.schemas.enums import TorchIntDType
from axolotl.utils.schemas.quantization import QATConfig
from tests.e2e.utils import require_torch_2_6_0
@pytest.fixture()
def model():
dummy_model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM2-135M",
device_map="cuda",
torch_dtype=torch.bfloat16,
)
with torch.device(dummy_model.device):
dummy_model.model.embed_tokens = torch.nn.Embedding(
dummy_model.model.embed_tokens.weight.shape[0],
dummy_model.model.embed_tokens.weight.shape[1],
dtype=dummy_model.model.embed_tokens.weight.dtype,
)
return dummy_model
ptq_config_test_cases = [
# weight_dtype, activation_dtype, group_size, expected_type, expected_params
(
TorchIntDType.uint4,
None,
None,
UIntXWeightOnlyConfig,
{"dtype": torch.uint4, "group_size": None},
),
(TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}),
(TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}),
(
TorchIntDType.int4,
TorchIntDType.int4,
None,
Int4DynamicActivationInt4WeightConfig,
{},
),
(
TorchIntDType.int8,
TorchIntDType.int8,
None,
Int8DynamicActivationInt8WeightConfig,
{},
),
]
ptq_test_cases = [
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception
(TorchIntDType.int8, None, 8, False, None),
(TorchIntDType.int4, None, 4, True, None),
(TorchIntDType.uint4, None, 8, False, None),
(TorchIntDType.int4, TorchIntDType.int4, 8, False, None),
(TorchIntDType.int8, TorchIntDType.int8, 8, True, None),
(TorchIntDType.int8, None, None, False, ValueError),
(TorchIntDType.int4, None, None, False, ValueError),
]
class TestQuantization:
"""
Test quantization utilities
"""
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,expected_type,expected_params",
ptq_config_test_cases,
)
@require_torch_2_6_0
def test_get_ptq_config(
self, weight_dtype, activation_dtype, group_size, expected_type, expected_params
):
config = get_ptq_config(weight_dtype, activation_dtype, group_size)
assert isinstance(config, expected_type)
for param_name, param_value in expected_params.items():
if isinstance(param_value, (PerAxis, PerGroup)):
if isinstance(param_value, PerAxis):
assert isinstance(getattr(config, param_name), PerAxis)
assert getattr(config, param_name).axis == param_value.axis
else:
assert isinstance(getattr(config, param_name), PerGroup)
assert (
getattr(config, param_name).group_size == param_value.group_size
)
else:
assert getattr(config, param_name) == param_value
@pytest.mark.parametrize(
"weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4]
)
@pytest.mark.parametrize(
"activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8]
)
@pytest.mark.parametrize("group_size", [4, 8])
@pytest.mark.parametrize("quantize_embedding", [False, True])
@require_torch_2_6_0
def test_prepare_model_for_qat(
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
): # pylint: disable=redefined-outer-name
prepare_model_for_qat(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
assert (
model.model.embed_tokens.weight_fake_quantizer.config.dtype
== weight_dtype.value
)
assert (
model.model.embed_tokens.weight_fake_quantizer.config.group_size
== group_size
)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer")
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
assert child.weight_fake_quantizer.config.group_size == group_size
if activation_dtype:
assert hasattr(child, "activation_fake_quantizer")
assert (
child.activation_fake_quantizer.config.dtype
== activation_dtype.value
)
else:
assert child.activation_fake_quantizer is None
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception",
ptq_test_cases,
)
@require_torch_2_6_0
def test_quantize_model_for_ptq(
self,
model,
weight_dtype,
activation_dtype,
group_size,
quantize_embedding,
expected_exception,
): # pylint: disable=redefined-outer-name
if expected_exception:
with pytest.raises(expected_exception):
quantize_model_for_ptq(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
else:
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(
model.model.embed_tokens.weight, AffineQuantizedTensor
), "Embedding weight should be quantized"
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
if activation_dtype:
assert isinstance(
child.weight, LinearActivationQuantizedTensor
), "Linear weight should be quantized with activation quantization"
else:
assert isinstance(
child.weight, AffineQuantizedTensor
), "Linear weight should be quantized without activation quantization"
class TestQuantizationCallback:
"""
Test QATCallback
"""
@pytest.fixture()
def trainer_state(self):
return TrainerState(
global_step=0,
)
@require_torch_2_6_0
def test_qat_callback_fake_quant_after_n_steps(
self, model, trainer_state
): # pylint: disable=redefined-outer-name
cfg = QATConfig(
weight_dtype="int8",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
fake_quant_after_n_steps=100,
)
prepare_model_for_qat(
model,
cfg.weight_dtype,
cfg.group_size,
cfg.activation_dtype,
cfg.quantize_embedding,
)
# ensure model has been quantized
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert isinstance(model.lm_head, FakeQuantizedLinear)
assert model.lm_head.weight_fake_quantizer.enabled
qat_callback = QATCallback(cfg)
# simulate first training step
qat_callback.on_step_begin(
args=None,
state=trainer_state,
control=None,
model=model,
)
# quantization should have been disabled
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
assert not model.lm_head.weight_fake_quantizer.enabled
trainer_state.global_step = 100
qat_callback.on_step_begin(
args=None,
state=trainer_state,
control=None,
model=model,
)
# quantization should have been enabled
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
@require_torch_2_6_0
def test_qat_callback_fake_quant_after_n_steps_is_none(
self, model, trainer_state
): # pylint: disable=redefined-outer-name
cfg = QATConfig(
weight_dtype="int8",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
fake_quant_after_n_steps=None,
)
prepare_model_for_qat(
model,
cfg.weight_dtype,
cfg.group_size,
cfg.activation_dtype,
cfg.quantize_embedding,
)
# ensure model has been quantized
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert isinstance(model.lm_head, FakeQuantizedLinear)
assert model.lm_head.weight_fake_quantizer.enabled
qat_callback = QATCallback(cfg)
# simulate first training step
qat_callback.on_step_begin(
args=None,
state=trainer_state,
control=None,
model=model,
)
# quantization should be enabled from the get-go
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
class TestConvertQATModelForPTQ:
"""
Test convert_qat_model_for_ptq
"""
@require_torch_2_6_0
def test_convert_qat_model_for_ptq(
self, model
): # pylint: disable=redefined-outer-name
config = QATConfig(
weight_dtype="int8",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
)
# quantize model for qat
prepare_model_for_qat(
model,
config.weight_dtype,
config.group_size,
config.activation_dtype,
config.quantize_embedding,
)
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert isinstance(model.lm_head, FakeQuantizedLinear)
# apply conversion
convert_qat_model_for_ptq(
model,
quantize_embedding=config.quantize_embedding,
)
# ensure modules have been swapped out
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert not isinstance(model.lm_head, FakeQuantizedLinear)
# ensure weights have been quantized
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
assert isinstance(model.lm_head.weight, nn.Parameter)

View File

@@ -10,8 +10,6 @@ from functools import wraps
from pathlib import Path
import torch
# from importlib.metadata import version
from packaging import version
from tbparse import SummaryReader