Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 (#3107)
This commit is contained in:
2
.github/workflows/multi-gpu-e2e.yml
vendored
2
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
|||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
axolotl_extras:
|
axolotl_extras: fbgemm-gpu
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
|
|||||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -304,7 +304,7 @@ jobs:
|
|||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
gpu_type: "B200"
|
gpu_type: "B200"
|
||||||
axolotl_extras:
|
axolotl_extras: fbgemm-gpu
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -51,3 +51,11 @@ axolotl quantize qat.yml
|
|||||||
```
|
```
|
||||||
|
|
||||||
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
|
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
|
||||||
|
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
|
||||||
|
If you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it,
|
||||||
|
e.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w`
|
||||||
|
|
||||||
|
:::
|
||||||
|
|||||||
64
examples/llama-3/3b-qat-fsdp2-nvfp4.yaml
Normal file
64
examples/llama-3/3b-qat-fsdp2-nvfp4.yaml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
base_model: meta-llama/Llama-3.2-3B
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_fused_linear_cross_entropy: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: yahma/alpaca-cleaned
|
||||||
|
type: alpaca
|
||||||
|
split: train[:95%]
|
||||||
|
|
||||||
|
output_dir: ./outputs/qat_out/
|
||||||
|
dataset_prepared_path: ./outputs/dataset_prepared
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
qat:
|
||||||
|
activation_dtype: nvfp4
|
||||||
|
weight_dtype: nvfp4
|
||||||
|
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 64
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
|
||||||
|
cosine_constant_lr_ratio: 0
|
||||||
|
cosine_min_lr_ratio: 1.0
|
||||||
|
learning_rate: 2e-5
|
||||||
|
save_only_model: true
|
||||||
|
bf16: true
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|finetune_right_pad_id|>
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -15,20 +15,18 @@ liger_glu_activation: true
|
|||||||
liger_layer_norm: true
|
liger_layer_norm: true
|
||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
|
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: yahma/alpaca-cleaned
|
- path: yahma/alpaca-cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
|
split: train[:95%]
|
||||||
|
|
||||||
output_dir: ./outputs/qat_out/
|
output_dir: ./outputs/qat_out/
|
||||||
|
dataset_prepared_path: ./outputs/qat_out/dataset_prepared
|
||||||
|
|
||||||
sample_packing: true
|
sample_packing: false
|
||||||
|
sequence_len: 8192
|
||||||
sequence_len: 512
|
flash_attention: true
|
||||||
|
|
||||||
flex_attention: true
|
|
||||||
flex_attn_compile_kwargs:
|
|
||||||
dynamic: false
|
|
||||||
mode: max-autotune-no-cudagraphs
|
|
||||||
|
|
||||||
qat:
|
qat:
|
||||||
activation_dtype: int8
|
activation_dtype: int8
|
||||||
@@ -67,7 +65,7 @@ fsdp:
|
|||||||
fsdp_config:
|
fsdp_config:
|
||||||
fsdp_version: 2
|
fsdp_version: 2
|
||||||
fsdp_offload_params: false
|
fsdp_offload_params: false
|
||||||
fsdp_cpu_ram_efficient_loading: true
|
fsdp_cpu_ram_efficient_loading: false
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
@@ -76,6 +74,6 @@ fsdp_config:
|
|||||||
fsdp_activation_checkpointing: true
|
fsdp_activation_checkpointing: true
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: <|end_of_text|>
|
pad_token: <|finetune_right_pad_id|>
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ langdetect==1.0.9
|
|||||||
immutabledict==4.2.0
|
immutabledict==4.2.0
|
||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|
||||||
torchao==0.12.0
|
torchao==0.13.0
|
||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -162,6 +162,7 @@ extras_require = {
|
|||||||
"llmcompressor": [
|
"llmcompressor": [
|
||||||
"llmcompressor==0.5.1",
|
"llmcompressor==0.5.1",
|
||||||
],
|
],
|
||||||
|
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
|
||||||
}
|
}
|
||||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||||
extras_require
|
extras_require
|
||||||
|
|||||||
@@ -115,6 +115,7 @@ class QuantizeCliArgs:
|
|||||||
quantize_embedding: Optional[bool] = field(default=None)
|
quantize_embedding: Optional[bool] = field(default=None)
|
||||||
group_size: Optional[int] = field(default=None)
|
group_size: Optional[int] = field(default=None)
|
||||||
output_dir: Optional[str] = field(default=None)
|
output_dir: Optional[str] = field(default=None)
|
||||||
|
hub_model_id: Optional[str] = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -5,12 +5,17 @@ CLI to post-training quantize a model using torchao
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
|
||||||
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.loaders import load_tokenizer
|
from axolotl.loaders import load_tokenizer
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
|
from axolotl.utils.quantization import (
|
||||||
|
TorchAOQuantDType,
|
||||||
|
get_quantization_config,
|
||||||
|
quantization_config_to_str,
|
||||||
|
quantize_model,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -43,13 +48,13 @@ def do_quantize(
|
|||||||
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
||||||
)
|
)
|
||||||
|
|
||||||
model_path = cli_args.get("model_path") or cfg.output_dir
|
model_path = cli_args.get("base_model") or cfg.output_dir
|
||||||
if weight_dtype := cli_args.get("weight_dtype"):
|
if weight_dtype := cli_args.get("weight_dtype"):
|
||||||
weight_dtype = TorchIntDType[weight_dtype]
|
weight_dtype = TorchAOQuantDType.from_string(weight_dtype)
|
||||||
else:
|
else:
|
||||||
weight_dtype = quantize_cfg.weight_dtype
|
weight_dtype = quantize_cfg.weight_dtype
|
||||||
if activation_dtype := cli_args.get("activation_dtype"):
|
if activation_dtype := cli_args.get("activation_dtype"):
|
||||||
activation_dtype = TorchIntDType[activation_dtype]
|
activation_dtype = TorchAOQuantDType.from_string(activation_dtype)
|
||||||
else:
|
else:
|
||||||
activation_dtype = quantize_cfg.activation_dtype
|
activation_dtype = quantize_cfg.activation_dtype
|
||||||
group_size = cli_args.get("group_size") or quantize_cfg.group_size
|
group_size = cli_args.get("group_size") or quantize_cfg.group_size
|
||||||
@@ -57,10 +62,15 @@ def do_quantize(
|
|||||||
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
|
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
|
||||||
)
|
)
|
||||||
output_dir = cli_args.get("output_dir") or cfg.output_dir
|
output_dir = cli_args.get("output_dir") or cfg.output_dir
|
||||||
|
hub_model_id = cli_args.get("hub_model_id") or cfg.hub_model_id
|
||||||
|
|
||||||
LOG.info(f"Loading model from {model_path}...")
|
LOG.info(f"Loading model from {model_path}.")
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
config = AutoConfig.from_pretrained(model_path)
|
||||||
|
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path, device_map="auto", torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Quantizing model with configuration: \n"
|
f"Quantizing model with configuration: \n"
|
||||||
@@ -70,11 +80,21 @@ def do_quantize(
|
|||||||
f"\tquantize_embedding: {quantize_embedding}"
|
f"\tquantize_embedding: {quantize_embedding}"
|
||||||
)
|
)
|
||||||
|
|
||||||
quantize_model_for_ptq(
|
quantize_model(
|
||||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
|
quantization_config = get_quantization_config(
|
||||||
|
weight_dtype, activation_dtype, group_size
|
||||||
|
)
|
||||||
|
|
||||||
|
ao_config = TorchAoConfig(
|
||||||
|
quant_type=quantization_config,
|
||||||
|
include_input_output_embeddings=quantize_embedding,
|
||||||
|
)
|
||||||
|
model.config.quantization_config = ao_config
|
||||||
|
|
||||||
|
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
str(Path(output_dir) / "quantized"),
|
str(Path(output_dir) / "quantized"),
|
||||||
safe_serialization=False,
|
safe_serialization=False,
|
||||||
@@ -86,4 +106,14 @@ def do_quantize(
|
|||||||
progressbar=True,
|
progressbar=True,
|
||||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||||
)
|
)
|
||||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
|
|
||||||
|
if hub_model_id:
|
||||||
|
hub_model_id = (
|
||||||
|
hub_model_id.rstrip("-")
|
||||||
|
+ f"-{quantization_config_to_str[type(quantization_config)]}"
|
||||||
|
)
|
||||||
|
model.push_to_hub(hub_model_id, safe_serialization=False)
|
||||||
|
tokenizer.push_to_hub(hub_model_id)
|
||||||
|
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
|
||||||
|
|
||||||
|
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
|
||||||
|
|||||||
@@ -30,11 +30,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
|||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders import (
|
from axolotl.loaders import ModelLoader, load_processor, load_tokenizer
|
||||||
ModelLoader,
|
|
||||||
load_processor,
|
|
||||||
load_tokenizer,
|
|
||||||
)
|
|
||||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
@@ -234,16 +230,15 @@ def save_trained_model(
|
|||||||
|
|
||||||
# handle QAT
|
# handle QAT
|
||||||
if cfg.qat:
|
if cfg.qat:
|
||||||
from axolotl.utils.quantization import convert_qat_model_for_ptq
|
from axolotl.utils.quantization import convert_qat_model
|
||||||
|
|
||||||
LOG.info("Processing QAT model for saving...")
|
convert_qat_model(
|
||||||
convert_qat_model_for_ptq(
|
|
||||||
model,
|
model,
|
||||||
quantize_embedding=cfg.qat.quantize_embedding,
|
quantize_embedding=cfg.qat.quantize_embedding,
|
||||||
)
|
)
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"QAT modules have been converted for PTQ. Please ensure you quantize "
|
"QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`"
|
||||||
"your model weights with `axolotl quantize`."
|
" with the same config which you used for training."
|
||||||
)
|
)
|
||||||
# Handle ReLoRA early return case
|
# Handle ReLoRA early return case
|
||||||
if cfg.relora:
|
if cfg.relora:
|
||||||
@@ -337,9 +332,7 @@ def save_trained_model(
|
|||||||
|
|
||||||
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||||
# TODO: add integration support so this can be implemented completely within the plugin
|
# TODO: add integration support so this can be implemented completely within the plugin
|
||||||
from axolotl.integrations.llm_compressor.utils import (
|
from axolotl.integrations.llm_compressor.utils import save_compressed_model
|
||||||
save_compressed_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
save_compressed_model(
|
save_compressed_model(
|
||||||
model=model,
|
model=model,
|
||||||
|
|||||||
@@ -3,30 +3,47 @@ Utilities for quantization including QAT and PTQ using torchao.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from packaging import version
|
||||||
from torchao.core.config import AOBaseConfig
|
from torchao.core.config import AOBaseConfig
|
||||||
from torchao.quantization import quantize_
|
from torchao.quantization import quantize_
|
||||||
from torchao.quantization.qat import (
|
from torchao.quantization.qat import (
|
||||||
FakeQuantizeConfig,
|
QATConfig,
|
||||||
FromIntXQuantizationAwareTrainingConfig,
|
|
||||||
IntXQuantizationAwareTrainingConfig,
|
|
||||||
)
|
)
|
||||||
from torchao.quantization.quant_api import (
|
from torchao.quantization.quant_api import (
|
||||||
Int4DynamicActivationInt4WeightConfig,
|
Float8DynamicActivationFloat8WeightConfig,
|
||||||
Int4WeightOnlyConfig,
|
Float8DynamicActivationInt4WeightConfig,
|
||||||
Int8DynamicActivationInt4WeightConfig,
|
Int8DynamicActivationInt4WeightConfig,
|
||||||
Int8DynamicActivationInt8WeightConfig,
|
|
||||||
Int8WeightOnlyConfig,
|
|
||||||
UIntXWeightOnlyConfig,
|
|
||||||
_is_linear,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import TorchIntDType
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||||
|
|
||||||
|
quantization_config_to_str = {
|
||||||
|
Int8DynamicActivationInt4WeightConfig: "int8int4",
|
||||||
|
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
|
||||||
|
Float8DynamicActivationInt4WeightConfig: "fp8int4",
|
||||||
|
}
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
||||||
|
try:
|
||||||
|
from torchao.prototype.mx_formats import NVFP4InferenceConfig
|
||||||
|
|
||||||
|
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# int4 weight config imports will fail on machines with fbgemm-gpu installed
|
||||||
|
# without a CUDA runtime available so we do this safely
|
||||||
|
try:
|
||||||
|
from torchao.quantization.quant_api import Int4WeightOnlyConfig
|
||||||
|
|
||||||
|
quantization_config_to_str[Int4WeightOnlyConfig] = "int4"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_ptq_config(
|
def get_quantization_config(
|
||||||
weight_dtype: TorchIntDType,
|
weight_dtype: TorchAOQuantDType,
|
||||||
activation_dtype: TorchIntDType | None = None,
|
activation_dtype: TorchAOQuantDType | None = None,
|
||||||
group_size: int | None = None,
|
group_size: int | None = None,
|
||||||
) -> AOBaseConfig:
|
) -> AOBaseConfig:
|
||||||
"""
|
"""
|
||||||
@@ -45,44 +62,101 @@ def get_ptq_config(
|
|||||||
or if the group size is not specified for int8 or int4 weight only quantization.
|
or if the group size is not specified for int8 or int4 weight only quantization.
|
||||||
"""
|
"""
|
||||||
if activation_dtype is None:
|
if activation_dtype is None:
|
||||||
if not weight_dtype.value.is_signed: # type: ignore[attr-defined,union-attr]
|
if weight_dtype == TorchAOQuantDType.int8:
|
||||||
return UIntXWeightOnlyConfig(
|
raise ValueError("Int8WeightOnlyConfig is not supported by torchao QAT.")
|
||||||
dtype=weight_dtype.value,
|
if weight_dtype == TorchAOQuantDType.int4:
|
||||||
group_size=group_size,
|
from torchao.quantization.quant_api import Int4WeightOnlyConfig
|
||||||
set_inductor_config=False,
|
|
||||||
)
|
if group_size is not None:
|
||||||
if weight_dtype == TorchIntDType.int8:
|
return Int4WeightOnlyConfig(group_size=group_size, version=2)
|
||||||
if group_size is None:
|
else:
|
||||||
raise ValueError(
|
return Int4WeightOnlyConfig(version=2)
|
||||||
"group_size must be specified for int8 weight only quantization"
|
if (
|
||||||
)
|
activation_dtype == TorchAOQuantDType.int4
|
||||||
return Int8WeightOnlyConfig(
|
and weight_dtype == TorchAOQuantDType.int4
|
||||||
group_size=group_size,
|
):
|
||||||
)
|
raise ValueError(
|
||||||
if weight_dtype == TorchIntDType.int4:
|
"Int4DynamicActivationInt4WeightConfig is not supported by torchao QAT."
|
||||||
if group_size is None:
|
)
|
||||||
raise ValueError(
|
if (
|
||||||
"group_size must be specified for int4 weight only quantization"
|
activation_dtype == TorchAOQuantDType.int8
|
||||||
)
|
and weight_dtype == TorchAOQuantDType.int8
|
||||||
return Int4WeightOnlyConfig(
|
):
|
||||||
group_size=group_size,
|
raise ValueError(
|
||||||
)
|
"Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT."
|
||||||
if activation_dtype == TorchIntDType.int4 and weight_dtype == TorchIntDType.int4:
|
)
|
||||||
return Int4DynamicActivationInt4WeightConfig()
|
if (
|
||||||
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int8:
|
activation_dtype == TorchAOQuantDType.int8
|
||||||
return Int8DynamicActivationInt8WeightConfig()
|
and weight_dtype == TorchAOQuantDType.int4
|
||||||
if activation_dtype == TorchIntDType.int8 and weight_dtype == TorchIntDType.int4:
|
):
|
||||||
return Int8DynamicActivationInt4WeightConfig()
|
if group_size is not None:
|
||||||
|
return Int8DynamicActivationInt4WeightConfig(group_size=group_size)
|
||||||
|
else:
|
||||||
|
return Int8DynamicActivationInt4WeightConfig()
|
||||||
|
if (
|
||||||
|
activation_dtype == TorchAOQuantDType.float8_e4m3fn
|
||||||
|
and weight_dtype == TorchAOQuantDType.float8_e4m3fn
|
||||||
|
):
|
||||||
|
return Float8DynamicActivationFloat8WeightConfig()
|
||||||
|
if (
|
||||||
|
activation_dtype == TorchAOQuantDType.float8_e4m3fn
|
||||||
|
and weight_dtype == TorchAOQuantDType.int4
|
||||||
|
):
|
||||||
|
return Float8DynamicActivationInt4WeightConfig()
|
||||||
|
if weight_dtype == TorchAOQuantDType.nvfp4:
|
||||||
|
from torchao.prototype.mx_formats import NVFP4InferenceConfig
|
||||||
|
|
||||||
|
if group_size is not None and group_size != 16:
|
||||||
|
raise ValueError("NVFP4 quantization must use a group_size of 16")
|
||||||
|
return NVFP4InferenceConfig()
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
|
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_model(
|
||||||
|
model,
|
||||||
|
weight_dtype: TorchAOQuantDType,
|
||||||
|
group_size: int | None = None,
|
||||||
|
activation_dtype: TorchAOQuantDType | None = None,
|
||||||
|
quantize_embedding: bool | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This function is used to quantize a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to quantize.
|
||||||
|
weight_dtype: The dtype to use for weight quantization.
|
||||||
|
group_size: The group size to use for weight quantization.
|
||||||
|
activation_dtype: The dtype to use for activation quantization.
|
||||||
|
quantize_embedding: Whether to quantize the model's embedding weights.
|
||||||
|
|
||||||
|
"""
|
||||||
|
linear_ptq_config = get_quantization_config(
|
||||||
|
weight_dtype=weight_dtype,
|
||||||
|
activation_dtype=activation_dtype,
|
||||||
|
group_size=group_size,
|
||||||
|
)
|
||||||
|
quantize_(model, linear_ptq_config)
|
||||||
|
if quantize_embedding:
|
||||||
|
# activation fake quantization is not supported for embedding layers
|
||||||
|
embedding_quantize_config = get_quantization_config(
|
||||||
|
weight_dtype=weight_dtype,
|
||||||
|
activation_dtype=None,
|
||||||
|
group_size=group_size,
|
||||||
|
)
|
||||||
|
quantize_(
|
||||||
|
model,
|
||||||
|
embedding_quantize_config,
|
||||||
|
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_for_qat(
|
def prepare_model_for_qat(
|
||||||
model,
|
model,
|
||||||
weight_dtype: TorchIntDType,
|
weight_dtype: TorchAOQuantDType,
|
||||||
group_size: int,
|
group_size: int | None = None,
|
||||||
activation_dtype: TorchIntDType | None = None,
|
activation_dtype: TorchAOQuantDType | None = None,
|
||||||
quantize_embedding: bool = False,
|
quantize_embedding: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -100,86 +174,40 @@ def prepare_model_for_qat(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the activation/weight dtype combination is invalid.
|
ValueError: If the activation/weight dtype combination is invalid.
|
||||||
"""
|
"""
|
||||||
if activation_dtype:
|
base_config = get_quantization_config(
|
||||||
activation_config = FakeQuantizeConfig(
|
|
||||||
dtype=activation_dtype.value, granularity="per_token", is_symmetric=False
|
|
||||||
)
|
|
||||||
weight_config = FakeQuantizeConfig(dtype=weight_dtype.value, group_size=group_size)
|
|
||||||
linear_quantize_config = IntXQuantizationAwareTrainingConfig(
|
|
||||||
activation_config=None if activation_dtype is None else activation_config,
|
|
||||||
weight_config=weight_config,
|
|
||||||
)
|
|
||||||
quantize_(model, linear_quantize_config)
|
|
||||||
if quantize_embedding:
|
|
||||||
# activation fake quantization is not supported for embedding layers
|
|
||||||
embedding_quantize_config = IntXQuantizationAwareTrainingConfig(
|
|
||||||
activation_config=None,
|
|
||||||
weight_config=weight_config,
|
|
||||||
)
|
|
||||||
quantize_(
|
|
||||||
model,
|
|
||||||
embedding_quantize_config,
|
|
||||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def quantize_model_for_ptq(
|
|
||||||
model,
|
|
||||||
weight_dtype: TorchIntDType,
|
|
||||||
group_size: int | None = None,
|
|
||||||
activation_dtype: TorchIntDType | None = None,
|
|
||||||
quantize_embedding: bool | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
This function is used to quantize a model for post-training quantization.
|
|
||||||
It swaps the model's linear layers with fake quantized linear layers.
|
|
||||||
If `quantize_embedding` is True, it will also swap the model's embedding weights with fake quantized embedding weights.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model to quantize.
|
|
||||||
weight_dtype: The dtype to use for weight quantization.
|
|
||||||
group_size: The group size to use for weight quantization.
|
|
||||||
activation_dtype: The dtype to use for activation quantization.
|
|
||||||
quantize_embedding: Whether to quantize the model's embedding weights.
|
|
||||||
|
|
||||||
"""
|
|
||||||
linear_ptq_config = get_ptq_config(
|
|
||||||
weight_dtype=weight_dtype,
|
weight_dtype=weight_dtype,
|
||||||
activation_dtype=activation_dtype,
|
activation_dtype=activation_dtype,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
quantize_(model, linear_ptq_config)
|
qat_config = QATConfig(base_config)
|
||||||
|
quantize_(model, qat_config)
|
||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
embedding_quantize_config = get_ptq_config(
|
# activation fake quantization is not supported for embedding layers
|
||||||
|
embedding_base_config = get_quantization_config(
|
||||||
weight_dtype=weight_dtype,
|
weight_dtype=weight_dtype,
|
||||||
activation_dtype=None,
|
activation_dtype=None,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
|
embedding_qat_config = QATConfig(embedding_base_config)
|
||||||
quantize_(
|
quantize_(
|
||||||
model,
|
model,
|
||||||
embedding_quantize_config,
|
embedding_qat_config,
|
||||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_qat_model_for_ptq(
|
def convert_qat_model(
|
||||||
model,
|
model,
|
||||||
*,
|
quantize_embedding: bool = False,
|
||||||
quantize_embedding: bool | None = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This function is used to convert a swap fake-quantized modules in a model
|
This function converts a QAT model which has fake quantized layers back to the original model.
|
||||||
which has been trained with QAT back to the original modules, ready for PTQ.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model to convert.
|
|
||||||
quantize_embedding: Whether to quantize the model's embedding weights.
|
|
||||||
"""
|
"""
|
||||||
|
config = QATConfig(step="convert")
|
||||||
|
quantize_(model, config)
|
||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
|
quantize_(
|
||||||
def filter_fn(m, _):
|
model,
|
||||||
return isinstance(m, nn.Embedding) or _is_linear(m)
|
config,
|
||||||
|
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||||
else:
|
)
|
||||||
filter_fn = _is_linear
|
|
||||||
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), filter_fn=filter_fn)
|
|
||||||
|
|||||||
@@ -5,18 +5,21 @@ from enum import Enum
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class TorchIntDType(Enum):
|
class TorchAOQuantDType(Enum):
|
||||||
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
|
int4 = torch.int4
|
||||||
|
int8 = torch.int8
|
||||||
|
float8_e4m3fn = torch.float8_e4m3fn
|
||||||
|
nvfp4 = "nvfp4"
|
||||||
|
|
||||||
uint1 = getattr(torch, "uint1", None)
|
def from_string(str):
|
||||||
uint2 = getattr(torch, "uint2", None)
|
if str == "int4":
|
||||||
uint3 = getattr(torch, "uint3", None)
|
return TorchAOQuantDType.int4
|
||||||
uint4 = getattr(torch, "uint4", None)
|
if str == "int8":
|
||||||
uint5 = getattr(torch, "uint5", None)
|
return TorchAOQuantDType.int8
|
||||||
uint6 = getattr(torch, "uint6", None)
|
if str in ["float8_e4m3fn", "fp8", "float8"]:
|
||||||
uint7 = getattr(torch, "uint7", None)
|
return TorchAOQuantDType.float8_e4m3fn
|
||||||
int4 = getattr(torch, "int4", None)
|
if str == "nvfp4":
|
||||||
int8 = getattr(torch, "int8", None)
|
return TorchAOQuantDType.nvfp4
|
||||||
|
|
||||||
|
|
||||||
class RLType(str, Enum):
|
class RLType(str, Enum):
|
||||||
|
|||||||
@@ -6,7 +6,23 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import TorchIntDType
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||||
|
|
||||||
|
|
||||||
|
def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if v == "int4":
|
||||||
|
return TorchAOQuantDType.int4
|
||||||
|
if v == "int8":
|
||||||
|
return TorchAOQuantDType.int8
|
||||||
|
if v in ["float8_e4m3fn", "fp8", "float8"]:
|
||||||
|
return TorchAOQuantDType.float8_e4m3fn
|
||||||
|
if v == "nvfp4":
|
||||||
|
return TorchAOQuantDType.nvfp4
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QATConfig(BaseModel):
|
class QATConfig(BaseModel):
|
||||||
@@ -14,13 +30,13 @@ class QATConfig(BaseModel):
|
|||||||
QAT Config Schema
|
QAT Config Schema
|
||||||
"""
|
"""
|
||||||
|
|
||||||
activation_dtype: TorchIntDType | None = Field(
|
activation_dtype: TorchAOQuantDType | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
|
description="Fake quantization layout to use for activation quantization.",
|
||||||
)
|
)
|
||||||
weight_dtype: TorchIntDType = Field(
|
weight_dtype: TorchAOQuantDType = Field(
|
||||||
default=TorchIntDType.int8,
|
default=TorchAOQuantDType.int8,
|
||||||
description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"',
|
description="Fake quantization layout to use for weight quantization.",
|
||||||
)
|
)
|
||||||
quantize_embedding: bool | None = Field(
|
quantize_embedding: bool | None = Field(
|
||||||
default=False, description="Quantize embedding"
|
default=False, description="Quantize embedding"
|
||||||
@@ -35,12 +51,8 @@ class QATConfig(BaseModel):
|
|||||||
|
|
||||||
@field_validator("activation_dtype", "weight_dtype", mode="before")
|
@field_validator("activation_dtype", "weight_dtype", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
|
def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:
|
||||||
if v == "int4":
|
return validate_ao_dtype(v)
|
||||||
return TorchIntDType.int4
|
|
||||||
if v == "int8":
|
|
||||||
return TorchIntDType.int8
|
|
||||||
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
|
|
||||||
|
|
||||||
|
|
||||||
class PTQConfig(BaseModel):
|
class PTQConfig(BaseModel):
|
||||||
@@ -48,13 +60,13 @@ class PTQConfig(BaseModel):
|
|||||||
PTQ Config Schema
|
PTQ Config Schema
|
||||||
"""
|
"""
|
||||||
|
|
||||||
weight_dtype: TorchIntDType = Field(
|
weight_dtype: TorchAOQuantDType = Field(
|
||||||
default=TorchIntDType.int8,
|
default=TorchAOQuantDType.int8,
|
||||||
description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8",
|
description="Fake quantization layout to use for weight quantization.",
|
||||||
)
|
)
|
||||||
activation_dtype: TorchIntDType | None = Field(
|
activation_dtype: TorchAOQuantDType | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
|
description="Fake quantization layout to use for activation quantization.",
|
||||||
)
|
)
|
||||||
quantize_embedding: bool | None = Field(
|
quantize_embedding: bool | None = Field(
|
||||||
default=None, description="Whether to quantize the embedding layer."
|
default=None, description="Whether to quantize the embedding layer."
|
||||||
@@ -66,9 +78,5 @@ class PTQConfig(BaseModel):
|
|||||||
|
|
||||||
@field_validator("activation_dtype", "weight_dtype", mode="before")
|
@field_validator("activation_dtype", "weight_dtype", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_dtype(cls, v: Any) -> TorchIntDType | None:
|
def validate_dtype(cls, v: Any) -> TorchAOQuantDType | None:
|
||||||
if v == "int4":
|
return validate_ao_dtype(v)
|
||||||
return TorchIntDType.int4
|
|
||||||
if v == "int8":
|
|
||||||
return TorchIntDType.int8
|
|
||||||
raise ValueError(f"Invalid dtype: '{v}'. Must be one of: ['int4', 'int8']")
|
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class TestQATLlama:
|
|||||||
"qat": {
|
"qat": {
|
||||||
"quantize_embedding": True,
|
"quantize_embedding": True,
|
||||||
"activation_dtype": "int8",
|
"activation_dtype": "int8",
|
||||||
"weight_dtype": "int8",
|
"weight_dtype": "int4",
|
||||||
"group_size": 8,
|
"group_size": 8,
|
||||||
},
|
},
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
@@ -111,7 +111,7 @@ class TestQATLlama:
|
|||||||
"qat": {
|
"qat": {
|
||||||
"quantize_embedding": True,
|
"quantize_embedding": True,
|
||||||
"activation_dtype": "int8",
|
"activation_dtype": "int8",
|
||||||
"weight_dtype": "int8",
|
"weight_dtype": "int4",
|
||||||
"group_size": 8,
|
"group_size": 8,
|
||||||
},
|
},
|
||||||
"save_first_step": False,
|
"save_first_step": False,
|
||||||
|
|||||||
@@ -5,41 +5,40 @@ Tests for axolotl.utils.quantization
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
|
from torchao.quantization import LinearActivationQuantizedTensor
|
||||||
from torchao.quantization.granularity import PerAxis, PerGroup
|
|
||||||
from torchao.quantization.linear_activation_quantized_tensor import (
|
|
||||||
LinearActivationQuantizedTensor,
|
|
||||||
)
|
|
||||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||||
from torchao.quantization.quant_api import (
|
from torchao.quantization.quant_api import (
|
||||||
Int4DynamicActivationInt4WeightConfig,
|
Float8DynamicActivationFloat8WeightConfig,
|
||||||
Int4WeightOnlyConfig,
|
Float8DynamicActivationInt4WeightConfig,
|
||||||
Int8DynamicActivationInt8WeightConfig,
|
Int8DynamicActivationInt4WeightConfig,
|
||||||
Int8WeightOnlyConfig,
|
|
||||||
UIntXWeightOnlyConfig,
|
|
||||||
)
|
)
|
||||||
|
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
|
|
||||||
from axolotl.utils.callbacks.qat import QATCallback
|
from axolotl.utils.callbacks.qat import QATCallback
|
||||||
from axolotl.utils.quantization import (
|
from axolotl.utils.quantization import (
|
||||||
convert_qat_model_for_ptq,
|
convert_qat_model,
|
||||||
get_ptq_config,
|
get_quantization_config,
|
||||||
prepare_model_for_qat,
|
prepare_model_for_qat,
|
||||||
quantize_model_for_ptq,
|
quantize_model,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.enums import TorchIntDType
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||||
from axolotl.utils.schemas.quantization import QATConfig
|
from axolotl.utils.schemas.quantization import QATConfig
|
||||||
|
|
||||||
from tests.e2e.utils import require_torch_2_6_0
|
from tests.e2e.utils import (
|
||||||
|
require_torch_2_8_0,
|
||||||
|
requires_cuda_ge_8_9,
|
||||||
|
requires_sm_ge_100,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def model():
|
def model():
|
||||||
dummy_model = AutoModelForCausalLM.from_pretrained(
|
dummy_model = AutoModelForCausalLM.from_pretrained(
|
||||||
"HuggingFaceTB/SmolLM2-135M",
|
"Qwen/Qwen2-0.5B",
|
||||||
device_map="cuda",
|
device_map="auto",
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
with torch.device(dummy_model.device):
|
with torch.device(dummy_model.device):
|
||||||
@@ -48,45 +47,56 @@ def model():
|
|||||||
dummy_model.model.embed_tokens.weight.shape[1],
|
dummy_model.model.embed_tokens.weight.shape[1],
|
||||||
dtype=dummy_model.model.embed_tokens.weight.dtype,
|
dtype=dummy_model.model.embed_tokens.weight.dtype,
|
||||||
)
|
)
|
||||||
return dummy_model
|
yield dummy_model
|
||||||
|
del dummy_model
|
||||||
|
|
||||||
|
|
||||||
ptq_config_test_cases = [
|
ptq_config_test_cases = [
|
||||||
# weight_dtype, activation_dtype, group_size, expected_type, expected_params
|
# weight_dtype, activation_dtype, group_size, expected_type
|
||||||
(
|
(
|
||||||
TorchIntDType.uint4,
|
TorchAOQuantDType.int4,
|
||||||
|
TorchAOQuantDType.int8,
|
||||||
None,
|
None,
|
||||||
None,
|
Int8DynamicActivationInt4WeightConfig,
|
||||||
UIntXWeightOnlyConfig,
|
|
||||||
{"dtype": torch.uint4, "group_size": None},
|
|
||||||
),
|
|
||||||
(TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}),
|
|
||||||
(TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}),
|
|
||||||
(
|
|
||||||
TorchIntDType.int4,
|
|
||||||
TorchIntDType.int4,
|
|
||||||
None,
|
|
||||||
Int4DynamicActivationInt4WeightConfig,
|
|
||||||
{},
|
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TorchIntDType.int8,
|
TorchAOQuantDType.float8_e4m3fn,
|
||||||
TorchIntDType.int8,
|
TorchAOQuantDType.float8_e4m3fn,
|
||||||
None,
|
None,
|
||||||
Int8DynamicActivationInt8WeightConfig,
|
Float8DynamicActivationFloat8WeightConfig,
|
||||||
{},
|
),
|
||||||
|
(
|
||||||
|
TorchAOQuantDType.int4,
|
||||||
|
TorchAOQuantDType.float8_e4m3fn,
|
||||||
|
None,
|
||||||
|
Float8DynamicActivationInt4WeightConfig,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
ptq_test_cases = [
|
ptq_test_cases = [
|
||||||
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception
|
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception, expected_tensor_class
|
||||||
(TorchIntDType.int8, None, 8, False, None),
|
(TorchAOQuantDType.int4, None, 4, True, None, Int4Tensor),
|
||||||
(TorchIntDType.int4, None, 4, True, None),
|
(
|
||||||
(TorchIntDType.uint4, None, 8, False, None),
|
TorchAOQuantDType.int4,
|
||||||
(TorchIntDType.int4, TorchIntDType.int4, 8, False, None),
|
TorchAOQuantDType.int8,
|
||||||
(TorchIntDType.int8, TorchIntDType.int8, 8, True, None),
|
8,
|
||||||
(TorchIntDType.int8, None, None, False, ValueError),
|
False,
|
||||||
(TorchIntDType.int4, None, None, False, ValueError),
|
None,
|
||||||
|
LinearActivationQuantizedTensor,
|
||||||
|
),
|
||||||
|
# (
|
||||||
|
# TorchAOQuantDType.int4,
|
||||||
|
# TorchAOQuantDType.float8_e4m3fn,
|
||||||
|
# None,
|
||||||
|
# False,
|
||||||
|
# None,
|
||||||
|
# Int4Tensor,
|
||||||
|
# ),
|
||||||
|
(TorchAOQuantDType.int4, None, None, False, None, Int4Tensor),
|
||||||
|
# Deprecated configs
|
||||||
|
(TorchAOQuantDType.int8, None, 8, False, ValueError, None),
|
||||||
|
(TorchAOQuantDType.int4, TorchAOQuantDType.int4, 8, False, ValueError, None),
|
||||||
|
(TorchAOQuantDType.int8, TorchAOQuantDType.int8, 8, True, ValueError, None),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -96,44 +106,132 @@ class TestQuantization:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"weight_dtype,activation_dtype,group_size,expected_type,expected_params",
|
"weight_dtype,activation_dtype,group_size,expected_type",
|
||||||
ptq_config_test_cases,
|
ptq_config_test_cases,
|
||||||
)
|
)
|
||||||
@require_torch_2_6_0
|
@requires_cuda_ge_8_9
|
||||||
|
@require_torch_2_8_0
|
||||||
def test_get_ptq_config(
|
def test_get_ptq_config(
|
||||||
self, weight_dtype, activation_dtype, group_size, expected_type, expected_params
|
self, weight_dtype, activation_dtype, group_size, expected_type
|
||||||
):
|
):
|
||||||
config = get_ptq_config(weight_dtype, activation_dtype, group_size)
|
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
|
||||||
|
|
||||||
assert isinstance(config, expected_type)
|
assert isinstance(config, expected_type)
|
||||||
|
|
||||||
for param_name, param_value in expected_params.items():
|
@requires_cuda_ge_8_9
|
||||||
if isinstance(param_value, (PerAxis, PerGroup)):
|
@require_torch_2_8_0
|
||||||
if isinstance(param_value, PerAxis):
|
def test_get_ptq_config_int4_weight_only(self):
|
||||||
assert isinstance(getattr(config, param_name), PerAxis)
|
from torchao.quantization.quant_api import Int4WeightOnlyConfig
|
||||||
assert getattr(config, param_name).axis == param_value.axis
|
|
||||||
else:
|
config = get_quantization_config(TorchAOQuantDType.int4, None, 4)
|
||||||
assert isinstance(getattr(config, param_name), PerGroup)
|
assert isinstance(config, Int4WeightOnlyConfig)
|
||||||
assert (
|
|
||||||
getattr(config, param_name).group_size == param_value.group_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert getattr(config, param_name) == param_value
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4]
|
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception,expected_tensor_class",
|
||||||
|
ptq_test_cases,
|
||||||
)
|
)
|
||||||
|
@requires_cuda_ge_8_9
|
||||||
|
@require_torch_2_8_0
|
||||||
|
def test_quantize_model_for_ptq(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
weight_dtype,
|
||||||
|
activation_dtype,
|
||||||
|
group_size,
|
||||||
|
quantize_embedding,
|
||||||
|
expected_exception,
|
||||||
|
expected_tensor_class,
|
||||||
|
):
|
||||||
|
if expected_exception:
|
||||||
|
with pytest.raises(expected_exception):
|
||||||
|
quantize_model(
|
||||||
|
model,
|
||||||
|
weight_dtype,
|
||||||
|
group_size,
|
||||||
|
activation_dtype,
|
||||||
|
quantize_embedding,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
quantize_model(
|
||||||
|
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||||
|
)
|
||||||
|
if quantize_embedding:
|
||||||
|
assert isinstance(
|
||||||
|
model.model.embed_tokens.weight, expected_tensor_class
|
||||||
|
), "Embedding weight should be quantized"
|
||||||
|
for child in list(model.children()):
|
||||||
|
if isinstance(child, torch.nn.Linear):
|
||||||
|
assert isinstance(child.weight, expected_tensor_class)
|
||||||
|
|
||||||
|
@require_torch_2_8_0
|
||||||
|
@requires_sm_ge_100
|
||||||
|
def test_quantize_model_for_ptq_fp8(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
):
|
||||||
|
from torchao.quantization.quantize_.workflows.float8.float8_tensor import (
|
||||||
|
Float8Tensor,
|
||||||
|
QuantizeTensorToFloat8Kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
quantize_model(
|
||||||
|
model,
|
||||||
|
TorchAOQuantDType.float8_e4m3fn,
|
||||||
|
None,
|
||||||
|
TorchAOQuantDType.float8_e4m3fn,
|
||||||
|
)
|
||||||
|
for child in list(model.children()):
|
||||||
|
if isinstance(child, torch.nn.Linear):
|
||||||
|
assert isinstance(child.weight, Float8Tensor)
|
||||||
|
assert child.weight.act_quant_kwargs is not None and isinstance(
|
||||||
|
child.weight.act_quant_kwargs, QuantizeTensorToFloat8Kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch_2_8_0
|
||||||
|
@requires_sm_ge_100
|
||||||
|
def test_quantize_model_for_ptq_nvfp4(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
):
|
||||||
|
from torchao.prototype.mx_formats.nvfp4_tensor import (
|
||||||
|
NVFP4Tensor,
|
||||||
|
QuantizeTensorToNVFP4Kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
quantize_model(model, TorchAOQuantDType.nvfp4, 16, TorchAOQuantDType.nvfp4)
|
||||||
|
for child in list(model.children()):
|
||||||
|
if isinstance(child, torch.nn.Linear):
|
||||||
|
assert isinstance(child.weight, NVFP4Tensor)
|
||||||
|
assert child.weight.act_quant_kwargs is not None and isinstance(
|
||||||
|
child.weight.act_quant_kwargs, QuantizeTensorToNVFP4Kwargs
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8]
|
"weight_dtype,activation_dtype,group_size,quantize_embedding",
|
||||||
|
[
|
||||||
|
(TorchAOQuantDType.int4, None, 8, False),
|
||||||
|
(TorchAOQuantDType.int4, None, 16, True),
|
||||||
|
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 8, False),
|
||||||
|
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 16, True),
|
||||||
|
(
|
||||||
|
TorchAOQuantDType.float8_e4m3fn,
|
||||||
|
TorchAOQuantDType.float8_e4m3fn,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
(TorchAOQuantDType.int4, TorchAOQuantDType.float8_e4m3fn, None, True),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("group_size", [4, 8])
|
@require_torch_2_8_0
|
||||||
@pytest.mark.parametrize("quantize_embedding", [False, True])
|
@requires_cuda_ge_8_9
|
||||||
@require_torch_2_6_0
|
|
||||||
def test_prepare_model_for_qat(
|
def test_prepare_model_for_qat(
|
||||||
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
|
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
|
||||||
):
|
):
|
||||||
prepare_model_for_qat(
|
prepare_model_for_qat(
|
||||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
model,
|
||||||
|
weight_dtype,
|
||||||
|
group_size,
|
||||||
|
activation_dtype,
|
||||||
|
quantize_embedding,
|
||||||
)
|
)
|
||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||||
@@ -142,17 +240,19 @@ class TestQuantization:
|
|||||||
model.model.embed_tokens.weight_fake_quantizer.config.dtype
|
model.model.embed_tokens.weight_fake_quantizer.config.dtype
|
||||||
== weight_dtype.value
|
== weight_dtype.value
|
||||||
)
|
)
|
||||||
assert (
|
if group_size:
|
||||||
model.model.embed_tokens.weight_fake_quantizer.config.group_size
|
assert (
|
||||||
== group_size
|
model.model.embed_tokens.weight_fake_quantizer.config.group_size
|
||||||
)
|
== group_size
|
||||||
|
)
|
||||||
|
|
||||||
for child in list(model.children()):
|
for child in list(model.children()):
|
||||||
if isinstance(child, torch.nn.Linear):
|
if isinstance(child, torch.nn.Linear):
|
||||||
assert isinstance(child, FakeQuantizedLinear)
|
assert isinstance(child, FakeQuantizedLinear)
|
||||||
assert hasattr(child, "weight_fake_quantizer")
|
assert hasattr(child, "weight_fake_quantizer")
|
||||||
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
|
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
|
||||||
assert child.weight_fake_quantizer.config.group_size == group_size
|
if group_size:
|
||||||
|
assert child.weight_fake_quantizer.config.group_size == group_size
|
||||||
if activation_dtype:
|
if activation_dtype:
|
||||||
assert hasattr(child, "activation_fake_quantizer")
|
assert hasattr(child, "activation_fake_quantizer")
|
||||||
assert (
|
assert (
|
||||||
@@ -162,49 +262,40 @@ class TestQuantization:
|
|||||||
else:
|
else:
|
||||||
assert child.activation_fake_quantizer is None
|
assert child.activation_fake_quantizer is None
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@require_torch_2_8_0
|
||||||
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception",
|
@requires_cuda_ge_8_9
|
||||||
ptq_test_cases,
|
def test_convert_qat_model(self, model):
|
||||||
)
|
config = QATConfig(
|
||||||
@require_torch_2_6_0
|
weight_dtype="int4",
|
||||||
def test_quantize_model_for_ptq(
|
activation_dtype="int8",
|
||||||
self,
|
group_size=8,
|
||||||
model,
|
quantize_embedding=True,
|
||||||
weight_dtype,
|
)
|
||||||
activation_dtype,
|
|
||||||
group_size,
|
# quantize model for qat
|
||||||
quantize_embedding,
|
prepare_model_for_qat(
|
||||||
expected_exception,
|
model,
|
||||||
):
|
config.weight_dtype,
|
||||||
if expected_exception:
|
config.group_size,
|
||||||
with pytest.raises(expected_exception):
|
config.activation_dtype,
|
||||||
quantize_model_for_ptq(
|
config.quantize_embedding,
|
||||||
model,
|
)
|
||||||
weight_dtype,
|
|
||||||
group_size,
|
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||||
activation_dtype,
|
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||||
quantize_embedding,
|
|
||||||
)
|
# apply conversion
|
||||||
else:
|
convert_qat_model(
|
||||||
quantize_model_for_ptq(
|
model,
|
||||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
config.quantize_embedding,
|
||||||
)
|
)
|
||||||
if quantize_embedding:
|
# ensure modules have been swapped out
|
||||||
assert isinstance(
|
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||||
model.model.embed_tokens.weight, AffineQuantizedTensor
|
assert not isinstance(model.lm_head, FakeQuantizedLinear)
|
||||||
), "Embedding weight should be quantized"
|
|
||||||
for child in list(model.children()):
|
# ensure weights have been quantized
|
||||||
if isinstance(child, torch.nn.Linear):
|
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
|
||||||
if activation_dtype:
|
assert isinstance(model.lm_head.weight, nn.Parameter)
|
||||||
assert isinstance(
|
|
||||||
child.weight, LinearActivationQuantizedTensor
|
|
||||||
), (
|
|
||||||
"Linear weight should be quantized with activation quantization"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert isinstance(child.weight, AffineQuantizedTensor), (
|
|
||||||
"Linear weight should be quantized without activation quantization"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizationCallback:
|
class TestQuantizationCallback:
|
||||||
@@ -218,10 +309,10 @@ class TestQuantizationCallback:
|
|||||||
global_step=0,
|
global_step=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch_2_6_0
|
@require_torch_2_8_0
|
||||||
def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state):
|
def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state):
|
||||||
cfg = QATConfig(
|
cfg = QATConfig(
|
||||||
weight_dtype="int8",
|
weight_dtype="int4",
|
||||||
activation_dtype="int8",
|
activation_dtype="int8",
|
||||||
group_size=8,
|
group_size=8,
|
||||||
quantize_embedding=True,
|
quantize_embedding=True,
|
||||||
@@ -268,10 +359,10 @@ class TestQuantizationCallback:
|
|||||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||||
assert model.lm_head.weight_fake_quantizer.enabled
|
assert model.lm_head.weight_fake_quantizer.enabled
|
||||||
|
|
||||||
@require_torch_2_6_0
|
@require_torch_2_8_0
|
||||||
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
|
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
|
||||||
cfg = QATConfig(
|
cfg = QATConfig(
|
||||||
weight_dtype="int8",
|
weight_dtype="int4",
|
||||||
activation_dtype="int8",
|
activation_dtype="int8",
|
||||||
group_size=8,
|
group_size=8,
|
||||||
quantize_embedding=True,
|
quantize_embedding=True,
|
||||||
@@ -304,43 +395,3 @@ class TestQuantizationCallback:
|
|||||||
# quantization should be enabled from the get-go
|
# quantization should be enabled from the get-go
|
||||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||||
assert model.lm_head.weight_fake_quantizer.enabled
|
assert model.lm_head.weight_fake_quantizer.enabled
|
||||||
|
|
||||||
|
|
||||||
class TestConvertQATModelForPTQ:
|
|
||||||
"""
|
|
||||||
Test convert_qat_model_for_ptq
|
|
||||||
"""
|
|
||||||
|
|
||||||
@require_torch_2_6_0
|
|
||||||
def test_convert_qat_model_for_ptq(self, model):
|
|
||||||
config = QATConfig(
|
|
||||||
weight_dtype="int8",
|
|
||||||
activation_dtype="int8",
|
|
||||||
group_size=8,
|
|
||||||
quantize_embedding=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# quantize model for qat
|
|
||||||
prepare_model_for_qat(
|
|
||||||
model,
|
|
||||||
config.weight_dtype,
|
|
||||||
config.group_size,
|
|
||||||
config.activation_dtype,
|
|
||||||
config.quantize_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
|
||||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
|
||||||
|
|
||||||
# apply conversion
|
|
||||||
convert_qat_model_for_ptq(
|
|
||||||
model,
|
|
||||||
quantize_embedding=config.quantize_embedding,
|
|
||||||
)
|
|
||||||
# ensure modules have been swapped out
|
|
||||||
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
|
||||||
assert not isinstance(model.lm_head, FakeQuantizedLinear)
|
|
||||||
|
|
||||||
# ensure weights have been quantized
|
|
||||||
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
|
|
||||||
assert isinstance(model.lm_head.weight, nn.Parameter)
|
|
||||||
|
|||||||
@@ -90,6 +90,18 @@ def require_torch_2_7_0(test_case):
|
|||||||
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
|
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_2_8_0(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires torch >= 2.7.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_min_2_8_0():
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
return torch_version >= version.parse("2.8.0")
|
||||||
|
|
||||||
|
return unittest.skipUnless(is_min_2_8_0(), "test requires torch>=2.8.0")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_torch_lt_2_6_0(test_case):
|
def require_torch_lt_2_6_0(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires torch < 2.6.0
|
Decorator marking a test that requires torch < 2.6.0
|
||||||
@@ -128,6 +140,24 @@ def require_llmcompressor(test_case):
|
|||||||
)(test_case)
|
)(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def requires_sm_ge_100(test_case):
|
||||||
|
is_sm_ge_100 = (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
and torch.version.cuda
|
||||||
|
and torch.cuda.get_device_capability() >= (10, 0)
|
||||||
|
)
|
||||||
|
return unittest.skipUnless(is_sm_ge_100, "test requires sm>=100")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def requires_cuda_ge_8_9(test_case):
|
||||||
|
is_cuda_ge_8_9 = (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
and torch.version.cuda
|
||||||
|
and torch.cuda.get_device_capability() >= (8, 9)
|
||||||
|
)
|
||||||
|
return unittest.skipUnless(is_cuda_ge_8_9, "test requires cuda>=8.9")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def is_hopper():
|
def is_hopper():
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
return compute_capability == (9, 0)
|
return compute_capability == (9, 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user