From 3d4984b9a54ee27f21dec7fa4f94fc6bd1431ddd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 22 Jul 2023 13:49:11 -0400 Subject: [PATCH] update prompts for open orca to match the paper (#317) fix the test for the updated system tokenizer --- .../prompt_strategies/alpaca_w_system.py | 25 ++++++++++++++++--- tests/test_prompt_tokenizers.py | 5 ++-- tests/test_prompters.py | 2 +- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py index 7e2d2ad42..ea7151366 100644 --- a/src/axolotl/prompt_strategies/alpaca_w_system.py +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -66,15 +66,34 @@ class SystemDataPrompter(AlpacaPrompter): ) -> Generator[str, None, None]: # returns the full prompt from instruction and optional input # if a label (=response, =output) is provided, it's also appended. + formatted_sys_prompt = f"### System:\n{system}\n\n" if system else "" if input: - res = system + self.turn_format.format(instruction=instruction, input=input) + res = formatted_sys_prompt + self.turn_format.format( + instruction=instruction, input=input + ) else: - res = system + self.turn_no_input_format.format(instruction=instruction) + res = formatted_sys_prompt + self.turn_no_input_format.format( + instruction=instruction + ) if output: res = f"{res}{output}" yield res +class OpenOrcaSystemDataPrompter(SystemDataPrompter): + """ + Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts + """ + + def match_prompt_style(self): + if self.prompt_style == PromptStyle.INSTRUCT.value: + self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n" + self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n" + if self.prompt_style == PromptStyle.CHAT.value: + self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" + self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" + + class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy): """ Tokenizing strategy for OpenOrca datasets @@ -113,7 +132,7 @@ def load_chat(tokenizer, cfg): def load_open_orca(tokenizer, cfg): return OpenOrcaPromptTokenizingStrategy( - SystemDataPrompter(PromptStyle.INSTRUCT.value), + OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len, diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index a3e4cdbdf..0b9545f43 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -130,8 +130,9 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase): "output": "Hi! How can I help?", } example = strat.tokenize_prompt(sample) - assert example["input_ids"][0:3] == [1, 671, 20118] # use cot - assert example["input_ids"][3] == 11889 # USER + assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "### System:" + assert example["input_ids"][5:7] == [1509, 20118] # "use cot" + assert example["input_ids"][9] == 11889 # USER if __name__ == "__main__": diff --git a/tests/test_prompters.py b/tests/test_prompters.py index 756b6f81b..112f25d33 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -70,7 +70,7 @@ class AlpacaPrompterTest(unittest.TestCase): ) ) assert "use cot" in res - assert res.startswith("use cot") + assert res.startswith("### System:") assert "### Instruction:" not in res assert "### Input:" not in res assert "alpacas" in res