plain input/output prompt strategy w/o chat templates (#1346)
* plain input/output prompt strategy w/o chat templates * disable duplicate code check * make sure to add an eos/eot token to the end of the output so it will stop * multi turn segement support and test
This commit is contained in:
54
src/axolotl/prompt_strategies/input_output.py
Normal file
54
src/axolotl/prompt_strategies/input_output.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Module for plain input/output prompt pairs"""
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
|
||||
|
||||
class RawInputOutputStrategy(PromptTokenizingStrategy):
|
||||
"""Prompt Strategy class for input/output pairs"""
|
||||
|
||||
def __init__(self, *args, eos_token=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.eos_token = eos_token
|
||||
if not eos_token:
|
||||
self.eos_token = self.tokenizer.eos_token
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# pylint: disable=duplicate-code
|
||||
input_ids = []
|
||||
labels = []
|
||||
for label, text in self.prompter.build_prompt(prompt["segments"]):
|
||||
tokenized_output = self.tokenizer(
|
||||
text, add_special_tokens=False, return_tensors=None
|
||||
)["input_ids"]
|
||||
input_ids += tokenized_output
|
||||
if label or self.train_on_inputs:
|
||||
labels += tokenized_output
|
||||
else:
|
||||
labels += [IGNORE_TOKEN_ID] * len(tokenized_output)
|
||||
|
||||
tokenized_prompt = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(input_ids),
|
||||
}
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class RawInputOutputPrompter(Prompter):
|
||||
"""prompter for raw i/o data"""
|
||||
|
||||
def build_prompt(self, source) -> Generator[Tuple[bool, str], None, None]:
|
||||
for segment in source:
|
||||
yield segment["label"], segment["text"]
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return RawInputOutputStrategy(
|
||||
RawInputOutputPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
Reference in New Issue
Block a user