plain input/output prompt strategy w/o chat templates (#1346)
* plain input/output prompt strategy w/o chat templates * disable duplicate code check * make sure to add an eos/eot token to the end of the output so it will stop * multi turn segement support and test
This commit is contained in:
116
tests/prompt_strategies/test_raw_io.py
Normal file
116
tests/prompt_strategies/test_raw_io.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Test module for raw i/o data for prompts
|
||||
"""
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.input_output import (
|
||||
RawInputOutputPrompter,
|
||||
RawInputOutputStrategy,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="segments_dataset")
|
||||
def fixture_sharegpt_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": False,
|
||||
"text": "<s>hello ",
|
||||
},
|
||||
{
|
||||
"label": True,
|
||||
"text": "hi there.<eot>",
|
||||
},
|
||||
{
|
||||
"label": False,
|
||||
"text": "goodbye ",
|
||||
},
|
||||
{
|
||||
"label": True,
|
||||
"text": "farewell<eot>",
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),
|
||||
]
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestRawInputOutputPrompts:
|
||||
"""
|
||||
Test class for raw i/o prompter
|
||||
"""
|
||||
|
||||
def test_segment_prompts(self, segments_dataset, tokenizer):
|
||||
strategy = RawInputOutputStrategy(
|
||||
RawInputOutputPrompter(),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, segments_dataset, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
|
||||
assert (
|
||||
tokenizer.decode(input_ids)
|
||||
== "<s> hello hi there.<eot> goodbye farewell<eot>"
|
||||
)
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
1, # <s>
|
||||
6312, # hell
|
||||
28709, # o
|
||||
28705, #
|
||||
12014, # hi
|
||||
736, # there
|
||||
28723, # .
|
||||
32000, # <eot>
|
||||
1179, # good
|
||||
17664, # bye
|
||||
28705, #
|
||||
19111, # fare
|
||||
5458, # well
|
||||
32000, # <eot>
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
-100, # <s>
|
||||
-100, # hell
|
||||
-100, # o
|
||||
-100, #
|
||||
12014, # hi
|
||||
736, # there
|
||||
28723, # .
|
||||
32000, # <eot>
|
||||
-100, # good
|
||||
-100, # bye
|
||||
-100, #
|
||||
19111, # fare
|
||||
5458, # well
|
||||
32000, # <eot>
|
||||
]
|
||||
# fmt: on
|
||||
Reference in New Issue
Block a user