Feat: Add sharegpt multirole (#1137)

* feat(prompt): support multiple roles for sharegpt

* fix: add handling of empty role back

* feat: rebased and allowed more dynamic roles via config

* fix: variable

* chore: update message

* feat: add vicuna format

* fix: JSON serializable error

* fix: typing

* fix: don't remap for unknown keys

* fix: add roles to pydantic

* feat: add test

* chore: remove leftover print

* chore: remove leftover comment

* chore: remove print

* fix: update test to use chatml
This commit is contained in:
NanoCode012
2024-03-19 20:51:49 +09:00
committed by GitHub
parent 43bdc5d3de
commit 40a88e8c4a
6 changed files with 146 additions and 26 deletions

View File

@@ -651,9 +651,13 @@ datasets:
train_on_split: train # Optional[str] name of dataset split to load from
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].
# Custom user instruction prompt
- path: repo

View File

@@ -1,5 +1,6 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging
from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -11,6 +12,8 @@ from axolotl.utils.tokenization import (
merge_consecutive_messages,
)
LOG = logging.getLogger("axolotl")
def register_chatml_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
@@ -42,11 +45,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
@@ -142,7 +147,12 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"system": "system",
}
turns = [
{"from": role_map[t[role_key]], "value": t[value_key]}
{
"from": (
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
),
"value": t[value_key],
}
for t in conversations
]
return turns

View File

@@ -11,7 +11,7 @@ from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
LOG = logging.getLogger("axolotl")
@@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):
def __init__(
self,
prompter,
prompter: Prompter,
tokenizer,
train_on_inputs: bool = False,
sequence_len: int = 2048,
@@ -340,6 +340,23 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
self.prompter._conversation.copy() # pylint: disable=protected-access
)
input_roles = {conversation.roles[0]}
output_roles = {conversation.roles[1]}
if len(conversation.roles) == 3:
tool_role_label = conversation.roles[2]
input_roles.add(tool_role_label)
# Add roles from the config
if self.prompter.roles:
if "input" in self.prompter.roles and self.prompter.roles["input"]:
for role in self.prompter.roles["input"]:
input_roles.add(role)
if "output" in self.prompter.roles and self.prompter.roles["output"]:
for role in self.prompter.roles["output"]:
output_roles.add(role)
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
@@ -360,19 +377,18 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
LOG.warning(f"expected tuple, got {part}")
continue
tool_role_label = None
if len(conversation.roles) == 3:
(
user_role_label,
assistant_role_label,
tool_role_label,
) = conversation.roles
else:
user_role_label, assistant_role_label = conversation.roles
role, content = part
# Uses "in" because role contains extra characters
if user_role_label in role:
input_turn = any(r.lower() in role.lower() for r in input_roles)
output_turn = any(r.lower() in role.lower() for r in output_roles)
empty_role = role.strip() == ""
if not any([input_turn, output_turn, empty_role]):
LOG.warning(f"unhandled role: {role}")
continue
if input_turn:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
@@ -392,7 +408,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant_role_label in role:
elif output_turn:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
@@ -423,7 +439,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
elif role == "":
elif empty_role:
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
@@ -434,11 +450,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif tool_role_label and tool_role_label in role:
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(

View File

@@ -259,6 +259,12 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
CONVERSATION_ROLE_FORMAT = {
"chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>",
"vicuna_v1.1": "{ROLE}",
}
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
"""
@@ -268,7 +274,9 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
role_key_human = "human"
role_key_model = "gpt"
# Optional, only used for tool usage datasets.
role_key_tool = None
role_key_tool: Optional[str] = None
# Optional, role input/output mapping
roles: Optional[dict] = None
def __init__(
self,
@@ -277,6 +285,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
roles: Optional[dict] = None,
):
if conversation:
if isinstance(conversation, Conversation):
@@ -291,6 +300,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
self.role_key_model = role_key_model
if role_key_tool:
self.role_key_tool = role_key_tool
if roles:
self.roles = roles
def _build_result(self, source):
if len(source) < 2:
@@ -322,11 +333,23 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
conv.messages = []
for _, sentence in enumerate(source):
role = roles[sentence["from"]]
if len(conv.messages) > 0 and (
(role == conv.messages[-1][0]) or (role not in conv.roles)
):
from_role = sentence["from"]
if from_role in roles:
role = roles[from_role]
else:
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
raise NotImplementedError(
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
"Please help us by creating an Issue to add support for this conversation type."
)
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
ROLE=from_role
)
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"])
return conv.get_turns()
@@ -354,11 +377,13 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
roles: Optional[dict] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
roles=roles,
)

View File

@@ -96,6 +96,8 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None
field_model: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""

View File

@@ -62,6 +62,38 @@ def fixture_sharegpt_glaive_dataset():
)
@pytest.fixture(name="multi_role_dataset")
def fixture_multi_role_dataset():
return Dataset.from_list(
[
{
"conversations": [
{
"from": "system",
"value": "use get_weather(city) to get the weather for a city",
},
{
"from": "human",
"value": "hello, what's the weather in New York?",
},
{
"from": "gpt",
"value": "let me get that for you",
},
{
"from": "tool",
"value": "get_weather(New York)",
},
{
"from": "gpt",
"value": "the weather in New York is 70 degrees and sunny",
},
]
}
]
)
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
@@ -196,3 +228,39 @@ class TestSharegpt:
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
]
# fmt: on
def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
tokenizer,
False, # train_on_inputs
2048, # sequence_len
)
dataset_wrapper = TokenizedPromptDataset(
strategy, multi_role_dataset, process_count=1
)
input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
assert input_ids == [
1, # bos
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
]
# fmt: on
labels = dataset_wrapper[0]["labels"]
# fmt: off
assert labels == [
-100, # bos
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
]
# fmt: on