improve tool handling roles (#1587)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Type
|
||||||
|
|
||||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||||
|
|
||||||
@@ -39,76 +39,40 @@ def register_chatml_template(system_message=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def build_loader(
|
||||||
conversation = (
|
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
||||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
prompter_cls: Type["ShareGPTPrompterV2"],
|
||||||
)
|
default_conversation: Optional[str] = None,
|
||||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
):
|
||||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
conversation = (
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
ds_cfg["conversation"]
|
||||||
ShareGPTPrompterV2(
|
if ds_cfg and "conversation" in ds_cfg
|
||||||
conversation=conversation,
|
else default_conversation
|
||||||
role_key_model=field_model,
|
)
|
||||||
role_key_human=field_human,
|
field_human = (
|
||||||
roles=roles,
|
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
),
|
)
|
||||||
tokenizer,
|
field_model = (
|
||||||
cfg.train_on_inputs,
|
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
cfg.sequence_len,
|
)
|
||||||
)
|
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
||||||
if ds_cfg and "strict" in ds_cfg:
|
strategy = tokenization_strategy_cls(
|
||||||
strategy.strict = ds_cfg["strict"]
|
prompter_cls(
|
||||||
return strategy
|
conversation=conversation,
|
||||||
|
role_key_model=field_model,
|
||||||
|
role_key_human=field_human,
|
||||||
|
roles=roles,
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
|
||||||
|
strategy.strict = ds_cfg["strict"]
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
return _load
|
||||||
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
conversation = (
|
|
||||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
|
||||||
)
|
|
||||||
strategy = UltrachatShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation=conversation,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
if ds_cfg and "strict" in ds_cfg:
|
|
||||||
strategy.strict = ds_cfg["strict"]
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
|
|
||||||
def load_role(tokenizer, cfg):
|
|
||||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_guanaco(tokenizer, cfg):
|
|
||||||
return GuanacoShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
conversation = (
|
|
||||||
ds_cfg["conversation"]
|
|
||||||
if ds_cfg and "conversation" in ds_cfg
|
|
||||||
else "chatml_glaive"
|
|
||||||
)
|
|
||||||
return GlaiveShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(conversation=conversation),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
@@ -158,7 +122,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
|
||||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
class SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||||
|
SimpleShareGPTPromptTokenizingStrategy
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||||
"""
|
"""
|
||||||
@@ -209,3 +175,16 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat
|
|||||||
conversation = merge_consecutive_messages(conversation)
|
conversation = merge_consecutive_messages(conversation)
|
||||||
|
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
|
|
||||||
|
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||||
|
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||||
|
load_ultrachat = build_loader(
|
||||||
|
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
|
||||||
|
)
|
||||||
|
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||||
|
load_glaive = build_loader(
|
||||||
|
GlaiveShareGPTPromptTokenizingStrategy,
|
||||||
|
ShareGPTPrompterV2,
|
||||||
|
default_conversation="chatml_glaive",
|
||||||
|
)
|
||||||
|
|||||||
@@ -348,7 +348,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
||||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
if (
|
||||||
|
role != "assistant"
|
||||||
|
): # back to back assistant calls may be okay for tool calls
|
||||||
|
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||||
|
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user