Lint test_prompters

This commit is contained in:
NanoCode012
2023-05-29 14:02:43 +09:00
parent 1645a4ddd5
commit 7eb33a77dd

View File

@@ -1,9 +1,15 @@
"""Module testing prompters"""
import unittest
from axolotl.prompters import AlpacaPrompter, PromptStyle
class AlpacaPrompterTest(unittest.TestCase):
"""
Test AlpacaPrompter
"""
def test_prompt_style_w_none(self):
prompter = AlpacaPrompter(prompt_style=None)
res = next(prompter.build_prompt("tell me a joke"))
@@ -11,7 +17,7 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "### Instruction:" in res
def test_prompt_style_w_instruct(self):
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
prompter = AlpacaPrompter(prompt_style=PromptStyle.INSTRUCT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
@@ -31,7 +37,7 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "ASSISTANT:" not in res
def test_prompt_style_w_chat(self):
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)