Feat: add tool calling support via tools column (#2774)

* feat: add tool_calling field support

* fix: add tests
This commit is contained in:
NanoCode012
2025-06-09 21:42:05 -07:00
committed by GitHub
parent 92afa4fa27
commit 83632f71d8
5 changed files with 327 additions and 27 deletions

View File

@@ -34,6 +34,7 @@ class ChatTemplatePrompter(Prompter):
message_field_training_detail: str | None = None,
field_messages: str = "messages",
field_system: str = "system",
field_tools: str = "tools",
roles: dict[str, list[str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
drop_system_message: bool = False,
@@ -66,6 +67,7 @@ class ChatTemplatePrompter(Prompter):
self.message_field_training_detail = message_field_training_detail
self.field_messages = field_messages
self.field_system = field_system
self.field_tools = field_tools
self.tokenizer = tokenizer
self.processor: ProcessorMixin | None = processor
self.chat_template = chat_template
@@ -77,17 +79,38 @@ class ChatTemplatePrompter(Prompter):
def chat_template_msg_variables(self) -> Set[str]:
return self._chat_template_msg_variables
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
def build_prompt(
self,
conversation,
add_generation_prompt=False,
images=None,
tools=None,
):
"""
Build a prompt from a conversation.
Args:
conversation: A list of messages.
add_generation_prompt: Whether to add a generation prompt.
images: A list of images. (optional)
tools: A list of tools. (optional)
"""
chat_template_kwargs = {
"chat_template": self.chat_template,
"add_generation_prompt": add_generation_prompt,
}
if tools:
chat_template_kwargs["tools"] = tools
if self.processor:
if not callable(self.processor):
raise TypeError("Processor must be callable")
text = self.processor.apply_chat_template(
conversation,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
**self.chat_template_kwargs,
**chat_template_kwargs,
)
batch = self.processor(
text=text,
@@ -104,9 +127,7 @@ class ChatTemplatePrompter(Prompter):
return self.tokenizer.apply_chat_template(
conversation,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
**self.chat_template_kwargs,
**chat_template_kwargs,
)
def get_offsets_for_train_detail(
@@ -376,7 +397,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
and not self.prompter.message_field_training_detail # type: ignore
):
turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
images = self._get_images(prompt)
prompt_ids = self.prompter.build_prompt( # type: ignore
turns[:-1],
add_generation_prompt=True,
@@ -405,7 +426,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt
turns = self.get_conversation_thread(prompt)
input_ids = self.prompter.build_prompt(turns) # type: ignore
tools = self._get_tools(prompt)
input_ids = self.prompter.build_prompt(turns, tools=tools) # type: ignore
labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1
@@ -444,7 +466,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
continue
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
turn_start_idx, turn_end_idx = self.find_turn(
turns=turns, turn_idx=index, tools=tools
)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
@@ -546,7 +570,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return i
return -1
def find_turn(self, turns: list[dict], turn_idx: int):
def find_turn(
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
):
"""
Locate the starting and ending indices of the specified turn in a conversation.
"""
@@ -577,10 +603,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
turns_with_content = turns[: turn_idx + 1]
# Generate the conversation up to the turn, with final turn replaced with dummy content
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore
# Generate the conversation up to the turn, with final turn included
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore
if not full_ids or not dummy_ids:
LOG.warning(f"Empty template generated for turn {turn_idx}")
@@ -633,9 +659,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
turns = []
possible_sys_turn = self.transform_message(
prompt[self.prompter.field_messages][0]
)
messages = self._get_messages(prompt)
possible_sys_turn = self.transform_message(messages[0])
if (
possible_sys_turn["role"] != "system"
and self.prompter.field_system in prompt
@@ -643,7 +670,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
turns.append(turn)
for message in prompt[self.prompter.field_messages]:
for message in messages:
transformed_message = self.transform_message(message)
turn = {
@@ -661,7 +688,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return turns
def transform_message(self, message):
def transform_message(self, message: dict) -> dict:
# Build the initial transformed message from the mappings
transformed_message = {}
for key, value in self.prompter.message_property_mappings.items():
@@ -738,9 +765,36 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return transformed_message
def get_images(self, prompt):
def _get_images(self, prompt):
return prompt.get(self.images, None)
def _get_tools(self, prompt) -> list[dict] | None:
"""Get tools from prompt if available."""
tools = prompt.get(self.prompter.field_tools, None)
if tools is None:
return None
if isinstance(tools, list):
return tools
raise ValueError(
"Unknown tools format. Please convert it into a list[dict].\n"
f"Current format: {type(tools)}"
)
def _get_messages(self, prompt):
messages = prompt.get(self.prompter.field_messages, None)
if messages is None:
raise ValueError("Messages is null. Please check `field_messages`.")
if isinstance(messages, list):
return messages
raise ValueError(
"Unknown messages format. Please convert it into a list[dict].\n"
f"Current format: {type(messages)}"
)
class StrategyLoader:
"""

View File

@@ -43,6 +43,7 @@ class SFTDataset(BaseModel):
field_human: str | None = None
field_model: str | None = None
field_messages: str | None = None
field_tools: str | None = None
# deprecated, use message_property_mappings
message_field_role: str | None = None
# deprecated, use message_property_mappings