Save processor in quantizer CLI (#3290)
This commit is contained in:
@@ -8,7 +8,7 @@ from typing import Union
|
|||||||
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
|
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
|
||||||
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.loaders import load_tokenizer
|
from axolotl.loaders import load_processor, load_tokenizer
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.quantization import (
|
from axolotl.utils.quantization import (
|
||||||
TorchAOQuantDType,
|
TorchAOQuantDType,
|
||||||
@@ -66,6 +66,11 @@ def do_quantize(
|
|||||||
|
|
||||||
LOG.info(f"Loading model from {model_path}.")
|
LOG.info(f"Loading model from {model_path}.")
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
|
processor = None
|
||||||
|
if cfg.is_multimodal:
|
||||||
|
processor = load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_path)
|
config = AutoConfig.from_pretrained(model_path)
|
||||||
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
@@ -107,6 +112,10 @@ def do_quantize(
|
|||||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if processor:
|
||||||
|
LOG.info(f"Saving processor to: {str(Path(output_dir) / 'quantized')}.")
|
||||||
|
processor.save_pretrained(str(Path(output_dir) / "quantized"))
|
||||||
|
|
||||||
if hub_model_id:
|
if hub_model_id:
|
||||||
hub_model_id = (
|
hub_model_id = (
|
||||||
hub_model_id.rstrip("-")
|
hub_model_id.rstrip("-")
|
||||||
@@ -114,6 +123,8 @@ def do_quantize(
|
|||||||
)
|
)
|
||||||
model.push_to_hub(hub_model_id, safe_serialization=False)
|
model.push_to_hub(hub_model_id, safe_serialization=False)
|
||||||
tokenizer.push_to_hub(hub_model_id)
|
tokenizer.push_to_hub(hub_model_id)
|
||||||
|
if processor:
|
||||||
|
processor.push_to_hub(hub_model_id)
|
||||||
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
|
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
|
||||||
|
|
||||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
|
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
|
||||||
|
|||||||
Reference in New Issue
Block a user