fix: chat template jinja file not being loaded during inference (#3112)
* fix: chat template jinja file not being loaded during inference * fix: bot comment
This commit is contained in:
@@ -14,10 +14,7 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
|||||||
from axolotl.cli.args import InferenceCliArgs
|
from axolotl.cli.args import InferenceCliArgs
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.chat_templates import (
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
get_chat_template,
|
|
||||||
get_chat_template_from_config,
|
|
||||||
)
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -64,7 +61,9 @@ def do_inference(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
elif cfg.chat_template:
|
elif cfg.chat_template:
|
||||||
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
|
chat_template_str = get_chat_template_from_config(
|
||||||
|
cfg, ds_cfg=None, tokenizer=tokenizer
|
||||||
|
)
|
||||||
elif cfg.datasets[0].type == "chat_template":
|
elif cfg.datasets[0].type == "chat_template":
|
||||||
chat_template_str = get_chat_template_from_config(
|
chat_template_str = get_chat_template_from_config(
|
||||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||||
@@ -159,7 +158,13 @@ def do_inference_gradio(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
elif cfg.chat_template:
|
elif cfg.chat_template:
|
||||||
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
|
chat_template_str = get_chat_template_from_config(
|
||||||
|
cfg, ds_cfg=None, tokenizer=tokenizer
|
||||||
|
)
|
||||||
|
elif cfg.datasets[0].type == "chat_template":
|
||||||
|
chat_template_str = get_chat_template_from_config(
|
||||||
|
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user