diff --git a/tests/test_prompters.py b/tests/test_prompters.py index b4a34c6c0..11610ccc5 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -1,9 +1,15 @@ +"""Module testing prompters""" + import unittest from axolotl.prompters import AlpacaPrompter, PromptStyle class AlpacaPrompterTest(unittest.TestCase): + """ + Test AlpacaPrompter + """ + def test_prompt_style_w_none(self): prompter = AlpacaPrompter(prompt_style=None) res = next(prompter.build_prompt("tell me a joke")) @@ -11,7 +17,7 @@ class AlpacaPrompterTest(unittest.TestCase): assert "### Instruction:" in res def test_prompt_style_w_instruct(self): - prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value) + prompter = AlpacaPrompter(prompt_style=PromptStyle.INSTRUCT.value) res = next( prompter.build_prompt("tell me a joke about the following", "alpacas") ) @@ -31,7 +37,7 @@ class AlpacaPrompterTest(unittest.TestCase): assert "ASSISTANT:" not in res def test_prompt_style_w_chat(self): - prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value) + prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value) res = next( prompter.build_prompt("tell me a joke about the following", "alpacas") )