improve tool handling roles (#1587)

This commit is contained in:
Wing Lian
2024-05-07 11:30:40 -04:00
committed by GitHub
parent 8b9c15b17f
commit cb78a36374
2 changed files with 54 additions and 72 deletions

View File

@@ -1,7 +1,7 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" """Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Type
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
@@ -39,76 +39,40 @@ def register_chatml_template(system_message=None):
) )
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def build_loader(
conversation = ( tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None prompter_cls: Type["ShareGPTPrompterV2"],
) default_conversation: Optional[str] = None,
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None ):
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None conversation = (
strategy = SimpleShareGPTPromptTokenizingStrategy( ds_cfg["conversation"]
ShareGPTPrompterV2( if ds_cfg and "conversation" in ds_cfg
conversation=conversation, else default_conversation
role_key_model=field_model, )
role_key_human=field_human, field_human = (
roles=roles, ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
), )
tokenizer, field_model = (
cfg.train_on_inputs, ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
cfg.sequence_len, )
) roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
if ds_cfg and "strict" in ds_cfg: strategy = tokenization_strategy_cls(
strategy.strict = ds_cfg["strict"] prompter_cls(
return strategy conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
return strategy
return _load
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = UltrachatShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg):
return SimpleRoleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else "chatml_glaive"
)
return GlaiveShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(conversation=conversation),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -158,7 +122,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
return turns return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleRoleShareGPTPromptTokenizingStrategy(
SimpleShareGPTPromptTokenizingStrategy
):
""" """
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
""" """
@@ -209,3 +175,16 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
conversation = merge_consecutive_messages(conversation) conversation = merge_consecutive_messages(conversation)
return conversation return conversation
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_ultrachat = build_loader(
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
)
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_glaive = build_loader(
GlaiveShareGPTPromptTokenizingStrategy,
ShareGPTPrompterV2,
default_conversation="chatml_glaive",
)

View File

@@ -348,7 +348,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
) )
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") if (
role != "assistant"
): # back to back assistant calls may be okay for tool calls
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])