fix: inference not using chat_template (#2019) [skip ci]

This commit is contained in:
NanoCode012
2024-11-13 22:06:41 +07:00
committed by GitHub
parent a4b1cc6df0
commit 8c480b2804

View File

@@ -190,18 +190,15 @@ def do_inference(
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -211,13 +208,31 @@ def do_inference(
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
@@ -257,13 +272,6 @@ def do_inference_gradio(
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
default_tokens: Dict[str, str] = {}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
chat_template_str = None