Feat: Add sharegpt multirole (#1137)

* feat(prompt): support multiple roles for sharegpt

* fix: add handling of empty role back

* feat: rebased and allowed more dynamic roles via config

* fix: variable

* chore: update message

* feat: add vicuna format

* fix: JSON serializable error

* fix: typing

* fix: don't remap for unknown keys

* fix: add roles to pydantic

* feat: add test

* chore: remove leftover print

* chore: remove leftover comment

* chore: remove print

* fix: update test to use chatml
This commit is contained in:
NanoCode012
2024-03-19 20:51:49 +09:00
committed by GitHub
parent 43bdc5d3de
commit 40a88e8c4a
6 changed files with 146 additions and 26 deletions

View File

@@ -651,9 +651,13 @@ datasets:
train_on_split: train # Optional[str] name of dataset split to load from train_on_split: train # Optional[str] name of dataset split to load from
# Optional[str] fastchat conversation type, only used with type: sharegpt # Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation. field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation. field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].
# Custom user instruction prompt # Custom user instruction prompt
- path: repo - path: repo

View File

@@ -1,5 +1,6 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -11,6 +12,8 @@ from axolotl.utils.tokenization import (
merge_consecutive_messages, merge_consecutive_messages,
) )
LOG = logging.getLogger("axolotl")
def register_chatml_template(system_message=None): def register_chatml_template(system_message=None):
system_message = system_message or "You are a helpful assistant." system_message = system_message or "You are a helpful assistant."
@@ -42,11 +45,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
) )
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = SimpleShareGPTPromptTokenizingStrategy( strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2( ShareGPTPrompterV2(
conversation=conversation, conversation=conversation,
role_key_model=field_model, role_key_model=field_model,
role_key_human=field_human, role_key_human=field_human,
roles=roles,
), ),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
@@ -142,7 +147,12 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"system": "system", "system": "system",
} }
turns = [ turns = [
{"from": role_map[t[role_key]], "value": t[value_key]} {
"from": (
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
),
"value": t[value_key],
}
for t in conversations for t in conversations
] ]
return turns return turns

View File

@@ -11,7 +11,7 @@ from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.monkeypatch.fastchat_conversation_turns import ( from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation, add_get_turns_to_conversation,
) )
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):
def __init__( def __init__(
self, self,
prompter, prompter: Prompter,
tokenizer, tokenizer,
train_on_inputs: bool = False, train_on_inputs: bool = False,
sequence_len: int = 2048, sequence_len: int = 2048,
@@ -340,6 +340,23 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
self.prompter._conversation.copy() # pylint: disable=protected-access self.prompter._conversation.copy() # pylint: disable=protected-access
) )
input_roles = {conversation.roles[0]}
output_roles = {conversation.roles[1]}
if len(conversation.roles) == 3:
tool_role_label = conversation.roles[2]
input_roles.add(tool_role_label)
# Add roles from the config
if self.prompter.roles:
if "input" in self.prompter.roles and self.prompter.roles["input"]:
for role in self.prompter.roles["input"]:
input_roles.add(role)
if "output" in self.prompter.roles and self.prompter.roles["output"]:
for role in self.prompter.roles["output"]:
output_roles.add(role)
# support for custom roles from the dataset, only useful for vicuna style prompts/roles # support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = [] role_remap = []
if ( if (
@@ -360,19 +377,18 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
LOG.warning(f"expected tuple, got {part}") LOG.warning(f"expected tuple, got {part}")
continue continue
tool_role_label = None
if len(conversation.roles) == 3:
(
user_role_label,
assistant_role_label,
tool_role_label,
) = conversation.roles
else:
user_role_label, assistant_role_label = conversation.roles
role, content = part role, content = part
# Uses "in" because role contains extra characters # Uses "in" because role contains extra characters
if user_role_label in role: input_turn = any(r.lower() in role.lower() for r in input_roles)
output_turn = any(r.lower() in role.lower() for r in output_roles)
empty_role = role.strip() == ""
if not any([input_turn, output_turn, empty_role]):
LOG.warning(f"unhandled role: {role}")
continue
if input_turn:
role = ( role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"]) role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap if role_remap
@@ -392,7 +408,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
else: else:
# everything from this is masked out from the labels # everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant_role_label in role: elif output_turn:
role = ( role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"]) role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap if role_remap
@@ -423,7 +439,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
labels[:len_role] = [IGNORE_TOKEN_ID] * min( labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels) len_role, len(labels)
) )
elif role == "": elif empty_role:
turn = content turn = content
# this is only ever the first part, should include the bos token and the user query # this is only ever the first part, should include the bos token and the user query
res = self._tokenize( res = self._tokenize(
@@ -434,11 +450,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
else: else:
# everything from this is masked out from the labels # everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif tool_role_label and tool_role_label in role:
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result( result, current_len = parse_tokenized_to_result(

View File

@@ -259,6 +259,12 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data." "Role did not alternate between turns (gpt and human). Please check your data."
) )
CONVERSATION_ROLE_FORMAT = {
"chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>",
"vicuna_v1.1": "{ROLE}",
}
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
""" """
@@ -268,7 +274,9 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
role_key_human = "human" role_key_human = "human"
role_key_model = "gpt" role_key_model = "gpt"
# Optional, only used for tool usage datasets. # Optional, only used for tool usage datasets.
role_key_tool = None role_key_tool: Optional[str] = None
# Optional, role input/output mapping
roles: Optional[dict] = None
def __init__( def __init__(
self, self,
@@ -277,6 +285,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
role_key_human: Optional[str] = None, role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None, role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None, role_key_tool: Optional[str] = None,
roles: Optional[dict] = None,
): ):
if conversation: if conversation:
if isinstance(conversation, Conversation): if isinstance(conversation, Conversation):
@@ -291,6 +300,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
self.role_key_model = role_key_model self.role_key_model = role_key_model
if role_key_tool: if role_key_tool:
self.role_key_tool = role_key_tool self.role_key_tool = role_key_tool
if roles:
self.roles = roles
def _build_result(self, source): def _build_result(self, source):
if len(source) < 2: if len(source) < 2:
@@ -322,11 +333,23 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
conv.messages = [] conv.messages = []
for _, sentence in enumerate(source): for _, sentence in enumerate(source):
role = roles[sentence["from"]] from_role = sentence["from"]
if len(conv.messages) > 0 and ( if from_role in roles:
(role == conv.messages[-1][0]) or (role not in conv.roles) role = roles[from_role]
): else:
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
raise NotImplementedError(
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
"Please help us by creating an Issue to add support for this conversation type."
)
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
ROLE=from_role
)
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
return conv.get_turns() return conv.get_turns()
@@ -354,11 +377,13 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
conversation: Optional[Union[str, Conversation]] = None, conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None, role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None, role_key_model: Optional[str] = None,
roles: Optional[dict] = None,
): ):
super().__init__( super().__init__(
conversation=conversation, conversation=conversation,
role_key_human=role_key_human, role_key_human=role_key_human,
role_key_model=role_key_model, role_key_model=role_key_model,
roles=roles,
) )

View File

@@ -96,6 +96,8 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None field_human: Optional[str] = None
field_model: Optional[str] = None field_model: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
class UserDefinedDPOType(BaseModel): class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO""" """User defined typing for DPO"""

View File

@@ -62,6 +62,38 @@ def fixture_sharegpt_glaive_dataset():
) )
@pytest.fixture(name="multi_role_dataset")
def fixture_multi_role_dataset():
return Dataset.from_list(
[
{
"conversations": [
{
"from": "system",
"value": "use get_weather(city) to get the weather for a city",
},
{
"from": "human",
"value": "hello, what's the weather in New York?",
},
{
"from": "gpt",
"value": "let me get that for you",
},
{
"from": "tool",
"value": "get_weather(New York)",
},
{
"from": "gpt",
"value": "the weather in New York is 70 degrees and sunny",
},
]
}
]
)
@pytest.fixture(name="tokenizer") @pytest.fixture(name="tokenizer")
def fixture_tokenizer(): def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
@@ -196,3 +228,39 @@ class TestSharegpt:
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt 32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
] ]
# fmt: on # fmt: on
def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
tokenizer,
False, # train_on_inputs
2048, # sequence_len
)
dataset_wrapper = TokenizedPromptDataset(
strategy, multi_role_dataset, process_count=1
)
input_ids = dataset_wrapper[0]["input_ids"]
# fmt: off
assert input_ids == [
1, # bos
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
]
# fmt: on
labels = dataset_wrapper[0]["labels"]
# fmt: off
assert labels == [
-100, # bos
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
]
# fmt: on