plain input/output prompt strategy w/o chat templates (#1346)
* plain input/output prompt strategy w/o chat templates * disable duplicate code check * make sure to add an eos/eot token to the end of the output so it will stop * multi turn segement support and test
This commit is contained in:
54
src/axolotl/prompt_strategies/input_output.py
Normal file
54
src/axolotl/prompt_strategies/input_output.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Module for plain input/output prompt pairs"""
|
||||||
|
from typing import Generator, Tuple
|
||||||
|
|
||||||
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
|
|
||||||
|
|
||||||
|
class RawInputOutputStrategy(PromptTokenizingStrategy):
|
||||||
|
"""Prompt Strategy class for input/output pairs"""
|
||||||
|
|
||||||
|
def __init__(self, *args, eos_token=None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.eos_token = eos_token
|
||||||
|
if not eos_token:
|
||||||
|
self.eos_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
|
def tokenize_prompt(self, prompt):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
input_ids = []
|
||||||
|
labels = []
|
||||||
|
for label, text in self.prompter.build_prompt(prompt["segments"]):
|
||||||
|
tokenized_output = self.tokenizer(
|
||||||
|
text, add_special_tokens=False, return_tensors=None
|
||||||
|
)["input_ids"]
|
||||||
|
input_ids += tokenized_output
|
||||||
|
if label or self.train_on_inputs:
|
||||||
|
labels += tokenized_output
|
||||||
|
else:
|
||||||
|
labels += [IGNORE_TOKEN_ID] * len(tokenized_output)
|
||||||
|
|
||||||
|
tokenized_prompt = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
"attention_mask": [1] * len(input_ids),
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenized_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class RawInputOutputPrompter(Prompter):
|
||||||
|
"""prompter for raw i/o data"""
|
||||||
|
|
||||||
|
def build_prompt(self, source) -> Generator[Tuple[bool, str], None, None]:
|
||||||
|
for segment in source:
|
||||||
|
yield segment["label"], segment["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def load(tokenizer, cfg):
|
||||||
|
return RawInputOutputStrategy(
|
||||||
|
RawInputOutputPrompter(),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
116
tests/prompt_strategies/test_raw_io.py
Normal file
116
tests/prompt_strategies/test_raw_io.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""
|
||||||
|
Test module for raw i/o data for prompts
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from datasets import Dataset
|
||||||
|
from tokenizers import AddedToken
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.datasets import TokenizedPromptDataset
|
||||||
|
from axolotl.prompt_strategies.input_output import (
|
||||||
|
RawInputOutputPrompter,
|
||||||
|
RawInputOutputStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="segments_dataset")
|
||||||
|
def fixture_sharegpt_dataset():
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"segments": [
|
||||||
|
{
|
||||||
|
"label": False,
|
||||||
|
"text": "<s>hello ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"label": True,
|
||||||
|
"text": "hi there.<eot>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"label": False,
|
||||||
|
"text": "goodbye ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"label": True,
|
||||||
|
"text": "farewell<eot>",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="tokenizer")
|
||||||
|
def fixture_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||||
|
tokenizer.add_tokens(
|
||||||
|
[
|
||||||
|
AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestRawInputOutputPrompts:
|
||||||
|
"""
|
||||||
|
Test class for raw i/o prompter
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_segment_prompts(self, segments_dataset, tokenizer):
|
||||||
|
strategy = RawInputOutputStrategy(
|
||||||
|
RawInputOutputPrompter(),
|
||||||
|
tokenizer,
|
||||||
|
False, # train_on_inputs
|
||||||
|
2048, # sequence_len
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
|
strategy, segments_dataset, process_count=1
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = dataset_wrapper[0]["input_ids"]
|
||||||
|
labels = dataset_wrapper[0]["labels"]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
tokenizer.decode(input_ids)
|
||||||
|
== "<s> hello hi there.<eot> goodbye farewell<eot>"
|
||||||
|
)
|
||||||
|
# fmt: off
|
||||||
|
assert input_ids == [
|
||||||
|
1, # <s>
|
||||||
|
6312, # hell
|
||||||
|
28709, # o
|
||||||
|
28705, #
|
||||||
|
12014, # hi
|
||||||
|
736, # there
|
||||||
|
28723, # .
|
||||||
|
32000, # <eot>
|
||||||
|
1179, # good
|
||||||
|
17664, # bye
|
||||||
|
28705, #
|
||||||
|
19111, # fare
|
||||||
|
5458, # well
|
||||||
|
32000, # <eot>
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
assert labels == [
|
||||||
|
-100, # <s>
|
||||||
|
-100, # hell
|
||||||
|
-100, # o
|
||||||
|
-100, #
|
||||||
|
12014, # hi
|
||||||
|
736, # there
|
||||||
|
28723, # .
|
||||||
|
32000, # <eot>
|
||||||
|
-100, # good
|
||||||
|
-100, # bye
|
||||||
|
-100, #
|
||||||
|
19111, # fare
|
||||||
|
5458, # well
|
||||||
|
32000, # <eot>
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
Reference in New Issue
Block a user