initial wip to get sys prompt from dataset

This commit is contained in:
Wing Lian
2023-06-17 19:22:58 -04:00
parent de8ed229c3
commit 8d20e0a3d3
4 changed files with 126 additions and 40 deletions

View File

@@ -45,8 +45,10 @@ class NoSystemPrompter(AlpacaPrompter):
Null Prompter with no system prompts
"""
prompt_input = "{instruction} {input} "
prompt_no_input = "{instruction} "
system_prompt = ""
system_no_input_prompt = ""
turn_format = "{instruction} {input} "
turn_no_input_format = "{instruction} "
def __init__(self): # pylint: disable=super-init-not-called
pass

View File

@@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for instruction-based prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
def parse_instruction_fields(
self, prompt
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
raise NotImplementedError
def tokenize_prompt(self, prompt):

View File

@@ -24,6 +24,8 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
@@ -32,23 +34,13 @@ class AlpacaPrompter:
def match_prompt_style(self):
if self.prompt_style == PromptStyle.INSTRUCT.value:
self.prompt_input = (
self.system_prompt
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.turn_no_input_format = (
"### Instruction:\n{instruction}\n\n### Response:\n"
)
self.prompt_no_input = (
self.system_no_input_prompt
+ "### Instruction:\n{instruction}\n\n### Response:\n"
)
self.response_split = "### Response:"
if self.prompt_style == PromptStyle.CHAT.value:
self.prompt_input = (
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
)
self.prompt_no_input = (
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
)
self.response_split = "ASSISTANT:"
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
def build_prompt(
self,
@@ -59,15 +51,39 @@ class AlpacaPrompter:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = self.prompt_input.format(instruction=instruction, input=input)
res = self.system_prompt + self.turn_format.format(
instruction=instruction, input=input
)
else:
res = self.prompt_no_input.format(instruction=instruction)
res = self.system_no_input_prompt + self.turn_no_input_format.format(
instruction=instruction
)
if output:
res = f"{res}{output}"
yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class SystemDataPrompter(AlpacaPrompter):
"""
Alpaca Style Prompter that uses system prompts from the dataset
"""
def build_prompt_w_system(
self,
system: str,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = system + self.turn_format.format(instruction=instruction, input=input)
else:
res = system + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
yield res
class UnpromptedPrompter(AlpacaPrompter):
@@ -93,7 +109,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
"""
system_prompt = (
"Choose the answer that best answers the question. Explain your reasoning."
"Choose the answer that best answers the question. Explain your reasoning.\n"
)
system_no_input_prompt = (
"Choose the answer that best answers the question. Explain your reasoning.\n"
)
@@ -102,7 +121,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
Prompter for multiple choice concise
"""
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
def match_prompt_style(self):
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
class SummarizeTLDRPrompter(AlpacaPrompter):
@@ -110,9 +134,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
Prompter for summarize TLDR
"""
prompt_no_input = (
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
)
system_prompt = ""
system_no_input_prompt = ""
def match_prompt_style(self):
self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
class CompletionPrompter:
@@ -128,9 +155,6 @@ class CompletionPrompter:
) -> Generator[str, None, None]:
yield instruction
def get_response(self, output: str) -> str:
return output.strip()
class GPTeacherPrompter(AlpacaPrompter):
"""
@@ -210,9 +234,6 @@ class ReflectAlpacaPrompter:
res = f"{res}{label}"
yield res
def get_response(self, output: str) -> str:
return output.split(self.response_split)[1].strip()
class SeparatorStyle(Enum):
"""Different separator style."""
@@ -289,12 +310,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
sep2=" ",
)
# def match_prompt_style(self):
# if self.prompt_style == PromptStyle.chat.value:
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
# self.response_split = "ASSISTANT:"
def build_prompt(self, source) -> Generator[str, None, None]:
# ignore the system prompt if provided
if source[0]["from"] == "system":

View File

@@ -2,7 +2,13 @@
import unittest
from axolotl.prompters import AlpacaPrompter, PromptStyle
from axolotl.prompters import (
AlpacaPrompter,
MultipleChoiceExplainPrompter,
PromptStyle,
SystemDataPrompter,
UnpromptedPrompter,
)
class AlpacaPrompterTest(unittest.TestCase):
@@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "### Response:" not in res
assert "USER:" in res
assert "ASSISTANT:" in res
def test_system_prompt(self):
prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
prompter.build_prompt_w_system(
"use cot", "tell me a joke about the following", "alpacas"
)
)
assert "use cot" in res
assert res.startswith("use cot")
assert "### Instruction:" not in res
assert "### Input:" not in res
assert "alpacas" in res
assert "### Response:" not in res
assert "USER:" in res
assert "ASSISTANT:" in res
class UnpromptedPrompterTest(unittest.TestCase):
"""
Test class for UnpromptedPrompter with no system prompts
"""
def test_prompt_style_w_none(self):
prompter = UnpromptedPrompter(prompt_style=None)
res = next(prompter.build_prompt("tell me a joke"))
assert "### Instruction:" in res
assert "tell me a joke" in res
assert res.startswith("###")
def test_prompt_style_w_instruct(self):
prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "### Instruction:" in res
assert "tell me a joke" in res
assert res.startswith("###")
def test_prompt_style_w_chat(self):
prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "USER:" in res
assert "tell me a joke" in res
assert res.startswith("USER:")
class MultipleChoiceExplainPrompterTest(unittest.TestCase):
"""
Test class for MultipleChoiceExplainPrompter
"""
def test_prompt_style_w_chat(self):
prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
assert "USER:" in res
assert "choose one" in res
assert "Choose the answer that best answers the question." in res
assert "- A\n- B\n- C" in res