Feat: add tool calling support via tools column (#2774)
* feat: add tool_calling field support * fix: add tests
This commit is contained in:
@@ -173,6 +173,10 @@ datasets:
|
||||
# Key containing the messages (default: "messages")
|
||||
field_messages: messages
|
||||
|
||||
# Key containing the tools (default: "tools")
|
||||
# Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
field_tools: tools
|
||||
|
||||
# Key containing the system message (default: "system")
|
||||
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
|
||||
field_system: system
|
||||
|
||||
@@ -52,7 +52,9 @@ We recommend checking the below examples for other usecases.
|
||||
|
||||
### Examples
|
||||
|
||||
1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||
#### Training on last message
|
||||
|
||||
(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
@@ -66,7 +68,9 @@ datasets:
|
||||
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
|
||||
:::
|
||||
|
||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
#### Overriding default chat template
|
||||
|
||||
Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
chat_template: gemma # this overwrites the tokenizer's chat_template
|
||||
@@ -76,7 +80,13 @@ datasets:
|
||||
roles_to_train: ["assistant"] # default value
|
||||
```
|
||||
|
||||
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||
::: {.callout-note}
|
||||
If you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default).
|
||||
:::
|
||||
|
||||
#### Using default chat template with fallback
|
||||
|
||||
Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
|
||||
@@ -85,7 +95,9 @@ datasets:
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||
#### Custom Jinja template
|
||||
|
||||
Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
|
||||
@@ -100,7 +112,9 @@ datasets:
|
||||
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
||||
:::
|
||||
|
||||
5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
|
||||
#### Using template with different token for EOT and EOS
|
||||
|
||||
- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
|
||||
|
||||
```yaml
|
||||
eot_tokens:
|
||||
@@ -125,7 +139,7 @@ Using `eot_tokens` requires each token that exists in `chat_template` to be a si
|
||||
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
|
||||
:::
|
||||
|
||||
6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
|
||||
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
|
||||
|
||||
```yaml
|
||||
eot_tokens:
|
||||
@@ -145,7 +159,73 @@ If EOS token only appears at the end of a prompt, `train_on_eos: last` is equiva
|
||||
:::
|
||||
|
||||
|
||||
7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
#### Using tool use
|
||||
|
||||
Instead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it.
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "...",
|
||||
"function": {
|
||||
"name": "...",
|
||||
"description": "...",
|
||||
"parameters": {
|
||||
"type": "...",
|
||||
"properties": {
|
||||
// ...
|
||||
},
|
||||
"required": ["..."],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
// ...
|
||||
{
|
||||
"role": "assistant", // call the function via assistant
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "...",
|
||||
"arguments": {
|
||||
"...": "...",
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "...",
|
||||
"content": "..."
|
||||
},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
# field_tools: tools # default is `tools`
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
Look into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template.
|
||||
:::
|
||||
|
||||
|
||||
#### Using fine-grained control over token masking
|
||||
|
||||
(Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
|
||||
For a data sample that looks like:
|
||||
|
||||
@@ -196,7 +276,9 @@ datasets:
|
||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||
:::
|
||||
|
||||
8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||
#### Reasoning split
|
||||
|
||||
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
|
||||
@@ -34,6 +34,7 @@ class ChatTemplatePrompter(Prompter):
|
||||
message_field_training_detail: str | None = None,
|
||||
field_messages: str = "messages",
|
||||
field_system: str = "system",
|
||||
field_tools: str = "tools",
|
||||
roles: dict[str, list[str]] | None = None,
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
drop_system_message: bool = False,
|
||||
@@ -66,6 +67,7 @@ class ChatTemplatePrompter(Prompter):
|
||||
self.message_field_training_detail = message_field_training_detail
|
||||
self.field_messages = field_messages
|
||||
self.field_system = field_system
|
||||
self.field_tools = field_tools
|
||||
self.tokenizer = tokenizer
|
||||
self.processor: ProcessorMixin | None = processor
|
||||
self.chat_template = chat_template
|
||||
@@ -77,17 +79,38 @@ class ChatTemplatePrompter(Prompter):
|
||||
def chat_template_msg_variables(self) -> Set[str]:
|
||||
return self._chat_template_msg_variables
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
||||
def build_prompt(
|
||||
self,
|
||||
conversation,
|
||||
add_generation_prompt=False,
|
||||
images=None,
|
||||
tools=None,
|
||||
):
|
||||
"""
|
||||
Build a prompt from a conversation.
|
||||
|
||||
Args:
|
||||
conversation: A list of messages.
|
||||
add_generation_prompt: Whether to add a generation prompt.
|
||||
images: A list of images. (optional)
|
||||
tools: A list of tools. (optional)
|
||||
"""
|
||||
chat_template_kwargs = {
|
||||
"chat_template": self.chat_template,
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
}
|
||||
|
||||
if tools:
|
||||
chat_template_kwargs["tools"] = tools
|
||||
|
||||
if self.processor:
|
||||
if not callable(self.processor):
|
||||
raise TypeError("Processor must be callable")
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
conversation,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**self.chat_template_kwargs,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
batch = self.processor(
|
||||
text=text,
|
||||
@@ -104,9 +127,7 @@ class ChatTemplatePrompter(Prompter):
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
chat_template=self.chat_template,
|
||||
**self.chat_template_kwargs,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
|
||||
def get_offsets_for_train_detail(
|
||||
@@ -376,7 +397,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
and not self.prompter.message_field_training_detail # type: ignore
|
||||
):
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
images = self.get_images(prompt)
|
||||
images = self._get_images(prompt)
|
||||
prompt_ids = self.prompter.build_prompt( # type: ignore
|
||||
turns[:-1],
|
||||
add_generation_prompt=True,
|
||||
@@ -405,7 +426,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return tokenized_prompt
|
||||
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
input_ids = self.prompter.build_prompt(turns) # type: ignore
|
||||
tools = self._get_tools(prompt)
|
||||
input_ids = self.prompter.build_prompt(turns, tools=tools) # type: ignore
|
||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||
|
||||
last_eos_idx = -1
|
||||
@@ -444,7 +466,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
continue
|
||||
|
||||
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
|
||||
turn_start_idx, turn_end_idx = self.find_turn(
|
||||
turns=turns, turn_idx=index, tools=tools
|
||||
)
|
||||
|
||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||
|
||||
@@ -546,7 +570,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return i
|
||||
return -1
|
||||
|
||||
def find_turn(self, turns: list[dict], turn_idx: int):
|
||||
def find_turn(
|
||||
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
|
||||
):
|
||||
"""
|
||||
Locate the starting and ending indices of the specified turn in a conversation.
|
||||
"""
|
||||
@@ -577,10 +603,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
turns_with_content = turns[: turn_idx + 1]
|
||||
|
||||
# Generate the conversation up to the turn, with final turn replaced with dummy content
|
||||
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
|
||||
dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore
|
||||
|
||||
# Generate the conversation up to the turn, with final turn included
|
||||
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
|
||||
full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore
|
||||
|
||||
if not full_ids or not dummy_ids:
|
||||
LOG.warning(f"Empty template generated for turn {turn_idx}")
|
||||
@@ -633,9 +659,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
def get_conversation_thread(self, prompt):
|
||||
turns = []
|
||||
|
||||
possible_sys_turn = self.transform_message(
|
||||
prompt[self.prompter.field_messages][0]
|
||||
)
|
||||
messages = self._get_messages(prompt)
|
||||
|
||||
possible_sys_turn = self.transform_message(messages[0])
|
||||
|
||||
if (
|
||||
possible_sys_turn["role"] != "system"
|
||||
and self.prompter.field_system in prompt
|
||||
@@ -643,7 +670,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
|
||||
turns.append(turn)
|
||||
|
||||
for message in prompt[self.prompter.field_messages]:
|
||||
for message in messages:
|
||||
transformed_message = self.transform_message(message)
|
||||
|
||||
turn = {
|
||||
@@ -661,7 +688,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return turns
|
||||
|
||||
def transform_message(self, message):
|
||||
def transform_message(self, message: dict) -> dict:
|
||||
# Build the initial transformed message from the mappings
|
||||
transformed_message = {}
|
||||
for key, value in self.prompter.message_property_mappings.items():
|
||||
@@ -738,9 +765,36 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return transformed_message
|
||||
|
||||
def get_images(self, prompt):
|
||||
def _get_images(self, prompt):
|
||||
return prompt.get(self.images, None)
|
||||
|
||||
def _get_tools(self, prompt) -> list[dict] | None:
|
||||
"""Get tools from prompt if available."""
|
||||
tools = prompt.get(self.prompter.field_tools, None)
|
||||
if tools is None:
|
||||
return None
|
||||
|
||||
if isinstance(tools, list):
|
||||
return tools
|
||||
|
||||
raise ValueError(
|
||||
"Unknown tools format. Please convert it into a list[dict].\n"
|
||||
f"Current format: {type(tools)}"
|
||||
)
|
||||
|
||||
def _get_messages(self, prompt):
|
||||
messages = prompt.get(self.prompter.field_messages, None)
|
||||
if messages is None:
|
||||
raise ValueError("Messages is null. Please check `field_messages`.")
|
||||
|
||||
if isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
raise ValueError(
|
||||
"Unknown messages format. Please convert it into a list[dict].\n"
|
||||
f"Current format: {type(messages)}"
|
||||
)
|
||||
|
||||
|
||||
class StrategyLoader:
|
||||
"""
|
||||
|
||||
@@ -43,6 +43,7 @@ class SFTDataset(BaseModel):
|
||||
field_human: str | None = None
|
||||
field_model: str | None = None
|
||||
field_messages: str | None = None
|
||||
field_tools: str | None = None
|
||||
# deprecated, use message_property_mappings
|
||||
message_field_role: str | None = None
|
||||
# deprecated, use message_property_mappings
|
||||
|
||||
@@ -1280,3 +1280,162 @@ class TestChatTemplateConfigurations:
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
|
||||
|
||||
|
||||
class TestChatTemplateToolCalling:
|
||||
"""
|
||||
Test class for tool calling functionality with chat templates.
|
||||
"""
|
||||
|
||||
def test_tool_calling_with_llama4_template(
|
||||
self,
|
||||
llama3_tokenizer,
|
||||
):
|
||||
LOG.info("Testing tool calling with llama3 tokenizer and llama4 chat template")
|
||||
|
||||
# Create tool calling dataset
|
||||
tool_calling_dataset = [
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "xml_escape",
|
||||
"description": 'Replaces any "<", ">", or "&" characters in the input string with their corresponding XML entities.',
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"s": {
|
||||
"type": "string",
|
||||
"description": "The input string to be XML-escaped.",
|
||||
}
|
||||
},
|
||||
"required": ["s"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiples",
|
||||
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"number": {
|
||||
"type": "integer",
|
||||
"description": "The number to find multiples of.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "The upper limit for the multiples.",
|
||||
},
|
||||
},
|
||||
"required": ["number", "limit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you help me find multiples of 5 that are less than 20?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiples",
|
||||
"arguments": {
|
||||
"number": 5,
|
||||
"limit": 20,
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "name": "multiples", "content": "5,10,15"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The multiples of 5 less than 20 are: 5, 10, and 15.",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Setup tokenizer with llama4 chat template
|
||||
tokenizer = deepcopy(llama3_tokenizer)
|
||||
|
||||
# Add EOS token to the tokenizer
|
||||
eot_token = "<|eot_id|>"
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]})
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template("llama4"),
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
field_messages="messages",
|
||||
field_tools="tools",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
eot_tokens=[eot_token],
|
||||
)
|
||||
|
||||
res = strategy.tokenize_prompt(tool_calling_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
|
||||
# Verify that the input_ids contain expected tokens
|
||||
assert len(input_ids) > 0, "Input IDs should not be empty"
|
||||
assert len(labels) == len(input_ids), "Labels should match input_ids length"
|
||||
|
||||
# Decode the full conversation to verify structure
|
||||
decoded_conversation = tokenizer.decode(input_ids)
|
||||
|
||||
# Verify tool calling structure is present in the decoded conversation
|
||||
assert (
|
||||
'"type": "function",' in decoded_conversation
|
||||
), "Tool type function should be in conversation"
|
||||
assert (
|
||||
'"name": "multiples",' in decoded_conversation
|
||||
), "Tool function name should be in conversation"
|
||||
|
||||
assert (
|
||||
'<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>'
|
||||
in decoded_conversation
|
||||
), "Assistant tool call should be in conversation"
|
||||
assert (
|
||||
"<|header_start|>ipython<|header_end|>" in decoded_conversation
|
||||
), "IPython header should be in conversation"
|
||||
assert (
|
||||
'"5,10,15"' in decoded_conversation
|
||||
), "Tool response should be in conversation"
|
||||
|
||||
# Get conversation turns to verify labeling
|
||||
turns = strategy.get_conversation_thread(tool_calling_dataset[0])
|
||||
tools = strategy._get_tools( # pylint: disable=protected-access
|
||||
tool_calling_dataset[0]
|
||||
)
|
||||
|
||||
# Check that assistant responses are properly labeled
|
||||
for i, turn in enumerate(tool_calling_dataset[0]["messages"]):
|
||||
if turn["role"] == "assistant":
|
||||
start_idx, end_idx = strategy.find_turn(
|
||||
turns=turns, turn_idx=i, tools=tools
|
||||
)
|
||||
|
||||
assert (
|
||||
start_idx != -1 and end_idx != -1
|
||||
), f"Assistant turn {i} should be found"
|
||||
|
||||
# Verify that assistant responses have proper labels
|
||||
turn_labels = labels[start_idx:end_idx]
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in turn_labels
|
||||
), f"Assistant turn {i} should be unmasked"
|
||||
|
||||
Reference in New Issue
Block a user