Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
5a51852af1 set torchao quant config on config.json of saved model 2025-07-17 16:46:25 -04:00
2 changed files with 5 additions and 1 deletions

View File

@@ -43,7 +43,7 @@ def do_quantize(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)
model_path = cli_args.get("model_path") or cfg.output_dir
model_path = cli_args.get("base_model") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype]
else:

View File

@@ -20,6 +20,7 @@ from torchao.quantization.quant_api import (
UIntXWeightOnlyConfig,
_is_linear,
)
from transformers import TorchAoConfig
from axolotl.utils.schemas.enums import TorchIntDType
@@ -149,7 +150,9 @@ def quantize_model_for_ptq(
group_size=group_size,
)
quantize_(model, linear_ptq_config)
quantization_config = TorchAoConfig(linear_ptq_config)
if quantize_embedding:
quantization_config.include_input_output_embeddings = True
embedding_quantize_config = get_ptq_config(
weight_dtype=weight_dtype,
activation_dtype=None,
@@ -160,6 +163,7 @@ def quantize_model_for_ptq(
embedding_quantize_config,
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)
model.config.quantization_config = quantization_config
def convert_qat_model_for_ptq(