initial wip to get sys prompt from dataset

This commit is contained in:
Wing Lian
2023-06-17 19:22:58 -04:00
parent de8ed229c3
commit 8d20e0a3d3
4 changed files with 126 additions and 40 deletions

View File

@@ -45,8 +45,10 @@ class NoSystemPrompter(AlpacaPrompter):
Null Prompter with no system prompts
"""
prompt_input = "{instruction} {input} "
prompt_no_input = "{instruction} "
system_prompt = ""
system_no_input_prompt = ""
turn_format = "{instruction} {input} "
turn_no_input_format = "{instruction} "
def __init__(self): # pylint: disable=super-init-not-called
pass

View File

@@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for instruction-based prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
def parse_instruction_fields(
self, prompt
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
raise NotImplementedError
def tokenize_prompt(self, prompt):

View File

@@ -24,6 +24,8 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
@@ -32,23 +34,13 @@ class AlpacaPrompter:
def match_prompt_style(self):
if self.prompt_style == PromptStyle.INSTRUCT.value:
self.prompt_input = (
self.system_prompt
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.turn_no_input_format = (
"### Instruction:\n{instruction}\n\n### Response:\n"
)
self.prompt_no_input = (
self.system_no_input_prompt
+ "### Instruction:\n{instruction}\n\n### Response:\n"
)
self.response_split = "### Response:"
if self.prompt_style == PromptStyle.CHAT.value:
self.prompt_input = (
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
)
self.prompt_no_input = (
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
)
self.response_split = "ASSISTANT:"
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
def build_prompt(
self,
@@ -59,15 +51,39 @@ class AlpacaPrompter:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.prompt_input.format(instruction=instruction, input=input)
res = self.system_prompt + self.turn_format.format(
instruction=instruction, input=input
)
else:
res = self.prompt_no_input.format(instruction=instruction)
res = self.system_no_input_prompt + self.turn_no_input_format.format(
instruction=instruction
)
if output:
res = f"{res}{output}"
yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class SystemDataPrompter(AlpacaPrompter):
"""
Alpaca Style Prompter that uses system prompts from the dataset
"""
def build_prompt_w_system(
self,
system: str,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = system + self.turn_format.format(instruction=instruction, input=input)
else:
res = system + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
yield res
class UnpromptedPrompter(AlpacaPrompter):
@@ -93,7 +109,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
"""
system_prompt = (
"Choose the answer that best answers the question. Explain your reasoning."
"Choose the answer that best answers the question. Explain your reasoning.\n"
)
system_no_input_prompt = (
"Choose the answer that best answers the question. Explain your reasoning.\n"
)
@@ -102,7 +121,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
Prompter for multiple choice concise
"""
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
def match_prompt_style(self):
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
class SummarizeTLDRPrompter(AlpacaPrompter):
@@ -110,9 +134,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
Prompter for summarize TLDR
"""
prompt_no_input = (
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
)
system_prompt = ""
system_no_input_prompt = ""
def match_prompt_style(self):
self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
class CompletionPrompter:
@@ -128,9 +155,6 @@ class CompletionPrompter:
) -> Generator[str, None, None]:
yield instruction
def get_response(self, output: str) -> str:
return output.strip()
class GPTeacherPrompter(AlpacaPrompter):
"""
@@ -210,9 +234,6 @@ class ReflectAlpacaPrompter:
res = f"{res}{label}"
yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class SeparatorStyle(Enum):
"""Different separator style."""
@@ -289,12 +310,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
sep2=" ",
)
# def match_prompt_style(self):
# if self.prompt_style == PromptStyle.chat.value:
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
# self.response_split = "ASSISTANT:"
def build_prompt(self, source) -> Generator[str, None, None]:
# ignore the system prompt if provided
if source[0]["from"] == "system":