diff --git a/src/axolotl/prompt_strategies/input_output.py b/src/axolotl/prompt_strategies/input_output.py new file mode 100644 index 000000000..fe14f039c --- /dev/null +++ b/src/axolotl/prompt_strategies/input_output.py @@ -0,0 +1,54 @@ +"""Module for plain input/output prompt pairs""" +from typing import Generator, Tuple + +from axolotl.prompt_tokenizers import PromptTokenizingStrategy +from axolotl.prompters import IGNORE_TOKEN_ID, Prompter + + +class RawInputOutputStrategy(PromptTokenizingStrategy): + """Prompt Strategy class for input/output pairs""" + + def __init__(self, *args, eos_token=None, **kwargs): + super().__init__(*args, **kwargs) + self.eos_token = eos_token + if not eos_token: + self.eos_token = self.tokenizer.eos_token + + def tokenize_prompt(self, prompt): + # pylint: disable=duplicate-code + input_ids = [] + labels = [] + for label, text in self.prompter.build_prompt(prompt["segments"]): + tokenized_output = self.tokenizer( + text, add_special_tokens=False, return_tensors=None + )["input_ids"] + input_ids += tokenized_output + if label or self.train_on_inputs: + labels += tokenized_output + else: + labels += [IGNORE_TOKEN_ID] * len(tokenized_output) + + tokenized_prompt = { + "input_ids": input_ids, + "labels": labels, + "attention_mask": [1] * len(input_ids), + } + + return tokenized_prompt + + +class RawInputOutputPrompter(Prompter): + """prompter for raw i/o data""" + + def build_prompt(self, source) -> Generator[Tuple[bool, str], None, None]: + for segment in source: + yield segment["label"], segment["text"] + + +def load(tokenizer, cfg): + return RawInputOutputStrategy( + RawInputOutputPrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) diff --git a/tests/prompt_strategies/test_raw_io.py b/tests/prompt_strategies/test_raw_io.py new file mode 100644 index 000000000..8c66538ec --- /dev/null +++ b/tests/prompt_strategies/test_raw_io.py @@ -0,0 +1,116 @@ +""" +Test module for raw i/o data for prompts +""" +import pytest +from datasets import Dataset +from tokenizers import AddedToken +from transformers import AutoTokenizer + +from axolotl.datasets import TokenizedPromptDataset +from axolotl.prompt_strategies.input_output import ( + RawInputOutputPrompter, + RawInputOutputStrategy, +) + + +@pytest.fixture(name="segments_dataset") +def fixture_sharegpt_dataset(): + return Dataset.from_list( + [ + { + "segments": [ + { + "label": False, + "text": "hello ", + }, + { + "label": True, + "text": "hi there.", + }, + { + "label": False, + "text": "goodbye ", + }, + { + "label": True, + "text": "farewell", + }, + ] + } + ] + ) + + +@pytest.fixture(name="tokenizer") +def fixture_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + tokenizer.add_tokens( + [ + AddedToken("", rstrip=False, lstrip=False, normalized=False), + ] + ) + + return tokenizer + + +class TestRawInputOutputPrompts: + """ + Test class for raw i/o prompter + """ + + def test_segment_prompts(self, segments_dataset, tokenizer): + strategy = RawInputOutputStrategy( + RawInputOutputPrompter(), + tokenizer, + False, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, segments_dataset, process_count=1 + ) + + input_ids = dataset_wrapper[0]["input_ids"] + labels = dataset_wrapper[0]["labels"] + + assert ( + tokenizer.decode(input_ids) + == " hello hi there. goodbye farewell" + ) + # fmt: off + assert input_ids == [ + 1, # + 6312, # hell + 28709, # o + 28705, # + 12014, # hi + 736, # there + 28723, # . + 32000, # + 1179, # good + 17664, # bye + 28705, # + 19111, # fare + 5458, # well + 32000, # + ] + # fmt: on + + # fmt: off + assert labels == [ + -100, # + -100, # hell + -100, # o + -100, # + 12014, # hi + 736, # there + 28723, # . + 32000, # + -100, # good + -100, # bye + -100, # + 19111, # fare + 5458, # well + 32000, # + ] + # fmt: on