diff --git a/scripts/finetune.py b/scripts/finetune.py index 5b7f8f2ab..5fb38b6f6 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -67,7 +67,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): instruction = get_multi_line_input() if not instruction: return - prompt = prompter_module().build_prompt(instruction=instruction) + prompt: str = next(prompter_module().build_prompt(instruction=instruction)) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) model.eval() diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 5792d191b..9cc4003c2 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -92,6 +92,15 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): ) +class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str, str, str): + return ( + prompt["question"], + "\n".join(f'- "{choice}"' for choice in prompt["choices"]), + prompt["explanation"], + ) + + class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): def parse_instruction_fields(self, prompt) -> (str, str, str): return ( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index a52ed4ad9..e4411b2c4 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -35,6 +35,10 @@ class JeopardyPrompter(AlpacaPrompter): prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" +class MultipleChoiceExplainPrompter(AlpacaPrompter): + prompt_input = "Choose the answer that best answers the question. Explain your reasoning.\n\n### Question:\n{instruction}\n\n### Choices:\n{input}\n\n### Response:\n" + + class CompletionPrompter(AlpacaPrompter): def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]: yield instruction diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 306213b19..2d2e7f2fd 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -19,7 +19,7 @@ from axolotl.prompt_tokenizers import ( AlpacaReflectionPTStrategy, ShareGPTPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy, - CompletionPromptTokenizingStrategy, + CompletionPromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy, ) from axolotl.prompters import ( AlpacaPrompter, @@ -27,7 +27,7 @@ from axolotl.prompters import ( ReflectAlpacaPrompter, ShareGPTPrompter, JeopardyPrompter, - CompletionPrompter, + CompletionPrompter, MultipleChoiceExplainPrompter, ) @@ -88,6 +88,12 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) + elif d.type == "explainchoice": + ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( + MultipleChoiceExplainPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) elif d.type == "jeopardy": ds_strategy = JeopardyPromptTokenizingStrategy( JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len