improve inference
This commit is contained in:
@@ -79,31 +79,31 @@ def do_inference(cfg, model, tokenizer):
|
||||
|
||||
from axolotl.prompters import ReflectAlpacaPrompter
|
||||
|
||||
instruction = str(input("Give me an instruction: "))
|
||||
instruction = (
|
||||
instruction if not instruction else "Tell me a joke about dromedaries."
|
||||
)
|
||||
prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
while True:
|
||||
instruction = str(input("Give me an instruction: "))
|
||||
if not instruction:
|
||||
return
|
||||
prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# gc = GenerationConfig() # TODO swap out and use this
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to("cuda"),
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=100,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# gc = GenerationConfig() # TODO swap out and use this
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to("cuda"),
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=100,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
|
||||
Reference in New Issue
Block a user