diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 000000000..8e5836f97 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,26 @@ +name: PyTest +on: push + +jobs: + test: + runs-on: ubuntu-latest + timeout-minutes: 10 + + steps: + - name: Check out repository code + uses: actions/checkout@v2 + + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: "3.9" + cache: 'pip' # caching pip dependencies + + - name: Install dependencies + run: | + pip install -e . + pip install -r requirements-tests.txt + + - name: Run tests + run: | + pytest tests/ diff --git a/requirements-tests.txt b/requirements-tests.txt new file mode 100644 index 000000000..e079f8a60 --- /dev/null +++ b/requirements-tests.txt @@ -0,0 +1 @@ +pytest diff --git a/tests/test_prompters.py b/tests/test_prompters.py new file mode 100644 index 000000000..1c3c13852 --- /dev/null +++ b/tests/test_prompters.py @@ -0,0 +1,49 @@ +import unittest + +from axolotl.prompters import AlpacaPrompter, PromptStyle + + +class AlpacaPrompterTest(unittest.TestCase): + def test_prompt_style_w_none(self): + prompter = AlpacaPrompter(prompt_style=None) + res = next(prompter.build_prompt("tell me a joke")) + # just testing that it uses instruct style + assert "### Instruction:" in res + + def test_prompt_style_w_instruct(self): + prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value) + res = next(prompter.build_prompt("tell me a joke about the following", "alpacas")) + assert "Below is an instruction" in res + assert "### Instruction:" in res + assert "### Input:" in res + assert "alpacas" in res + assert "### Response:" in res + assert "USER:" not in res + assert "ASSISTANT:" not in res + res = next(prompter.build_prompt("tell me a joke about the following")) + assert "Below is an instruction" in res + assert "### Instruction:" in res + assert "### Input:" not in res + assert "### Response:" in res + assert "USER:" not in res + assert "ASSISTANT:" not in res + + def test_prompt_style_w_chat(self): + prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value) + res = next(prompter.build_prompt("tell me a joke about the following", "alpacas")) + assert "Below is an instruction" in res + assert "### Instruction:" not in res + assert "### Input:" not in res + assert "alpacas" in res + assert "### Response:" not in res + assert "USER:" in res + assert "ASSISTANT:" in res + res = next(prompter.build_prompt("tell me a joke about the following")) + assert "Below is an instruction" in res + assert "### Instruction:" not in res + assert "### Input:" not in res + assert "### Response:" not in res + assert "USER:" in res + assert "ASSISTANT:" in res + +