Tokenization open assistant (#1)
* refactor prompt tokenization to more easily support open assistant * add open assisstant handling, more logging, black formatting
This commit is contained in:
@@ -31,14 +31,18 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
pass
|
||||
|
||||
|
||||
class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
full_prompt = self._tokenize_full_prompt(prompt)
|
||||
instruction, input, response = self.parse_instruction_fields(prompt)
|
||||
full_prompt = self._build_full_prompt(instruction, input, response)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt = self.prompter.build_prompt(
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
@@ -49,11 +53,11 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return tokenized_full_prompt
|
||||
|
||||
def _tokenize_full_prompt(self, prompt):
|
||||
def _build_full_prompt(self, instruction, input, response):
|
||||
return self.prompter.build_prompt(
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
prompt["output"],
|
||||
instruction,
|
||||
input,
|
||||
response,
|
||||
)
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True):
|
||||
@@ -76,11 +80,29 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
return result
|
||||
|
||||
|
||||
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
|
||||
def _tokenize_full_prompt(self, prompt):
|
||||
return self.prompter.build_prompt(
|
||||
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
return (
|
||||
prompt["instruction"],
|
||||
prompt["input"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
prompt["output"],
|
||||
)
|
||||
|
||||
|
||||
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
return (
|
||||
prompt["INSTRUCTION"],
|
||||
"",
|
||||
prompt["RESPONSE"],
|
||||
)
|
||||
|
||||
|
||||
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
||||
return (
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
prompt["response"],
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user