improve inference

This commit is contained in:
Wing Lian
2023-04-19 12:57:27 -04:00
parent 5749eb0a1c
commit d65385912e
2 changed files with 42 additions and 39 deletions

View File

@@ -79,31 +79,31 @@ def do_inference(cfg, model, tokenizer):
from axolotl.prompters import ReflectAlpacaPrompter
instruction = str(input("Give me an instruction: "))
instruction = (
instruction if not instruction else "Tell me a joke about dromedaries."
)
prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
while True:
instruction = str(input("Give me an instruction: "))
if not instruction:
return
prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this
generated = model.generate(
inputs=batch["input_ids"].to("cuda"),
do_sample=True,
use_cache=True,
repetition_penalty=1.1,
max_new_tokens=100,
temperature=0.9,
top_p=0.95,
top_k=40,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
model.eval()
with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this
generated = model.generate(
inputs=batch["input_ids"].to("cuda"),
do_sample=True,
use_cache=True,
repetition_penalty=1.1,
max_new_tokens=100,
temperature=0.9,
top_p=0.95,
top_k=40,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def choose_config(path: Path):