Feat: add tool calling support via tools column (#2774)

* feat: add tool_calling field support

* fix: add tests
This commit is contained in:
NanoCode012
2025-06-09 21:42:05 -07:00
committed by GitHub
parent 92afa4fa27
commit 83632f71d8
5 changed files with 327 additions and 27 deletions

View File

@@ -173,6 +173,10 @@ datasets:
# Key containing the messages (default: "messages")
field_messages: messages
# Key containing the tools (default: "tools")
# Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
field_tools: tools
# Key containing the system message (default: "system")
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
field_system: system

View File

@@ -52,7 +52,9 @@ We recommend checking the below examples for other usecases.
### Examples
1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
#### Training on last message
(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
@@ -66,7 +68,9 @@ datasets:
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
#### Overriding default chat template
Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
@@ -76,7 +80,13 @@ datasets:
roles_to_train: ["assistant"] # default value
```
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
::: {.callout-note}
If you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default).
:::
#### Using default chat template with fallback
Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
@@ -85,7 +95,9 @@ datasets:
type: chat_template
```
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
#### Custom Jinja template
Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
@@ -100,7 +112,9 @@ datasets:
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
:::
5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
#### Using template with different token for EOT and EOS
- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
```yaml
eot_tokens:
@@ -125,7 +139,7 @@ Using `eot_tokens` requires each token that exists in `chat_template` to be a si
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
:::
6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
```yaml
eot_tokens:
@@ -145,7 +159,73 @@ If EOS token only appears at the end of a prompt, `train_on_eos: last` is equiva
:::
7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
#### Using tool use
Instead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it.
```json
{
"tools": [
{
"type": "...",
"function": {
"name": "...",
"description": "...",
"parameters": {
"type": "...",
"properties": {
// ...
},
"required": ["..."],
},
},
},
],
"messages": [
// ...
{
"role": "assistant", // call the function via assistant
"tool_calls": [
{
"type": "function",
"function": {
"name": "...",
"arguments": {
"...": "...",
}
}
}
]
},
{
"role": "tool",
"name": "...",
"content": "..."
},
],
}
```
::: {.callout-note}
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
:::
```yaml
chat_template: llama4
datasets:
- path: ...
type: chat_template
# field_tools: tools # default is `tools`
```
::: {.callout-tip}
Look into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template.
:::
#### Using fine-grained control over token masking
(Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
@@ -196,7 +276,9 @@ datasets:
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
#### Reasoning split
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
```yaml
datasets:

View File

@@ -34,6 +34,7 @@ class ChatTemplatePrompter(Prompter):
message_field_training_detail: str | None = None,
field_messages: str = "messages",
field_system: str = "system",
field_tools: str = "tools",
roles: dict[str, list[str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
drop_system_message: bool = False,
@@ -66,6 +67,7 @@ class ChatTemplatePrompter(Prompter):
self.message_field_training_detail = message_field_training_detail
self.field_messages = field_messages
self.field_system = field_system
self.field_tools = field_tools
self.tokenizer = tokenizer
self.processor: ProcessorMixin | None = processor
self.chat_template = chat_template
@@ -77,17 +79,38 @@ class ChatTemplatePrompter(Prompter):
def chat_template_msg_variables(self) -> Set[str]:
return self._chat_template_msg_variables
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
def build_prompt(
self,
conversation,
add_generation_prompt=False,
images=None,
tools=None,
):
"""
Build a prompt from a conversation.
Args:
conversation: A list of messages.
add_generation_prompt: Whether to add a generation prompt.
images: A list of images. (optional)
tools: A list of tools. (optional)
"""
chat_template_kwargs = {
"chat_template": self.chat_template,
"add_generation_prompt": add_generation_prompt,
}
if tools:
chat_template_kwargs["tools"] = tools
if self.processor:
if not callable(self.processor):
raise TypeError("Processor must be callable")
text = self.processor.apply_chat_template(
conversation,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
**self.chat_template_kwargs,
**chat_template_kwargs,
)
batch = self.processor(
text=text,
@@ -104,9 +127,7 @@ class ChatTemplatePrompter(Prompter):
return self.tokenizer.apply_chat_template(
conversation,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
**self.chat_template_kwargs,
**chat_template_kwargs,
)
def get_offsets_for_train_detail(
@@ -376,7 +397,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
and not self.prompter.message_field_training_detail # type: ignore
):
turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
images = self._get_images(prompt)
prompt_ids = self.prompter.build_prompt( # type: ignore
turns[:-1],
add_generation_prompt=True,
@@ -405,7 +426,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt
turns = self.get_conversation_thread(prompt)
input_ids = self.prompter.build_prompt(turns) # type: ignore
tools = self._get_tools(prompt)
input_ids = self.prompter.build_prompt(turns, tools=tools) # type: ignore
labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1
@@ -444,7 +466,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
continue
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
turn_start_idx, turn_end_idx = self.find_turn(
turns=turns, turn_idx=index, tools=tools
)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
@@ -546,7 +570,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return i
return -1
def find_turn(self, turns: list[dict], turn_idx: int):
def find_turn(
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
):
"""
Locate the starting and ending indices of the specified turn in a conversation.
"""
@@ -577,10 +603,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
turns_with_content = turns[: turn_idx + 1]
# Generate the conversation up to the turn, with final turn replaced with dummy content
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore
# Generate the conversation up to the turn, with final turn included
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore
if not full_ids or not dummy_ids:
LOG.warning(f"Empty template generated for turn {turn_idx}")
@@ -633,9 +659,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
turns = []
possible_sys_turn = self.transform_message(
prompt[self.prompter.field_messages][0]
)
messages = self._get_messages(prompt)
possible_sys_turn = self.transform_message(messages[0])
if (
possible_sys_turn["role"] != "system"
and self.prompter.field_system in prompt
@@ -643,7 +670,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
turns.append(turn)
for message in prompt[self.prompter.field_messages]:
for message in messages:
transformed_message = self.transform_message(message)
turn = {
@@ -661,7 +688,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return turns
def transform_message(self, message):
def transform_message(self, message: dict) -> dict:
# Build the initial transformed message from the mappings
transformed_message = {}
for key, value in self.prompter.message_property_mappings.items():
@@ -738,9 +765,36 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return transformed_message
def get_images(self, prompt):
def _get_images(self, prompt):
return prompt.get(self.images, None)
def _get_tools(self, prompt) -> list[dict] | None:
"""Get tools from prompt if available."""
tools = prompt.get(self.prompter.field_tools, None)
if tools is None:
return None
if isinstance(tools, list):
return tools
raise ValueError(
"Unknown tools format. Please convert it into a list[dict].\n"
f"Current format: {type(tools)}"
)
def _get_messages(self, prompt):
messages = prompt.get(self.prompter.field_messages, None)
if messages is None:
raise ValueError("Messages is null. Please check `field_messages`.")
if isinstance(messages, list):
return messages
raise ValueError(
"Unknown messages format. Please convert it into a list[dict].\n"
f"Current format: {type(messages)}"
)
class StrategyLoader:
"""

View File

@@ -43,6 +43,7 @@ class SFTDataset(BaseModel):
field_human: str | None = None
field_model: str | None = None
field_messages: str | None = None
field_tools: str | None = None
# deprecated, use message_property_mappings
message_field_role: str | None = None
# deprecated, use message_property_mappings

View File

@@ -1280,3 +1280,162 @@ class TestChatTemplateConfigurations:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
class TestChatTemplateToolCalling:
"""
Test class for tool calling functionality with chat templates.
"""
def test_tool_calling_with_llama4_template(
self,
llama3_tokenizer,
):
LOG.info("Testing tool calling with llama3 tokenizer and llama4 chat template")
# Create tool calling dataset
tool_calling_dataset = [
{
"tools": [
{
"type": "function",
"function": {
"name": "xml_escape",
"description": 'Replaces any "<", ">", or "&" characters in the input string with their corresponding XML entities.',
"parameters": {
"type": "object",
"properties": {
"s": {
"type": "string",
"description": "The input string to be XML-escaped.",
}
},
"required": ["s"],
},
},
},
{
"type": "function",
"function": {
"name": "multiples",
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
"parameters": {
"type": "object",
"properties": {
"number": {
"type": "integer",
"description": "The number to find multiples of.",
},
"limit": {
"type": "integer",
"description": "The upper limit for the multiples.",
},
},
"required": ["number", "limit"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "Can you help me find multiples of 5 that are less than 20?",
},
{
"role": "assistant",
"tool_calls": [
{
"type": "function",
"function": {
"name": "multiples",
"arguments": {
"number": 5,
"limit": 20,
},
},
}
],
},
{"role": "tool", "name": "multiples", "content": "5,10,15"},
{
"role": "assistant",
"content": "The multiples of 5 less than 20 are: 5, 10, and 15.",
},
],
}
]
# Setup tokenizer with llama4 chat template
tokenizer = deepcopy(llama3_tokenizer)
# Add EOS token to the tokenizer
eot_token = "<|eot_id|>"
tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]})
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template("llama4"),
message_property_mappings={"role": "role", "content": "content"},
field_messages="messages",
field_tools="tools",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
eot_tokens=[eot_token],
)
res = strategy.tokenize_prompt(tool_calling_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# Verify that the input_ids contain expected tokens
assert len(input_ids) > 0, "Input IDs should not be empty"
assert len(labels) == len(input_ids), "Labels should match input_ids length"
# Decode the full conversation to verify structure
decoded_conversation = tokenizer.decode(input_ids)
# Verify tool calling structure is present in the decoded conversation
assert (
'"type": "function",' in decoded_conversation
), "Tool type function should be in conversation"
assert (
'"name": "multiples",' in decoded_conversation
), "Tool function name should be in conversation"
assert (
'<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>'
in decoded_conversation
), "Assistant tool call should be in conversation"
assert (
"<|header_start|>ipython<|header_end|>" in decoded_conversation
), "IPython header should be in conversation"
assert (
'"5,10,15"' in decoded_conversation
), "Tool response should be in conversation"
# Get conversation turns to verify labeling
turns = strategy.get_conversation_thread(tool_calling_dataset[0])
tools = strategy._get_tools( # pylint: disable=protected-access
tool_calling_dataset[0]
)
# Check that assistant responses are properly labeled
for i, turn in enumerate(tool_calling_dataset[0]["messages"]):
if turn["role"] == "assistant":
start_idx, end_idx = strategy.find_turn(
turns=turns, turn_idx=i, tools=tools
)
assert (
start_idx != -1 and end_idx != -1
), f"Assistant turn {i} should be found"
# Verify that assistant responses have proper labels
turn_labels = labels[start_idx:end_idx]
assert all(
label != IGNORE_TOKEN_ID for label in turn_labels
), f"Assistant turn {i} should be unmasked"