diff --git a/tests/prompt_strategies/test_alpacha.py b/tests/prompt_strategies/test_alpacha.py index 882307d69..dd618be38 100644 --- a/tests/prompt_strategies/test_alpacha.py +++ b/tests/prompt_strategies/test_alpacha.py @@ -3,12 +3,11 @@ Test module for alpacha integration w chatml """ import pytest from datasets import Dataset -from tokenizers import AddedToken -from transformers import AutoTokenizer from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy from axolotl.prompters import AlpacaPrompter +from utils import fixture_tokenizer @pytest.fixture(name="alpacha_dataset") @@ -24,25 +23,6 @@ def fixture_alpacha_dataset(): ) -@pytest.fixture(name="tokenizer") -def fixture_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - tokenizer.add_special_tokens( - { - "eos_token": AddedToken( - "<|im_end|>", rstrip=False, lstrip=False, normalized=False - ) - } - ) - tokenizer.add_tokens( - [ - AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False), - ] - ) - - return tokenizer - - class TestAlpacha: """ Test class for alpacha prompter diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index ee62ab5d0..5be5c1012 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -3,12 +3,11 @@ Test module for sharegpt integration w chatml """ import pytest from datasets import Dataset -from tokenizers import AddedToken -from transformers import AutoTokenizer from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies.sharegpt import SimpleShareGPTPromptTokenizingStrategy from axolotl.prompters import ShareGPTPrompterV2 +from utils import fixture_tokenizer @pytest.fixture(name="sharegpt_dataset") @@ -43,25 +42,6 @@ def fixture_sharegpt_dataset(): ) -@pytest.fixture(name="tokenizer") -def fixture_tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - tokenizer.add_special_tokens( - { - "eos_token": AddedToken( - "<|im_end|>", rstrip=False, lstrip=False, normalized=False - ) - } - ) - tokenizer.add_tokens( - [ - AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False), - ] - ) - - return tokenizer - - class TestSharegpt: """ Test class for sharegpt prompter @@ -96,7 +76,7 @@ class TestSharegpt: ] # fmt: on - def test_w_train_on_input(self, sharegpt_dataset, tokenizer): + def test_no_train_on_input(self, sharegpt_dataset, tokenizer): strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation="chatml", @@ -124,7 +104,7 @@ class TestSharegpt: ] # fmt: on - def test_no_train_on_input(self, sharegpt_dataset, tokenizer): + def test_w_train_on_input(self, sharegpt_dataset, tokenizer): strategy = SimpleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation="chatml",