From 8d20e0a3d3f44721bb3e45f4a6d51577dd7099bc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 17 Jun 2023 19:22:58 -0400 Subject: [PATCH 1/3] initial wip to get sys prompt from dataset --- src/axolotl/prompt_strategies/alpaca_chat.py | 6 +- src/axolotl/prompt_tokenizers.py | 4 +- src/axolotl/prompters.py | 87 ++++++++++++-------- tests/test_prompters.py | 69 +++++++++++++++- 4 files changed, 126 insertions(+), 40 deletions(-) diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 6161d7e37..32801c3c3 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -45,8 +45,10 @@ class NoSystemPrompter(AlpacaPrompter): Null Prompter with no system prompts """ - prompt_input = "{instruction} {input} " - prompt_no_input = "{instruction} " + system_prompt = "" + system_no_input_prompt = "" + turn_format = "{instruction} {input} " + turn_no_input_format = "{instruction} " def __init__(self): # pylint: disable=super-init-not-called pass diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 6408620d7..cf80539eb 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): Tokenizing strategy for instruction-based prompts. """ - def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + def parse_instruction_fields( + self, prompt + ) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]: raise NotImplementedError def tokenize_prompt(self, prompt): diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 29cc4446b..4db915238 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -24,6 +24,8 @@ class AlpacaPrompter: system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" + turn_format: str + turn_no_input_format: str prompt_style: Optional[PromptStyle] = None def __init__(self, prompt_style=PromptStyle.INSTRUCT.value): @@ -32,23 +34,13 @@ class AlpacaPrompter: def match_prompt_style(self): if self.prompt_style == PromptStyle.INSTRUCT.value: - self.prompt_input = ( - self.system_prompt - + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + self.turn_no_input_format = ( + "### Instruction:\n{instruction}\n\n### Response:\n" ) - self.prompt_no_input = ( - self.system_no_input_prompt - + "### Instruction:\n{instruction}\n\n### Response:\n" - ) - self.response_split = "### Response:" if self.prompt_style == PromptStyle.CHAT.value: - self.prompt_input = ( - self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" - ) - self.prompt_no_input = ( - self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" - ) - self.response_split = "ASSISTANT:" + self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" + self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" def build_prompt( self, @@ -59,15 +51,39 @@ class AlpacaPrompter: # returns the full prompt from instruction and optional input # if a label (=response, =output) is provided, it's also appended. if input: - res = self.prompt_input.format(instruction=instruction, input=input) + res = self.system_prompt + self.turn_format.format( + instruction=instruction, input=input + ) else: - res = self.prompt_no_input.format(instruction=instruction) + res = self.system_no_input_prompt + self.turn_no_input_format.format( + instruction=instruction + ) if output: res = f"{res}{output}" yield res - def get_response(self, output: str) -> str: - return output.split(self.response_split)[1].strip() + +class SystemDataPrompter(AlpacaPrompter): + """ + Alpaca Style Prompter that uses system prompts from the dataset + """ + + def build_prompt_w_system( + self, + system: str, + instruction: str, + input: Union[None, str] = None, # pylint: disable=redefined-builtin + output: Union[None, str] = None, + ) -> Generator[str, None, None]: + # returns the full prompt from instruction and optional input + # if a label (=response, =output) is provided, it's also appended. + if input: + res = system + self.turn_format.format(instruction=instruction, input=input) + else: + res = system + self.turn_no_input_format.format(instruction=instruction) + if output: + res = f"{res}{output}" + yield res class UnpromptedPrompter(AlpacaPrompter): @@ -93,7 +109,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter): """ system_prompt = ( - "Choose the answer that best answers the question. Explain your reasoning." + "Choose the answer that best answers the question. Explain your reasoning.\n" + ) + system_no_input_prompt = ( + "Choose the answer that best answers the question. Explain your reasoning.\n" ) @@ -102,7 +121,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter): Prompter for multiple choice concise """ - prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n" + system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n" + system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n" + + def match_prompt_style(self): + self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" + self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" class SummarizeTLDRPrompter(AlpacaPrompter): @@ -110,9 +134,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter): Prompter for summarize TLDR """ - prompt_no_input = ( - "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" - ) + system_prompt = "" + system_no_input_prompt = "" + + def match_prompt_style(self): + self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:" + self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" class CompletionPrompter: @@ -128,9 +155,6 @@ class CompletionPrompter: ) -> Generator[str, None, None]: yield instruction - def get_response(self, output: str) -> str: - return output.strip() - class GPTeacherPrompter(AlpacaPrompter): """ @@ -210,9 +234,6 @@ class ReflectAlpacaPrompter: res = f"{res}{label}" yield res - def get_response(self, output: str) -> str: - return output.split(self.response_split)[1].strip() - class SeparatorStyle(Enum): """Different separator style.""" @@ -289,12 +310,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods sep2=" ", ) - # def match_prompt_style(self): - # if self.prompt_style == PromptStyle.chat.value: - # self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" - # self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" - # self.response_split = "ASSISTANT:" - def build_prompt(self, source) -> Generator[str, None, None]: # ignore the system prompt if provided if source[0]["from"] == "system": diff --git a/tests/test_prompters.py b/tests/test_prompters.py index 11610ccc5..bb33afbb6 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -2,7 +2,13 @@ import unittest -from axolotl.prompters import AlpacaPrompter, PromptStyle +from axolotl.prompters import ( + AlpacaPrompter, + MultipleChoiceExplainPrompter, + PromptStyle, + SystemDataPrompter, + UnpromptedPrompter, +) class AlpacaPrompterTest(unittest.TestCase): @@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase): assert "### Response:" not in res assert "USER:" in res assert "ASSISTANT:" in res + + def test_system_prompt(self): + prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value) + res = next( + prompter.build_prompt_w_system( + "use cot", "tell me a joke about the following", "alpacas" + ) + ) + assert "use cot" in res + assert res.startswith("use cot") + assert "### Instruction:" not in res + assert "### Input:" not in res + assert "alpacas" in res + assert "### Response:" not in res + assert "USER:" in res + assert "ASSISTANT:" in res + + +class UnpromptedPrompterTest(unittest.TestCase): + """ + Test class for UnpromptedPrompter with no system prompts + """ + + def test_prompt_style_w_none(self): + prompter = UnpromptedPrompter(prompt_style=None) + res = next(prompter.build_prompt("tell me a joke")) + assert "### Instruction:" in res + assert "tell me a joke" in res + assert res.startswith("###") + + def test_prompt_style_w_instruct(self): + prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value) + res = next( + prompter.build_prompt("tell me a joke about the following", "alpacas") + ) + assert "### Instruction:" in res + assert "tell me a joke" in res + assert res.startswith("###") + + def test_prompt_style_w_chat(self): + prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value) + res = next( + prompter.build_prompt("tell me a joke about the following", "alpacas") + ) + assert "USER:" in res + assert "tell me a joke" in res + assert res.startswith("USER:") + + +class MultipleChoiceExplainPrompterTest(unittest.TestCase): + """ + Test class for MultipleChoiceExplainPrompter + """ + + def test_prompt_style_w_chat(self): + prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value) + res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C")) + assert "USER:" in res + assert "choose one" in res + assert "Choose the answer that best answers the question." in res + assert "- A\n- B\n- C" in res From 3a38271276224741fc9b2766b322a9bc54bba9c3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 17 Jun 2023 23:52:40 -0400 Subject: [PATCH 2/3] add tests and supoort for loader for sys prompt data --- .../prompt_strategies/alpaca_w_system.py | 83 +++++++++++++++++++ src/axolotl/prompters.py | 23 ----- src/axolotl/utils/tokenization.py | 2 + tests/test_prompt_tokenizers.py | 40 ++++++++- tests/test_prompters.py | 2 +- 5 files changed, 125 insertions(+), 25 deletions(-) create mode 100644 src/axolotl/prompt_strategies/alpaca_w_system.py diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py new file mode 100644 index 000000000..88acf0d0e --- /dev/null +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -0,0 +1,83 @@ +""" +Prompt strategies loader for alpaca instruction datasets with system prompts +""" +from typing import Generator, Tuple, Union + +from axolotl.prompt_tokenizers import PromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter, PromptStyle + + +class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy): + """ + Tokenizing strategy for instruction-based prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]: + return ( + prompt["instruction"], + prompt["input"] if "input" in prompt else "", + prompt["output"], + prompt["system"], + ) + + def tokenize_prompt(self, prompt): + ( + instruction, + input, # pylint: disable=redefined-builtin + response, + system, + ) = self.parse_instruction_fields(prompt) + user_prompt = next( + iter( + self.prompter.build_prompt_w_system( + system, + instruction, + input, + ) + ) + ) + tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False) + if not self.train_on_inputs: + user_prompt_len = len(tokenized_prompt["input_ids"]) + # TODO this could be sped up using numpy array slicing + tokenized_prompt["labels"] = [-100] * user_prompt_len + tokenized_res_prompt = self._tokenize( + response, strip_bos_token=True, add_eos_token=True + ) + tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"] + tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] + tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] + + return tokenized_prompt + + +class SystemDataPrompter(AlpacaPrompter): + """ + Alpaca Style Prompter that uses system prompts from the dataset + """ + + def build_prompt_w_system( + self, + system: str, + instruction: str, + input: Union[None, str] = None, # pylint: disable=redefined-builtin + output: Union[None, str] = None, + ) -> Generator[str, None, None]: + # returns the full prompt from instruction and optional input + # if a label (=response, =output) is provided, it's also appended. + if input: + res = system + self.turn_format.format(instruction=instruction, input=input) + else: + res = system + self.turn_no_input_format.format(instruction=instruction) + if output: + res = f"{res}{output}" + yield res + + +def load(tokenizer, cfg): + return InstructionWSystemPromptTokenizingStrategy( + SystemDataPrompter(PromptStyle.CHAT.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 4db915238..715a227c8 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -63,29 +63,6 @@ class AlpacaPrompter: yield res -class SystemDataPrompter(AlpacaPrompter): - """ - Alpaca Style Prompter that uses system prompts from the dataset - """ - - def build_prompt_w_system( - self, - system: str, - instruction: str, - input: Union[None, str] = None, # pylint: disable=redefined-builtin - output: Union[None, str] = None, - ) -> Generator[str, None, None]: - # returns the full prompt from instruction and optional input - # if a label (=response, =output) is provided, it's also appended. - if input: - res = system + self.turn_format.format(instruction=instruction, input=input) - else: - res = system + self.turn_no_input_format.format(instruction=instruction) - if output: - res = f"{res}{output}" - yield res - - class UnpromptedPrompter(AlpacaPrompter): """ Prompter for alpaca no system prompt diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 1c535eb1b..7d0d1dd83 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer): logging.info(" ".join(colored_tokens)) logging.info("\n\n\n") + + return " ".join(colored_tokens) diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index aba340eee..3ddbe77bf 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -7,11 +7,15 @@ from pathlib import Path from transformers import AutoTokenizer from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter +from axolotl.prompt_strategies.alpaca_w_system import ( + InstructionWSystemPromptTokenizingStrategy, + SystemDataPrompter, +) from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, ) -from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter +from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter logging.basicConfig(level="INFO") @@ -96,5 +100,39 @@ class TestPromptTokenizationStrategies(unittest.TestCase): assert example["labels"][world_idx - 1] == -100 +class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase): + """ + Test class for prompt tokenization strategies with sys prompt from the dataset + """ + + def setUp(self) -> None: + # pylint: disable=duplicate-code + self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + self.tokenizer.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + } + ) + + def test_system_alpaca(self): + prompter = SystemDataPrompter(PromptStyle.CHAT.value) + strat = InstructionWSystemPromptTokenizingStrategy( + prompter, + self.tokenizer, + False, + 2048, + ) + sample = { + "system": "use cot", + "instruction": "hello!", + "output": "Hi! How can I help?", + } + example = strat.tokenize_prompt(sample) + assert example["input_ids"][0:3] == [1, 671, 20118] # use cot + assert example["input_ids"][3] == 11889 # USER + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_prompters.py b/tests/test_prompters.py index bb33afbb6..756b6f81b 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -2,11 +2,11 @@ import unittest +from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter from axolotl.prompters import ( AlpacaPrompter, MultipleChoiceExplainPrompter, PromptStyle, - SystemDataPrompter, UnpromptedPrompter, ) From 7b57ed761882b4492659eeafffbf8ffddd3f0fbb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 18 Jun 2023 06:40:28 -0400 Subject: [PATCH 3/3] pylint for duplicated code for system prompts --- src/axolotl/datasets.py | 1 + src/axolotl/prompt_strategies/alpaca_w_system.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 40c58bc9c..5593a8dd3 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -126,6 +126,7 @@ class ConstantLengthDataset(IterableDataset): buffer_len = 0 if example: + # FIXME # just going to drop data points that are too long if len(example["input_ids"]) <= self.seq_length: input_ids = example["input_ids"] diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py index 88acf0d0e..aacae8739 100644 --- a/src/axolotl/prompt_strategies/alpaca_w_system.py +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -21,6 +21,7 @@ class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy): ) def tokenize_prompt(self, prompt): + # pylint: disable=duplicate-code ( instruction, input, # pylint: disable=redefined-builtin