add alpaca multiple choice instruct dataset support

This commit is contained in:
Wing Lian
2023-05-16 21:45:34 -04:00
parent f98e173b59
commit b46bc02f0a
4 changed files with 22 additions and 3 deletions

View File

@@ -67,7 +67,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
instruction = get_multi_line_input() instruction = get_multi_line_input()
if not instruction: if not instruction:
return return
prompt = prompter_module().build_prompt(instruction=instruction) prompt: str = next(prompter_module().build_prompt(instruction=instruction))
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval() model.eval()

View File

@@ -92,6 +92,15 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
) )
class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
return (
prompt["question"],
"\n".join(f'- "{choice}"' for choice in prompt["choices"]),
prompt["explanation"],
)
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str): def parse_instruction_fields(self, prompt) -> (str, str, str):
return ( return (

View File

@@ -35,6 +35,10 @@ class JeopardyPrompter(AlpacaPrompter):
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
class MultipleChoiceExplainPrompter(AlpacaPrompter):
prompt_input = "Choose the answer that best answers the question. Explain your reasoning.\n\n### Question:\n{instruction}\n\n### Choices:\n{input}\n\n### Response:\n"
class CompletionPrompter(AlpacaPrompter): class CompletionPrompter(AlpacaPrompter):
def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]: def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]:
yield instruction yield instruction

View File

@@ -19,7 +19,7 @@ from axolotl.prompt_tokenizers import (
AlpacaReflectionPTStrategy, AlpacaReflectionPTStrategy,
ShareGPTPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy,
CompletionPromptTokenizingStrategy, CompletionPromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy,
) )
from axolotl.prompters import ( from axolotl.prompters import (
AlpacaPrompter, AlpacaPrompter,
@@ -27,7 +27,7 @@ from axolotl.prompters import (
ReflectAlpacaPrompter, ReflectAlpacaPrompter,
ShareGPTPrompter, ShareGPTPrompter,
JeopardyPrompter, JeopardyPrompter,
CompletionPrompter, CompletionPrompter, MultipleChoiceExplainPrompter,
) )
@@ -88,6 +88,12 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d.type == "explainchoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceExplainPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d.type == "jeopardy": elif d.type == "jeopardy":
ds_strategy = JeopardyPromptTokenizingStrategy( ds_strategy = JeopardyPromptTokenizingStrategy(
JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len