improve: Enhance code readability of prompt_tokenizers.py (#707)
This commit is contained in:
committed by
GitHub
parent
440c3ab527
commit
3a99495b05
@@ -45,6 +45,8 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
self.prompter = prompter
|
self.prompter = prompter
|
||||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||||
self.train_on_inputs = train_on_inputs
|
self.train_on_inputs = train_on_inputs
|
||||||
|
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
|
||||||
|
# TODO: Document how they are different.
|
||||||
self.sequence_len = sequence_len
|
self.sequence_len = sequence_len
|
||||||
self.max_length = sequence_len
|
self.max_length = sequence_len
|
||||||
|
|
||||||
@@ -59,34 +61,31 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
def _tokenize(
|
def _tokenize(
|
||||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
result: BatchEncoding
|
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||||
if not prompt:
|
if not prompt:
|
||||||
LOG.warning("Empty text requested for tokenization.")
|
LOG.warning("Empty text requested for tokenization.")
|
||||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
return empty
|
||||||
else:
|
|
||||||
result = self.tokenizer(
|
result = self.tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
padding=False,
|
padding=False,
|
||||||
return_tensors=None,
|
return_tensors=None,
|
||||||
)
|
)
|
||||||
if len(result["input_ids"]) == 0:
|
if len(result["input_ids"]) == 0:
|
||||||
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
||||||
|
return empty
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(result["input_ids"]) > 0
|
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
|
||||||
and len(result["input_ids"]) < self.max_length
|
and len(result["input_ids"]) < self.max_length
|
||||||
and add_eos_token
|
and add_eos_token
|
||||||
):
|
):
|
||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||||
result["attention_mask"].append(1)
|
result["attention_mask"].append(1)
|
||||||
|
|
||||||
if (
|
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||||
len(result["input_ids"]) > 0
|
|
||||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
|
||||||
and strip_bos_token
|
|
||||||
):
|
|
||||||
result["input_ids"] = result["input_ids"][1:]
|
result["input_ids"] = result["input_ids"][1:]
|
||||||
result["attention_mask"] = result["attention_mask"][1:]
|
result["attention_mask"] = result["attention_mask"][1:]
|
||||||
|
|
||||||
@@ -122,7 +121,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
if not self.train_on_inputs:
|
if not self.train_on_inputs:
|
||||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||||
# TODO this could be sped up using numpy array slicing
|
# TODO this could be sped up using numpy array slicing
|
||||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
|
||||||
tokenized_res_prompt = self._tokenize(
|
tokenized_res_prompt = self._tokenize(
|
||||||
response, strip_bos_token=True, add_eos_token=True
|
response, strip_bos_token=True, add_eos_token=True
|
||||||
)
|
)
|
||||||
@@ -270,7 +269,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||||
# TODO this could be sped up using numpy array slicing
|
# TODO this could be sped up using numpy array slicing
|
||||||
tokenized_full_prompt["labels"] = [
|
tokenized_full_prompt["labels"] = [
|
||||||
-100
|
IGNORE_INDEX
|
||||||
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
||||||
|
|
||||||
return tokenized_full_prompt
|
return tokenized_full_prompt
|
||||||
@@ -334,6 +333,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
return prompt["conversations"]
|
return prompt["conversations"]
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
|
# Initial values. We will append to these as we go through the conversation.
|
||||||
result, current_len = tokenize_prompt_default()
|
result, current_len = tokenize_prompt_default()
|
||||||
conversation: Conversation = (
|
conversation: Conversation = (
|
||||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||||
@@ -355,62 +355,67 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
for _, part in enumerate(
|
for _, part in enumerate(
|
||||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||||
):
|
):
|
||||||
if isinstance(part, tuple):
|
if not isinstance(part, tuple):
|
||||||
if conversation.roles[0] in part[0]:
|
LOG.warning(f"expected tuple, got {part}")
|
||||||
role = (
|
continue
|
||||||
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
|
|
||||||
if role_remap
|
user, assistant = conversation.roles
|
||||||
else part[0]
|
role, content = part
|
||||||
)
|
|
||||||
turn = role + part[1]
|
# Uses "in" because role contains extra characters
|
||||||
# this is still the user query, we should
|
if user in role:
|
||||||
if not part[1].strip():
|
role = (
|
||||||
LOG.warning(f"user turn has empty text: {prompt}")
|
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||||
res = self._tokenize(
|
if role_remap
|
||||||
turn,
|
else role
|
||||||
add_eos_token=False,
|
)
|
||||||
strip_bos_token=True,
|
turn = role + content
|
||||||
)
|
# this is still the user query, we should
|
||||||
# everything from this is masked out from the labels
|
if not content.strip():
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
LOG.warning(f"user turn has empty text: {prompt}")
|
||||||
elif conversation.roles[1] in part[0]:
|
res = self._tokenize(
|
||||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
turn,
|
||||||
role = (
|
add_eos_token=False,
|
||||||
part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
|
strip_bos_token=True,
|
||||||
if role_remap
|
)
|
||||||
else part[0]
|
# everything from this is masked out from the labels
|
||||||
)
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
turn = role + part[1]
|
elif assistant in role:
|
||||||
# this should be the assistant response, should end with an eos token
|
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||||
if not part[1].strip():
|
role = (
|
||||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||||
res = self._tokenize(
|
if role_remap
|
||||||
turn,
|
else role
|
||||||
add_eos_token=True,
|
)
|
||||||
strip_bos_token=True,
|
turn = role + content
|
||||||
)
|
# this should be the assistant response, should end with an eos token
|
||||||
role_res = self._tokenize(
|
if not content.strip():
|
||||||
role.rstrip(),
|
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||||
add_eos_token=False,
|
res = self._tokenize(
|
||||||
strip_bos_token=True,
|
turn,
|
||||||
)
|
add_eos_token=True,
|
||||||
# not masked out from labels
|
strip_bos_token=True,
|
||||||
labels = copy.deepcopy(res["input_ids"])
|
)
|
||||||
len_role = len(role_res["input_ids"])
|
role_res = self._tokenize(
|
||||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
role.rstrip(),
|
||||||
len_role, len(labels)
|
add_eos_token=False,
|
||||||
)
|
strip_bos_token=True,
|
||||||
elif part[0] == "":
|
)
|
||||||
turn = part[1]
|
# not masked out from labels
|
||||||
# this is only ever the first part, should include the bos token and the user query
|
labels = copy.deepcopy(res["input_ids"])
|
||||||
res = self._tokenize(
|
len_role = len(role_res["input_ids"])
|
||||||
turn, add_eos_token=False, strip_bos_token=False
|
labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
|
||||||
)
|
elif role == "":
|
||||||
# everything from this is masked out from the labels
|
turn = content
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
# this is only ever the first part, should include the bos token and the user query
|
||||||
else:
|
res = self._tokenize(
|
||||||
LOG.warning(f"unhandled role: {part[0]}")
|
turn, add_eos_token=False, strip_bos_token=False
|
||||||
continue
|
)
|
||||||
|
# everything from this is masked out from the labels
|
||||||
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
|
else:
|
||||||
|
LOG.warning(f"unhandled role: {role}")
|
||||||
|
continue
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
result, current_len = parse_tokenized_to_result(
|
result, current_len = parse_tokenized_to_result(
|
||||||
@@ -424,38 +429,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
except (KeyError, AssertionError, IndexError) as err:
|
except (KeyError, AssertionError, IndexError) as err:
|
||||||
raise InvalidDataException(str(err)) from err
|
raise InvalidDataException(str(err)) from err
|
||||||
|
|
||||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
|
||||||
if not prompt.strip():
|
|
||||||
LOG.warning("Empty text requested for tokenization.")
|
|
||||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
|
||||||
else:
|
|
||||||
result = self.tokenizer(
|
|
||||||
prompt,
|
|
||||||
truncation=True,
|
|
||||||
max_length=self.sequence_len,
|
|
||||||
padding=False,
|
|
||||||
return_tensors=None,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
len(result["input_ids"]) > 0
|
|
||||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
|
||||||
and len(result["input_ids"]) < self.sequence_len
|
|
||||||
and add_eos_token
|
|
||||||
):
|
|
||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
|
||||||
result["attention_mask"].append(1)
|
|
||||||
|
|
||||||
if (
|
|
||||||
len(result["input_ids"]) > 0
|
|
||||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
|
||||||
and strip_bos_token
|
|
||||||
):
|
|
||||||
result["input_ids"] = result["input_ids"][1:]
|
|
||||||
result["attention_mask"] = result["attention_mask"][1:]
|
|
||||||
|
|
||||||
result["labels"] = result["input_ids"].copy()
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user