Add streaming inference & fix stopping at EOS

This commit is contained in:
Glavin Wiechert
2023-06-10 08:14:47 +00:00
parent 931e606459
commit fec6bcc3e6

View File

@@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
from transformers import GenerationConfig
from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
@@ -64,13 +64,21 @@ def get_multi_line_input() -> Optional[str]:
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"})
tokenizer.add_special_tokens({"eos_token": "</s>"})
default_tokens = {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>"
}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
@@ -79,7 +87,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
@@ -98,10 +106,13 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))