Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
5a51852af1 set torchao quant config on config.json of saved model 2025-07-17 16:46:25 -04:00
2 changed files with 5 additions and 1 deletions

View File

@@ -43,7 +43,7 @@ def do_quantize(
"No quantization configuration found. Please specify either qat or quantization in your config file." "No quantization configuration found. Please specify either qat or quantization in your config file."
) )
model_path = cli_args.get("model_path") or cfg.output_dir model_path = cli_args.get("base_model") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"): if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype] weight_dtype = TorchIntDType[weight_dtype]
else: else:

View File

@@ -20,6 +20,7 @@ from torchao.quantization.quant_api import (
UIntXWeightOnlyConfig, UIntXWeightOnlyConfig,
_is_linear, _is_linear,
) )
from transformers import TorchAoConfig
from axolotl.utils.schemas.enums import TorchIntDType from axolotl.utils.schemas.enums import TorchIntDType
@@ -149,7 +150,9 @@ def quantize_model_for_ptq(
group_size=group_size, group_size=group_size,
) )
quantize_(model, linear_ptq_config) quantize_(model, linear_ptq_config)
quantization_config = TorchAoConfig(linear_ptq_config)
if quantize_embedding: if quantize_embedding:
quantization_config.include_input_output_embeddings = True
embedding_quantize_config = get_ptq_config( embedding_quantize_config = get_ptq_config(
weight_dtype=weight_dtype, weight_dtype=weight_dtype,
activation_dtype=None, activation_dtype=None,
@@ -160,6 +163,7 @@ def quantize_model_for_ptq(
embedding_quantize_config, embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
) )
model.config.quantization_config = quantization_config
def convert_qat_model_for_ptq( def convert_qat_model_for_ptq(