diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py
index 40c58bc9c..5593a8dd3 100644
--- a/src/axolotl/datasets.py
+++ b/src/axolotl/datasets.py
@@ -126,6 +126,7 @@ class ConstantLengthDataset(IterableDataset):
buffer_len = 0
if example:
+ # FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py
index 952a55961..17fe69be7 100644
--- a/src/axolotl/prompt_strategies/alpaca_chat.py
+++ b/src/axolotl/prompt_strategies/alpaca_chat.py
@@ -45,8 +45,10 @@ class NoSystemPrompter(AlpacaPrompter):
Null Prompter with no system prompts
"""
- prompt_input = "{instruction} {input} "
- prompt_no_input = "{instruction} "
+ system_prompt = ""
+ system_no_input_prompt = ""
+ turn_format = "{instruction} {input} "
+ turn_no_input_format = "{instruction} "
def __init__(self): # pylint: disable=super-init-not-called
pass
diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py
new file mode 100644
index 000000000..aacae8739
--- /dev/null
+++ b/src/axolotl/prompt_strategies/alpaca_w_system.py
@@ -0,0 +1,84 @@
+"""
+Prompt strategies loader for alpaca instruction datasets with system prompts
+"""
+from typing import Generator, Tuple, Union
+
+from axolotl.prompt_tokenizers import PromptTokenizingStrategy
+from axolotl.prompters import AlpacaPrompter, PromptStyle
+
+
+class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
+ """
+ Tokenizing strategy for instruction-based prompts.
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
+ return (
+ prompt["instruction"],
+ prompt["input"] if "input" in prompt else "",
+ prompt["output"],
+ prompt["system"],
+ )
+
+ def tokenize_prompt(self, prompt):
+ # pylint: disable=duplicate-code
+ (
+ instruction,
+ input, # pylint: disable=redefined-builtin
+ response,
+ system,
+ ) = self.parse_instruction_fields(prompt)
+ user_prompt = next(
+ iter(
+ self.prompter.build_prompt_w_system(
+ system,
+ instruction,
+ input,
+ )
+ )
+ )
+ tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
+ if not self.train_on_inputs:
+ user_prompt_len = len(tokenized_prompt["input_ids"])
+ # TODO this could be sped up using numpy array slicing
+ tokenized_prompt["labels"] = [-100] * user_prompt_len
+ tokenized_res_prompt = self._tokenize(
+ response, strip_bos_token=True, add_eos_token=True
+ )
+ tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
+ tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
+ tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
+
+ return tokenized_prompt
+
+
+class SystemDataPrompter(AlpacaPrompter):
+ """
+ Alpaca Style Prompter that uses system prompts from the dataset
+ """
+
+ def build_prompt_w_system(
+ self,
+ system: str,
+ instruction: str,
+ input: Union[None, str] = None, # pylint: disable=redefined-builtin
+ output: Union[None, str] = None,
+ ) -> Generator[str, None, None]:
+ # returns the full prompt from instruction and optional input
+ # if a label (=response, =output) is provided, it's also appended.
+ if input:
+ res = system + self.turn_format.format(instruction=instruction, input=input)
+ else:
+ res = system + self.turn_no_input_format.format(instruction=instruction)
+ if output:
+ res = f"{res}{output}"
+ yield res
+
+
+def load(tokenizer, cfg):
+ return InstructionWSystemPromptTokenizingStrategy(
+ SystemDataPrompter(PromptStyle.CHAT.value),
+ tokenizer,
+ cfg.train_on_inputs,
+ cfg.sequence_len,
+ )
diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py
index 6408620d7..cf80539eb 100644
--- a/src/axolotl/prompt_tokenizers.py
+++ b/src/axolotl/prompt_tokenizers.py
@@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for instruction-based prompts.
"""
- def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
+ def parse_instruction_fields(
+ self, prompt
+ ) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
raise NotImplementedError
def tokenize_prompt(self, prompt):
diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py
index 29cc4446b..715a227c8 100644
--- a/src/axolotl/prompters.py
+++ b/src/axolotl/prompters.py
@@ -24,6 +24,8 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
+ turn_format: str
+ turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
@@ -32,23 +34,13 @@ class AlpacaPrompter:
def match_prompt_style(self):
if self.prompt_style == PromptStyle.INSTRUCT.value:
- self.prompt_input = (
- self.system_prompt
- + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
+ self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
+ self.turn_no_input_format = (
+ "### Instruction:\n{instruction}\n\n### Response:\n"
)
- self.prompt_no_input = (
- self.system_no_input_prompt
- + "### Instruction:\n{instruction}\n\n### Response:\n"
- )
- self.response_split = "### Response:"
if self.prompt_style == PromptStyle.CHAT.value:
- self.prompt_input = (
- self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
- )
- self.prompt_no_input = (
- self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
- )
- self.response_split = "ASSISTANT:"
+ self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
+ self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
def build_prompt(
self,
@@ -59,16 +51,17 @@ class AlpacaPrompter:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
- res = self.prompt_input.format(instruction=instruction, input=input)
+ res = self.system_prompt + self.turn_format.format(
+ instruction=instruction, input=input
+ )
else:
- res = self.prompt_no_input.format(instruction=instruction)
+ res = self.system_no_input_prompt + self.turn_no_input_format.format(
+ instruction=instruction
+ )
if output:
res = f"{res}{output}"
yield res
- def get_response(self, output: str) -> str:
- return output.split(self.response_split)[1].strip()
-
class UnpromptedPrompter(AlpacaPrompter):
"""
@@ -93,7 +86,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
"""
system_prompt = (
- "Choose the answer that best answers the question. Explain your reasoning."
+ "Choose the answer that best answers the question. Explain your reasoning.\n"
+ )
+ system_no_input_prompt = (
+ "Choose the answer that best answers the question. Explain your reasoning.\n"
)
@@ -102,7 +98,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
Prompter for multiple choice concise
"""
- prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
+ system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
+ system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
+
+ def match_prompt_style(self):
+ self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
+ self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
class SummarizeTLDRPrompter(AlpacaPrompter):
@@ -110,9 +111,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
Prompter for summarize TLDR
"""
- prompt_no_input = (
- "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
- )
+ system_prompt = ""
+ system_no_input_prompt = ""
+
+ def match_prompt_style(self):
+ self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
+ self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
class CompletionPrompter:
@@ -128,9 +132,6 @@ class CompletionPrompter:
) -> Generator[str, None, None]:
yield instruction
- def get_response(self, output: str) -> str:
- return output.strip()
-
class GPTeacherPrompter(AlpacaPrompter):
"""
@@ -210,9 +211,6 @@ class ReflectAlpacaPrompter:
res = f"{res}{label}"
yield res
- def get_response(self, output: str) -> str:
- return output.split(self.response_split)[1].strip()
-
class SeparatorStyle(Enum):
"""Different separator style."""
@@ -289,12 +287,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
sep2=" ",
)
- # def match_prompt_style(self):
- # if self.prompt_style == PromptStyle.chat.value:
- # self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
- # self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
- # self.response_split = "ASSISTANT:"
-
def build_prompt(self, source) -> Generator[str, None, None]:
# ignore the system prompt if provided
if source[0]["from"] == "system":
diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py
index 1c535eb1b..7d0d1dd83 100644
--- a/src/axolotl/utils/tokenization.py
+++ b/src/axolotl/utils/tokenization.py
@@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer):
logging.info(" ".join(colored_tokens))
logging.info("\n\n\n")
+
+ return " ".join(colored_tokens)
diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py
index aba340eee..3ddbe77bf 100644
--- a/tests/test_prompt_tokenizers.py
+++ b/tests/test_prompt_tokenizers.py
@@ -7,11 +7,15 @@ from pathlib import Path
from transformers import AutoTokenizer
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
+from axolotl.prompt_strategies.alpaca_w_system import (
+ InstructionWSystemPromptTokenizingStrategy,
+ SystemDataPrompter,
+)
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
-from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter
+from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
logging.basicConfig(level="INFO")
@@ -96,5 +100,39 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
assert example["labels"][world_idx - 1] == -100
+class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
+ """
+ Test class for prompt tokenization strategies with sys prompt from the dataset
+ """
+
+ def setUp(self) -> None:
+ # pylint: disable=duplicate-code
+ self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
+ self.tokenizer.add_special_tokens(
+ {
+ "bos_token": "",
+ "eos_token": "",
+ "unk_token": "",
+ }
+ )
+
+ def test_system_alpaca(self):
+ prompter = SystemDataPrompter(PromptStyle.CHAT.value)
+ strat = InstructionWSystemPromptTokenizingStrategy(
+ prompter,
+ self.tokenizer,
+ False,
+ 2048,
+ )
+ sample = {
+ "system": "use cot",
+ "instruction": "hello!",
+ "output": "Hi! How can I help?",
+ }
+ example = strat.tokenize_prompt(sample)
+ assert example["input_ids"][0:3] == [1, 671, 20118] # use cot
+ assert example["input_ids"][3] == 11889 # USER
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_prompters.py b/tests/test_prompters.py
index 11610ccc5..756b6f81b 100644
--- a/tests/test_prompters.py
+++ b/tests/test_prompters.py
@@ -2,7 +2,13 @@
import unittest
-from axolotl.prompters import AlpacaPrompter, PromptStyle
+from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
+from axolotl.prompters import (
+ AlpacaPrompter,
+ MultipleChoiceExplainPrompter,
+ PromptStyle,
+ UnpromptedPrompter,
+)
class AlpacaPrompterTest(unittest.TestCase):
@@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "### Response:" not in res
assert "USER:" in res
assert "ASSISTANT:" in res
+
+ def test_system_prompt(self):
+ prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
+ res = next(
+ prompter.build_prompt_w_system(
+ "use cot", "tell me a joke about the following", "alpacas"
+ )
+ )
+ assert "use cot" in res
+ assert res.startswith("use cot")
+ assert "### Instruction:" not in res
+ assert "### Input:" not in res
+ assert "alpacas" in res
+ assert "### Response:" not in res
+ assert "USER:" in res
+ assert "ASSISTANT:" in res
+
+
+class UnpromptedPrompterTest(unittest.TestCase):
+ """
+ Test class for UnpromptedPrompter with no system prompts
+ """
+
+ def test_prompt_style_w_none(self):
+ prompter = UnpromptedPrompter(prompt_style=None)
+ res = next(prompter.build_prompt("tell me a joke"))
+ assert "### Instruction:" in res
+ assert "tell me a joke" in res
+ assert res.startswith("###")
+
+ def test_prompt_style_w_instruct(self):
+ prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
+ res = next(
+ prompter.build_prompt("tell me a joke about the following", "alpacas")
+ )
+ assert "### Instruction:" in res
+ assert "tell me a joke" in res
+ assert res.startswith("###")
+
+ def test_prompt_style_w_chat(self):
+ prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
+ res = next(
+ prompter.build_prompt("tell me a joke about the following", "alpacas")
+ )
+ assert "USER:" in res
+ assert "tell me a joke" in res
+ assert res.startswith("USER:")
+
+
+class MultipleChoiceExplainPrompterTest(unittest.TestCase):
+ """
+ Test class for MultipleChoiceExplainPrompter
+ """
+
+ def test_prompt_style_w_chat(self):
+ prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
+ res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
+ assert "USER:" in res
+ assert "choose one" in res
+ assert "Choose the answer that best answers the question." in res
+ assert "- A\n- B\n- C" in res