From 75b20fb66f4b37e01217e01dc6fb7ef40ff5227f Mon Sep 17 00:00:00 2001 From: salman Date: Sat, 6 Dec 2025 16:27:18 +0000 Subject: [PATCH] Save processor in quantizer CLI (#3290) --- src/axolotl/cli/quantize.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/axolotl/cli/quantize.py b/src/axolotl/cli/quantize.py index c11bcc6d9..f4fcc6d7d 100644 --- a/src/axolotl/cli/quantize.py +++ b/src/axolotl/cli/quantize.py @@ -8,7 +8,7 @@ from typing import Union from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig from axolotl.cli.config import load_cfg -from axolotl.loaders import load_tokenizer +from axolotl.loaders import load_processor, load_tokenizer from axolotl.utils.logging import get_logger from axolotl.utils.quantization import ( TorchAOQuantDType, @@ -66,6 +66,11 @@ def do_quantize( LOG.info(f"Loading model from {model_path}.") tokenizer = load_tokenizer(cfg) + + processor = None + if cfg.is_multimodal: + processor = load_processor(cfg, tokenizer) + config = AutoConfig.from_pretrained(model_path) torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None model = AutoModelForCausalLM.from_pretrained( @@ -107,6 +112,10 @@ def do_quantize( save_jinja_files=cfg.tokenizer_save_jinja_files, ) + if processor: + LOG.info(f"Saving processor to: {str(Path(output_dir) / 'quantized')}.") + processor.save_pretrained(str(Path(output_dir) / "quantized")) + if hub_model_id: hub_model_id = ( hub_model_id.rstrip("-") @@ -114,6 +123,8 @@ def do_quantize( ) model.push_to_hub(hub_model_id, safe_serialization=False) tokenizer.push_to_hub(hub_model_id) + if processor: + processor.push_to_hub(hub_model_id) LOG.info(f"Quantized model pushed to: {hub_model_id}.") LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")