From 8d20e0a3d3f44721bb3e45f4a6d51577dd7099bc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 17 Jun 2023 19:22:58 -0400 Subject: [PATCH] initial wip to get sys prompt from dataset --- src/axolotl/prompt_strategies/alpaca_chat.py | 6 +- src/axolotl/prompt_tokenizers.py | 4 +- src/axolotl/prompters.py | 87 ++++++++++++-------- tests/test_prompters.py | 69 +++++++++++++++- 4 files changed, 126 insertions(+), 40 deletions(-) diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 6161d7e37..32801c3c3 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -45,8 +45,10 @@ class NoSystemPrompter(AlpacaPrompter): Null Prompter with no system prompts """ - prompt_input = "{instruction} {input} " - prompt_no_input = "{instruction} " + system_prompt = "" + system_no_input_prompt = "" + turn_format = "{instruction} {input} " + turn_no_input_format = "{instruction} " def __init__(self): # pylint: disable=super-init-not-called pass diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 6408620d7..cf80539eb 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): Tokenizing strategy for instruction-based prompts. """ - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + def parse_instruction_fields( + self, prompt + ) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]: raise NotImplementedError def tokenize_prompt(self, prompt): diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 29cc4446b..4db915238 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -24,6 +24,8 @@ class AlpacaPrompter: system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" + turn_format: str + turn_no_input_format: str prompt_style: Optional[PromptStyle] = None def __init__(self, prompt_style=PromptStyle.INSTRUCT.value): @@ -32,23 +34,13 @@ class AlpacaPrompter: def match_prompt_style(self): if self.prompt_style == PromptStyle.INSTRUCT.value: - self.prompt_input = ( - self.system_prompt - + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + self.turn_no_input_format = ( + "### Instruction:\n{instruction}\n\n### Response:\n" ) - self.prompt_no_input = ( - self.system_no_input_prompt - + "### Instruction:\n{instruction}\n\n### Response:\n" - ) - self.response_split = "### Response:" if self.prompt_style == PromptStyle.CHAT.value: - self.prompt_input = ( - self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" - ) - self.prompt_no_input = ( - self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" - ) - self.response_split = "ASSISTANT:" + self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" + self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" def build_prompt( self, @@ -59,15 +51,39 @@ class AlpacaPrompter: # returns the full prompt from instruction and optional input # if a label (=response, =output) is provided, it's also appended. if input: - res = self.prompt_input.format(instruction=instruction, input=input) + res = self.system_prompt + self.turn_format.format( + instruction=instruction, input=input + ) else: - res = self.prompt_no_input.format(instruction=instruction) + res = self.system_no_input_prompt + self.turn_no_input_format.format( + instruction=instruction + ) if output: res = f"{res}{output}" yield res - def get_response(self, output: str) -> str: - return output.split(self.response_split)[1].strip() + +class SystemDataPrompter(AlpacaPrompter): + """ + Alpaca Style Prompter that uses system prompts from the dataset + """ + + def build_prompt_w_system( + self, + system: str, + instruction: str, + input: Union[None, str] = None, # pylint: disable=redefined-builtin + output: Union[None, str] = None, + ) -> Generator[str, None, None]: + # returns the full prompt from instruction and optional input + # if a label (=response, =output) is provided, it's also appended. + if input: + res = system + self.turn_format.format(instruction=instruction, input=input) + else: + res = system + self.turn_no_input_format.format(instruction=instruction) + if output: + res = f"{res}{output}" + yield res class UnpromptedPrompter(AlpacaPrompter): @@ -93,7 +109,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter): """ system_prompt = ( - "Choose the answer that best answers the question. Explain your reasoning." + "Choose the answer that best answers the question. Explain your reasoning.\n" + ) + system_no_input_prompt = ( + "Choose the answer that best answers the question. Explain your reasoning.\n" ) @@ -102,7 +121,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter): Prompter for multiple choice concise """ - prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n" + system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n" + system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n" + + def match_prompt_style(self): + self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" + self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" class SummarizeTLDRPrompter(AlpacaPrompter): @@ -110,9 +134,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter): Prompter for summarize TLDR """ - prompt_no_input = ( - "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" - ) + system_prompt = "" + system_no_input_prompt = "" + + def match_prompt_style(self): + self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:" + self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" class CompletionPrompter: @@ -128,9 +155,6 @@ class CompletionPrompter: ) -> Generator[str, None, None]: yield instruction - def get_response(self, output: str) -> str: - return output.strip() - class GPTeacherPrompter(AlpacaPrompter): """ @@ -210,9 +234,6 @@ class ReflectAlpacaPrompter: res = f"{res}{label}" yield res - def get_response(self, output: str) -> str: - return output.split(self.response_split)[1].strip() - class SeparatorStyle(Enum): """Different separator style.""" @@ -289,12 +310,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods sep2=" ", ) - # def match_prompt_style(self): - # if self.prompt_style == PromptStyle.chat.value: - # self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" - # self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" - # self.response_split = "ASSISTANT:" - def build_prompt(self, source) -> Generator[str, None, None]: # ignore the system prompt if provided if source[0]["from"] == "system": diff --git a/tests/test_prompters.py b/tests/test_prompters.py index 11610ccc5..bb33afbb6 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -2,7 +2,13 @@ import unittest -from axolotl.prompters import AlpacaPrompter, PromptStyle +from axolotl.prompters import ( + AlpacaPrompter, + MultipleChoiceExplainPrompter, + PromptStyle, + SystemDataPrompter, + UnpromptedPrompter, +) class AlpacaPrompterTest(unittest.TestCase): @@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase): assert "### Response:" not in res assert "USER:" in res assert "ASSISTANT:" in res + + def test_system_prompt(self): + prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value) + res = next( + prompter.build_prompt_w_system( + "use cot", "tell me a joke about the following", "alpacas" + ) + ) + assert "use cot" in res + assert res.startswith("use cot") + assert "### Instruction:" not in res + assert "### Input:" not in res + assert "alpacas" in res + assert "### Response:" not in res + assert "USER:" in res + assert "ASSISTANT:" in res + + +class UnpromptedPrompterTest(unittest.TestCase): + """ + Test class for UnpromptedPrompter with no system prompts + """ + + def test_prompt_style_w_none(self): + prompter = UnpromptedPrompter(prompt_style=None) + res = next(prompter.build_prompt("tell me a joke")) + assert "### Instruction:" in res + assert "tell me a joke" in res + assert res.startswith("###") + + def test_prompt_style_w_instruct(self): + prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value) + res = next( + prompter.build_prompt("tell me a joke about the following", "alpacas") + ) + assert "### Instruction:" in res + assert "tell me a joke" in res + assert res.startswith("###") + + def test_prompt_style_w_chat(self): + prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value) + res = next( + prompter.build_prompt("tell me a joke about the following", "alpacas") + ) + assert "USER:" in res + assert "tell me a joke" in res + assert res.startswith("USER:") + + +class MultipleChoiceExplainPrompterTest(unittest.TestCase): + """ + Test class for MultipleChoiceExplainPrompter + """ + + def test_prompt_style_w_chat(self): + prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value) + res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C")) + assert "USER:" in res + assert "choose one" in res + assert "Choose the answer that best answers the question." in res + assert "- A\n- B\n- C" in res