@@ -90,6 +90,18 @@ class VllmServeCliArgs:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizeCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl quantize` command."""
|
||||
|
||||
base_model: Optional[str] = field(default=None)
|
||||
weight_dtype: Optional[str] = field(default=None)
|
||||
activation_dtype: Optional[str] = field(default=None)
|
||||
quantize_embedding: Optional[bool] = field(default=None)
|
||||
group_size: Optional[int] = field(default=None)
|
||||
output_dir: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluateCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
|
||||
|
||||
@@ -17,6 +17,7 @@ import axolotl
|
||||
from axolotl.cli.args import (
|
||||
EvaluateCliArgs,
|
||||
PreprocessCliArgs,
|
||||
QuantizeCliArgs,
|
||||
TrainerCliArgs,
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
@@ -333,6 +334,16 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
|
||||
do_vllm_serve(config, cli_args)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(QuantizeCliArgs)
|
||||
@filter_none_kwargs
|
||||
def quantize(config: str, **cli_args: QuantizeCliArgs):
|
||||
from axolotl.cli.quantize import do_quantize
|
||||
|
||||
do_quantize(config, cli_args)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("model", type=click.Path(exists=True, path_type=str))
|
||||
@click.argument("output", type=click.Path(exists=False, path_type=str))
|
||||
|
||||
90
src/axolotl/cli/quantize.py
Normal file
90
src/axolotl/cli/quantize.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
CLI to post-training quantize a model using torchao
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_quantize(
|
||||
config: Union[Path, str],
|
||||
cli_args: dict,
|
||||
):
|
||||
"""
|
||||
Quantizes a model's model's weights
|
||||
|
||||
Args:
|
||||
config (Union[Path, str]): The path to the config file
|
||||
cli_args (dict): Additional command-line arguments
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
cfg = load_cfg(config)
|
||||
|
||||
if cfg.qat and cfg.quantization:
|
||||
raise ValueError(
|
||||
"QAT and quantization cannot be used together. Please specify only one of qat or quantization in your config file."
|
||||
)
|
||||
|
||||
if cfg.qat:
|
||||
quantize_cfg = cfg.qat
|
||||
elif cfg.quantization:
|
||||
quantize_cfg = cfg.quantization
|
||||
else:
|
||||
raise ValueError(
|
||||
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
||||
)
|
||||
|
||||
model_path = cli_args.get("model_path") or cfg.output_dir
|
||||
if weight_dtype := cli_args.get("weight_dtype"):
|
||||
weight_dtype = TorchIntDType[weight_dtype]
|
||||
else:
|
||||
weight_dtype = quantize_cfg.weight_dtype
|
||||
if activation_dtype := cli_args.get("activation_dtype"):
|
||||
activation_dtype = TorchIntDType[activation_dtype]
|
||||
else:
|
||||
activation_dtype = quantize_cfg.activation_dtype
|
||||
group_size = cli_args.get("group_size") or quantize_cfg.group_size
|
||||
quantize_embedding = (
|
||||
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
|
||||
)
|
||||
output_dir = cli_args.get("output_dir") or cfg.output_dir
|
||||
|
||||
LOG.info(f"Loading model from {model_path}...")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
||||
|
||||
LOG.info(
|
||||
f"Quantizing model with configuration: \n"
|
||||
f"\tweight_dtype: {weight_dtype}\n"
|
||||
f"\tactivation_dtype: {activation_dtype}\n"
|
||||
f"\tgroup_size: {group_size}\n"
|
||||
f"\tquantize_embedding: {quantize_embedding}"
|
||||
)
|
||||
|
||||
quantize_model_for_ptq(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
|
||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
|
||||
model.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
progressbar=True,
|
||||
)
|
||||
tokenizer.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
progressbar=True,
|
||||
)
|
||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
|
||||
@@ -79,6 +79,7 @@ from axolotl.utils.callbacks import (
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
@@ -254,6 +255,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
if self.cfg.qat:
|
||||
callbacks.append(QATCallback(self.cfg.qat))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
|
||||
@@ -191,6 +191,7 @@ class ModelLoader:
|
||||
self._adjust_model_config()
|
||||
self._log_memory_usage()
|
||||
self._configure_embedding_dtypes()
|
||||
self._configure_qat()
|
||||
|
||||
def _resize_token_embeddings(self):
|
||||
"""Resize token embeddings if needed."""
|
||||
@@ -305,6 +306,19 @@ class ModelLoader:
|
||||
before_kbit_train_or_finetune=False,
|
||||
)
|
||||
|
||||
def _configure_qat(self):
|
||||
"""Configure QAT."""
|
||||
if self.cfg.qat:
|
||||
from axolotl.utils.quantization import prepare_model_for_qat
|
||||
|
||||
prepare_model_for_qat(
|
||||
self.model,
|
||||
self.cfg.qat.weight_dtype,
|
||||
self.cfg.qat.group_size,
|
||||
self.cfg.qat.activation_dtype,
|
||||
self.cfg.qat.quantize_embedding,
|
||||
)
|
||||
|
||||
def _load_adapters(self) -> PeftConfig | None:
|
||||
"""Load LoRA or other adapters."""
|
||||
# Load LoRA or adapter
|
||||
|
||||
@@ -80,9 +80,9 @@ class PatchManager:
|
||||
def _apply_fsdp_patches(self):
|
||||
"""Apply patches for FSDP configurations."""
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||
|
||||
patch_accelerate_fsdp_utils()
|
||||
patch_accelerate_fsdp2()
|
||||
|
||||
def _apply_adapter_patches(self):
|
||||
"""Apply patches for adapter configurations."""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation
|
||||
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -52,7 +52,146 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
|
||||
model.load_state_dict(sharded_sd, assign=True)
|
||||
|
||||
|
||||
def patch_accelerate_fsdp_utils():
|
||||
def set_state_dict_type(self, state_dict_type=None):
|
||||
"""
|
||||
Set the state dict config based on the `StateDictType`.
|
||||
"""
|
||||
import os
|
||||
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullOptimStateDictConfig,
|
||||
FullStateDictConfig,
|
||||
ShardedOptimStateDictConfig,
|
||||
ShardedStateDictConfig,
|
||||
StateDictType,
|
||||
)
|
||||
|
||||
# Override the state_dict_type if provided, typical use case:
|
||||
# user trains with sharded, but final save is with full
|
||||
if state_dict_type is not None:
|
||||
self.state_dict_type = state_dict_type
|
||||
|
||||
if self.state_dict_type is None:
|
||||
self.state_dict_type = os.environ.get(
|
||||
"FSDP_STATE_DICT_TYPE",
|
||||
"FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT",
|
||||
)
|
||||
if isinstance(self.state_dict_type, str):
|
||||
if self.state_dict_type.isdigit():
|
||||
self.state_dict_type = StateDictType(int(self.state_dict_type))
|
||||
else:
|
||||
self.state_dict_type = StateDictType[self.state_dict_type.upper()]
|
||||
|
||||
if self.state_dict_type == StateDictType.FULL_STATE_DICT:
|
||||
if self.state_dict_config is None:
|
||||
self.state_dict_config = FullStateDictConfig(
|
||||
offload_to_cpu=True, rank0_only=True
|
||||
)
|
||||
if self.optim_state_dict_config is None:
|
||||
self.optim_state_dict_config = FullOptimStateDictConfig(
|
||||
offload_to_cpu=True, rank0_only=True
|
||||
)
|
||||
elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT:
|
||||
if self.state_dict_config is None:
|
||||
self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
|
||||
if self.optim_state_dict_config is None:
|
||||
self.optim_state_dict_config = ShardedOptimStateDictConfig(
|
||||
offload_to_cpu=True
|
||||
)
|
||||
|
||||
|
||||
def get_state_dict(self, model, unwrap=True):
|
||||
"""
|
||||
Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full
|
||||
precision.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
A PyTorch model sent through [`Accelerator.prepare`]
|
||||
unwrap (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict
|
||||
|
||||
Returns:
|
||||
`dict`: The state dictionary of the model potentially without full precision.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from accelerate import Accelerator
|
||||
|
||||
>>> accelerator = Accelerator()
|
||||
>>> net = torch.nn.Linear(2, 2)
|
||||
>>> net = accelerator.prepare(net)
|
||||
>>> state_dict = accelerator.get_state_dict(net)
|
||||
```
|
||||
"""
|
||||
from accelerate import DistributedType
|
||||
from accelerate.utils import compare_versions
|
||||
|
||||
if self.distributed_type == DistributedType.DEEPSPEED:
|
||||
zero3_sharding = self.deepspeed_config["zero_optimization"]["stage"] == 3
|
||||
tp_sharding = (
|
||||
self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1
|
||||
)
|
||||
if zero3_sharding or tp_sharding:
|
||||
if model.zero_gather_16bit_weights_on_model_save():
|
||||
if tp_sharding and not compare_versions("deepspeed", ">=", "0.16.4"):
|
||||
raise ImportError(
|
||||
"Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`."
|
||||
)
|
||||
state_dict = (
|
||||
model._consolidated_16bit_state_dict() # pylint: disable=protected-access
|
||||
if tp_sharding
|
||||
else model._zero3_consolidated_16bit_state_dict() # pylint: disable=protected-access
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
|
||||
"To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
|
||||
"set `zero3_save_16bit_model` to True when using `accelerate config`. "
|
||||
"To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
|
||||
)
|
||||
else:
|
||||
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
||||
|
||||
state_dict = clone_tensors_for_torch_save(
|
||||
self.unwrap_model(model).state_dict()
|
||||
)
|
||||
elif self.is_fsdp2:
|
||||
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
||||
state_dict = {}
|
||||
sharded_state_dict = model.state_dict()
|
||||
for param_name, param in sharded_state_dict.items():
|
||||
if param.is_cpu:
|
||||
param = param.to(torch.device("cuda"))
|
||||
|
||||
param = param.full_tensor()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
state_dict[param_name] = param.cpu()
|
||||
torch.distributed.barrier()
|
||||
elif self.distributed_type == DistributedType.FSDP:
|
||||
from torch.distributed.fsdp import FullStateDictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
|
||||
full_state_dict_config = FullStateDictConfig(
|
||||
offload_to_cpu=True, rank0_only=True
|
||||
)
|
||||
with FSDP.state_dict_type(
|
||||
model, StateDictType.FULL_STATE_DICT, full_state_dict_config
|
||||
):
|
||||
state_dict = model.state_dict()
|
||||
else:
|
||||
if unwrap:
|
||||
model = self.unwrap_model(model)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def patch_accelerate_fsdp2():
|
||||
import accelerate
|
||||
from accelerate.utils import fsdp_utils
|
||||
|
||||
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
|
||||
@@ -61,3 +200,19 @@ def patch_accelerate_fsdp_utils():
|
||||
"fsdp2_load_full_state_dict",
|
||||
fsdp2_load_full_state_dict,
|
||||
)
|
||||
|
||||
accelerate.Accelerator.get_state_dict = get_state_dict
|
||||
setattr(
|
||||
sys.modules["accelerate"],
|
||||
"Accelerator.get_state_dict",
|
||||
get_state_dict,
|
||||
)
|
||||
|
||||
accelerate.utils.dataclasses.FullyShardedDataParallelPlugin.set_state_dict_type = (
|
||||
set_state_dict_type
|
||||
)
|
||||
setattr(
|
||||
sys.modules["accelerate.utils.dataclasses"],
|
||||
"FullyShardedDataParallelPlugin.set_state_dict_type",
|
||||
set_state_dict_type,
|
||||
)
|
||||
|
||||
@@ -238,13 +238,27 @@ def save_trained_model(
|
||||
model: The trained model to save.
|
||||
safe_serialization: Whether to use safe serialization.
|
||||
"""
|
||||
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
|
||||
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
|
||||
|
||||
# Post training module hooks
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "_post_training"):
|
||||
module._post_training(model, name) # pylint: disable=protected-access
|
||||
|
||||
# handle QAT
|
||||
if cfg.qat:
|
||||
from axolotl.utils.quantization import convert_qat_model_for_ptq
|
||||
|
||||
LOG.info("Processing QAT model for saving...")
|
||||
convert_qat_model_for_ptq(
|
||||
model,
|
||||
quantize_embedding=cfg.qat.quantize_embedding,
|
||||
)
|
||||
LOG.info(
|
||||
"QAT modules have been converted for PTQ. Please ensure you quantize "
|
||||
"your model weights with `axolotl quantize`."
|
||||
)
|
||||
|
||||
# Handle FSDP state dict type
|
||||
state_dict_type = "FULL_STATE_DICT"
|
||||
if trainer.is_fsdp_enabled and str(cfg.fsdp_config.fsdp_version) != "2":
|
||||
@@ -321,6 +335,8 @@ def save_trained_model(
|
||||
save_compressed=cfg.llmcompressor.save_compressed,
|
||||
)
|
||||
|
||||
LOG.info(f"Model successfully saved to {cfg.output_dir}")
|
||||
|
||||
|
||||
def create_model_card(cfg: DictDefault, trainer: Trainer):
|
||||
"""
|
||||
|
||||
50
src/axolotl/utils/callbacks/qat.py
Normal file
50
src/axolotl/utils/callbacks/qat.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""QAT Callback for HF Causal Trainer"""
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
from torch import nn
|
||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from axolotl.utils.schemas.quantization import QATConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def toggle_fake_quant(mod: nn.Module, enable: bool):
|
||||
"""
|
||||
Toggle fake quantization for any fake quantized linear or embedding layers in the model.
|
||||
|
||||
Args:
|
||||
mod: The module to toggle fake quantization for.
|
||||
enable: Whether to enable or disable fake quantization.
|
||||
"""
|
||||
if isinstance(mod, (FakeQuantizedLinear, FakeQuantizedEmbedding)):
|
||||
if (
|
||||
isinstance(mod, FakeQuantizedLinear)
|
||||
and mod.activation_fake_quantizer is not None
|
||||
):
|
||||
mod.activation_fake_quantizer.enabled = enable
|
||||
mod.weight_fake_quantizer.enabled = enable
|
||||
|
||||
|
||||
class QATCallback(TrainerCallback):
|
||||
"""
|
||||
Callback to toggle fake quantization for the model.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: QATConfig):
|
||||
self.cfg = cfg
|
||||
|
||||
def on_step_begin(
|
||||
self, args, state, control, model, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
if self.cfg.fake_quant_after_n_steps is not None:
|
||||
if state.global_step == 0:
|
||||
LOG.info(f"Disabling fake quantization at step {state.global_step}")
|
||||
model.apply(partial(toggle_fake_quant, enable=False))
|
||||
elif state.global_step == self.cfg.fake_quant_after_n_steps:
|
||||
LOG.info(f"Enabling fake quantization at step {state.global_step}")
|
||||
model.apply(partial(toggle_fake_quant, enable=True))
|
||||
@@ -103,7 +103,12 @@ def cleanup_distributed():
|
||||
termination or when training successfully completes.
|
||||
"""
|
||||
# Ensure that all operations are completed before destroying the process group
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if torch.xpu.is_available():
|
||||
torch.xpu.synchronize()
|
||||
|
||||
# Destroy the process group
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
189
src/axolotl/utils/quantization.py
Normal file
189
src/axolotl/utils/quantization.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Utilities for quantization including QAT and PTQ using torchao.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao.core.config import AOBaseConfig
|
||||
from torchao.quantization import quantize_
|
||||
from torchao.quantization.qat import (
|
||||
FakeQuantizeConfig,
|
||||
FromIntXQuantizationAwareTrainingConfig,
|
||||
IntXQuantizationAwareTrainingConfig,
|
||||
)
|
||||
from torchao.quantization.quant_api import (
|
||||
Int4DynamicActivationInt4WeightConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt4WeightConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
UIntXWeightOnlyConfig,
|
||||
_is_linear,
|
||||
)
|
||||
|
||||
from axolotl.utils.schemas.enums import TorchIntDType
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_ptq_config(
|
||||
weight_dtype: TorchIntDType,
|
||||
activation_dtype: TorchIntDType | None = None,
|
||||
group_size: int | None = None,
|
||||
) -> AOBaseConfig:
|
||||
"""
|
||||
This function is used to build a post-training quantization config.
|
||||
|
||||
Args:
|
||||
weight_dtype: The dtype to use for weight quantization.
|
||||
activation_dtype: The dtype to use for activation quantization.
|
||||
group_size: The group size to use for weight quantization.
|
||||
|
||||
Returns:
|
||||
The post-training quantization config.
|
||||
|
||||
Raises:
|
||||
ValueError: If the activation dtype is not specified and the weight dtype is not int8 or int4,
|
||||
or if the group size is not specified for int8 or int4 weight only quantization.
|
||||
"""
|
||||
if activation_dtype is None:
|
||||
if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr]
|
||||
return UIntXWeightOnlyConfig(
|
||||
dtype=weight_dtype.value,
|
||||
group_size=group_size,
|
||||
set_inductor_config=False,
|
||||
)
|
||||
if weight_dtype == TorchIntDType.int8:
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
"group_size must be specified for int8 weight only quantization"
|
||||
)
|
||||
return Int8WeightOnlyConfig(
|
||||
group_size=group_size,
|
||||
)
|
||||
if weight_dtype == TorchIntDType.int4:
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
"group_size must be specified for int4 weight only quantization"
|
||||
)
|
||||
return Int4WeightOnlyConfig(
|
||||
group_size=group_size,
|
||||
)
|
||||
if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4:
|
||||
return Int4DynamicActivationInt4WeightConfig()
|
||||
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8:
|
||||
return Int8DynamicActivationInt8WeightConfig()
|
||||
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4:
|
||||
return Int8DynamicActivationInt4WeightConfig()
|
||||
raise ValueError(
|
||||
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
|
||||
)
|
||||
|
||||
|
||||
def prepare_model_for_qat(
|
||||
model,
|
||||
weight_dtype: TorchIntDType,
|
||||
group_size: int,
|
||||
activation_dtype: TorchIntDType | None = None,
|
||||
quantize_embedding: bool = False,
|
||||
):
|
||||
"""
|
||||
This function is used to prepare a model for QAT by swapping the model's linear
|
||||
layers with fake quantized linear layers, and optionally the embedding weights with
|
||||
fake quantized embedding weights.
|
||||
|
||||
Args:
|
||||
model: The model to quantize.
|
||||
weight_dtype: The dtype to use for weight quantization.
|
||||
group_size: The group size to use for weight quantization.
|
||||
activation_dtype: The dtype to use for activation quantization.
|
||||
quantize_embedding: Whether to quantize the model's embedding weights.
|
||||
|
||||
Raises:
|
||||
ValueError: If the activation/weight dtype combination is invalid.
|
||||
"""
|
||||
if activation_dtype:
|
||||
activation_config = FakeQuantizeConfig(
|
||||
dtype=activation_dtype.value, granularity="per_token", is_symmetric=False
|
||||
)
|
||||
weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size)
|
||||
linear_quantize_config = IntXQuantizationAwareTrainingConfig(
|
||||
activation_config=None if activation_dtype is None else activation_config,
|
||||
weight_config=weight_config,
|
||||
)
|
||||
quantize_(model, linear_quantize_config)
|
||||
if quantize_embedding:
|
||||
# activation fake quantization is not supported for embedding layers
|
||||
embedding_quantize_config = IntXQuantizationAwareTrainingConfig(
|
||||
activation_config=None,
|
||||
weight_config=weight_config,
|
||||
)
|
||||
quantize_(
|
||||
model,
|
||||
embedding_quantize_config,
|
||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||
)
|
||||
|
||||
|
||||
def quantize_model_for_ptq(
|
||||
model,
|
||||
weight_dtype: TorchIntDType,
|
||||
group_size: int | None = None,
|
||||
activation_dtype: TorchIntDType | None = None,
|
||||
quantize_embedding: bool | None = None,
|
||||
):
|
||||
"""
|
||||
This function is used to quantize a model for post-training quantization.
|
||||
It swaps the model's linear layers with fake quantized linear layers.
|
||||
If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights.
|
||||
|
||||
Args:
|
||||
model: The model to quantize.
|
||||
weight_dtype: The dtype to use for weight quantization.
|
||||
group_size: The group size to use for weight quantization.
|
||||
activation_dtype: The dtype to use for activation quantization.
|
||||
quantize_embedding: Whether to quantize the model's embedding weights.
|
||||
|
||||
"""
|
||||
linear_ptq_config = get_ptq_config(
|
||||
weight_dtype=weight_dtype,
|
||||
activation_dtype=activation_dtype,
|
||||
group_size=group_size,
|
||||
)
|
||||
quantize_(model, linear_ptq_config)
|
||||
if quantize_embedding:
|
||||
embedding_quantize_config = get_ptq_config(
|
||||
weight_dtype=weight_dtype,
|
||||
activation_dtype=None,
|
||||
group_size=group_size,
|
||||
)
|
||||
quantize_(
|
||||
model,
|
||||
embedding_quantize_config,
|
||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||
)
|
||||
|
||||
|
||||
def convert_qat_model_for_ptq(
|
||||
model,
|
||||
*,
|
||||
quantize_embedding: bool | None = None,
|
||||
):
|
||||
"""
|
||||
This function is used to convert a swap fake-quantized modules in a model
|
||||
which has been trained with QAT back to the original modules, ready for PTQ.
|
||||
|
||||
Args:
|
||||
model: The model to convert.
|
||||
quantize_embedding: Whether to quantize the model's embedding weights.
|
||||
"""
|
||||
if quantize_embedding:
|
||||
|
||||
def filter_fn(m, _):
|
||||
return isinstance(m, nn.Embedding) or _is_linear(m)
|
||||
|
||||
else:
|
||||
filter_fn = _is_linear
|
||||
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)
|
||||
@@ -44,6 +44,7 @@ from axolotl.utils.schemas.model import (
|
||||
)
|
||||
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||
from axolotl.utils.schemas.quantization import PTQConfig, QATConfig
|
||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
@@ -91,6 +92,8 @@ class AxolotlInputConfig(
|
||||
vllm: VllmConfig | None = Field(
|
||||
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
qat: QATConfig | None = None
|
||||
quantization: PTQConfig | None = None
|
||||
reward_model: bool | None = None
|
||||
process_reward_model: bool | None = None
|
||||
num_labels: int | None = None
|
||||
@@ -126,7 +129,7 @@ class AxolotlInputConfig(
|
||||
default=None,
|
||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||
)
|
||||
dataset_processes: int | None = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
|
||||
dataset_processes: int | None = Field(default=min(32, os.cpu_count() or 1))
|
||||
dataset_exact_deduplication: bool | None = None
|
||||
dataset_keep_in_memory: bool | None = None
|
||||
dataloader_pin_memory: bool | None = None
|
||||
@@ -1481,3 +1484,42 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_qat_config(cls, data):
|
||||
qat_cfg = data.get("qat", {})
|
||||
if not qat_cfg:
|
||||
return data
|
||||
|
||||
if data.get("peft"):
|
||||
raise ValueError("QAT and PEFT cannot be used together.")
|
||||
|
||||
if data.get("load_in_8bit"):
|
||||
raise ValueError("QAT and load_in_8bit cannot be used together.")
|
||||
|
||||
if data.get("load_in_4bit"):
|
||||
raise ValueError("QAT and load_in_4bit cannot be used together.")
|
||||
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
|
||||
if (
|
||||
data.get("fsdp")
|
||||
and data.get("fsdp_config")
|
||||
and str(data["fsdp_config"].get("fsdp_version")) == "2"
|
||||
):
|
||||
if version.parse(torch_version) < version.parse("2.7.0"):
|
||||
raise ValueError(
|
||||
"FSDP2 and QAT are not supported on torch version < 2.7.0"
|
||||
)
|
||||
|
||||
if version.parse(torch_version) < version.parse("2.6.0"):
|
||||
raise ValueError("QAT is not supported on torch version < 2.6.0")
|
||||
|
||||
return data
|
||||
|
||||
@@ -2,6 +2,22 @@
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TorchIntDType(Enum):
|
||||
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
|
||||
|
||||
uint1 = getattr(torch, "uint1", None) # pylint: disable=invalid-name
|
||||
uint2 = getattr(torch, "uint2", None) # pylint: disable=invalid-name
|
||||
uint3 = getattr(torch, "uint3", None) # pylint: disable=invalid-name
|
||||
uint4 = getattr(torch, "uint4", None) # pylint: disable=invalid-name
|
||||
uint5 = getattr(torch, "uint5", None) # pylint: disable=invalid-name
|
||||
uint6 = getattr(torch, "uint6", None) # pylint: disable=invalid-name
|
||||
uint7 = getattr(torch, "uint7", None) # pylint: disable=invalid-name
|
||||
int4 = getattr(torch, "int4", None) # pylint: disable=invalid-name
|
||||
int8 = getattr(torch, "int8", None) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
64
src/axolotl/utils/schemas/quantization.py
Normal file
64
src/axolotl/utils/schemas/quantization.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
QAT Config Schema
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from axolotl.utils.schemas.enums import TorchIntDType
|
||||
|
||||
|
||||
class QATConfig(BaseModel):
|
||||
"""
|
||||
QAT Config Schema
|
||||
"""
|
||||
|
||||
activation_dtype: TorchIntDType | None = Field(
|
||||
default=None, description="Activation dtype"
|
||||
)
|
||||
weight_dtype: TorchIntDType = Field(
|
||||
default=TorchIntDType.int8, description="Weight dtype"
|
||||
)
|
||||
quantize_embedding: bool | None = Field(
|
||||
default=False, description="Quantize embedding"
|
||||
)
|
||||
group_size: int | None = Field(default=32, description="Group size")
|
||||
fake_quant_after_n_steps: int | None = Field(
|
||||
default=None, description="Fake quant after n steps"
|
||||
)
|
||||
|
||||
@field_validator("activation_dtype", "weight_dtype", mode="before")
|
||||
@classmethod
|
||||
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
|
||||
if v == "int4":
|
||||
return TorchIntDType.int4
|
||||
if v == "int8":
|
||||
return TorchIntDType.int8
|
||||
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
|
||||
|
||||
|
||||
class PTQConfig(BaseModel):
|
||||
"""
|
||||
PTQ Config Schema
|
||||
"""
|
||||
|
||||
weight_dtype: TorchIntDType = Field(
|
||||
default=TorchIntDType.int8, description="Weight dtype"
|
||||
)
|
||||
activation_dtype: TorchIntDType | None = Field(
|
||||
default=None, description="Activation dtype"
|
||||
)
|
||||
quantize_embedding: bool | None = Field(
|
||||
default=None, description="Quantize embedding"
|
||||
)
|
||||
group_size: int | None = Field(default=32, description="Group size")
|
||||
|
||||
@field_validator("activation_dtype", "weight_dtype", mode="before")
|
||||
@classmethod
|
||||
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
|
||||
if v == "int4":
|
||||
return TorchIntDType.int4
|
||||
if v == "int8":
|
||||
return TorchIntDType.int8
|
||||
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
|
||||
Reference in New Issue
Block a user