diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 40c58bc9c..5593a8dd3 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -126,6 +126,7 @@ class ConstantLengthDataset(IterableDataset): buffer_len = 0 if example: + # FIXME # just going to drop data points that are too long if len(example["input_ids"]) <= self.seq_length: input_ids = example["input_ids"] diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 952a55961..17fe69be7 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -45,8 +45,10 @@ class NoSystemPrompter(AlpacaPrompter): Null Prompter with no system prompts """ - prompt_input = "{instruction} {input} " - prompt_no_input = "{instruction} " + system_prompt = "" + system_no_input_prompt = "" + turn_format = "{instruction} {input} " + turn_no_input_format = "{instruction} " def __init__(self): # pylint: disable=super-init-not-called pass diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py new file mode 100644 index 000000000..aacae8739 --- /dev/null +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -0,0 +1,84 @@ +""" +Prompt strategies loader for alpaca instruction datasets with system prompts +""" +from typing import Generator, Tuple, Union + +from axolotl.prompt_tokenizers import PromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter, PromptStyle + + +class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy): + """ + Tokenizing strategy for instruction-based prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]: + return ( + prompt["instruction"], + prompt["input"] if "input" in prompt else "", + prompt["output"], + prompt["system"], + ) + + def tokenize_prompt(self, prompt): + # pylint: disable=duplicate-code + ( + instruction, + input, # pylint: disable=redefined-builtin + response, + system, + ) = self.parse_instruction_fields(prompt) + user_prompt = next( + iter( + self.prompter.build_prompt_w_system( + system, + instruction, + input, + ) + ) + ) + tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False) + if not self.train_on_inputs: + user_prompt_len = len(tokenized_prompt["input_ids"]) + # TODO this could be sped up using numpy array slicing + tokenized_prompt["labels"] = [-100] * user_prompt_len + tokenized_res_prompt = self._tokenize( + response, strip_bos_token=True, add_eos_token=True + ) + tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"] + tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] + tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] + + return tokenized_prompt + + +class SystemDataPrompter(AlpacaPrompter): + """ + Alpaca Style Prompter that uses system prompts from the dataset + """ + + def build_prompt_w_system( + self, + system: str, + instruction: str, + input: Union[None, str] = None, # pylint: disable=redefined-builtin + output: Union[None, str] = None, + ) -> Generator[str, None, None]: + # returns the full prompt from instruction and optional input + # if a label (=response, =output) is provided, it's also appended. + if input: + res = system + self.turn_format.format(instruction=instruction, input=input) + else: + res = system + self.turn_no_input_format.format(instruction=instruction) + if output: + res = f"{res}{output}" + yield res + + +def load(tokenizer, cfg): + return InstructionWSystemPromptTokenizingStrategy( + SystemDataPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 6408620d7..cf80539eb 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): Tokenizing strategy for instruction-based prompts. """ - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + def parse_instruction_fields( + self, prompt + ) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]: raise NotImplementedError def tokenize_prompt(self, prompt): diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 29cc4446b..715a227c8 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -24,6 +24,8 @@ class AlpacaPrompter: system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" + turn_format: str + turn_no_input_format: str prompt_style: Optional[PromptStyle] = None def __init__(self, prompt_style=PromptStyle.INSTRUCT.value): @@ -32,23 +34,13 @@ class AlpacaPrompter: def match_prompt_style(self): if self.prompt_style == PromptStyle.INSTRUCT.value: - self.prompt_input = ( - self.system_prompt - + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + self.turn_no_input_format = ( + "### Instruction:\n{instruction}\n\n### Response:\n" ) - self.prompt_no_input = ( - self.system_no_input_prompt - + "### Instruction:\n{instruction}\n\n### Response:\n" - ) - self.response_split = "### Response:" if self.prompt_style == PromptStyle.CHAT.value: - self.prompt_input = ( - self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" - ) - self.prompt_no_input = ( - self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" - ) - self.response_split = "ASSISTANT:" + self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" + self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" def build_prompt( self, @@ -59,16 +51,17 @@ class AlpacaPrompter: # returns the full prompt from instruction and optional input # if a label (=response, =output) is provided, it's also appended. if input: - res = self.prompt_input.format(instruction=instruction, input=input) + res = self.system_prompt + self.turn_format.format( + instruction=instruction, input=input + ) else: - res = self.prompt_no_input.format(instruction=instruction) + res = self.system_no_input_prompt + self.turn_no_input_format.format( + instruction=instruction + ) if output: res = f"{res}{output}" yield res - def get_response(self, output: str) -> str: - return output.split(self.response_split)[1].strip() - class UnpromptedPrompter(AlpacaPrompter): """ @@ -93,7 +86,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter): """ system_prompt = ( - "Choose the answer that best answers the question. Explain your reasoning." + "Choose the answer that best answers the question. Explain your reasoning.\n" + ) + system_no_input_prompt = ( + "Choose the answer that best answers the question. Explain your reasoning.\n" ) @@ -102,7 +98,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter): Prompter for multiple choice concise """ - prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n" + system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n" + system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n" + + def match_prompt_style(self): + self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" + self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" class SummarizeTLDRPrompter(AlpacaPrompter): @@ -110,9 +111,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter): Prompter for summarize TLDR """ - prompt_no_input = ( - "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" - ) + system_prompt = "" + system_no_input_prompt = "" + + def match_prompt_style(self): + self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:" + self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" class CompletionPrompter: @@ -128,9 +132,6 @@ class CompletionPrompter: ) -> Generator[str, None, None]: yield instruction - def get_response(self, output: str) -> str: - return output.strip() - class GPTeacherPrompter(AlpacaPrompter): """ @@ -210,9 +211,6 @@ class ReflectAlpacaPrompter: res = f"{res}{label}" yield res - def get_response(self, output: str) -> str: - return output.split(self.response_split)[1].strip() - class SeparatorStyle(Enum): """Different separator style.""" @@ -289,12 +287,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods sep2=" ", ) - # def match_prompt_style(self): - # if self.prompt_style == PromptStyle.chat.value: - # self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" - # self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" - # self.response_split = "ASSISTANT:" - def build_prompt(self, source) -> Generator[str, None, None]: # ignore the system prompt if provided if source[0]["from"] == "system": diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 1c535eb1b..7d0d1dd83 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer): logging.info(" ".join(colored_tokens)) logging.info("\n\n\n") + + return " ".join(colored_tokens) diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index aba340eee..3ddbe77bf 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -7,11 +7,15 @@ from pathlib import Path from transformers import AutoTokenizer from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter +from axolotl.prompt_strategies.alpaca_w_system import ( + InstructionWSystemPromptTokenizingStrategy, + SystemDataPrompter, +) from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, ) -from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter +from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter logging.basicConfig(level="INFO") @@ -96,5 +100,39 @@ class TestPromptTokenizationStrategies(unittest.TestCase): assert example["labels"][world_idx - 1] == -100 +class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase): + """ + Test class for prompt tokenization strategies with sys prompt from the dataset + """ + + def setUp(self) -> None: + # pylint: disable=duplicate-code + self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + self.tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + } + ) + + def test_system_alpaca(self): + prompter = SystemDataPrompter(PromptStyle.CHAT.value) + strat = InstructionWSystemPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + sample = { + "system": "use cot", + "instruction": "hello!", + "output": "Hi! How can I help?", + } + example = strat.tokenize_prompt(sample) + assert example["input_ids"][0:3] == [1, 671, 20118] # use cot + assert example["input_ids"][3] == 11889 # USER + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_prompters.py b/tests/test_prompters.py index 11610ccc5..756b6f81b 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -2,7 +2,13 @@ import unittest -from axolotl.prompters import AlpacaPrompter, PromptStyle +from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter +from axolotl.prompters import ( + AlpacaPrompter, + MultipleChoiceExplainPrompter, + PromptStyle, + UnpromptedPrompter, +) class AlpacaPrompterTest(unittest.TestCase): @@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase): assert "### Response:" not in res assert "USER:" in res assert "ASSISTANT:" in res + + def test_system_prompt(self): + prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value) + res = next( + prompter.build_prompt_w_system( + "use cot", "tell me a joke about the following", "alpacas" + ) + ) + assert "use cot" in res + assert res.startswith("use cot") + assert "### Instruction:" not in res + assert "### Input:" not in res + assert "alpacas" in res + assert "### Response:" not in res + assert "USER:" in res + assert "ASSISTANT:" in res + + +class UnpromptedPrompterTest(unittest.TestCase): + """ + Test class for UnpromptedPrompter with no system prompts + """ + + def test_prompt_style_w_none(self): + prompter = UnpromptedPrompter(prompt_style=None) + res = next(prompter.build_prompt("tell me a joke")) + assert "### Instruction:" in res + assert "tell me a joke" in res + assert res.startswith("###") + + def test_prompt_style_w_instruct(self): + prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value) + res = next( + prompter.build_prompt("tell me a joke about the following", "alpacas") + ) + assert "### Instruction:" in res + assert "tell me a joke" in res + assert res.startswith("###") + + def test_prompt_style_w_chat(self): + prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value) + res = next( + prompter.build_prompt("tell me a joke about the following", "alpacas") + ) + assert "USER:" in res + assert "tell me a joke" in res + assert res.startswith("USER:") + + +class MultipleChoiceExplainPrompterTest(unittest.TestCase): + """ + Test class for MultipleChoiceExplainPrompter + """ + + def test_prompt_style_w_chat(self): + prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value) + res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C")) + assert "USER:" in res + assert "choose one" in res + assert "Choose the answer that best answers the question." in res + assert "- A\n- B\n- C" in res