Feat: Add sharegpt multirole (#1137)
* feat(prompt): support multiple roles for sharegpt * fix: add handling of empty role back * feat: rebased and allowed more dynamic roles via config * fix: variable * chore: update message * feat: add vicuna format * fix: JSON serializable error * fix: typing * fix: don't remap for unknown keys * fix: add roles to pydantic * feat: add test * chore: remove leftover print * chore: remove leftover comment * chore: remove print * fix: update test to use chatml
This commit is contained in:
@@ -651,9 +651,13 @@ datasets:
|
|||||||
train_on_split: train # Optional[str] name of dataset split to load from
|
train_on_split: train # Optional[str] name of dataset split to load from
|
||||||
|
|
||||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
field_human: # Optional[str]. Human key to use for conversation.
|
field_human: # Optional[str]. Human key to use for conversation.
|
||||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||||
|
# Add additional keys from your dataset as input or output roles
|
||||||
|
roles:
|
||||||
|
input: # Optional[List[str]]. These will be masked based on train_on_input
|
||||||
|
output: # Optional[List[str]].
|
||||||
|
|
||||||
# Custom user instruction prompt
|
# Custom user instruction prompt
|
||||||
- path: repo
|
- path: repo
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||||
@@ -11,6 +12,8 @@ from axolotl.utils.tokenization import (
|
|||||||
merge_consecutive_messages,
|
merge_consecutive_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def register_chatml_template(system_message=None):
|
def register_chatml_template(system_message=None):
|
||||||
system_message = system_message or "You are a helpful assistant."
|
system_message = system_message or "You are a helpful assistant."
|
||||||
@@ -42,11 +45,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
)
|
)
|
||||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
|
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompterV2(
|
ShareGPTPrompterV2(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
role_key_model=field_model,
|
role_key_model=field_model,
|
||||||
role_key_human=field_human,
|
role_key_human=field_human,
|
||||||
|
roles=roles,
|
||||||
),
|
),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
@@ -142,7 +147,12 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
"system": "system",
|
"system": "system",
|
||||||
}
|
}
|
||||||
turns = [
|
turns = [
|
||||||
{"from": role_map[t[role_key]], "value": t[value_key]}
|
{
|
||||||
|
"from": (
|
||||||
|
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
|
||||||
|
),
|
||||||
|
"value": t[value_key],
|
||||||
|
}
|
||||||
for t in conversations
|
for t in conversations
|
||||||
]
|
]
|
||||||
return turns
|
return turns
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from transformers import BatchEncoding, PreTrainedTokenizer
|
|||||||
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
||||||
add_get_turns_to_conversation,
|
add_get_turns_to_conversation,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prompter,
|
prompter: Prompter,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
train_on_inputs: bool = False,
|
train_on_inputs: bool = False,
|
||||||
sequence_len: int = 2048,
|
sequence_len: int = 2048,
|
||||||
@@ -340,6 +340,23 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_roles = {conversation.roles[0]}
|
||||||
|
output_roles = {conversation.roles[1]}
|
||||||
|
|
||||||
|
if len(conversation.roles) == 3:
|
||||||
|
tool_role_label = conversation.roles[2]
|
||||||
|
input_roles.add(tool_role_label)
|
||||||
|
|
||||||
|
# Add roles from the config
|
||||||
|
if self.prompter.roles:
|
||||||
|
if "input" in self.prompter.roles and self.prompter.roles["input"]:
|
||||||
|
for role in self.prompter.roles["input"]:
|
||||||
|
input_roles.add(role)
|
||||||
|
|
||||||
|
if "output" in self.prompter.roles and self.prompter.roles["output"]:
|
||||||
|
for role in self.prompter.roles["output"]:
|
||||||
|
output_roles.add(role)
|
||||||
|
|
||||||
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
||||||
role_remap = []
|
role_remap = []
|
||||||
if (
|
if (
|
||||||
@@ -360,19 +377,18 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
LOG.warning(f"expected tuple, got {part}")
|
LOG.warning(f"expected tuple, got {part}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tool_role_label = None
|
|
||||||
if len(conversation.roles) == 3:
|
|
||||||
(
|
|
||||||
user_role_label,
|
|
||||||
assistant_role_label,
|
|
||||||
tool_role_label,
|
|
||||||
) = conversation.roles
|
|
||||||
else:
|
|
||||||
user_role_label, assistant_role_label = conversation.roles
|
|
||||||
role, content = part
|
role, content = part
|
||||||
|
|
||||||
# Uses "in" because role contains extra characters
|
# Uses "in" because role contains extra characters
|
||||||
if user_role_label in role:
|
input_turn = any(r.lower() in role.lower() for r in input_roles)
|
||||||
|
output_turn = any(r.lower() in role.lower() for r in output_roles)
|
||||||
|
empty_role = role.strip() == ""
|
||||||
|
|
||||||
|
if not any([input_turn, output_turn, empty_role]):
|
||||||
|
LOG.warning(f"unhandled role: {role}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if input_turn:
|
||||||
role = (
|
role = (
|
||||||
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||||
if role_remap
|
if role_remap
|
||||||
@@ -392,7 +408,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
else:
|
else:
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
elif assistant_role_label in role:
|
elif output_turn:
|
||||||
role = (
|
role = (
|
||||||
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||||
if role_remap
|
if role_remap
|
||||||
@@ -423,7 +439,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
||||||
len_role, len(labels)
|
len_role, len(labels)
|
||||||
)
|
)
|
||||||
elif role == "":
|
elif empty_role:
|
||||||
turn = content
|
turn = content
|
||||||
# this is only ever the first part, should include the bos token and the user query
|
# this is only ever the first part, should include the bos token and the user query
|
||||||
res = self._tokenize(
|
res = self._tokenize(
|
||||||
@@ -434,11 +450,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
else:
|
else:
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
elif tool_role_label and tool_role_label in role:
|
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
|
||||||
else:
|
|
||||||
LOG.warning(f"unhandled role: {role}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
result, current_len = parse_tokenized_to_result(
|
result, current_len = parse_tokenized_to_result(
|
||||||
|
|||||||
@@ -259,6 +259,12 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
|
|||||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CONVERSATION_ROLE_FORMAT = {
|
||||||
|
"chatml": "<|im_start|>{ROLE}",
|
||||||
|
"zephyr": "<|{ROLE}|>",
|
||||||
|
"vicuna_v1.1": "{ROLE}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||||
"""
|
"""
|
||||||
@@ -268,7 +274,9 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
role_key_human = "human"
|
role_key_human = "human"
|
||||||
role_key_model = "gpt"
|
role_key_model = "gpt"
|
||||||
# Optional, only used for tool usage datasets.
|
# Optional, only used for tool usage datasets.
|
||||||
role_key_tool = None
|
role_key_tool: Optional[str] = None
|
||||||
|
# Optional, role input/output mapping
|
||||||
|
roles: Optional[dict] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -277,6 +285,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
role_key_human: Optional[str] = None,
|
role_key_human: Optional[str] = None,
|
||||||
role_key_model: Optional[str] = None,
|
role_key_model: Optional[str] = None,
|
||||||
role_key_tool: Optional[str] = None,
|
role_key_tool: Optional[str] = None,
|
||||||
|
roles: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if conversation:
|
if conversation:
|
||||||
if isinstance(conversation, Conversation):
|
if isinstance(conversation, Conversation):
|
||||||
@@ -291,6 +300,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
self.role_key_model = role_key_model
|
self.role_key_model = role_key_model
|
||||||
if role_key_tool:
|
if role_key_tool:
|
||||||
self.role_key_tool = role_key_tool
|
self.role_key_tool = role_key_tool
|
||||||
|
if roles:
|
||||||
|
self.roles = roles
|
||||||
|
|
||||||
def _build_result(self, source):
|
def _build_result(self, source):
|
||||||
if len(source) < 2:
|
if len(source) < 2:
|
||||||
@@ -322,11 +333,23 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
|
|
||||||
conv.messages = []
|
conv.messages = []
|
||||||
for _, sentence in enumerate(source):
|
for _, sentence in enumerate(source):
|
||||||
role = roles[sentence["from"]]
|
from_role = sentence["from"]
|
||||||
if len(conv.messages) > 0 and (
|
if from_role in roles:
|
||||||
(role == conv.messages[-1][0]) or (role not in conv.roles)
|
role = roles[from_role]
|
||||||
):
|
else:
|
||||||
|
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
|
||||||
|
"Please help us by creating an Issue to add support for this conversation type."
|
||||||
|
)
|
||||||
|
|
||||||
|
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
|
||||||
|
ROLE=from_role
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
||||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||||
|
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
return conv.get_turns()
|
return conv.get_turns()
|
||||||
@@ -354,11 +377,13 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
|||||||
conversation: Optional[Union[str, Conversation]] = None,
|
conversation: Optional[Union[str, Conversation]] = None,
|
||||||
role_key_human: Optional[str] = None,
|
role_key_human: Optional[str] = None,
|
||||||
role_key_model: Optional[str] = None,
|
role_key_model: Optional[str] = None,
|
||||||
|
roles: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
role_key_human=role_key_human,
|
role_key_human=role_key_human,
|
||||||
role_key_model=role_key_model,
|
role_key_model=role_key_model,
|
||||||
|
roles=roles,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -96,6 +96,8 @@ class SFTDataset(BaseModel):
|
|||||||
field_human: Optional[str] = None
|
field_human: Optional[str] = None
|
||||||
field_model: Optional[str] = None
|
field_model: Optional[str] = None
|
||||||
|
|
||||||
|
roles: Optional[Dict[str, List[str]]] = None
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedDPOType(BaseModel):
|
class UserDefinedDPOType(BaseModel):
|
||||||
"""User defined typing for DPO"""
|
"""User defined typing for DPO"""
|
||||||
|
|||||||
@@ -62,6 +62,38 @@ def fixture_sharegpt_glaive_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="multi_role_dataset")
|
||||||
|
def fixture_multi_role_dataset():
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{
|
||||||
|
"from": "system",
|
||||||
|
"value": "use get_weather(city) to get the weather for a city",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "hello, what's the weather in New York?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "let me get that for you",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "tool",
|
||||||
|
"value": "get_weather(New York)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "gpt",
|
||||||
|
"value": "the weather in New York is 70 degrees and sunny",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="tokenizer")
|
@pytest.fixture(name="tokenizer")
|
||||||
def fixture_tokenizer():
|
def fixture_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||||
@@ -196,3 +228,39 @@ class TestSharegpt:
|
|||||||
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
|
||||||
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
|
||||||
|
tokenizer,
|
||||||
|
False, # train_on_inputs
|
||||||
|
2048, # sequence_len
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
|
strategy, multi_role_dataset, process_count=1
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = dataset_wrapper[0]["input_ids"]
|
||||||
|
# fmt: off
|
||||||
|
assert input_ids == [
|
||||||
|
1, # bos
|
||||||
|
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
|
||||||
|
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
|
||||||
|
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
||||||
|
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
|
||||||
|
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
labels = dataset_wrapper[0]["labels"]
|
||||||
|
# fmt: off
|
||||||
|
assert labels == [
|
||||||
|
-100, # bos
|
||||||
|
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
|
||||||
|
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
|
||||||
|
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
||||||
|
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
|
||||||
|
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|||||||
Reference in New Issue
Block a user