Feat: Swap to GenerationConfig

This commit is contained in:
NanoCode012
2023-05-31 10:48:19 +09:00
parent 0abcd71a85
commit 988aeb9c34

View File

@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
from transformers import GenerationConfig
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
@@ -73,26 +74,33 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
instruction = get_multi_line_input()
if not instruction:
return
prompt: str = next(prompter_module().build_prompt(instruction=instruction))
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
do_sample=True,
use_cache=True,
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=100,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))