misc sharegpt fixes (#723)

* support for sharegpt with assistant talking first, better masking of assistant token, allow remap of roles from dataset

* invalid role is actually not possible

* update tokenized fixture for corrected labels
This commit is contained in:
Wing Lian
2023-10-13 11:04:39 -04:00
committed by GitHub
parent bfbdba8614
commit f30afe4544
4 changed files with 107 additions and 36 deletions

View File

@@ -2,7 +2,6 @@
import abc
import copy
import functools
import logging
from typing import Dict, List, Tuple, Union
@@ -57,26 +56,6 @@ class PromptTokenizingStrategy(abc.ABC):
def supports_batched(self):
return False
@functools.lru_cache(maxsize=128)
def _get_user_token(self):
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False
@functools.lru_cache(maxsize=128)
def _get_assistant_token(self):
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
@@ -356,18 +335,34 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
result, current_len = tokenize_prompt_default()
user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
conversation: Conversation = (
self.prompter._conversation # pylint: disable=protected-access
self.prompter._conversation.copy() # pylint: disable=protected-access
)
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
conversation.name == "vicuna_v1.1"
and "roles" in prompt
and len(prompt["roles"]) >= 2
):
role_remap = [
{"from": conversation.roles[0], "to": prompt["roles"][0]},
{"from": conversation.roles[1], "to": prompt["roles"][1]},
]
try:
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
if conversation.roles[0] in part[0]:
turn = part[0] + part[1] if not user_token else part[1]
role = (
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else part[0]
)
turn = role + part[1]
# this is still the user query, we should
if not part[1].strip():
LOG.warning(f"user turn has empty text: {prompt}")
@@ -376,13 +371,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
add_eos_token=False,
strip_bos_token=True,
)
if user_token:
res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif conversation.roles[1] in part[0]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
turn = part[0] + part[1] if not assistant_token else part[1]
role = (
part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else part[0]
)
turn = role + part[1]
# this should be the assistant response, should end with an eos token
if not part[1].strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
@@ -391,13 +389,17 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
add_eos_token=True,
strip_bos_token=True,
)
if assistant_token:
res["input_ids"] = [
assistant_token,
*res["input_ids"],
]
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
elif part[0] == "":
turn = part[1]
# this is only ever the first part, should include the bos token and the user query

View File

@@ -274,9 +274,11 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
raise err
conv.messages = []
for j, sentence in enumerate(source):
for _, sentence in enumerate(source):
role = roles[sentence["from"]]
if role != conv.roles[j % 2]:
if len(conv.messages) > 0 and (
(role == conv.messages[-1][0]) or (role not in conv.roles)
):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"])