From 8c480b28043584c46ed4d9d574ce63ee35e29dea Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 13 Nov 2024 22:06:41 +0700 Subject: [PATCH] fix: inference not using chat_template (#2019) [skip ci] --- src/axolotl/cli/__init__.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 84586ccc3..589b6b575 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -190,18 +190,15 @@ def do_inference( ): model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) prompter = cli_args.prompter - default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} - - for token, symbol in default_tokens.items(): - # If the token isn't already specified in the config, add it - if not (cfg.special_tokens and token in cfg.special_tokens): - tokenizer.add_special_tokens({token: symbol}) prompter_module = None + chat_template_str = None if prompter: prompter_module = getattr( importlib.import_module("axolotl.prompters"), prompter ) + elif cfg.chat_template: + chat_template_str = get_chat_template(cfg.chat_template) model = model.to(cfg.device, dtype=cfg.torch_dtype) @@ -211,13 +208,31 @@ def do_inference( instruction = get_multi_line_input() if not instruction: return + if prompter_module: prompt: str = next( prompter_module().build_prompt(instruction=instruction.strip("\n")) ) else: prompt = instruction.strip() - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + if chat_template_str: + batch = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": prompt, + } + ], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) print("=" * 40) model.eval() @@ -257,13 +272,6 @@ def do_inference_gradio( model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) prompter = cli_args.prompter - # default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} - default_tokens: Dict[str, str] = {} - - for token, symbol in default_tokens.items(): - # If the token isn't already specified in the config, add it - if not (cfg.special_tokens and token in cfg.special_tokens): - tokenizer.add_special_tokens({token: symbol}) prompter_module = None chat_template_str = None