concise multiple choice and tldr summarize

This commit is contained in:
Wing Lian
2023-05-17 11:29:17 -04:00
parent 8c2f3cb0f8
commit 13650732f8
3 changed files with 36 additions and 3 deletions

View File

@@ -97,7 +97,7 @@ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingSt
return ( return (
prompt["question"], prompt["question"],
"\n".join(f'- "{choice}"' for choice in prompt["choices"]), "\n".join(f'- "{choice}"' for choice in prompt["choices"]),
prompt["explanation"], prompt["solution"] if "solution" in prompt else prompt["explanation"],
) )
@@ -119,6 +119,15 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
) )
class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
return (
prompt["article"],
"",
prompt["summary"],
)
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str): def parse_instruction_fields(self, prompt) -> (str, str, str):
return ( return (

View File

@@ -39,6 +39,14 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
prompt_input = "Choose the answer that best answers the question. Explain your reasoning.\n\n### Question:\n{instruction}\n\n### Choices:\n{input}\n\n### Response:\n" prompt_input = "Choose the answer that best answers the question. Explain your reasoning.\n\n### Question:\n{instruction}\n\n### Choices:\n{input}\n\n### Response:\n"
class MultipleChoiceConcisePrompter(AlpacaPrompter):
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
class SummarizeTLDRPrompter(AlpacaPrompter):
prompt_no_input = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
class CompletionPrompter(AlpacaPrompter): class CompletionPrompter(AlpacaPrompter):
def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]: def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]:
yield instruction yield instruction

View File

@@ -19,7 +19,9 @@ from axolotl.prompt_tokenizers import (
AlpacaReflectionPTStrategy, AlpacaReflectionPTStrategy,
ShareGPTPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy,
CompletionPromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy, CompletionPromptTokenizingStrategy,
AlpacaMultipleChoicePromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy,
) )
from axolotl.prompters import ( from axolotl.prompters import (
AlpacaPrompter, AlpacaPrompter,
@@ -27,7 +29,9 @@ from axolotl.prompters import (
ReflectAlpacaPrompter, ReflectAlpacaPrompter,
ShareGPTPrompter, ShareGPTPrompter,
JeopardyPrompter, JeopardyPrompter,
CompletionPrompter, MultipleChoiceExplainPrompter, CompletionPrompter,
MultipleChoiceExplainPrompter,
SummarizeTLDRPrompter, MultipleChoiceConcisePrompter,
) )
@@ -94,6 +98,18 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d.type == "concisechoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceConcisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d.type == "summarizetldr":
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
SummarizeTLDRPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d.type == "jeopardy": elif d.type == "jeopardy":
ds_strategy = JeopardyPromptTokenizingStrategy( ds_strategy = JeopardyPromptTokenizingStrategy(
JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len