From cb78a36374252333f1b1534c79d5fee946aec73e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 7 May 2024 11:30:40 -0400 Subject: [PATCH] improve tool handling roles (#1587) --- src/axolotl/prompt_strategies/sharegpt.py | 121 +++++++++------------- src/axolotl/prompters.py | 5 +- 2 files changed, 54 insertions(+), 72 deletions(-) diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 55bdd37b4..b556b3583 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -1,7 +1,7 @@ """Module containing the SimpleShareGPTPromptTokenizingStrategy class""" import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Type from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template @@ -39,76 +39,40 @@ def register_chatml_template(system_message=None): ) -def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - conversation = ( - ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None - ) - field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None - field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None - roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None - strategy = SimpleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation=conversation, - role_key_model=field_model, - role_key_human=field_human, - roles=roles, - ), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - if ds_cfg and "strict" in ds_cfg: - strategy.strict = ds_cfg["strict"] - return strategy +def build_loader( + tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"], + prompter_cls: Type["ShareGPTPrompterV2"], + default_conversation: Optional[str] = None, +): + def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + conversation = ( + ds_cfg["conversation"] + if ds_cfg and "conversation" in ds_cfg + else default_conversation + ) + field_human = ( + ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None + ) + field_model = ( + ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None + ) + roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None + strategy = tokenization_strategy_cls( + prompter_cls( + conversation=conversation, + role_key_model=field_model, + role_key_human=field_human, + roles=roles, + ), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"): + strategy.strict = ds_cfg["strict"] + return strategy - -def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - conversation = ( - ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None - ) - strategy = UltrachatShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2( - conversation=conversation, - ), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - if ds_cfg and "strict" in ds_cfg: - strategy.strict = ds_cfg["strict"] - return strategy - - -def load_role(tokenizer, cfg): - return SimpleRoleShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_guanaco(tokenizer, cfg): - return GuanacoShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - - -def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): - conversation = ( - ds_cfg["conversation"] - if ds_cfg and "conversation" in ds_cfg - else "chatml_glaive" - ) - return GlaiveShareGPTPromptTokenizingStrategy( - ShareGPTPrompterV2(conversation=conversation), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) + return _load class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): @@ -158,7 +122,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): return turns -class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): +class SimpleRoleShareGPTPromptTokenizingStrategy( + SimpleShareGPTPromptTokenizingStrategy +): """ basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from """ @@ -209,3 +175,16 @@ class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrat conversation = merge_consecutive_messages(conversation) return conversation + + +load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2) +load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2) +load_ultrachat = build_loader( + UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2 +) +load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2) +load_glaive = build_loader( + GlaiveShareGPTPromptTokenizingStrategy, + ShareGPTPrompterV2, + default_conversation="chatml_glaive", +) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 2b6b4f857..7a089c0ec 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -348,7 +348,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods ) if len(conv.messages) > 0 and ((role == conv.messages[-1][0])): - LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") + if ( + role != "assistant" + ): # back to back assistant calls may be okay for tool calls + LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") conv.append_message(role, sentence["value"])