Lint test_prompters
This commit is contained in:
@@ -1,9 +1,15 @@
|
||||
"""Module testing prompters"""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
|
||||
|
||||
class AlpacaPrompterTest(unittest.TestCase):
|
||||
"""
|
||||
Test AlpacaPrompter
|
||||
"""
|
||||
|
||||
def test_prompt_style_w_none(self):
|
||||
prompter = AlpacaPrompter(prompt_style=None)
|
||||
res = next(prompter.build_prompt("tell me a joke"))
|
||||
@@ -11,7 +17,7 @@ class AlpacaPrompterTest(unittest.TestCase):
|
||||
assert "### Instruction:" in res
|
||||
|
||||
def test_prompt_style_w_instruct(self):
|
||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
|
||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.INSTRUCT.value)
|
||||
res = next(
|
||||
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
||||
)
|
||||
@@ -31,7 +37,7 @@ class AlpacaPrompterTest(unittest.TestCase):
|
||||
assert "ASSISTANT:" not in res
|
||||
|
||||
def test_prompt_style_w_chat(self):
|
||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
|
||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)
|
||||
res = next(
|
||||
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user