From 5a51852af1c827a143c1de551237d040719fb1d2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 17 Jul 2025 16:46:25 -0400 Subject: [PATCH] set torchao quant config on config.json of saved model --- src/axolotl/cli/quantize.py | 2 +- src/axolotl/utils/quantization.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/cli/quantize.py b/src/axolotl/cli/quantize.py index 0782976fe..43f7d5267 100644 --- a/src/axolotl/cli/quantize.py +++ b/src/axolotl/cli/quantize.py @@ -43,7 +43,7 @@ def do_quantize( "No quantization configuration found. Please specify either qat or quantization in your config file." ) - model_path = cli_args.get("model_path") or cfg.output_dir + model_path = cli_args.get("base_model") or cfg.output_dir if weight_dtype := cli_args.get("weight_dtype"): weight_dtype = TorchIntDType[weight_dtype] else: diff --git a/src/axolotl/utils/quantization.py b/src/axolotl/utils/quantization.py index f9a30b660..65b4c33ad 100644 --- a/src/axolotl/utils/quantization.py +++ b/src/axolotl/utils/quantization.py @@ -20,6 +20,7 @@ from torchao.quantization.quant_api import ( UIntXWeightOnlyConfig, _is_linear, ) +from transformers import TorchAoConfig from axolotl.utils.schemas.enums import TorchIntDType @@ -149,7 +150,9 @@ def quantize_model_for_ptq( group_size=group_size, ) quantize_(model, linear_ptq_config) + quantization_config = TorchAoConfig(linear_ptq_config) if quantize_embedding: + quantization_config.include_input_output_embeddings = True embedding_quantize_config = get_ptq_config( weight_dtype=weight_dtype, activation_dtype=None, @@ -160,6 +163,7 @@ def quantize_model_for_ptq( embedding_quantize_config, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), ) + model.config.quantization_config = quantization_config def convert_qat_model_for_ptq(