initial wip to get sys prompt from dataset
This commit is contained in:
@@ -2,7 +2,13 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
from axolotl.prompters import (
|
||||
AlpacaPrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
PromptStyle,
|
||||
SystemDataPrompter,
|
||||
UnpromptedPrompter,
|
||||
)
|
||||
|
||||
|
||||
class AlpacaPrompterTest(unittest.TestCase):
|
||||
@@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase):
|
||||
assert "### Response:" not in res
|
||||
assert "USER:" in res
|
||||
assert "ASSISTANT:" in res
|
||||
|
||||
def test_system_prompt(self):
|
||||
prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
|
||||
res = next(
|
||||
prompter.build_prompt_w_system(
|
||||
"use cot", "tell me a joke about the following", "alpacas"
|
||||
)
|
||||
)
|
||||
assert "use cot" in res
|
||||
assert res.startswith("use cot")
|
||||
assert "### Instruction:" not in res
|
||||
assert "### Input:" not in res
|
||||
assert "alpacas" in res
|
||||
assert "### Response:" not in res
|
||||
assert "USER:" in res
|
||||
assert "ASSISTANT:" in res
|
||||
|
||||
|
||||
class UnpromptedPrompterTest(unittest.TestCase):
|
||||
"""
|
||||
Test class for UnpromptedPrompter with no system prompts
|
||||
"""
|
||||
|
||||
def test_prompt_style_w_none(self):
|
||||
prompter = UnpromptedPrompter(prompt_style=None)
|
||||
res = next(prompter.build_prompt("tell me a joke"))
|
||||
assert "### Instruction:" in res
|
||||
assert "tell me a joke" in res
|
||||
assert res.startswith("###")
|
||||
|
||||
def test_prompt_style_w_instruct(self):
|
||||
prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
|
||||
res = next(
|
||||
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
||||
)
|
||||
assert "### Instruction:" in res
|
||||
assert "tell me a joke" in res
|
||||
assert res.startswith("###")
|
||||
|
||||
def test_prompt_style_w_chat(self):
|
||||
prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
|
||||
res = next(
|
||||
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
||||
)
|
||||
assert "USER:" in res
|
||||
assert "tell me a joke" in res
|
||||
assert res.startswith("USER:")
|
||||
|
||||
|
||||
class MultipleChoiceExplainPrompterTest(unittest.TestCase):
|
||||
"""
|
||||
Test class for MultipleChoiceExplainPrompter
|
||||
"""
|
||||
|
||||
def test_prompt_style_w_chat(self):
|
||||
prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
|
||||
res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
|
||||
assert "USER:" in res
|
||||
assert "choose one" in res
|
||||
assert "Choose the answer that best answers the question." in res
|
||||
assert "- A\n- B\n- C" in res
|
||||
|
||||
Reference in New Issue
Block a user