improve: Enhance code readability of prompt_tokenizers.py (#707)

This commit is contained in:
seungduk.kim.2304
2023-10-19 21:12:17 +09:00
committed by GitHub
parent 440c3ab527
commit 3a99495b05

View File

@@ -45,6 +45,8 @@ class PromptTokenizingStrategy(abc.ABC):
self.prompter = prompter self.prompter = prompter
self.tokenizer: PreTrainedTokenizer = tokenizer self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs self.train_on_inputs = train_on_inputs
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
# TODO: Document how they are different.
self.sequence_len = sequence_len self.sequence_len = sequence_len
self.max_length = sequence_len self.max_length = sequence_len
@@ -59,34 +61,31 @@ class PromptTokenizingStrategy(abc.ABC):
def _tokenize( def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding: ) -> BatchEncoding:
result: BatchEncoding empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
if not prompt: if not prompt:
LOG.warning("Empty text requested for tokenization.") LOG.warning("Empty text requested for tokenization.")
result = BatchEncoding(data={"input_ids": [], "attention_mask": []}) return empty
else:
result = self.tokenizer( result = self.tokenizer(
prompt, prompt,
truncation=True, truncation=True,
max_length=self.max_length, max_length=self.max_length,
padding=False, padding=False,
return_tensors=None, return_tensors=None,
) )
if len(result["input_ids"]) == 0: if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset") LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
return empty
if ( if (
len(result["input_ids"]) > 0 result["input_ids"][-1] != self.tokenizer.eos_token_id
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.max_length and len(result["input_ids"]) < self.max_length
and add_eos_token and add_eos_token
): ):
result["input_ids"].append(self.tokenizer.eos_token_id) result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1) result["attention_mask"].append(1)
if ( if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:] result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:] result["attention_mask"] = result["attention_mask"][1:]
@@ -122,7 +121,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
if not self.train_on_inputs: if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"]) user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
tokenized_prompt["labels"] = [-100] * user_prompt_len tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
tokenized_res_prompt = self._tokenize( tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True response, strip_bos_token=True, add_eos_token=True
) )
@@ -270,7 +269,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
user_prompt_len = len(tokenized_user_prompt["input_ids"]) user_prompt_len = len(tokenized_user_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [ tokenized_full_prompt["labels"] = [
-100 IGNORE_INDEX
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
return tokenized_full_prompt return tokenized_full_prompt
@@ -334,6 +333,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
return prompt["conversations"] return prompt["conversations"]
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
# Initial values. We will append to these as we go through the conversation.
result, current_len = tokenize_prompt_default() result, current_len = tokenize_prompt_default()
conversation: Conversation = ( conversation: Conversation = (
self.prompter._conversation.copy() # pylint: disable=protected-access self.prompter._conversation.copy() # pylint: disable=protected-access
@@ -355,62 +355,67 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
for _, part in enumerate( for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt)) self.prompter.build_prompt(self.get_conversation_thread(prompt))
): ):
if isinstance(part, tuple): if not isinstance(part, tuple):
if conversation.roles[0] in part[0]: LOG.warning(f"expected tuple, got {part}")
role = ( continue
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap user, assistant = conversation.roles
else part[0] role, content = part
)
turn = role + part[1] # Uses "in" because role contains extra characters
# this is still the user query, we should if user in role:
if not part[1].strip(): role = (
LOG.warning(f"user turn has empty text: {prompt}") role.replace(role_remap[0]["from"], role_remap[0]["to"])
res = self._tokenize( if role_remap
turn, else role
add_eos_token=False, )
strip_bos_token=True, turn = role + content
) # this is still the user query, we should
# everything from this is masked out from the labels if not content.strip():
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) LOG.warning(f"user turn has empty text: {prompt}")
elif conversation.roles[1] in part[0]: res = self._tokenize(
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID turn,
role = ( add_eos_token=False,
part[0].replace(role_remap[1]["from"], role_remap[1]["to"]) strip_bos_token=True,
if role_remap )
else part[0] # everything from this is masked out from the labels
) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
turn = role + part[1] elif assistant in role:
# this should be the assistant response, should end with an eos token # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
if not part[1].strip(): role = (
LOG.warning(f"assistant turn has empty text: {prompt}") role.replace(role_remap[1]["from"], role_remap[1]["to"])
res = self._tokenize( if role_remap
turn, else role
add_eos_token=True, )
strip_bos_token=True, turn = role + content
) # this should be the assistant response, should end with an eos token
role_res = self._tokenize( if not content.strip():
role.rstrip(), LOG.warning(f"assistant turn has empty text: {prompt}")
add_eos_token=False, res = self._tokenize(
strip_bos_token=True, turn,
) add_eos_token=True,
# not masked out from labels strip_bos_token=True,
labels = copy.deepcopy(res["input_ids"]) )
len_role = len(role_res["input_ids"]) role_res = self._tokenize(
labels[:len_role] = [IGNORE_TOKEN_ID] * min( role.rstrip(),
len_role, len(labels) add_eos_token=False,
) strip_bos_token=True,
elif part[0] == "": )
turn = part[1] # not masked out from labels
# this is only ever the first part, should include the bos token and the user query labels = copy.deepcopy(res["input_ids"])
res = self._tokenize( len_role = len(role_res["input_ids"])
turn, add_eos_token=False, strip_bos_token=False labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
) elif role == "":
# everything from this is masked out from the labels turn = content
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) # this is only ever the first part, should include the bos token and the user query
else: res = self._tokenize(
LOG.warning(f"unhandled role: {part[0]}") turn, add_eos_token=False, strip_bos_token=False
continue )
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result( result, current_len = parse_tokenized_to_result(
@@ -424,38 +429,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
except (KeyError, AssertionError, IndexError) as err: except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
if not prompt.strip():
LOG.warning("Empty text requested for tokenization.")
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
else:
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if (
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
result["labels"] = result["input_ids"].copy()
return result
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
""" """