initial wip to get sys prompt from dataset

This commit is contained in:
Wing Lian
2023-06-17 19:22:58 -04:00
parent de8ed229c3
commit 8d20e0a3d3
4 changed files with 126 additions and 40 deletions

View File

@@ -45,8 +45,10 @@ class NoSystemPrompter(AlpacaPrompter):
Null Prompter with no system prompts Null Prompter with no system prompts
""" """
prompt_input = "{instruction} {input} " system_prompt = ""
prompt_no_input = "{instruction} " system_no_input_prompt = ""
turn_format = "{instruction} {input} "
turn_no_input_format = "{instruction} "
def __init__(self): # pylint: disable=super-init-not-called def __init__(self): # pylint: disable=super-init-not-called
pass pass

View File

@@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for instruction-based prompts. Tokenizing strategy for instruction-based prompts.
""" """
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: def parse_instruction_fields(
self, prompt
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
raise NotImplementedError raise NotImplementedError
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):

View File

@@ -24,6 +24,8 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None prompt_style: Optional[PromptStyle] = None
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value): def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
@@ -32,23 +34,13 @@ class AlpacaPrompter:
def match_prompt_style(self): def match_prompt_style(self):
if self.prompt_style == PromptStyle.INSTRUCT.value: if self.prompt_style == PromptStyle.INSTRUCT.value:
self.prompt_input = ( self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.system_prompt self.turn_no_input_format = (
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" "### Instruction:\n{instruction}\n\n### Response:\n"
) )
self.prompt_no_input = (
self.system_no_input_prompt
+ "### Instruction:\n{instruction}\n\n### Response:\n"
)
self.response_split = "### Response:"
if self.prompt_style == PromptStyle.CHAT.value: if self.prompt_style == PromptStyle.CHAT.value:
self.prompt_input = ( self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
)
self.prompt_no_input = (
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
)
self.response_split = "ASSISTANT:"
def build_prompt( def build_prompt(
self, self,
@@ -59,15 +51,39 @@ class AlpacaPrompter:
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
if input: if input:
res = self.prompt_input.format(instruction=instruction, input=input) res = self.system_prompt + self.turn_format.format(
instruction=instruction, input=input
)
else: else:
res = self.prompt_no_input.format(instruction=instruction) res = self.system_no_input_prompt + self.turn_no_input_format.format(
instruction=instruction
)
if output: if output:
res = f"{res}{output}" res = f"{res}{output}"
yield res yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip() class SystemDataPrompter(AlpacaPrompter):
"""
Alpaca Style Prompter that uses system prompts from the dataset
"""
def build_prompt_w_system(
self,
system: str,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = system + self.turn_format.format(instruction=instruction, input=input)
else:
res = system + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
yield res
class UnpromptedPrompter(AlpacaPrompter): class UnpromptedPrompter(AlpacaPrompter):
@@ -93,7 +109,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
""" """
system_prompt = ( system_prompt = (
"Choose the answer that best answers the question. Explain your reasoning." "Choose the answer that best answers the question. Explain your reasoning.\n"
)
system_no_input_prompt = (
"Choose the answer that best answers the question. Explain your reasoning.\n"
) )
@@ -102,7 +121,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
Prompter for multiple choice concise Prompter for multiple choice concise
""" """
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n" system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
def match_prompt_style(self):
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
class SummarizeTLDRPrompter(AlpacaPrompter): class SummarizeTLDRPrompter(AlpacaPrompter):
@@ -110,9 +134,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
Prompter for summarize TLDR Prompter for summarize TLDR
""" """
prompt_no_input = ( system_prompt = ""
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" system_no_input_prompt = ""
)
def match_prompt_style(self):
self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
class CompletionPrompter: class CompletionPrompter:
@@ -128,9 +155,6 @@ class CompletionPrompter:
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
yield instruction yield instruction
def get_response(self, output: str) -> str:
return output.strip()
class GPTeacherPrompter(AlpacaPrompter): class GPTeacherPrompter(AlpacaPrompter):
""" """
@@ -210,9 +234,6 @@ class ReflectAlpacaPrompter:
res = f"{res}{label}" res = f"{res}{label}"
yield res yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class SeparatorStyle(Enum): class SeparatorStyle(Enum):
"""Different separator style.""" """Different separator style."""
@@ -289,12 +310,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
sep2=" ", sep2=" ",
) )
# def match_prompt_style(self):
# if self.prompt_style == PromptStyle.chat.value:
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
# self.response_split = "ASSISTANT:"
def build_prompt(self, source) -> Generator[str, None, None]: def build_prompt(self, source) -> Generator[str, None, None]:
# ignore the system prompt if provided # ignore the system prompt if provided
if source[0]["from"] == "system": if source[0]["from"] == "system":

View File

@@ -2,7 +2,13 @@
import unittest import unittest
from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.prompters import (
AlpacaPrompter,
MultipleChoiceExplainPrompter,
PromptStyle,
SystemDataPrompter,
UnpromptedPrompter,
)
class AlpacaPrompterTest(unittest.TestCase): class AlpacaPrompterTest(unittest.TestCase):
@@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "### Response:" not in res assert "### Response:" not in res
assert "USER:" in res assert "USER:" in res
assert "ASSISTANT:" in res assert "ASSISTANT:" in res
def test_system_prompt(self):
prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
prompter.build_prompt_w_system(
"use cot", "tell me a joke about the following", "alpacas"
)
)
assert "use cot" in res
assert res.startswith("use cot")
assert "### Instruction:" not in res
assert "### Input:" not in res
assert "alpacas" in res
assert "### Response:" not in res
assert "USER:" in res
assert "ASSISTANT:" in res
class UnpromptedPrompterTest(unittest.TestCase):
"""
Test class for UnpromptedPrompter with no system prompts
"""
def test_prompt_style_w_none(self):
prompter = UnpromptedPrompter(prompt_style=None)
res = next(prompter.build_prompt("tell me a joke"))
assert "### Instruction:" in res
assert "tell me a joke" in res
assert res.startswith("###")
def test_prompt_style_w_instruct(self):
prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "### Instruction:" in res
assert "tell me a joke" in res
assert res.startswith("###")
def test_prompt_style_w_chat(self):
prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "USER:" in res
assert "tell me a joke" in res
assert res.startswith("USER:")
class MultipleChoiceExplainPrompterTest(unittest.TestCase):
"""
Test class for MultipleChoiceExplainPrompter
"""
def test_prompt_style_w_chat(self):
prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
assert "USER:" in res
assert "choose one" in res
assert "Choose the answer that best answers the question." in res
assert "- A\n- B\n- C" in res