Feat: add tool calling support via tools column (#2774)
* feat: add tool_calling field support * fix: add tests
This commit is contained in:
@@ -34,6 +34,7 @@ class ChatTemplatePrompter(Prompter):
|
||||
message_field_training_detail: str | None = None,
|
||||
field_messages: str = "messages",
|
||||
field_system: str = "system",
|
||||
field_tools: str = "tools",
|
||||
roles: dict[str, list[str]] | None = None,
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
drop_system_message: bool = False,
|
||||
@@ -66,6 +67,7 @@ class ChatTemplatePrompter(Prompter):
|
||||
self.message_field_training_detail = message_field_training_detail
|
||||
self.field_messages = field_messages
|
||||
self.field_system = field_system
|
||||
self.field_tools = field_tools
|
||||
self.tokenizer = tokenizer
|
||||
self.processor: ProcessorMixin | None = processor
|
||||
self.chat_template = chat_template
|
||||
@@ -77,17 +79,38 @@ class ChatTemplatePrompter(Prompter):
|
||||
def chat_template_msg_variables(self) -> Set[str]:
|
||||
return self._chat_template_msg_variables
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
||||
def build_prompt(
|
||||
self,
|
||||
conversation,
|
||||
add_generation_prompt=False,
|
||||
images=None,
|
||||
tools=None,
|
||||
):
|
||||
"""
|
||||
Build a prompt from a conversation.
|
||||
|
||||
Args:
|
||||
conversation: A list of messages.
|
||||
add_generation_prompt: Whether to add a generation prompt.
|
||||
images: A list of images. (optional)
|
||||
tools: A list of tools. (optional)
|
||||
"""
|
||||
chat_template_kwargs = {
|
||||
"chat_template": self.chat_template,
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
}
|
||||
|
||||
if tools:
|
||||
chat_template_kwargs["tools"] = tools
|
||||
|
||||
if self.processor:
|
||||
if not callable(self.processor):
|
||||
raise TypeError("Processor must be callable")
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
conversation,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**self.chat_template_kwargs,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
batch = self.processor(
|
||||
text=text,
|
||||
@@ -104,9 +127,7 @@ class ChatTemplatePrompter(Prompter):
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
chat_template=self.chat_template,
|
||||
**self.chat_template_kwargs,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
|
||||
def get_offsets_for_train_detail(
|
||||
@@ -376,7 +397,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
and not self.prompter.message_field_training_detail # type: ignore
|
||||
):
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
images = self.get_images(prompt)
|
||||
images = self._get_images(prompt)
|
||||
prompt_ids = self.prompter.build_prompt( # type: ignore
|
||||
turns[:-1],
|
||||
add_generation_prompt=True,
|
||||
@@ -405,7 +426,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return tokenized_prompt
|
||||
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
input_ids = self.prompter.build_prompt(turns) # type: ignore
|
||||
tools = self._get_tools(prompt)
|
||||
input_ids = self.prompter.build_prompt(turns, tools=tools) # type: ignore
|
||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||
|
||||
last_eos_idx = -1
|
||||
@@ -444,7 +466,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
continue
|
||||
|
||||
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
|
||||
turn_start_idx, turn_end_idx = self.find_turn(
|
||||
turns=turns, turn_idx=index, tools=tools
|
||||
)
|
||||
|
||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||
|
||||
@@ -546,7 +570,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return i
|
||||
return -1
|
||||
|
||||
def find_turn(self, turns: list[dict], turn_idx: int):
|
||||
def find_turn(
|
||||
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
|
||||
):
|
||||
"""
|
||||
Locate the starting and ending indices of the specified turn in a conversation.
|
||||
"""
|
||||
@@ -577,10 +603,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
turns_with_content = turns[: turn_idx + 1]
|
||||
|
||||
# Generate the conversation up to the turn, with final turn replaced with dummy content
|
||||
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
|
||||
dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore
|
||||
|
||||
# Generate the conversation up to the turn, with final turn included
|
||||
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
|
||||
full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore
|
||||
|
||||
if not full_ids or not dummy_ids:
|
||||
LOG.warning(f"Empty template generated for turn {turn_idx}")
|
||||
@@ -633,9 +659,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
def get_conversation_thread(self, prompt):
|
||||
turns = []
|
||||
|
||||
possible_sys_turn = self.transform_message(
|
||||
prompt[self.prompter.field_messages][0]
|
||||
)
|
||||
messages = self._get_messages(prompt)
|
||||
|
||||
possible_sys_turn = self.transform_message(messages[0])
|
||||
|
||||
if (
|
||||
possible_sys_turn["role"] != "system"
|
||||
and self.prompter.field_system in prompt
|
||||
@@ -643,7 +670,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
|
||||
turns.append(turn)
|
||||
|
||||
for message in prompt[self.prompter.field_messages]:
|
||||
for message in messages:
|
||||
transformed_message = self.transform_message(message)
|
||||
|
||||
turn = {
|
||||
@@ -661,7 +688,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return turns
|
||||
|
||||
def transform_message(self, message):
|
||||
def transform_message(self, message: dict) -> dict:
|
||||
# Build the initial transformed message from the mappings
|
||||
transformed_message = {}
|
||||
for key, value in self.prompter.message_property_mappings.items():
|
||||
@@ -738,9 +765,36 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return transformed_message
|
||||
|
||||
def get_images(self, prompt):
|
||||
def _get_images(self, prompt):
|
||||
return prompt.get(self.images, None)
|
||||
|
||||
def _get_tools(self, prompt) -> list[dict] | None:
|
||||
"""Get tools from prompt if available."""
|
||||
tools = prompt.get(self.prompter.field_tools, None)
|
||||
if tools is None:
|
||||
return None
|
||||
|
||||
if isinstance(tools, list):
|
||||
return tools
|
||||
|
||||
raise ValueError(
|
||||
"Unknown tools format. Please convert it into a list[dict].\n"
|
||||
f"Current format: {type(tools)}"
|
||||
)
|
||||
|
||||
def _get_messages(self, prompt):
|
||||
messages = prompt.get(self.prompter.field_messages, None)
|
||||
if messages is None:
|
||||
raise ValueError("Messages is null. Please check `field_messages`.")
|
||||
|
||||
if isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
raise ValueError(
|
||||
"Unknown messages format. Please convert it into a list[dict].\n"
|
||||
f"Current format: {type(messages)}"
|
||||
)
|
||||
|
||||
|
||||
class StrategyLoader:
|
||||
"""
|
||||
|
||||
@@ -43,6 +43,7 @@ class SFTDataset(BaseModel):
|
||||
field_human: str | None = None
|
||||
field_model: str | None = None
|
||||
field_messages: str | None = None
|
||||
field_tools: str | None = None
|
||||
# deprecated, use message_property_mappings
|
||||
message_field_role: str | None = None
|
||||
# deprecated, use message_property_mappings
|
||||
|
||||
Reference in New Issue
Block a user