Add streaming inference & fix stopping at EOS

This commit is contained in:
Glavin Wiechert
2023-06-10 08:14:47 +00:00
parent 931e606459
commit fec6bcc3e6

View File

@@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch import torch
import yaml import yaml
from transformers import GenerationConfig from transformers import GenerationConfig, TextStreamer
from axolotl.utils.data import load_prepare_datasets from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -64,13 +64,21 @@ def get_multi_line_input() -> Optional[str]:
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"}) default_tokens = {
tokenizer.add_special_tokens({"bos_token": "<s>"}) "unk_token": "<unk>",
tokenizer.add_special_tokens({"eos_token": "</s>"}) "bos_token": "<s>",
"eos_token": "</s>"
}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter) prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
while True: while True:
print("=" * 80)
# support for multiline inputs # support for multiline inputs
instruction = get_multi_line_input() instruction = get_multi_line_input()
if not instruction: if not instruction:
@@ -79,7 +87,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
prompter_module().build_prompt(instruction=instruction.strip("\n")) prompter_module().build_prompt(instruction=instruction.strip("\n"))
) )
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
generation_config = GenerationConfig( generation_config = GenerationConfig(
@@ -98,10 +106,13 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
output_hidden_states=False, output_hidden_states=False,
output_scores=False, output_scores=False,
) )
streamer = TextStreamer(tokenizer)
generated = model.generate( generated = model.generate(
inputs=batch["input_ids"].to(cfg.device), inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config, generation_config=generation_config,
streamer=streamer,
) )
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))