Merge pull request #119 from NanoCode012/feat/update-inference

Feat(inference): Swap to GenerationConfig
This commit is contained in:
NanoCode012
2023-05-31 14:09:18 +09:00
committed by GitHub

View File

@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch import torch
import yaml import yaml
from transformers import GenerationConfig
from axolotl.utils.data import load_prepare_datasets from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -73,26 +74,33 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
instruction = get_multi_line_input() instruction = get_multi_line_input()
if not instruction: if not instruction:
return return
prompt: str = next(prompter_module().build_prompt(instruction=instruction)) prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this generation_config = GenerationConfig(
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
do_sample=True,
use_cache=True,
repetition_penalty=1.1, repetition_penalty=1.1,
max_new_tokens=100, max_new_tokens=1024,
temperature=0.9, temperature=0.9,
top_p=0.95, top_p=0.95,
top_k=40, top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
output_scores=False, output_scores=False,
) )
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))