diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py new file mode 100644 index 000000000..88acf0d0e --- /dev/null +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -0,0 +1,83 @@ +""" +Prompt strategies loader for alpaca instruction datasets with system prompts +""" +from typing import Generator, Tuple, Union + +from axolotl.prompt_tokenizers import PromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter, PromptStyle + + +class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy): + """ + Tokenizing strategy for instruction-based prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]: + return ( + prompt["instruction"], + prompt["input"] if "input" in prompt else "", + prompt["output"], + prompt["system"], + ) + + def tokenize_prompt(self, prompt): + ( + instruction, + input, # pylint: disable=redefined-builtin + response, + system, + ) = self.parse_instruction_fields(prompt) + user_prompt = next( + iter( + self.prompter.build_prompt_w_system( + system, + instruction, + input, + ) + ) + ) + tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False) + if not self.train_on_inputs: + user_prompt_len = len(tokenized_prompt["input_ids"]) + # TODO this could be sped up using numpy array slicing + tokenized_prompt["labels"] = [-100] * user_prompt_len + tokenized_res_prompt = self._tokenize( + response, strip_bos_token=True, add_eos_token=True + ) + tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"] + tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] + tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] + + return tokenized_prompt + + +class SystemDataPrompter(AlpacaPrompter): + """ + Alpaca Style Prompter that uses system prompts from the dataset + """ + + def build_prompt_w_system( + self, + system: str, + instruction: str, + input: Union[None, str] = None, # pylint: disable=redefined-builtin + output: Union[None, str] = None, + ) -> Generator[str, None, None]: + # returns the full prompt from instruction and optional input + # if a label (=response, =output) is provided, it's also appended. + if input: + res = system + self.turn_format.format(instruction=instruction, input=input) + else: + res = system + self.turn_no_input_format.format(instruction=instruction) + if output: + res = f"{res}{output}" + yield res + + +def load(tokenizer, cfg): + return InstructionWSystemPromptTokenizingStrategy( + SystemDataPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 4db915238..715a227c8 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -63,29 +63,6 @@ class AlpacaPrompter: yield res -class SystemDataPrompter(AlpacaPrompter): - """ - Alpaca Style Prompter that uses system prompts from the dataset - """ - - def build_prompt_w_system( - self, - system: str, - instruction: str, - input: Union[None, str] = None, # pylint: disable=redefined-builtin - output: Union[None, str] = None, - ) -> Generator[str, None, None]: - # returns the full prompt from instruction and optional input - # if a label (=response, =output) is provided, it's also appended. - if input: - res = system + self.turn_format.format(instruction=instruction, input=input) - else: - res = system + self.turn_no_input_format.format(instruction=instruction) - if output: - res = f"{res}{output}" - yield res - - class UnpromptedPrompter(AlpacaPrompter): """ Prompter for alpaca no system prompt diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 1c535eb1b..7d0d1dd83 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer): logging.info(" ".join(colored_tokens)) logging.info("\n\n\n") + + return " ".join(colored_tokens) diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index aba340eee..3ddbe77bf 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -7,11 +7,15 @@ from pathlib import Path from transformers import AutoTokenizer from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter +from axolotl.prompt_strategies.alpaca_w_system import ( + InstructionWSystemPromptTokenizingStrategy, + SystemDataPrompter, +) from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, ) -from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter +from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter logging.basicConfig(level="INFO") @@ -96,5 +100,39 @@ class TestPromptTokenizationStrategies(unittest.TestCase): assert example["labels"][world_idx - 1] == -100 +class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase): + """ + Test class for prompt tokenization strategies with sys prompt from the dataset + """ + + def setUp(self) -> None: + # pylint: disable=duplicate-code + self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + self.tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + } + ) + + def test_system_alpaca(self): + prompter = SystemDataPrompter(PromptStyle.CHAT.value) + strat = InstructionWSystemPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + sample = { + "system": "use cot", + "instruction": "hello!", + "output": "Hi! How can I help?", + } + example = strat.tokenize_prompt(sample) + assert example["input_ids"][0:3] == [1, 671, 20118] # use cot + assert example["input_ids"][3] == 11889 # USER + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_prompters.py b/tests/test_prompters.py index bb33afbb6..756b6f81b 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -2,11 +2,11 @@ import unittest +from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter from axolotl.prompters import ( AlpacaPrompter, MultipleChoiceExplainPrompter, PromptStyle, - SystemDataPrompter, UnpromptedPrompter, )