fix pylint
This commit is contained in:
@@ -3,12 +3,11 @@ Test module for alpacha integration w chatml
|
||||
"""
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter
|
||||
from utils import fixture_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="alpacha_dataset")
|
||||
@@ -24,25 +23,6 @@ def fixture_alpacha_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
||||
]
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestAlpacha:
|
||||
"""
|
||||
Test class for alpacha prompter
|
||||
|
||||
@@ -3,12 +3,11 @@ Test module for sharegpt integration w chatml
|
||||
"""
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.sharegpt import SimpleShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
from utils import fixture_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="sharegpt_dataset")
|
||||
@@ -43,25 +42,6 @@ def fixture_sharegpt_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
||||
]
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestSharegpt:
|
||||
"""
|
||||
Test class for sharegpt prompter
|
||||
@@ -96,7 +76,7 @@ class TestSharegpt:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
@@ -124,7 +104,7 @@ class TestSharegpt:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
|
||||
Reference in New Issue
Block a user