diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 9cc4003c2..7f79ef192 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -97,7 +97,7 @@ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingSt return ( prompt["question"], "\n".join(f'- "{choice}"' for choice in prompt["choices"]), - prompt["explanation"], + prompt["solution"] if "solution" in prompt else prompt["explanation"], ) @@ -119,6 +119,15 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy) ) +class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str, str, str): + return ( + prompt["article"], + "", + prompt["summary"], + ) + + class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): def parse_instruction_fields(self, prompt) -> (str, str, str): return ( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index e4411b2c4..8a8cfa247 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -39,6 +39,14 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter): prompt_input = "Choose the answer that best answers the question. Explain your reasoning.\n\n### Question:\n{instruction}\n\n### Choices:\n{input}\n\n### Response:\n" +class MultipleChoiceConcisePrompter(AlpacaPrompter): + prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n" + + +class SummarizeTLDRPrompter(AlpacaPrompter): + prompt_no_input = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" + + class CompletionPrompter(AlpacaPrompter): def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]: yield instruction diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 2d2e7f2fd..28b6ee072 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -19,7 +19,9 @@ from axolotl.prompt_tokenizers import ( AlpacaReflectionPTStrategy, ShareGPTPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy, - CompletionPromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy, + CompletionPromptTokenizingStrategy, + AlpacaMultipleChoicePromptTokenizingStrategy, + SummarizeTLDRPromptTokenizingStrategy, ) from axolotl.prompters import ( AlpacaPrompter, @@ -27,7 +29,9 @@ from axolotl.prompters import ( ReflectAlpacaPrompter, ShareGPTPrompter, JeopardyPrompter, - CompletionPrompter, MultipleChoiceExplainPrompter, + CompletionPrompter, + MultipleChoiceExplainPrompter, + SummarizeTLDRPrompter, MultipleChoiceConcisePrompter, ) @@ -94,6 +98,18 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) + elif d.type == "concisechoice": + ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( + MultipleChoiceConcisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) + elif d.type == "summarizetldr": + ds_strategy = SummarizeTLDRPromptTokenizingStrategy( + SummarizeTLDRPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) elif d.type == "jeopardy": ds_strategy = JeopardyPromptTokenizingStrategy( JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len