diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 0ffa3e55f..13ff450f8 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -65,8 +65,10 @@ class AlpacaPrompter(Prompter): self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" elif self.prompt_style == PromptStyle.PHI.value: self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>" - self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>" - self.system_format = "<|system|>{system}\n" + self.turn_no_input_format = ( + "<|user|>\n{instruction}<|end|>\n<|assistant|>\n" + ) + self.system_format = "<|system|>\n{system}<|end|>\n" def _build_result(self, instruction, input_text, output): # returns the full prompt from instruction and optional input diff --git a/tests/test_prompters.py b/tests/test_prompters.py index 6c5b8f27c..3d61398e0 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -42,6 +42,19 @@ class AlpacaPrompterTest(unittest.TestCase): assert "USER:" not in res assert "ASSISTANT:" not in res + def test_prompt_style_w_phi(self): + prompter = AlpacaPrompter(prompt_style=PromptStyle.PHI.value) + res = next(prompter.build_prompt("tell me a joke about the following")) + assert ( + """<|system|> +Below is an instruction that describes a task. Write a response that appropriately completes the request.<|end|> +<|user|> +tell me a joke about the following<|end|> +<|assistant|> +""" + == res + ) + def test_prompt_style_w_chat(self): prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value) res = next(