misc sharegpt fixes (#723)
* support for sharegpt with assistant talking first, better masking of assistant token, allow remap of roles from dataset * invalid role is actually not possible * update tokenized fixture for corrected labels
This commit is contained in:
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
import copy
|
import copy
|
||||||
import functools
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
@@ -57,26 +56,6 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
def supports_batched(self):
|
def supports_batched(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=128)
|
|
||||||
def _get_user_token(self):
|
|
||||||
try:
|
|
||||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
|
||||||
if isinstance(id_or_ids, (int,)):
|
|
||||||
return id_or_ids
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=128)
|
|
||||||
def _get_assistant_token(self):
|
|
||||||
try:
|
|
||||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
|
||||||
if isinstance(id_or_ids, (int,)):
|
|
||||||
return id_or_ids
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _tokenize(
|
def _tokenize(
|
||||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
@@ -356,18 +335,34 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
result, current_len = tokenize_prompt_default()
|
result, current_len = tokenize_prompt_default()
|
||||||
user_token = self._get_user_token()
|
|
||||||
assistant_token = self._get_assistant_token()
|
|
||||||
conversation: Conversation = (
|
conversation: Conversation = (
|
||||||
self.prompter._conversation # pylint: disable=protected-access
|
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
||||||
|
role_remap = []
|
||||||
|
if (
|
||||||
|
conversation.name == "vicuna_v1.1"
|
||||||
|
and "roles" in prompt
|
||||||
|
and len(prompt["roles"]) >= 2
|
||||||
|
):
|
||||||
|
role_remap = [
|
||||||
|
{"from": conversation.roles[0], "to": prompt["roles"][0]},
|
||||||
|
{"from": conversation.roles[1], "to": prompt["roles"][1]},
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for _, part in enumerate(
|
for _, part in enumerate(
|
||||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||||
):
|
):
|
||||||
if isinstance(part, tuple):
|
if isinstance(part, tuple):
|
||||||
if conversation.roles[0] in part[0]:
|
if conversation.roles[0] in part[0]:
|
||||||
turn = part[0] + part[1] if not user_token else part[1]
|
role = (
|
||||||
|
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||||
|
if role_remap
|
||||||
|
else part[0]
|
||||||
|
)
|
||||||
|
turn = role + part[1]
|
||||||
# this is still the user query, we should
|
# this is still the user query, we should
|
||||||
if not part[1].strip():
|
if not part[1].strip():
|
||||||
LOG.warning(f"user turn has empty text: {prompt}")
|
LOG.warning(f"user turn has empty text: {prompt}")
|
||||||
@@ -376,13 +371,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
add_eos_token=False,
|
add_eos_token=False,
|
||||||
strip_bos_token=True,
|
strip_bos_token=True,
|
||||||
)
|
)
|
||||||
if user_token:
|
|
||||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
elif conversation.roles[1] in part[0]:
|
elif conversation.roles[1] in part[0]:
|
||||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||||
turn = part[0] + part[1] if not assistant_token else part[1]
|
role = (
|
||||||
|
part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||||
|
if role_remap
|
||||||
|
else part[0]
|
||||||
|
)
|
||||||
|
turn = role + part[1]
|
||||||
# this should be the assistant response, should end with an eos token
|
# this should be the assistant response, should end with an eos token
|
||||||
if not part[1].strip():
|
if not part[1].strip():
|
||||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||||
@@ -391,13 +389,17 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
add_eos_token=True,
|
add_eos_token=True,
|
||||||
strip_bos_token=True,
|
strip_bos_token=True,
|
||||||
)
|
)
|
||||||
if assistant_token:
|
role_res = self._tokenize(
|
||||||
res["input_ids"] = [
|
role.rstrip(),
|
||||||
assistant_token,
|
add_eos_token=False,
|
||||||
*res["input_ids"],
|
strip_bos_token=True,
|
||||||
]
|
)
|
||||||
# not masked out from labels
|
# not masked out from labels
|
||||||
labels = copy.deepcopy(res["input_ids"])
|
labels = copy.deepcopy(res["input_ids"])
|
||||||
|
len_role = len(role_res["input_ids"])
|
||||||
|
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
||||||
|
len_role, len(labels)
|
||||||
|
)
|
||||||
elif part[0] == "":
|
elif part[0] == "":
|
||||||
turn = part[1]
|
turn = part[1]
|
||||||
# this is only ever the first part, should include the bos token and the user query
|
# this is only ever the first part, should include the bos token and the user query
|
||||||
|
|||||||
@@ -274,9 +274,11 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
raise err
|
raise err
|
||||||
|
|
||||||
conv.messages = []
|
conv.messages = []
|
||||||
for j, sentence in enumerate(source):
|
for _, sentence in enumerate(source):
|
||||||
role = roles[sentence["from"]]
|
role = roles[sentence["from"]]
|
||||||
if role != conv.roles[j % 2]:
|
if len(conv.messages) > 0 and (
|
||||||
|
(role == conv.messages[-1][0]) or (role not in conv.roles)
|
||||||
|
):
|
||||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
|
|||||||
2
tests/fixtures/conversation.tokenized.json
vendored
2
tests/fixtures/conversation.tokenized.json
vendored
File diff suppressed because one or more lines are too long
@@ -90,6 +90,73 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
strat.tokenize_prompt(conversation)
|
strat.tokenize_prompt(conversation)
|
||||||
assert "assistant turn has empty text" in self._caplog.records[1].message
|
assert "assistant turn has empty text" in self._caplog.records[1].message
|
||||||
|
|
||||||
|
def test_sharegpt_warnings_turns(self):
|
||||||
|
conversation = {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "dolor"},
|
||||||
|
{"from": "human", "value": "dolor"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
prompter = ShareGPTPrompterV2()
|
||||||
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
|
prompter,
|
||||||
|
self.tokenizer,
|
||||||
|
False,
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
strat.tokenize_prompt(conversation)
|
||||||
|
assert (
|
||||||
|
"Role did not alternate between turns (gpt and human)"
|
||||||
|
in self._caplog.records[0].message
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_sharegpt_changes_roles(self):
|
||||||
|
conversation = {
|
||||||
|
"roles": ["USER", "CHARACTER"],
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "dolor"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
prompter = ShareGPTPrompterV2()
|
||||||
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
|
prompter,
|
||||||
|
self.tokenizer,
|
||||||
|
False,
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
res = strat.tokenize_prompt(conversation)
|
||||||
|
assert "CHARACTER" in self.tokenizer.decode(res["input_ids"])
|
||||||
|
|
||||||
|
def test_sharegpt_assistant_label_ignore(self):
|
||||||
|
conversation = {
|
||||||
|
"roles": ["user", "assistant"],
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "dolor"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
prompter = ShareGPTPrompterV2()
|
||||||
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
|
prompter,
|
||||||
|
self.tokenizer,
|
||||||
|
False,
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
res = strat.tokenize_prompt(conversation)
|
||||||
|
idx = res["input_ids"].index(20255) # assistant token
|
||||||
|
assert res["labels"][idx] == -100
|
||||||
|
|
||||||
def test_no_sys_prompt(self):
|
def test_no_sys_prompt(self):
|
||||||
"""
|
"""
|
||||||
tests the interface between the user and assistant parts
|
tests the interface between the user and assistant parts
|
||||||
|
|||||||
Reference in New Issue
Block a user