chore: Refactor inf_kwargs out
This commit is contained in:
@@ -63,7 +63,7 @@ def get_multi_line_input() -> Optional[str]:
|
|||||||
return instruction
|
return instruction
|
||||||
|
|
||||||
|
|
||||||
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
|
||||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||||
|
|
||||||
for token, symbol in default_tokens.items():
|
for token, symbol in default_tokens.items():
|
||||||
@@ -257,13 +257,13 @@ def train(
|
|||||||
|
|
||||||
if cfg.inference:
|
if cfg.inference:
|
||||||
logging.info("calling do_inference function")
|
logging.info("calling do_inference function")
|
||||||
inf_kwargs: Dict[str, Any] = {}
|
prompter: Optional[str] = "AlpacaPrompter"
|
||||||
if "prompter" in kwargs:
|
if "prompter" in kwargs:
|
||||||
if kwargs["prompter"] == "None":
|
if kwargs["prompter"] == "None":
|
||||||
inf_kwargs["prompter"] = None
|
prompter = None
|
||||||
else:
|
else:
|
||||||
inf_kwargs["prompter"] = kwargs["prompter"]
|
prompter = kwargs["prompter"]
|
||||||
do_inference(cfg, model, tokenizer, **inf_kwargs)
|
do_inference(cfg, model, tokenizer, prompter=prompter)
|
||||||
return
|
return
|
||||||
|
|
||||||
if "shard" in kwargs:
|
if "shard" in kwargs:
|
||||||
|
|||||||
Reference in New Issue
Block a user