Merge pull request #199 from NanoCode012/chore/prompter-arg

chore: Refactor inf_kwargs out
This commit is contained in:
NanoCode012
2023-06-13 17:56:22 +09:00
committed by GitHub

View File

@@ -63,7 +63,7 @@ def get_multi_line_input() -> Optional[str]:
return instruction return instruction
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"} default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items(): for token, symbol in default_tokens.items():
@@ -257,13 +257,13 @@ def train(
if cfg.inference: if cfg.inference:
logging.info("calling do_inference function") logging.info("calling do_inference function")
inf_kwargs: Dict[str, Any] = {} prompter: Optional[str] = "AlpacaPrompter"
if "prompter" in kwargs: if "prompter" in kwargs:
if kwargs["prompter"] == "None": if kwargs["prompter"] == "None":
inf_kwargs["prompter"] = None prompter = None
else: else:
inf_kwargs["prompter"] = kwargs["prompter"] prompter = kwargs["prompter"]
do_inference(cfg, model, tokenizer, **inf_kwargs) do_inference(cfg, model, tokenizer, prompter=prompter)
return return
if "shard" in kwargs: if "shard" in kwargs: