Refactor duplicate code between Prompter and Pygmalion

This commit is contained in:
NanoCode012
2023-05-29 15:08:26 +09:00
parent 1f3c3f5ea0
commit 8e46c0fb0d
2 changed files with 86 additions and 76 deletions

View File

@@ -5,7 +5,11 @@ import logging
from collections import defaultdict from collections import defaultdict
from typing import Generator from typing import Generator
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import (
PromptTokenizingStrategy,
parse_tokenized_to_result,
tokenize_prompt_default,
)
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
@@ -23,12 +27,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
self.bot_prefix_token_ids = res["input_ids"] self.bot_prefix_token_ids = res["input_ids"]
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
result = { result, current_len = tokenize_prompt_default()
"input_ids": [],
"attention_mask": [],
"labels": [],
}
current_len = 0
for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
role, message = part role, message = part
if role == "system": if role == "system":
@@ -67,37 +66,15 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
else: else:
logging.warning(f"unknown role in conversation: {role}") logging.warning(f"unknown role in conversation: {role}")
res = defaultdict(lambda: []) res = defaultdict(lambda: [])
input_ids = res["input_ids"]
input_len = len(input_ids)
result["input_ids"][current_len : current_len + input_len] = input_ids
result["attention_mask"][current_len : current_len + input_len] = [
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
]
result["labels"][current_len : current_len + input_len] = labels
current_len += input_len
return result
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): # pylint: disable=duplicate-code
result = self.tokenizer( result, current_len = parse_tokenized_to_result(
prompt, result,
truncation=True, current_len,
max_length=self.sequence_len, res,
padding=False, labels,
return_tensors=None, pad_token_id=self.tokenizer.pad_token_id,
) )
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
result["labels"] = result["input_ids"].copy()
return result return result

View File

@@ -4,7 +4,7 @@ import abc
import copy import copy
import functools import functools
import logging import logging
from typing import Tuple from typing import Dict, List, Tuple
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@@ -58,6 +58,29 @@ class PromptTokenizingStrategy(abc.ABC):
return id_or_ids return id_or_ids
return False return False
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
result["labels"] = result["input_ids"].copy()
return result
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
""" """
@@ -106,29 +129,6 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
) )
) )
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
result["labels"] = result["input_ids"].copy()
return result
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
""" """
@@ -295,7 +295,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
) )
) )
def _tokenize(self, prompt, add_eos_token=True): def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer( result = self.tokenizer(
prompt, prompt,
truncation=True, truncation=True,
@@ -339,12 +339,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
return prompt["conversations"] return prompt["conversations"]
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
result = { result, current_len = tokenize_prompt_default()
"input_ids": [],
"attention_mask": [],
"labels": [],
}
current_len = 0
user_token = self._get_user_token() user_token = self._get_user_token()
assistant_token = self._get_assistant_token() assistant_token = self._get_assistant_token()
try: try:
@@ -382,14 +377,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
) )
# everything from this is masked out from the labels # everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
input_ids = res["input_ids"]
input_len = len(input_ids) # pylint: disable=duplicate-code
result["input_ids"][current_len : current_len + input_len] = input_ids result, current_len = parse_tokenized_to_result(
result["attention_mask"][current_len : current_len + input_len] = [ result,
1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids current_len,
] res,
result["labels"][current_len : current_len + input_len] = labels labels,
current_len += input_len pad_token_id=self.tokenizer.pad_token_id,
)
return result return result
except (KeyError, AssertionError, IndexError) as err: except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err raise InvalidDataException(str(err)) from err
@@ -416,3 +412,40 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
result["labels"] = result["input_ids"].copy() result["labels"] = result["input_ids"].copy()
return result return result
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
"""
Returns the default values for the tokenize prompt function
"""
result = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
current_len = 0
return result, current_len
def parse_tokenized_to_result(
result: Dict[str, List[int]],
current_len: int,
res: Dict[str, List[int]],
labels: list[int],
pad_token_id: int | None = None,
) -> Tuple[Dict[str, List[int]], int]:
"""
Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result
"""
input_ids = res["input_ids"]
input_len = len(input_ids)
result["input_ids"][current_len : current_len + input_len] = input_ids
result["attention_mask"][current_len : current_len + input_len] = [
1 if x != pad_token_id else 0 for x in input_ids
]
result["labels"][current_len : current_len + input_len] = labels
current_len += input_len
return result, current_len