better handling and logging of empty sharegpt turns (#603)

This commit is contained in:
Wing Lian
2023-09-22 16:13:42 -04:00
committed by GitHub
parent 501958bb6f
commit a363604dcf
3 changed files with 105 additions and 14 deletions

View File

@@ -358,10 +358,12 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
):
if isinstance(part, tuple):
if part[0] == "USER:":
part = part[0] + part[1] if not user_token else part[1]
turn = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
if not part[1].strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
part.strip(),
turn.strip(),
add_eos_token=False,
strip_bos_token=True,
)
@@ -371,10 +373,12 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif part[0] == "ASSISTANT:":
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
part = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistent response, should end with an eos token
turn = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistant response, should end with an eos token
if not part[1].strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
res = self._tokenize(
part.strip(),
turn.strip(),
add_eos_token=True,
strip_bos_token=True,
)
@@ -409,22 +413,31 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if not prompt.strip():
LOG.warning("Empty text requested for tokenization.")
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
else:
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
if (
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]