refactor inference, warn if model is frozen

This commit is contained in:
Wing Lian
2023-05-07 01:53:30 -04:00
parent cb9a887047
commit 247825bd57
3 changed files with 20 additions and 4 deletions

View File

@@ -6,9 +6,11 @@ import random
import signal
import sys
from pathlib import Path
from typing import Optional
import fire
import torch
import transformers
import yaml
from attrdict import AttrDefault
@@ -46,6 +48,15 @@ def choose_device(cfg):
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + Z to finish): ")
instruction = ""
for line in sys.stdin:
instruction += line
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"})
@@ -55,8 +66,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
while True:
# support for multiline inputs
print("Give me an instruction (Ctrl + D to finish): ")
instruction = pathlib.Path("/proc/self/fd/0").read_text()
instruction = get_multi_line_input()
if not instruction:
return
prompt = prompter_module().build_prompt(instruction=instruction)
@@ -66,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this
generated = model.generate(
inputs=batch["input_ids"].to("cuda"),
inputs=batch["input_ids"].to(cfg.device),
do_sample=True,
use_cache=True,
repetition_penalty=1.1,