Tokenization open assistant (#1)

* refactor prompt tokenization to more easily support open assistant

* add open assisstant handling, more logging, black formatting
This commit is contained in:
Wing Lian
2023-04-18 01:45:49 -04:00
committed by GitHub
parent eb808903e5
commit 87d7825435
2 changed files with 149 additions and 51 deletions

View File

@@ -31,14 +31,18 @@ class PromptTokenizingStrategy(abc.ABC):
pass
class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
raise NotImplementedError
def tokenize_prompt(self, prompt):
full_prompt = self._tokenize_full_prompt(prompt)
instruction, input, response = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, response)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = self.prompter.build_prompt(
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
instruction,
input,
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
@@ -49,11 +53,11 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt
def _tokenize_full_prompt(self, prompt):
def _build_full_prompt(self, instruction, input, response):
return self.prompter.build_prompt(
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
prompt["output"],
instruction,
input,
response,
)
def _tokenize(self, prompt, add_eos_token=True):
@@ -76,11 +80,29 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
return result
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
def _tokenize_full_prompt(self, prompt):
return self.prompter.build_prompt(
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
return (
prompt["instruction"],
prompt["input"],
prompt["input"] if "input" in prompt else "",
prompt["output"],
)
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
return (
prompt["INSTRUCTION"],
"",
prompt["RESPONSE"],
)
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
prompt["response"],
)