diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py
new file mode 100644
index 000000000..88acf0d0e
--- /dev/null
+++ b/src/axolotl/prompt_strategies/alpaca_w_system.py
@@ -0,0 +1,83 @@
+"""
+Prompt strategies loader for alpaca instruction datasets with system prompts
+"""
+from typing import Generator, Tuple, Union
+
+from axolotl.prompt_tokenizers import PromptTokenizingStrategy
+from axolotl.prompters import AlpacaPrompter, PromptStyle
+
+
+class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
+ """
+ Tokenizing strategy for instruction-based prompts.
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
+ return (
+ prompt["instruction"],
+ prompt["input"] if "input" in prompt else "",
+ prompt["output"],
+ prompt["system"],
+ )
+
+ def tokenize_prompt(self, prompt):
+ (
+ instruction,
+ input, # pylint: disable=redefined-builtin
+ response,
+ system,
+ ) = self.parse_instruction_fields(prompt)
+ user_prompt = next(
+ iter(
+ self.prompter.build_prompt_w_system(
+ system,
+ instruction,
+ input,
+ )
+ )
+ )
+ tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
+ if not self.train_on_inputs:
+ user_prompt_len = len(tokenized_prompt["input_ids"])
+ # TODO this could be sped up using numpy array slicing
+ tokenized_prompt["labels"] = [-100] * user_prompt_len
+ tokenized_res_prompt = self._tokenize(
+ response, strip_bos_token=True, add_eos_token=True
+ )
+ tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
+ tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
+ tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
+
+ return tokenized_prompt
+
+
+class SystemDataPrompter(AlpacaPrompter):
+ """
+ Alpaca Style Prompter that uses system prompts from the dataset
+ """
+
+ def build_prompt_w_system(
+ self,
+ system: str,
+ instruction: str,
+ input: Union[None, str] = None, # pylint: disable=redefined-builtin
+ output: Union[None, str] = None,
+ ) -> Generator[str, None, None]:
+ # returns the full prompt from instruction and optional input
+ # if a label (=response, =output) is provided, it's also appended.
+ if input:
+ res = system + self.turn_format.format(instruction=instruction, input=input)
+ else:
+ res = system + self.turn_no_input_format.format(instruction=instruction)
+ if output:
+ res = f"{res}{output}"
+ yield res
+
+
+def load(tokenizer, cfg):
+ return InstructionWSystemPromptTokenizingStrategy(
+ SystemDataPrompter(PromptStyle.CHAT.value),
+ tokenizer,
+ cfg.train_on_inputs,
+ cfg.sequence_len,
+ )
diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py
index 4db915238..715a227c8 100644
--- a/src/axolotl/prompters.py
+++ b/src/axolotl/prompters.py
@@ -63,29 +63,6 @@ class AlpacaPrompter:
yield res
-class SystemDataPrompter(AlpacaPrompter):
- """
- Alpaca Style Prompter that uses system prompts from the dataset
- """
-
- def build_prompt_w_system(
- self,
- system: str,
- instruction: str,
- input: Union[None, str] = None, # pylint: disable=redefined-builtin
- output: Union[None, str] = None,
- ) -> Generator[str, None, None]:
- # returns the full prompt from instruction and optional input
- # if a label (=response, =output) is provided, it's also appended.
- if input:
- res = system + self.turn_format.format(instruction=instruction, input=input)
- else:
- res = system + self.turn_no_input_format.format(instruction=instruction)
- if output:
- res = f"{res}{output}"
- yield res
-
-
class UnpromptedPrompter(AlpacaPrompter):
"""
Prompter for alpaca no system prompt
diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py
index 1c535eb1b..7d0d1dd83 100644
--- a/src/axolotl/utils/tokenization.py
+++ b/src/axolotl/utils/tokenization.py
@@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer):
logging.info(" ".join(colored_tokens))
logging.info("\n\n\n")
+
+ return " ".join(colored_tokens)
diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py
index aba340eee..3ddbe77bf 100644
--- a/tests/test_prompt_tokenizers.py
+++ b/tests/test_prompt_tokenizers.py
@@ -7,11 +7,15 @@ from pathlib import Path
from transformers import AutoTokenizer
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
+from axolotl.prompt_strategies.alpaca_w_system import (
+ InstructionWSystemPromptTokenizingStrategy,
+ SystemDataPrompter,
+)
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
-from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter
+from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
logging.basicConfig(level="INFO")
@@ -96,5 +100,39 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
assert example["labels"][world_idx - 1] == -100
+class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
+ """
+ Test class for prompt tokenization strategies with sys prompt from the dataset
+ """
+
+ def setUp(self) -> None:
+ # pylint: disable=duplicate-code
+ self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
+ self.tokenizer.add_special_tokens(
+ {
+ "bos_token": "",
+ "eos_token": "",
+ "unk_token": "",
+ }
+ )
+
+ def test_system_alpaca(self):
+ prompter = SystemDataPrompter(PromptStyle.CHAT.value)
+ strat = InstructionWSystemPromptTokenizingStrategy(
+ prompter,
+ self.tokenizer,
+ False,
+ 2048,
+ )
+ sample = {
+ "system": "use cot",
+ "instruction": "hello!",
+ "output": "Hi! How can I help?",
+ }
+ example = strat.tokenize_prompt(sample)
+ assert example["input_ids"][0:3] == [1, 671, 20118] # use cot
+ assert example["input_ids"][3] == 11889 # USER
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_prompters.py b/tests/test_prompters.py
index bb33afbb6..756b6f81b 100644
--- a/tests/test_prompters.py
+++ b/tests/test_prompters.py
@@ -2,11 +2,11 @@
import unittest
+from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
from axolotl.prompters import (
AlpacaPrompter,
MultipleChoiceExplainPrompter,
PromptStyle,
- SystemDataPrompter,
UnpromptedPrompter,
)