plain input/output prompt strategy w/o chat templates (#1346)
* plain input/output prompt strategy w/o chat templates * disable duplicate code check * make sure to add an eos/eot token to the end of the output so it will stop * multi turn segement support and test
This commit is contained in:
54
src/axolotl/prompt_strategies/input_output.py
Normal file
54
src/axolotl/prompt_strategies/input_output.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Module for plain input/output prompt pairs"""
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
|
||||
|
||||
class RawInputOutputStrategy(PromptTokenizingStrategy):
|
||||
"""Prompt Strategy class for input/output pairs"""
|
||||
|
||||
def __init__(self, *args, eos_token=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.eos_token = eos_token
|
||||
if not eos_token:
|
||||
self.eos_token = self.tokenizer.eos_token
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# pylint: disable=duplicate-code
|
||||
input_ids = []
|
||||
labels = []
|
||||
for label, text in self.prompter.build_prompt(prompt["segments"]):
|
||||
tokenized_output = self.tokenizer(
|
||||
text, add_special_tokens=False, return_tensors=None
|
||||
)["input_ids"]
|
||||
input_ids += tokenized_output
|
||||
if label or self.train_on_inputs:
|
||||
labels += tokenized_output
|
||||
else:
|
||||
labels += [IGNORE_TOKEN_ID] * len(tokenized_output)
|
||||
|
||||
tokenized_prompt = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(input_ids),
|
||||
}
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class RawInputOutputPrompter(Prompter):
|
||||
"""prompter for raw i/o data"""
|
||||
|
||||
def build_prompt(self, source) -> Generator[Tuple[bool, str], None, None]:
|
||||
for segment in source:
|
||||
yield segment["label"], segment["text"]
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return RawInputOutputStrategy(
|
||||
RawInputOutputPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
116
tests/prompt_strategies/test_raw_io.py
Normal file
116
tests/prompt_strategies/test_raw_io.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Test module for raw i/o data for prompts
|
||||
"""
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.input_output import (
|
||||
RawInputOutputPrompter,
|
||||
RawInputOutputStrategy,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="segments_dataset")
|
||||
def fixture_sharegpt_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": False,
|
||||
"text": "<s>hello ",
|
||||
},
|
||||
{
|
||||
"label": True,
|
||||
"text": "hi there.<eot>",
|
||||
},
|
||||
{
|
||||
"label": False,
|
||||
"text": "goodbye ",
|
||||
},
|
||||
{
|
||||
"label": True,
|
||||
"text": "farewell<eot>",
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),
|
||||
]
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestRawInputOutputPrompts:
|
||||
"""
|
||||
Test class for raw i/o prompter
|
||||
"""
|
||||
|
||||
def test_segment_prompts(self, segments_dataset, tokenizer):
|
||||
strategy = RawInputOutputStrategy(
|
||||
RawInputOutputPrompter(),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, segments_dataset, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
|
||||
assert (
|
||||
tokenizer.decode(input_ids)
|
||||
== "<s> hello hi there.<eot> goodbye farewell<eot>"
|
||||
)
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
1, # <s>
|
||||
6312, # hell
|
||||
28709, # o
|
||||
28705, #
|
||||
12014, # hi
|
||||
736, # there
|
||||
28723, # .
|
||||
32000, # <eot>
|
||||
1179, # good
|
||||
17664, # bye
|
||||
28705, #
|
||||
19111, # fare
|
||||
5458, # well
|
||||
32000, # <eot>
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
-100, # <s>
|
||||
-100, # hell
|
||||
-100, # o
|
||||
-100, #
|
||||
12014, # hi
|
||||
736, # there
|
||||
28723, # .
|
||||
32000, # <eot>
|
||||
-100, # good
|
||||
-100, # bye
|
||||
-100, #
|
||||
19111, # fare
|
||||
5458, # well
|
||||
32000, # <eot>
|
||||
]
|
||||
# fmt: on
|
||||
Reference in New Issue
Block a user