diff --git a/src/axolotl/cli/quantize.py b/src/axolotl/cli/quantize.py index c11bcc6d9..f4fcc6d7d 100644 --- a/src/axolotl/cli/quantize.py +++ b/src/axolotl/cli/quantize.py @@ -8,7 +8,7 @@ from typing import Union from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig from axolotl.cli.config import load_cfg -from axolotl.loaders import load_tokenizer +from axolotl.loaders import load_processor, load_tokenizer from axolotl.utils.logging import get_logger from axolotl.utils.quantization import ( TorchAOQuantDType, @@ -66,6 +66,11 @@ def do_quantize( LOG.info(f"Loading model from {model_path}.") tokenizer = load_tokenizer(cfg) + + processor = None + if cfg.is_multimodal: + processor = load_processor(cfg, tokenizer) + config = AutoConfig.from_pretrained(model_path) torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None model = AutoModelForCausalLM.from_pretrained( @@ -107,6 +112,10 @@ def do_quantize( save_jinja_files=cfg.tokenizer_save_jinja_files, ) + if processor: + LOG.info(f"Saving processor to: {str(Path(output_dir) / 'quantized')}.") + processor.save_pretrained(str(Path(output_dir) / "quantized")) + if hub_model_id: hub_model_id = ( hub_model_id.rstrip("-") @@ -114,6 +123,8 @@ def do_quantize( ) model.push_to_hub(hub_model_id, safe_serialization=False) tokenizer.push_to_hub(hub_model_id) + if processor: + processor.push_to_hub(hub_model_id) LOG.info(f"Quantized model pushed to: {hub_model_id}.") LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")