Lint prompt_tokenizers

This commit is contained in:
NanoCode012
2023-05-29 14:20:11 +09:00
parent 01c8a333b3
commit 5d86137f70

View File

@@ -1,7 +1,10 @@
"""Module containing PromptTokenizingStrategy and Prompter classes"""
import abc import abc
import copy import copy
import functools import functools
import logging import logging
from typing import Tuple
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@@ -15,10 +18,16 @@ LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
class InvalidDataException(Exception): class InvalidDataException(Exception):
pass """
Exception raised when the data is invalid
"""
class PromptTokenizingStrategy(abc.ABC): class PromptTokenizingStrategy(abc.ABC):
"""
Abstract class for tokenizing strategies
"""
def __init__( def __init__(
self, self,
prompter, prompter,
@@ -35,14 +44,14 @@ class PromptTokenizingStrategy(abc.ABC):
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
pass pass
@functools.cache @functools.lru_cache(maxsize=128)
def _get_user_token(self): def _get_user_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if isinstance(id_or_ids, (int,)): if isinstance(id_or_ids, (int,)):
return id_or_ids return id_or_ids
return False return False
@functools.cache @functools.lru_cache(maxsize=128)
def _get_assistant_token(self): def _get_assistant_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if isinstance(id_or_ids, (int,)): if isinstance(id_or_ids, (int,)):
@@ -51,11 +60,19 @@ class PromptTokenizingStrategy(abc.ABC):
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str): """
Tokenizing strategy for instruction-based prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
raise NotImplementedError raise NotImplementedError
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
instruction, input, response = self.parse_instruction_fields(prompt) (
instruction,
input, # pylint: disable=redefined-builtin
response,
) = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, response) full_prompt = self._build_full_prompt(instruction, input, response)
tokenized_full_prompt = self._tokenize(full_prompt) tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs: if not self.train_on_inputs:
@@ -76,7 +93,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt return tokenized_full_prompt
def _build_full_prompt(self, instruction, input, response): def _build_full_prompt(
self, instruction, input, response # pylint: disable=redefined-builtin
):
return next( return next(
iter( iter(
self.prompter.build_prompt( self.prompter.build_prompt(
@@ -112,7 +131,11 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str): """
Tokenizing strategy for Alpaca prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return ( return (
prompt["instruction"], prompt["instruction"],
prompt["input"] if "input" in prompt else "", prompt["input"] if "input" in prompt else "",
@@ -121,7 +144,11 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str): """
Tokenizing strategy for Alpaca Multiple Choice prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return ( return (
prompt["question"], prompt["question"],
"\n".join(f'- "{choice}"' for choice in prompt["choices"]), "\n".join(f'- "{choice}"' for choice in prompt["choices"]),
@@ -130,7 +157,11 @@ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingSt
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str): """
Tokenizing strategy for Jeopardy prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return ( return (
prompt["question"], prompt["question"],
prompt["category"], prompt["category"],
@@ -139,7 +170,11 @@ class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str): """
Tokenizing strategy for OpenAssistant prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return ( return (
prompt["INSTRUCTION"], prompt["INSTRUCTION"],
"", "",
@@ -148,7 +183,11 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str): """
Tokenizing strategy for SummarizeTLDR prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return ( return (
prompt["article"], prompt["article"],
"", "",
@@ -157,7 +196,11 @@ class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str): """
Tokenizing strategy for GPTeacher prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return ( return (
prompt["instruction"], prompt["instruction"],
prompt["input"] if "input" in prompt else "", prompt["input"] if "input" in prompt else "",
@@ -166,7 +209,11 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str): """
Tokenizing strategy for NomicGPT4All prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return ( return (
prompt["prompt"], prompt["prompt"],
"", "",
@@ -175,6 +222,10 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for Completion prompts.
"""
def parse_instruction_fields(self, prompt) -> str: def parse_instruction_fields(self, prompt) -> str:
return prompt["text"] return prompt["text"]
@@ -185,18 +236,24 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
return tokenized_full_prompt return tokenized_full_prompt
def _build_full_prompt(self, instruction, input, response): def _build_full_prompt(
self, instruction, input, response
): # pylint: disable=unused-argument, redefined-builtin
return next(iter(self.prompter.build_prompt(instruction))) return next(iter(self.prompter.build_prompt(instruction)))
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str): """
Tokenizing strategy for Reflection prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
raise NotImplementedError raise NotImplementedError
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
( (
instruction, instruction,
input, input, # pylint: disable=redefined-builtin
output, output,
reflection, reflection,
corrected, corrected,
@@ -223,7 +280,9 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt return tokenized_full_prompt
def _build_full_prompt(self, instruction, input, output, reflection, corrected): def _build_full_prompt(
self, instruction, input, output, reflection, corrected
): # pylint: disable=redefined-builtin
return next( return next(
iter( iter(
self.prompter.build_prompt( self.prompter.build_prompt(
@@ -257,7 +316,11 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str): """
Tokenizing strategy for Alpaca Reflection prompts.
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
return ( return (
prompt["instruction"], prompt["instruction"],
prompt["input"] if "input" in prompt else "", prompt["input"] if "input" in prompt else "",
@@ -268,6 +331,10 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for ShareGPT prompts.
"""
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
return prompt["conversations"] return prompt["conversations"]
@@ -281,7 +348,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
user_token = self._get_user_token() user_token = self._get_user_token()
assistant_token = self._get_assistant_token() assistant_token = self._get_assistant_token()
try: try:
for i, part in enumerate( for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt)) self.prompter.build_prompt(self.get_conversation_thread(prompt))
): ):
if isinstance(part, tuple): if isinstance(part, tuple):
@@ -307,7 +374,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
# not masked out from labels # not masked out from labels
labels = copy.deepcopy(res["input_ids"]) labels = copy.deepcopy(res["input_ids"])
else: else:
logging.warning("unhandled role: " + part[0]) logging.warning(f"unhandled role: {part[0]}")
else: else:
# this is only ever the first part, should include the bos token and the user query # this is only ever the first part, should include the bos token and the user query
res = self._tokenize( res = self._tokenize(
@@ -324,8 +391,8 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
result["labels"][current_len : current_len + input_len] = labels result["labels"][current_len : current_len + input_len] = labels
current_len += input_len current_len += input_len
return result return result
except (KeyError, AssertionError, IndexError) as e: except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(e)) raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer( result = self.tokenizer(