misc sharegpt fixes (#723)

* support for sharegpt with assistant talking first, better masking of assistant token, allow remap of roles from dataset

* invalid role is actually not possible

* update tokenized fixture for corrected labels
This commit is contained in:
Wing Lian
2023-10-13 11:04:39 -04:00
committed by GitHub
parent bfbdba8614
commit f30afe4544
4 changed files with 107 additions and 36 deletions

View File

@@ -2,7 +2,6 @@
import abc
import copy
import functools
import logging
from typing import Dict, List, Tuple, Union
@@ -57,26 +56,6 @@ class PromptTokenizingStrategy(abc.ABC):
def supports_batched(self):
return False
@functools.lru_cache(maxsize=128)
def _get_user_token(self):
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False
@functools.lru_cache(maxsize=128)
def _get_assistant_token(self):
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
@@ -356,18 +335,34 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
result, current_len = tokenize_prompt_default()
user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
conversation: Conversation = (
self.prompter._conversation # pylint: disable=protected-access
self.prompter._conversation.copy() # pylint: disable=protected-access
)
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
conversation.name == "vicuna_v1.1"
and "roles" in prompt
and len(prompt["roles"]) >= 2
):
role_remap = [
{"from": conversation.roles[0], "to": prompt["roles"][0]},
{"from": conversation.roles[1], "to": prompt["roles"][1]},
]
try:
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
if conversation.roles[0] in part[0]:
turn = part[0] + part[1] if not user_token else part[1]
role = (
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else part[0]
)
turn = role + part[1]
# this is still the user query, we should
if not part[1].strip():
LOG.warning(f"user turn has empty text: {prompt}")
@@ -376,13 +371,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
add_eos_token=False,
strip_bos_token=True,
)
if user_token:
res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif conversation.roles[1] in part[0]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
turn = part[0] + part[1] if not assistant_token else part[1]
role = (
part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else part[0]
)
turn = role + part[1]
# this should be the assistant response, should end with an eos token
if not part[1].strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
@@ -391,13 +389,17 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
add_eos_token=True,
strip_bos_token=True,
)
if assistant_token:
res["input_ids"] = [
assistant_token,
*res["input_ids"],
]
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
elif part[0] == "":
turn = part[1]
# this is only ever the first part, should include the bos token and the user query

View File

@@ -274,9 +274,11 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
raise err
conv.messages = []
for j, sentence in enumerate(source):
for _, sentence in enumerate(source):
role = roles[sentence["from"]]
if role != conv.roles[j % 2]:
if len(conv.messages) > 0 and (
(role == conv.messages[-1][0]) or (role not in conv.roles)
):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"])

File diff suppressed because one or more lines are too long

View File

@@ -90,6 +90,73 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
strat.tokenize_prompt(conversation)
assert "assistant turn has empty text" in self._caplog.records[1].message
def test_sharegpt_warnings_turns(self):
conversation = {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "dolor"},
{"from": "human", "value": "dolor"},
{"from": "gpt", "value": "sit"},
]
}
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
strat.tokenize_prompt(conversation)
assert (
"Role did not alternate between turns (gpt and human)"
in self._caplog.records[0].message
)
def test_sharegpt_changes_roles(self):
conversation = {
"roles": ["USER", "CHARACTER"],
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "dolor"},
{"from": "gpt", "value": "sit"},
],
}
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
res = strat.tokenize_prompt(conversation)
assert "CHARACTER" in self.tokenizer.decode(res["input_ids"])
def test_sharegpt_assistant_label_ignore(self):
conversation = {
"roles": ["user", "assistant"],
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "dolor"},
{"from": "gpt", "value": "sit"},
],
}
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
res = strat.tokenize_prompt(conversation)
idx = res["input_ids"].index(20255) # assistant token
assert res["labels"][idx] == -100
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts