@@ -43,6 +43,7 @@ quartodoc:
|
||||
- cli.vllm_serve
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- cli.quantize
|
||||
- title: Trainers
|
||||
desc: Training implementations
|
||||
contents:
|
||||
@@ -147,6 +148,7 @@ quartodoc:
|
||||
- utils.optimizers.adopt
|
||||
- utils.data.pretraining
|
||||
- utils.data.sft
|
||||
- utils.quantization
|
||||
- title: Schemas
|
||||
desc: Pydantic data models for Axolotl config
|
||||
contents:
|
||||
@@ -196,7 +198,7 @@ quartodoc:
|
||||
- utils.callbacks.lisa
|
||||
- utils.callbacks.mlflow_
|
||||
- utils.callbacks.comet_
|
||||
|
||||
- utils.callbacks.qat
|
||||
website:
|
||||
title: "Axolotl"
|
||||
description: "We make fine-tuning accessible, scalable, and fun"
|
||||
@@ -256,6 +258,8 @@ website:
|
||||
- docs/lr_groups.qmd
|
||||
- docs/lora_optims.qmd
|
||||
- docs/dataset_loading.qmd
|
||||
- docs/qat.qmd
|
||||
- docs/quantize.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
contents:
|
||||
|
||||
10
docs/cli.qmd
10
docs/cli.qmd
@@ -209,6 +209,16 @@ axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
|
||||
|
||||
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
|
||||
|
||||
### quantize
|
||||
|
||||
Quantizes a model using the quantization configuration specified in your YAML file.
|
||||
|
||||
```bash
|
||||
axolotl quantize config.yml
|
||||
```
|
||||
|
||||
See [Quantization](./quantize.qmd) for more details.
|
||||
|
||||
|
||||
## Legacy CLI Usage
|
||||
|
||||
|
||||
@@ -65,6 +65,20 @@ bnb_config_kwargs:
|
||||
bnb_4bit_quant_type: nf4
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
# quantization aware training
|
||||
qat:
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
|
||||
|
||||
# post-training quantization
|
||||
quantization:
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
|
||||
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
|
||||
32
docs/qat.qmd
Normal file
32
docs/qat.qmd
Normal file
@@ -0,0 +1,32 @@
|
||||
---
|
||||
title: "Quantization Aware Training (QAT)"
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-expand: 2
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
[Quantization Aware Training](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/#quantization-aware-training) (QAT) is a technique for improving the accuracy of models which are quantized
|
||||
by applying "fake" quantizations to the model's weights (and optionally, activations) during training. This fake
|
||||
quantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually
|
||||
quantized, the accuracy loss is minimized. We use the quantization techniques implemented in [torchao](https://github.com/pytorch/ao) to provide
|
||||
support for QAT and post-training quantization (PTQ) in axolotl.
|
||||
|
||||
We recommend reviewing the excellent QAT tutorial in the [torchtune library](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#quantizing-the-qat-model),
|
||||
and the QAT documentation in the [torchao library](https://github.com/pytorch/ao/tree/main/torchao/quantization/qat), for more details.
|
||||
|
||||
## Configuring QAT in Axolotl
|
||||
|
||||
To enable QAT in axolotl, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
qat:
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
|
||||
```
|
||||
|
||||
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize` command](./quantize.md) to do this.
|
||||
53
docs/quantize.qmd
Normal file
53
docs/quantize.qmd
Normal file
@@ -0,0 +1,53 @@
|
||||
---
|
||||
title: "Quantization with torchao"
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-expand: 2
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
Quantization is a technique to lower the memory footprint of your model, potentially at the cost of accuracy or model performance. We support quantizing your model using the [torchao](https://github.com/pytorch/ao) library. Quantization is supported for both post-training quantization (PTQ) and quantization-aware training (QAT).
|
||||
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment.
|
||||
|
||||
:::
|
||||
|
||||
## Configuring Quantization in Axolotl
|
||||
|
||||
Quantization is configured using the `quantization` key in your configuration file.
|
||||
|
||||
```yaml
|
||||
base_model: # The path to the model to quantize.
|
||||
quantization:
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
|
||||
|
||||
output_dir: # The path to the output directory.
|
||||
```
|
||||
|
||||
Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.
|
||||
|
||||
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which
|
||||
you used to train the model:
|
||||
|
||||
```yaml
|
||||
# qat.yml
|
||||
qat:
|
||||
activation_dtype: int8
|
||||
weight_dtype: int8
|
||||
group_size: 256
|
||||
quantize_embedding: true
|
||||
|
||||
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
|
||||
```
|
||||
|
||||
```bash
|
||||
axolotl quantize qat.yml
|
||||
```
|
||||
|
||||
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
|
||||
79
examples/llama-3/3b-qat-fsdp2.yaml
Normal file
79
examples/llama-3/3b-qat-fsdp2.yaml
Normal file
@@ -0,0 +1,79 @@
|
||||
base_model: meta-llama/Llama-3.2-3B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
datasets:
|
||||
- path: yahma/alpaca-cleaned
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/qat_out/
|
||||
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
sequence_len: 512
|
||||
|
||||
flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
qat:
|
||||
activation_dtype: int8
|
||||
weight_dtype: int4
|
||||
group_size: 32
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
|
||||
cosine_constant_lr_ratio: 0
|
||||
cosine_min_lr_ratio: 1.0
|
||||
learning_rate: 2e-5
|
||||
save_only_model: true
|
||||
bf16: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_steps: 10
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
78
examples/qwen3/8b-qat-fsdp2.yml
Normal file
78
examples/qwen3/8b-qat-fsdp2.yml
Normal file
@@ -0,0 +1,78 @@
|
||||
base_model: Qwen/Qwen3-8B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/qat_out/
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
flex_attention: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
qat:
|
||||
activation_dtype: int8
|
||||
weight_dtype: int4
|
||||
group_size: 256
|
||||
fake_quant_after_n_steps: 1000
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
max_steps: 2000
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_steps: 10
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
@@ -63,7 +63,7 @@ langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.9.0
|
||||
torchao==0.10.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
|
||||
@@ -90,6 +90,18 @@ class VllmServeCliArgs:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizeCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl quantize` command."""
|
||||
|
||||
base_model: Optional[str] = field(default=None)
|
||||
weight_dtype: Optional[str] = field(default=None)
|
||||
activation_dtype: Optional[str] = field(default=None)
|
||||
quantize_embedding: Optional[bool] = field(default=None)
|
||||
group_size: Optional[int] = field(default=None)
|
||||
output_dir: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluateCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
|
||||
|
||||
@@ -17,6 +17,7 @@ import axolotl
|
||||
from axolotl.cli.args import (
|
||||
EvaluateCliArgs,
|
||||
PreprocessCliArgs,
|
||||
QuantizeCliArgs,
|
||||
TrainerCliArgs,
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
@@ -333,6 +334,16 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
|
||||
do_vllm_serve(config, cli_args)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(QuantizeCliArgs)
|
||||
@filter_none_kwargs
|
||||
def quantize(config: str, **cli_args: QuantizeCliArgs):
|
||||
from axolotl.cli.quantize import do_quantize
|
||||
|
||||
do_quantize(config, cli_args)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("model", type=click.Path(exists=True, path_type=str))
|
||||
@click.argument("output", type=click.Path(exists=False, path_type=str))
|
||||
|
||||
90
src/axolotl/cli/quantize.py
Normal file
90
src/axolotl/cli/quantize.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
CLI to post-training quantize a model using torchao
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_quantize(
|
||||
config: Union[Path, str],
|
||||
cli_args: dict,
|
||||
):
|
||||
"""
|
||||
Quantizes a model's model's weights
|
||||
|
||||
Args:
|
||||
config (Union[Path, str]): The path to the config file
|
||||
cli_args (dict): Additional command-line arguments
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
cfg = load_cfg(config)
|
||||
|
||||
if cfg.qat and cfg.quantization:
|
||||
raise ValueError(
|
||||
"QAT and quantization cannot be used together. Please specify only one of qat or quantization in your config file."
|
||||
)
|
||||
|
||||
if cfg.qat:
|
||||
quantize_cfg = cfg.qat
|
||||
elif cfg.quantization:
|
||||
quantize_cfg = cfg.quantization
|
||||
else:
|
||||
raise ValueError(
|
||||
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
||||
)
|
||||
|
||||
model_path = cli_args.get("model_path") or cfg.output_dir
|
||||
if weight_dtype := cli_args.get("weight_dtype"):
|
||||
weight_dtype = TorchIntDType[weight_dtype]
|
||||
else:
|
||||
weight_dtype = quantize_cfg.weight_dtype
|
||||
if activation_dtype := cli_args.get("activation_dtype"):
|
||||
activation_dtype = TorchIntDType[activation_dtype]
|
||||
else:
|
||||
activation_dtype = quantize_cfg.activation_dtype
|
||||
group_size = cli_args.get("group_size") or quantize_cfg.group_size
|
||||
quantize_embedding = (
|
||||
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
|
||||
)
|
||||
output_dir = cli_args.get("output_dir") or cfg.output_dir
|
||||
|
||||
LOG.info(f"Loading model from {model_path}...")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
||||
|
||||
LOG.info(
|
||||
f"Quantizing model with configuration: \n"
|
||||
f"\tweight_dtype: {weight_dtype}\n"
|
||||
f"\tactivation_dtype: {activation_dtype}\n"
|
||||
f"\tgroup_size: {group_size}\n"
|
||||
f"\tquantize_embedding: {quantize_embedding}"
|
||||
)
|
||||
|
||||
quantize_model_for_ptq(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
|
||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
|
||||
model.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
progressbar=True,
|
||||
)
|
||||
tokenizer.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
progressbar=True,
|
||||
)
|
||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
|
||||
@@ -79,6 +79,7 @@ from axolotl.utils.callbacks import (
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
@@ -254,6 +255,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
if self.cfg.qat:
|
||||
callbacks.append(QATCallback(self.cfg.qat))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
|
||||
@@ -191,6 +191,7 @@ class ModelLoader:
|
||||
self._adjust_model_config()
|
||||
self._log_memory_usage()
|
||||
self._configure_embedding_dtypes()
|
||||
self._configure_qat()
|
||||
|
||||
def _resize_token_embeddings(self):
|
||||
"""Resize token embeddings if needed."""
|
||||
@@ -305,6 +306,19 @@ class ModelLoader:
|
||||
before_kbit_train_or_finetune=False,
|
||||
)
|
||||
|
||||
def _configure_qat(self):
|
||||
"""Configure QAT."""
|
||||
if self.cfg.qat:
|
||||
from axolotl.utils.quantization import prepare_model_for_qat
|
||||
|
||||
prepare_model_for_qat(
|
||||
self.model,
|
||||
self.cfg.qat.weight_dtype,
|
||||
self.cfg.qat.group_size,
|
||||
self.cfg.qat.activation_dtype,
|
||||
self.cfg.qat.quantize_embedding,
|
||||
)
|
||||
|
||||
def _load_adapters(self) -> PeftConfig | None:
|
||||
"""Load LoRA or other adapters."""
|
||||
# Load LoRA or adapter
|
||||
|
||||
@@ -80,9 +80,9 @@ class PatchManager:
|
||||
def _apply_fsdp_patches(self):
|
||||
"""Apply patches for FSDP configurations."""
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||
|
||||
patch_accelerate_fsdp_utils()
|
||||
patch_accelerate_fsdp2()
|
||||
|
||||
def _apply_adapter_patches(self):
|
||||
"""Apply patches for adapter configurations."""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
|
||||
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -52,7 +52,146 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
|
||||
model.load_state_dict(sharded_sd, assign=True)
|
||||
|
||||
|
||||
def patch_accelerate_fsdp_utils():
|
||||
def set_state_dict_type(self, state_dict_type=None):
|
||||
"""
|
||||
Set the state dict config based on the `StateDictType`.
|
||||
"""
|
||||
import os
|
||||
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullOptimStateDictConfig,
|
||||
FullStateDictConfig,
|
||||
ShardedOptimStateDictConfig,
|
||||
ShardedStateDictConfig,
|
||||
StateDictType,
|
||||
)
|
||||
|
||||
# Override the state_dict_type if provided, typical use case:
|
||||
# user trains with sharded, but final save is with full
|
||||
if state_dict_type is not None:
|
||||
self.state_dict_type = state_dict_type
|
||||
|
||||
if self.state_dict_type is None:
|
||||
self.state_dict_type = os.environ.get(
|
||||
"FSDP_STATE_DICT_TYPE",
|
||||
"FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT",
|
||||
)
|
||||
if isinstance(self.state_dict_type, str):
|
||||
if self.state_dict_type.isdigit():
|
||||
self.state_dict_type = StateDictType(int(self.state_dict_type))
|
||||
else:
|
||||
self.state_dict_type = StateDictType[self.state_dict_type.upper()]
|
||||
|
||||
if self.state_dict_type == StateDictType.FULL_STATE_DICT:
|
||||
if self.state_dict_config is None:
|
||||
self.state_dict_config = FullStateDictConfig(
|
||||
offload_to_cpu=True, rank0_only=True
|
||||
)
|
||||
if self.optim_state_dict_config is None:
|
||||
self.optim_state_dict_config = FullOptimStateDictConfig(
|
||||
offload_to_cpu=True, rank0_only=True
|
||||
)
|
||||
elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT:
|
||||
if self.state_dict_config is None:
|
||||
self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
|
||||
if self.optim_state_dict_config is None:
|
||||
self.optim_state_dict_config = ShardedOptimStateDictConfig(
|
||||
offload_to_cpu=True
|
||||
)
|
||||
|
||||
|
||||
def get_state_dict(self, model, unwrap=True):
|
||||
"""
|
||||
Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full
|
||||
precision.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
A PyTorch model sent through [`Accelerator.prepare`]
|
||||
unwrap (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict
|
||||
|
||||
Returns:
|
||||
`dict`: The state dictionary of the model potentially without full precision.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from accelerate import Accelerator
|
||||
|
||||
>>> accelerator = Accelerator()
|
||||
>>> net = torch.nn.Linear(2, 2)
|
||||
>>> net = accelerator.prepare(net)
|
||||
>>> state_dict = accelerator.get_state_dict(net)
|
||||
```
|
||||
"""
|
||||
from accelerate import DistributedType
|
||||
from accelerate.utils import compare_versions
|
||||
|
||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
||||
zero3_sharding = self.deepspeed_config["zero_optimization"]["stage"] == 3
|
||||
tp_sharding = (
|
||||
self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1
|
||||
)
|
||||
if zero3_sharding or tp_sharding:
|
||||
if model.zero_gather_16bit_weights_on_model_save():
|
||||
if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"):
|
||||
raise ImportError(
|
||||
"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`."
|
||||
)
|
||||
state_dict = (
|
||||
model._consolidated_16bit_state_dict() # pylint: disable=protected-access
|
||||
if tp_sharding
|
||||
else model._zero3_consolidated_16bit_state_dict() # pylint: disable=protected-access
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
|
||||
"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
|
||||
"set `zero3_save_16bit_model` to True when using `accelerate config`. "
|
||||
"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
|
||||
)
|
||||
else:
|
||||
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
||||
|
||||
state_dict = clone_tensors_for_torch_save(
|
||||
self.unwrap_model(model).state_dict()
|
||||
)
|
||||
elif self.is_fsdp2:
|
||||
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
||||
state_dict = {}
|
||||
sharded_state_dict = model.state_dict()
|
||||
for param_name, param in sharded_state_dict.items():
|
||||
if param.is_cpu:
|
||||
param = param.to(torch.device("cuda"))
|
||||
|
||||
param = param.full_tensor()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
state_dict[param_name] = param.cpu()
|
||||
torch.distributed.barrier()
|
||||
elif self.distributed_type == DistributedType.FSDP:
|
||||
from torch.distributed.fsdp import FullStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
|
||||
full_state_dict_config = FullStateDictConfig(
|
||||
offload_to_cpu=True, rank0_only=True
|
||||
)
|
||||
with FSDP.state_dict_type(
|
||||
model, StateDictType.FULL_STATE_DICT, full_state_dict_config
|
||||
):
|
||||
state_dict = model.state_dict()
|
||||
else:
|
||||
if unwrap:
|
||||
model = self.unwrap_model(model)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def patch_accelerate_fsdp2():
|
||||
import accelerate
|
||||
from accelerate.utils import fsdp_utils
|
||||
|
||||
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
|
||||
@@ -61,3 +200,19 @@ def patch_accelerate_fsdp_utils():
|
||||
"fsdp2_load_full_state_dict",
|
||||
fsdp2_load_full_state_dict,
|
||||
)
|
||||
|
||||
accelerate.Accelerator.get_state_dict = get_state_dict
|
||||
setattr(
|
||||
sys.modules["accelerate"],
|
||||
"Accelerator.get_state_dict",
|
||||
get_state_dict,
|
||||
)
|
||||
|
||||
accelerate.utils.dataclasses.FullyShardedDataParallelPlugin.set_state_dict_type = (
|
||||
set_state_dict_type
|
||||
)
|
||||
setattr(
|
||||
sys.modules["accelerate.utils.dataclasses"],
|
||||
"FullyShardedDataParallelPlugin.set_state_dict_type",
|
||||
set_state_dict_type,
|
||||
)
|
||||
|
||||
@@ -238,13 +238,27 @@ def save_trained_model(
|
||||
model: The trained model to save.
|
||||
safe_serialization: Whether to use safe serialization.
|
||||
"""
|
||||
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
|
||||
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
|
||||
|
||||
# Post training module hooks
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "_post_training"):
|
||||
module._post_training(model, name) # pylint: disable=protected-access
|
||||
|
||||
# handle QAT
|
||||
if cfg.qat:
|
||||
from axolotl.utils.quantization import convert_qat_model_for_ptq
|
||||
|
||||
LOG.info("Processing QAT model for saving...")
|
||||
convert_qat_model_for_ptq(
|
||||
model,
|
||||
quantize_embedding=cfg.qat.quantize_embedding,
|
||||
)
|
||||
LOG.info(
|
||||
"QAT modules have been converted for PTQ. Please ensure you quantize "
|
||||
"your model weights with `axolotl quantize`."
|
||||
)
|
||||
|
||||
# Handle FSDP state dict type
|
||||
state_dict_type = "FULL_STATE_DICT"
|
||||
if trainer.is_fsdp_enabled and str(cfg.fsdp_config.fsdp_version) != "2":
|
||||
@@ -321,6 +335,8 @@ def save_trained_model(
|
||||
save_compressed=cfg.llmcompressor.save_compressed,
|
||||
)
|
||||
|
||||
LOG.info(f"Model successfully saved to {cfg.output_dir}")
|
||||
|
||||
|
||||
def create_model_card(cfg: DictDefault, trainer: Trainer):
|
||||
"""
|
||||
|
||||
50
src/axolotl/utils/callbacks/qat.py
Normal file
50
src/axolotl/utils/callbacks/qat.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""QAT Callback for HF Causal Trainer"""
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
from torch import nn
|
||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from axolotl.utils.schemas.quantization import QATConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def toggle_fake_quant(mod: nn.Module, enable: bool):
|
||||
"""
|
||||
Toggle fake quantization for any fake quantized linear or embedding layers in the model.
|
||||
|
||||
Args:
|
||||
mod: The module to toggle fake quantization for.
|
||||
enable: Whether to enable or disable fake quantization.
|
||||
"""
|
||||
if isinstance(mod, (FakeQuantizedLinear, FakeQuantizedEmbedding)):
|
||||
if (
|
||||
isinstance(mod, FakeQuantizedLinear)
|
||||
and mod.activation_fake_quantizer is not None
|
||||
):
|
||||
mod.activation_fake_quantizer.enabled = enable
|
||||
mod.weight_fake_quantizer.enabled = enable
|
||||
|
||||
|
||||
class QATCallback(TrainerCallback):
|
||||
"""
|
||||
Callback to toggle fake quantization for the model.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: QATConfig):
|
||||
self.cfg = cfg
|
||||
|
||||
def on_step_begin(
|
||||
self, args, state, control, model, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
if self.cfg.fake_quant_after_n_steps is not None:
|
||||
if state.global_step == 0:
|
||||
LOG.info(f"Disabling fake quantization at step {state.global_step}")
|
||||
model.apply(partial(toggle_fake_quant, enable=False))
|
||||
elif state.global_step == self.cfg.fake_quant_after_n_steps:
|
||||
LOG.info(f"Enabling fake quantization at step {state.global_step}")
|
||||
model.apply(partial(toggle_fake_quant, enable=True))
|
||||
@@ -103,7 +103,12 @@ def cleanup_distributed():
|
||||
termination or when training successfully completes.
|
||||
"""
|
||||
# Ensure that all operations are completed before destroying the process group
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if torch.xpu.is_available():
|
||||
torch.xpu.synchronize()
|
||||
|
||||
# Destroy the process group
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
189
src/axolotl/utils/quantization.py
Normal file
189
src/axolotl/utils/quantization.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Utilities for quantization including QAT and PTQ using torchao.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao.core.config import AOBaseConfig
|
||||
from torchao.quantization import quantize_
|
||||
from torchao.quantization.qat import (
|
||||
FakeQuantizeConfig,
|
||||
FromIntXQuantizationAwareTrainingConfig,
|
||||
IntXQuantizationAwareTrainingConfig,
|
||||
)
|
||||
from torchao.quantization.quant_api import (
|
||||
Int4DynamicActivationInt4WeightConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt4WeightConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
UIntXWeightOnlyConfig,
|
||||
_is_linear,
|
||||
)
|
||||
|
||||
from axolotl.utils.schemas.enums import TorchIntDType
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_ptq_config(
|
||||
weight_dtype: TorchIntDType,
|
||||
activation_dtype: TorchIntDType | None = None,
|
||||
group_size: int | None = None,
|
||||
) -> AOBaseConfig:
|
||||
"""
|
||||
This function is used to build a post-training quantization config.
|
||||
|
||||
Args:
|
||||
weight_dtype: The dtype to use for weight quantization.
|
||||
activation_dtype: The dtype to use for activation quantization.
|
||||
group_size: The group size to use for weight quantization.
|
||||
|
||||
Returns:
|
||||
The post-training quantization config.
|
||||
|
||||
Raises:
|
||||
ValueError: If the activation dtype is not specified and the weight dtype is not int8 or int4,
|
||||
or if the group size is not specified for int8 or int4 weight only quantization.
|
||||
"""
|
||||
if activation_dtype is None:
|
||||
if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr]
|
||||
return UIntXWeightOnlyConfig(
|
||||
dtype=weight_dtype.value,
|
||||
group_size=group_size,
|
||||
set_inductor_config=False,
|
||||
)
|
||||
if weight_dtype == TorchIntDType.int8:
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
"group_size must be specified for int8 weight only quantization"
|
||||
)
|
||||
return Int8WeightOnlyConfig(
|
||||
group_size=group_size,
|
||||
)
|
||||
if weight_dtype == TorchIntDType.int4:
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
"group_size must be specified for int4 weight only quantization"
|
||||
)
|
||||
return Int4WeightOnlyConfig(
|
||||
group_size=group_size,
|
||||
)
|
||||
if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4:
|
||||
return Int4DynamicActivationInt4WeightConfig()
|
||||
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8:
|
||||
return Int8DynamicActivationInt8WeightConfig()
|
||||
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4:
|
||||
return Int8DynamicActivationInt4WeightConfig()
|
||||
raise ValueError(
|
||||
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
|
||||
)
|
||||
|
||||
|
||||
def prepare_model_for_qat(
|
||||
model,
|
||||
weight_dtype: TorchIntDType,
|
||||
group_size: int,
|
||||
activation_dtype: TorchIntDType | None = None,
|
||||
quantize_embedding: bool = False,
|
||||
):
|
||||
"""
|
||||
This function is used to prepare a model for QAT by swapping the model's linear
|
||||
layers with fake quantized linear layers, and optionally the embedding weights with
|
||||
fake quantized embedding weights.
|
||||
|
||||
Args:
|
||||
model: The model to quantize.
|
||||
weight_dtype: The dtype to use for weight quantization.
|
||||
group_size: The group size to use for weight quantization.
|
||||
activation_dtype: The dtype to use for activation quantization.
|
||||
quantize_embedding: Whether to quantize the model's embedding weights.
|
||||
|
||||
Raises:
|
||||
ValueError: If the activation/weight dtype combination is invalid.
|
||||
"""
|
||||
if activation_dtype:
|
||||
activation_config = FakeQuantizeConfig(
|
||||
dtype=activation_dtype.value, granularity="per_token", is_symmetric=False
|
||||
)
|
||||
weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size)
|
||||
linear_quantize_config = IntXQuantizationAwareTrainingConfig(
|
||||
activation_config=None if activation_dtype is None else activation_config,
|
||||
weight_config=weight_config,
|
||||
)
|
||||
quantize_(model, linear_quantize_config)
|
||||
if quantize_embedding:
|
||||
# activation fake quantization is not supported for embedding layers
|
||||
embedding_quantize_config = IntXQuantizationAwareTrainingConfig(
|
||||
activation_config=None,
|
||||
weight_config=weight_config,
|
||||
)
|
||||
quantize_(
|
||||
model,
|
||||
embedding_quantize_config,
|
||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||
)
|
||||
|
||||
|
||||
def quantize_model_for_ptq(
|
||||
model,
|
||||
weight_dtype: TorchIntDType,
|
||||
group_size: int | None = None,
|
||||
activation_dtype: TorchIntDType | None = None,
|
||||
quantize_embedding: bool | None = None,
|
||||
):
|
||||
"""
|
||||
This function is used to quantize a model for post-training quantization.
|
||||
It swaps the model's linear layers with fake quantized linear layers.
|
||||
If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights.
|
||||
|
||||
Args:
|
||||
model: The model to quantize.
|
||||
weight_dtype: The dtype to use for weight quantization.
|
||||
group_size: The group size to use for weight quantization.
|
||||
activation_dtype: The dtype to use for activation quantization.
|
||||
quantize_embedding: Whether to quantize the model's embedding weights.
|
||||
|
||||
"""
|
||||
linear_ptq_config = get_ptq_config(
|
||||
weight_dtype=weight_dtype,
|
||||
activation_dtype=activation_dtype,
|
||||
group_size=group_size,
|
||||
)
|
||||
quantize_(model, linear_ptq_config)
|
||||
if quantize_embedding:
|
||||
embedding_quantize_config = get_ptq_config(
|
||||
weight_dtype=weight_dtype,
|
||||
activation_dtype=None,
|
||||
group_size=group_size,
|
||||
)
|
||||
quantize_(
|
||||
model,
|
||||
embedding_quantize_config,
|
||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||
)
|
||||
|
||||
|
||||
def convert_qat_model_for_ptq(
|
||||
model,
|
||||
*,
|
||||
quantize_embedding: bool | None = None,
|
||||
):
|
||||
"""
|
||||
This function is used to convert a swap fake-quantized modules in a model
|
||||
which has been trained with QAT back to the original modules, ready for PTQ.
|
||||
|
||||
Args:
|
||||
model: The model to convert.
|
||||
quantize_embedding: Whether to quantize the model's embedding weights.
|
||||
"""
|
||||
if quantize_embedding:
|
||||
|
||||
def filter_fn(m, _):
|
||||
return isinstance(m, nn.Embedding) or _is_linear(m)
|
||||
|
||||
else:
|
||||
filter_fn = _is_linear
|
||||
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)
|
||||
@@ -44,6 +44,7 @@ from axolotl.utils.schemas.model import (
|
||||
)
|
||||
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||
from axolotl.utils.schemas.quantization import PTQConfig, QATConfig
|
||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
@@ -91,6 +92,8 @@ class AxolotlInputConfig(
|
||||
vllm: VllmConfig | None = Field(
|
||||
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
qat: QATConfig | None = None
|
||||
quantization: PTQConfig | None = None
|
||||
reward_model: bool | None = None
|
||||
process_reward_model: bool | None = None
|
||||
num_labels: int | None = None
|
||||
@@ -126,7 +129,7 @@ class AxolotlInputConfig(
|
||||
default=None,
|
||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||
)
|
||||
dataset_processes: int | None = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
|
||||
dataset_processes: int | None = Field(default=min(32, os.cpu_count() or 1))
|
||||
dataset_exact_deduplication: bool | None = None
|
||||
dataset_keep_in_memory: bool | None = None
|
||||
dataloader_pin_memory: bool | None = None
|
||||
@@ -1481,3 +1484,42 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_qat_config(cls, data):
|
||||
qat_cfg = data.get("qat", {})
|
||||
if not qat_cfg:
|
||||
return data
|
||||
|
||||
if data.get("peft"):
|
||||
raise ValueError("QAT and PEFT cannot be used together.")
|
||||
|
||||
if data.get("load_in_8bit"):
|
||||
raise ValueError("QAT and load_in_8bit cannot be used together.")
|
||||
|
||||
if data.get("load_in_4bit"):
|
||||
raise ValueError("QAT and load_in_4bit cannot be used together.")
|
||||
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
|
||||
if (
|
||||
data.get("fsdp")
|
||||
and data.get("fsdp_config")
|
||||
and str(data["fsdp_config"].get("fsdp_version")) == "2"
|
||||
):
|
||||
if version.parse(torch_version) < version.parse("2.7.0"):
|
||||
raise ValueError(
|
||||
"FSDP2 and QAT are not supported on torch version < 2.7.0"
|
||||
)
|
||||
|
||||
if version.parse(torch_version) < version.parse("2.6.0"):
|
||||
raise ValueError("QAT is not supported on torch version < 2.6.0")
|
||||
|
||||
return data
|
||||
|
||||
@@ -2,6 +2,22 @@
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TorchIntDType(Enum):
|
||||
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
|
||||
|
||||
uint1 = getattr(torch, "uint1", None) # pylint: disable=invalid-name
|
||||
uint2 = getattr(torch, "uint2", None) # pylint: disable=invalid-name
|
||||
uint3 = getattr(torch, "uint3", None) # pylint: disable=invalid-name
|
||||
uint4 = getattr(torch, "uint4", None) # pylint: disable=invalid-name
|
||||
uint5 = getattr(torch, "uint5", None) # pylint: disable=invalid-name
|
||||
uint6 = getattr(torch, "uint6", None) # pylint: disable=invalid-name
|
||||
uint7 = getattr(torch, "uint7", None) # pylint: disable=invalid-name
|
||||
int4 = getattr(torch, "int4", None) # pylint: disable=invalid-name
|
||||
int8 = getattr(torch, "int8", None) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
64
src/axolotl/utils/schemas/quantization.py
Normal file
64
src/axolotl/utils/schemas/quantization.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
QAT Config Schema
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from axolotl.utils.schemas.enums import TorchIntDType
|
||||
|
||||
|
||||
class QATConfig(BaseModel):
|
||||
"""
|
||||
QAT Config Schema
|
||||
"""
|
||||
|
||||
activation_dtype: TorchIntDType | None = Field(
|
||||
default=None, description="Activation dtype"
|
||||
)
|
||||
weight_dtype: TorchIntDType = Field(
|
||||
default=TorchIntDType.int8, description="Weight dtype"
|
||||
)
|
||||
quantize_embedding: bool | None = Field(
|
||||
default=False, description="Quantize embedding"
|
||||
)
|
||||
group_size: int | None = Field(default=32, description="Group size")
|
||||
fake_quant_after_n_steps: int | None = Field(
|
||||
default=None, description="Fake quant after n steps"
|
||||
)
|
||||
|
||||
@field_validator("activation_dtype", "weight_dtype", mode="before")
|
||||
@classmethod
|
||||
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
|
||||
if v == "int4":
|
||||
return TorchIntDType.int4
|
||||
if v == "int8":
|
||||
return TorchIntDType.int8
|
||||
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
|
||||
|
||||
|
||||
class PTQConfig(BaseModel):
|
||||
"""
|
||||
PTQ Config Schema
|
||||
"""
|
||||
|
||||
weight_dtype: TorchIntDType = Field(
|
||||
default=TorchIntDType.int8, description="Weight dtype"
|
||||
)
|
||||
activation_dtype: TorchIntDType | None = Field(
|
||||
default=None, description="Activation dtype"
|
||||
)
|
||||
quantize_embedding: bool | None = Field(
|
||||
default=None, description="Quantize embedding"
|
||||
)
|
||||
group_size: int | None = Field(default=32, description="Group size")
|
||||
|
||||
@field_validator("activation_dtype", "weight_dtype", mode="before")
|
||||
@classmethod
|
||||
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
|
||||
if v == "int4":
|
||||
return TorchIntDType.int4
|
||||
if v == "int8":
|
||||
return TorchIntDType.int8
|
||||
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
|
||||
@@ -4,7 +4,6 @@ GRPO test suite
|
||||
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import subprocess # nosec B404
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -118,7 +117,10 @@ def start_vllm(
|
||||
recursive_kill(process)
|
||||
with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file:
|
||||
print(log_file.read())
|
||||
shutil.rmtree("/tmp/vllm.log")
|
||||
try:
|
||||
os.remove("/tmp/vllm.log")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
|
||||
|
||||
# return the process
|
||||
|
||||
71
tests/e2e/test_qat.py
Normal file
71
tests/e2e/test_qat.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
E2E tests for QAT
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestQATLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for QAT Llama models
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_qat_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversations",
|
||||
"message_property_mappings": {
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
"drop_system_message": True,
|
||||
"split": "train[:1%]",
|
||||
},
|
||||
],
|
||||
"chat_template": "chatml",
|
||||
"qat": {
|
||||
"quantize_embedding": True,
|
||||
"activation_dtype": "int8",
|
||||
"weight_dtype": "int8",
|
||||
"group_size": 8,
|
||||
},
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
||||
350
tests/e2e/test_quantization.py
Normal file
350
tests/e2e/test_quantization.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
Tests for axolotl.utils.quantization
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
|
||||
from torchao.quantization.granularity import PerAxis, PerGroup
|
||||
from torchao.quantization.linear_activation_quantized_tensor import (
|
||||
LinearActivationQuantizedTensor,
|
||||
)
|
||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||
from torchao.quantization.quant_api import (
|
||||
Int4DynamicActivationInt4WeightConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
UIntXWeightOnlyConfig,
|
||||
)
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.trainer_callback import TrainerState
|
||||
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.quantization import (
|
||||
convert_qat_model_for_ptq,
|
||||
get_ptq_config,
|
||||
prepare_model_for_qat,
|
||||
quantize_model_for_ptq,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import TorchIntDType
|
||||
from axolotl.utils.schemas.quantization import QATConfig
|
||||
|
||||
from tests.e2e.utils import require_torch_2_6_0
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def model():
|
||||
dummy_model = AutoModelForCausalLM.from_pretrained(
|
||||
"HuggingFaceTB/SmolLM2-135M",
|
||||
device_map="cuda",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
with torch.device(dummy_model.device):
|
||||
dummy_model.model.embed_tokens = torch.nn.Embedding(
|
||||
dummy_model.model.embed_tokens.weight.shape[0],
|
||||
dummy_model.model.embed_tokens.weight.shape[1],
|
||||
dtype=dummy_model.model.embed_tokens.weight.dtype,
|
||||
)
|
||||
return dummy_model
|
||||
|
||||
|
||||
ptq_config_test_cases = [
|
||||
# weight_dtype, activation_dtype, group_size, expected_type, expected_params
|
||||
(
|
||||
TorchIntDType.uint4,
|
||||
None,
|
||||
None,
|
||||
UIntXWeightOnlyConfig,
|
||||
{"dtype": torch.uint4, "group_size": None},
|
||||
),
|
||||
(TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}),
|
||||
(TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}),
|
||||
(
|
||||
TorchIntDType.int4,
|
||||
TorchIntDType.int4,
|
||||
None,
|
||||
Int4DynamicActivationInt4WeightConfig,
|
||||
{},
|
||||
),
|
||||
(
|
||||
TorchIntDType.int8,
|
||||
TorchIntDType.int8,
|
||||
None,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
{},
|
||||
),
|
||||
]
|
||||
|
||||
ptq_test_cases = [
|
||||
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception
|
||||
(TorchIntDType.int8, None, 8, False, None),
|
||||
(TorchIntDType.int4, None, 4, True, None),
|
||||
(TorchIntDType.uint4, None, 8, False, None),
|
||||
(TorchIntDType.int4, TorchIntDType.int4, 8, False, None),
|
||||
(TorchIntDType.int8, TorchIntDType.int8, 8, True, None),
|
||||
(TorchIntDType.int8, None, None, False, ValueError),
|
||||
(TorchIntDType.int4, None, None, False, ValueError),
|
||||
]
|
||||
|
||||
|
||||
class TestQuantization:
|
||||
"""
|
||||
Test quantization utilities
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype,activation_dtype,group_size,expected_type,expected_params",
|
||||
ptq_config_test_cases,
|
||||
)
|
||||
@require_torch_2_6_0
|
||||
def test_get_ptq_config(
|
||||
self, weight_dtype, activation_dtype, group_size, expected_type, expected_params
|
||||
):
|
||||
config = get_ptq_config(weight_dtype, activation_dtype, group_size)
|
||||
|
||||
assert isinstance(config, expected_type)
|
||||
|
||||
for param_name, param_value in expected_params.items():
|
||||
if isinstance(param_value, (PerAxis, PerGroup)):
|
||||
if isinstance(param_value, PerAxis):
|
||||
assert isinstance(getattr(config, param_name), PerAxis)
|
||||
assert getattr(config, param_name).axis == param_value.axis
|
||||
else:
|
||||
assert isinstance(getattr(config, param_name), PerGroup)
|
||||
assert (
|
||||
getattr(config, param_name).group_size == param_value.group_size
|
||||
)
|
||||
else:
|
||||
assert getattr(config, param_name) == param_value
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8]
|
||||
)
|
||||
@pytest.mark.parametrize("group_size", [4, 8])
|
||||
@pytest.mark.parametrize("quantize_embedding", [False, True])
|
||||
@require_torch_2_6_0
|
||||
def test_prepare_model_for_qat(
|
||||
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
|
||||
): # pylint: disable=redefined-outer-name
|
||||
prepare_model_for_qat(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
if quantize_embedding:
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
||||
assert (
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.dtype
|
||||
== weight_dtype.value
|
||||
)
|
||||
assert (
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.group_size
|
||||
== group_size
|
||||
)
|
||||
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child, FakeQuantizedLinear)
|
||||
assert hasattr(child, "weight_fake_quantizer")
|
||||
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
|
||||
assert child.weight_fake_quantizer.config.group_size == group_size
|
||||
if activation_dtype:
|
||||
assert hasattr(child, "activation_fake_quantizer")
|
||||
assert (
|
||||
child.activation_fake_quantizer.config.dtype
|
||||
== activation_dtype.value
|
||||
)
|
||||
else:
|
||||
assert child.activation_fake_quantizer is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception",
|
||||
ptq_test_cases,
|
||||
)
|
||||
@require_torch_2_6_0
|
||||
def test_quantize_model_for_ptq(
|
||||
self,
|
||||
model,
|
||||
weight_dtype,
|
||||
activation_dtype,
|
||||
group_size,
|
||||
quantize_embedding,
|
||||
expected_exception,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception):
|
||||
quantize_model_for_ptq(
|
||||
model,
|
||||
weight_dtype,
|
||||
group_size,
|
||||
activation_dtype,
|
||||
quantize_embedding,
|
||||
)
|
||||
else:
|
||||
quantize_model_for_ptq(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
if quantize_embedding:
|
||||
assert isinstance(
|
||||
model.model.embed_tokens.weight, AffineQuantizedTensor
|
||||
), "Embedding weight should be quantized"
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
if activation_dtype:
|
||||
assert isinstance(
|
||||
child.weight, LinearActivationQuantizedTensor
|
||||
), "Linear weight should be quantized with activation quantization"
|
||||
else:
|
||||
assert isinstance(
|
||||
child.weight, AffineQuantizedTensor
|
||||
), "Linear weight should be quantized without activation quantization"
|
||||
|
||||
|
||||
class TestQuantizationCallback:
|
||||
"""
|
||||
Test QATCallback
|
||||
"""
|
||||
|
||||
@pytest.fixture()
|
||||
def trainer_state(self):
|
||||
return TrainerState(
|
||||
global_step=0,
|
||||
)
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_qat_callback_fake_quant_after_n_steps(
|
||||
self, model, trainer_state
|
||||
): # pylint: disable=redefined-outer-name
|
||||
cfg = QATConfig(
|
||||
weight_dtype="int8",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
fake_quant_after_n_steps=100,
|
||||
)
|
||||
|
||||
prepare_model_for_qat(
|
||||
model,
|
||||
cfg.weight_dtype,
|
||||
cfg.group_size,
|
||||
cfg.activation_dtype,
|
||||
cfg.quantize_embedding,
|
||||
)
|
||||
|
||||
# ensure model has been quantized
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
qat_callback = QATCallback(cfg)
|
||||
|
||||
# simulate first training step
|
||||
qat_callback.on_step_begin(
|
||||
args=None,
|
||||
state=trainer_state,
|
||||
control=None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# quantization should have been disabled
|
||||
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert not model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
trainer_state.global_step = 100
|
||||
qat_callback.on_step_begin(
|
||||
args=None,
|
||||
state=trainer_state,
|
||||
control=None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# quantization should have been enabled
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_qat_callback_fake_quant_after_n_steps_is_none(
|
||||
self, model, trainer_state
|
||||
): # pylint: disable=redefined-outer-name
|
||||
cfg = QATConfig(
|
||||
weight_dtype="int8",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
fake_quant_after_n_steps=None,
|
||||
)
|
||||
|
||||
prepare_model_for_qat(
|
||||
model,
|
||||
cfg.weight_dtype,
|
||||
cfg.group_size,
|
||||
cfg.activation_dtype,
|
||||
cfg.quantize_embedding,
|
||||
)
|
||||
|
||||
# ensure model has been quantized
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
qat_callback = QATCallback(cfg)
|
||||
# simulate first training step
|
||||
qat_callback.on_step_begin(
|
||||
args=None,
|
||||
state=trainer_state,
|
||||
control=None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# quantization should be enabled from the get-go
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
|
||||
class TestConvertQATModelForPTQ:
|
||||
"""
|
||||
Test convert_qat_model_for_ptq
|
||||
"""
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_convert_qat_model_for_ptq(
|
||||
self, model
|
||||
): # pylint: disable=redefined-outer-name
|
||||
config = QATConfig(
|
||||
weight_dtype="int8",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
)
|
||||
|
||||
# quantize model for qat
|
||||
prepare_model_for_qat(
|
||||
model,
|
||||
config.weight_dtype,
|
||||
config.group_size,
|
||||
config.activation_dtype,
|
||||
config.quantize_embedding,
|
||||
)
|
||||
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
|
||||
# apply conversion
|
||||
convert_qat_model_for_ptq(
|
||||
model,
|
||||
quantize_embedding=config.quantize_embedding,
|
||||
)
|
||||
# ensure modules have been swapped out
|
||||
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert not isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
|
||||
# ensure weights have been quantized
|
||||
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
|
||||
assert isinstance(model.lm_head.weight, nn.Parameter)
|
||||
@@ -10,8 +10,6 @@ from functools import wraps
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
# from importlib.metadata import version
|
||||
from packaging import version
|
||||
from tbparse import SummaryReader
|
||||
|
||||
|
||||
Reference in New Issue
Block a user