From c4e4f8115c9deaf4f831afbfbeaf0d0fe2eac7b8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 10 Jun 2023 15:07:40 -0400 Subject: [PATCH] pass a prompt in from stdin for inference --- README.md | 5 +++++ scripts/finetune.py | 23 ++++++++++++++++++----- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index de929f237..180d97932 100644 --- a/README.md +++ b/README.md @@ -495,6 +495,11 @@ Pass the appropriate flag to the train command: ```bash --inference --base_model ./completed-model ``` +- Full weights finetune w/ a prompt from a text file: + ```bash + cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \ + --base_model ./completed-model --inference --prompter=None --load_in_8bit=True + ``` ### Merge LORA to base diff --git a/scripts/finetune.py b/scripts/finetune.py index fa2dcf903..8a458890c 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -71,7 +71,11 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): if not (cfg.special_tokens and token in cfg.special_tokens): tokenizer.add_special_tokens({token: symbol}) - prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter) + prompter_module = None + if prompter: + prompter_module = getattr( + importlib.import_module("axolotl.prompters"), prompter + ) while True: print("=" * 80) @@ -79,9 +83,12 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): instruction = get_multi_line_input() if not instruction: return - prompt: str = next( - prompter_module().build_prompt(instruction=instruction.strip("\n")) - ) + if prompter_module: + prompt: str = next( + prompter_module().build_prompt(instruction=instruction.strip("\n")) + ) + else: + prompt = instruction.strip() batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) print("=" * 40) model.eval() @@ -242,7 +249,13 @@ def train( if "inference" in kwargs: logging.info("calling do_inference function") - do_inference(cfg, model, tokenizer) + inf_kwargs: Dict[str, Any] = {} + if "prompter" in kwargs: + if kwargs["prompter"] == "None": + inf_kwargs["prompter"] = None + else: + inf_kwargs["prompter"] = kwargs["prompter"] + do_inference(cfg, model, tokenizer, **inf_kwargs) return if "shard" in kwargs: