Feat: Add sharegpt multirole (#1137)

* feat(prompt): support multiple roles for sharegpt

* fix: add handling of empty role back

* feat: rebased and allowed more dynamic roles via config

* fix: variable

* chore: update message

* feat: add vicuna format

* fix: JSON serializable error

* fix: typing

* fix: don't remap for unknown keys

* fix: add roles to pydantic

* feat: add test

* chore: remove leftover print

* chore: remove leftover comment

* chore: remove print

* fix: update test to use chatml
This commit is contained in:
NanoCode012
2024-03-19 20:51:49 +09:00
committed by GitHub
parent 43bdc5d3de
commit 40a88e8c4a
6 changed files with 146 additions and 26 deletions

View File

@@ -1,5 +1,6 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging
from typing import Any, Dict, Optional
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -11,6 +12,8 @@ from axolotl.utils.tokenization import (
merge_consecutive_messages,
)
LOG = logging.getLogger("axolotl")
def register_chatml_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
@@ -42,11 +45,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
@@ -142,7 +147,12 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"system": "system",
}
turns = [
{"from": role_map[t[role_key]], "value": t[value_key]}
{
"from": (
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
),
"value": t[value_key],
}
for t in conversations
]
return turns

View File

@@ -11,7 +11,7 @@ from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
LOG = logging.getLogger("axolotl")
@@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):
def __init__(
self,
prompter,
prompter: Prompter,
tokenizer,
train_on_inputs: bool = False,
sequence_len: int = 2048,
@@ -340,6 +340,23 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
self.prompter._conversation.copy() # pylint: disable=protected-access
)
input_roles = {conversation.roles[0]}
output_roles = {conversation.roles[1]}
if len(conversation.roles) == 3:
tool_role_label = conversation.roles[2]
input_roles.add(tool_role_label)
# Add roles from the config
if self.prompter.roles:
if "input" in self.prompter.roles and self.prompter.roles["input"]:
for role in self.prompter.roles["input"]:
input_roles.add(role)
if "output" in self.prompter.roles and self.prompter.roles["output"]:
for role in self.prompter.roles["output"]:
output_roles.add(role)
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
@@ -360,19 +377,18 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
LOG.warning(f"expected tuple, got {part}")
continue
tool_role_label = None
if len(conversation.roles) == 3:
(
user_role_label,
assistant_role_label,
tool_role_label,
) = conversation.roles
else:
user_role_label, assistant_role_label = conversation.roles
role, content = part
# Uses "in" because role contains extra characters
if user_role_label in role:
input_turn = any(r.lower() in role.lower() for r in input_roles)
output_turn = any(r.lower() in role.lower() for r in output_roles)
empty_role = role.strip() == ""
if not any([input_turn, output_turn, empty_role]):
LOG.warning(f"unhandled role: {role}")
continue
if input_turn:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
@@ -392,7 +408,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant_role_label in role:
elif output_turn:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
@@ -423,7 +439,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
elif role == "":
elif empty_role:
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
@@ -434,11 +450,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif tool_role_label and tool_role_label in role:
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(

View File

@@ -259,6 +259,12 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
CONVERSATION_ROLE_FORMAT = {
"chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>",
"vicuna_v1.1": "{ROLE}",
}
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
"""
@@ -268,7 +274,9 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
role_key_human = "human"
role_key_model = "gpt"
# Optional, only used for tool usage datasets.
role_key_tool = None
role_key_tool: Optional[str] = None
# Optional, role input/output mapping
roles: Optional[dict] = None
def __init__(
self,
@@ -277,6 +285,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
roles: Optional[dict] = None,
):
if conversation:
if isinstance(conversation, Conversation):
@@ -291,6 +300,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
self.role_key_model = role_key_model
if role_key_tool:
self.role_key_tool = role_key_tool
if roles:
self.roles = roles
def _build_result(self, source):
if len(source) < 2:
@@ -322,11 +333,23 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
conv.messages = []
for _, sentence in enumerate(source):
role = roles[sentence["from"]]
if len(conv.messages) > 0 and (
(role == conv.messages[-1][0]) or (role not in conv.roles)
):
from_role = sentence["from"]
if from_role in roles:
role = roles[from_role]
else:
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
raise NotImplementedError(
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
"Please help us by creating an Issue to add support for this conversation type."
)
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
ROLE=from_role
)
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"])
return conv.get_turns()
@@ -354,11 +377,13 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
roles: Optional[dict] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
roles=roles,
)

View File

@@ -96,6 +96,8 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None
field_model: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""