Lint prompt_tokenizers

This commit is contained in:
NanoCode012
2023-05-29 14:20:11 +09:00
parent 01c8a333b3
commit 5d86137f70

View File

@@ -1,7 +1,10 @@
"""Module containing PromptTokenizingStrategy and Prompter classes"""
import abc
import copy
import functools
import logging
from typing import Tuple
from transformers import PreTrainedTokenizer
@@ -15,10 +18,16 @@ LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
class InvalidDataException(Exception):
pass
"""
Exception raised when the data is invalid
"""
class PromptTokenizingStrategy(abc.ABC):
"""
Abstract class for tokenizing strategies
"""
def __init__(
self,
prompter,
@@ -35,14 +44,14 @@ class PromptTokenizingStrategy(abc.ABC):
def tokenize_prompt(self, prompt):
pass
@functools.cache
@functools.lru_cache(maxsize=128)
def _get_user_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
return False
@functools.cache
@functools.lru_cache(maxsize=128)
def _get_assistant_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if isinstance(id_or_ids, (int,)):
@@ -51,11 +60,19 @@ class PromptTokenizingStrategy(abc.ABC):
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for instruction-based prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
raise NotImplementedError
def tokenize_prompt(self, prompt):
instruction, input, response = self.parse_instruction_fields(prompt)
(
instruction,
input, # pylint: disable=redefined-builtin
response,
) = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, response)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
@@ -76,7 +93,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt
def _build_full_prompt(self, instruction, input, response):
def _build_full_prompt(
self, instruction, input, response # pylint: disable=redefined-builtin
):
return next(
iter(
self.prompter.build_prompt(
@@ -112,7 +131,11 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for Alpaca prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
@@ -121,7 +144,11 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for Alpaca Multiple Choice prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["question"],
"\n".join(f'- "{choice}"' for choice in prompt["choices"]),
@@ -130,7 +157,11 @@ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingSt
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for Jeopardy prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["question"],
prompt["category"],
@@ -139,7 +170,11 @@ class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for OpenAssistant prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["INSTRUCTION"],
"",
@@ -148,7 +183,11 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for SummarizeTLDR prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["article"],
"",
@@ -157,7 +196,11 @@ class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for GPTeacher prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
@@ -166,7 +209,11 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for NomicGPT4All prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["prompt"],
"",
@@ -175,6 +222,10 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for Completion prompts.
"""
def parse_instruction_fields(self, prompt) -> str:
return prompt["text"]
@@ -185,18 +236,24 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
return tokenized_full_prompt
def _build_full_prompt(self, instruction, input, response):
def _build_full_prompt(
self, instruction, input, response
): # pylint: disable=unused-argument, redefined-builtin
return next(iter(self.prompter.build_prompt(instruction)))
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
"""
Tokenizing strategy for Reflection prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
raise NotImplementedError
def tokenize_prompt(self, prompt):
(
instruction,
input,
input, # pylint: disable=redefined-builtin
output,
reflection,
corrected,
@@ -223,7 +280,9 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt
def _build_full_prompt(self, instruction, input, output, reflection, corrected):
def _build_full_prompt(
self, instruction, input, output, reflection, corrected
): # pylint: disable=redefined-builtin
return next(
iter(
self.prompter.build_prompt(
@@ -257,7 +316,11 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
"""
Tokenizing strategy for Alpaca Reflection prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
@@ -268,6 +331,10 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for ShareGPT prompts.
"""
def get_conversation_thread(self, prompt):
return prompt["conversations"]
@@ -281,7 +348,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
try:
for i, part in enumerate(
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
@@ -307,7 +374,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
else:
logging.warning("unhandled role: " + part[0])
logging.warning(f"unhandled role: {part[0]}")
else:
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
@@ -324,8 +391,8 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
result["labels"][current_len : current_len + input_len] = labels
current_len += input_len
return result
except (KeyError, AssertionError, IndexError) as e:
raise InvalidDataException(str(e))
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(