Feat: add tool calling support via tools column (#2774)
* feat: add tool_calling field support * fix: add tests
This commit is contained in:
@@ -173,6 +173,10 @@ datasets:
|
|||||||
# Key containing the messages (default: "messages")
|
# Key containing the messages (default: "messages")
|
||||||
field_messages: messages
|
field_messages: messages
|
||||||
|
|
||||||
|
# Key containing the tools (default: "tools")
|
||||||
|
# Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||||
|
field_tools: tools
|
||||||
|
|
||||||
# Key containing the system message (default: "system")
|
# Key containing the system message (default: "system")
|
||||||
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
|
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
|
||||||
field_system: system
|
field_system: system
|
||||||
|
|||||||
@@ -52,7 +52,9 @@ We recommend checking the below examples for other usecases.
|
|||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
#### Training on last message
|
||||||
|
|
||||||
|
(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
datasets:
|
datasets:
|
||||||
@@ -66,7 +68,9 @@ datasets:
|
|||||||
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
|
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
#### Overriding default chat template
|
||||||
|
|
||||||
|
Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
chat_template: gemma # this overwrites the tokenizer's chat_template
|
chat_template: gemma # this overwrites the tokenizer's chat_template
|
||||||
@@ -76,7 +80,13 @@ datasets:
|
|||||||
roles_to_train: ["assistant"] # default value
|
roles_to_train: ["assistant"] # default value
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
::: {.callout-note}
|
||||||
|
If you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default).
|
||||||
|
:::
|
||||||
|
|
||||||
|
#### Using default chat template with fallback
|
||||||
|
|
||||||
|
Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
|
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
|
||||||
@@ -85,7 +95,9 @@ datasets:
|
|||||||
type: chat_template
|
type: chat_template
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
#### Custom Jinja template
|
||||||
|
|
||||||
|
Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
|
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
|
||||||
@@ -100,7 +112,9 @@ datasets:
|
|||||||
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
|
#### Using template with different token for EOT and EOS
|
||||||
|
|
||||||
|
- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
eot_tokens:
|
eot_tokens:
|
||||||
@@ -125,7 +139,7 @@ Using `eot_tokens` requires each token that exists in `chat_template` to be a si
|
|||||||
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
|
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
|
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
eot_tokens:
|
eot_tokens:
|
||||||
@@ -145,7 +159,73 @@ If EOS token only appears at the end of a prompt, `train_on_eos: last` is equiva
|
|||||||
:::
|
:::
|
||||||
|
|
||||||
|
|
||||||
7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
#### Using tool use
|
||||||
|
|
||||||
|
Instead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "...",
|
||||||
|
"function": {
|
||||||
|
"name": "...",
|
||||||
|
"description": "...",
|
||||||
|
"parameters": {
|
||||||
|
"type": "...",
|
||||||
|
"properties": {
|
||||||
|
// ...
|
||||||
|
},
|
||||||
|
"required": ["..."],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
// ...
|
||||||
|
{
|
||||||
|
"role": "assistant", // call the function via assistant
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "...",
|
||||||
|
"arguments": {
|
||||||
|
"...": "...",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"name": "...",
|
||||||
|
"content": "..."
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||||
|
:::
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
chat_template: llama4
|
||||||
|
datasets:
|
||||||
|
- path: ...
|
||||||
|
type: chat_template
|
||||||
|
# field_tools: tools # default is `tools`
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Look into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template.
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
#### Using fine-grained control over token masking
|
||||||
|
|
||||||
|
(Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||||
|
|
||||||
For a data sample that looks like:
|
For a data sample that looks like:
|
||||||
|
|
||||||
@@ -196,7 +276,9 @@ datasets:
|
|||||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
#### Reasoning split
|
||||||
|
|
||||||
|
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
datasets:
|
datasets:
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
message_field_training_detail: str | None = None,
|
message_field_training_detail: str | None = None,
|
||||||
field_messages: str = "messages",
|
field_messages: str = "messages",
|
||||||
field_system: str = "system",
|
field_system: str = "system",
|
||||||
|
field_tools: str = "tools",
|
||||||
roles: dict[str, list[str]] | None = None,
|
roles: dict[str, list[str]] | None = None,
|
||||||
chat_template_kwargs: dict[str, Any] | None = None,
|
chat_template_kwargs: dict[str, Any] | None = None,
|
||||||
drop_system_message: bool = False,
|
drop_system_message: bool = False,
|
||||||
@@ -66,6 +67,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
self.message_field_training_detail = message_field_training_detail
|
self.message_field_training_detail = message_field_training_detail
|
||||||
self.field_messages = field_messages
|
self.field_messages = field_messages
|
||||||
self.field_system = field_system
|
self.field_system = field_system
|
||||||
|
self.field_tools = field_tools
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.processor: ProcessorMixin | None = processor
|
self.processor: ProcessorMixin | None = processor
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
@@ -77,17 +79,38 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
def chat_template_msg_variables(self) -> Set[str]:
|
def chat_template_msg_variables(self) -> Set[str]:
|
||||||
return self._chat_template_msg_variables
|
return self._chat_template_msg_variables
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
def build_prompt(
|
||||||
|
self,
|
||||||
|
conversation,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
images=None,
|
||||||
|
tools=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Build a prompt from a conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation: A list of messages.
|
||||||
|
add_generation_prompt: Whether to add a generation prompt.
|
||||||
|
images: A list of images. (optional)
|
||||||
|
tools: A list of tools. (optional)
|
||||||
|
"""
|
||||||
|
chat_template_kwargs = {
|
||||||
|
"chat_template": self.chat_template,
|
||||||
|
"add_generation_prompt": add_generation_prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
chat_template_kwargs["tools"] = tools
|
||||||
|
|
||||||
if self.processor:
|
if self.processor:
|
||||||
if not callable(self.processor):
|
if not callable(self.processor):
|
||||||
raise TypeError("Processor must be callable")
|
raise TypeError("Processor must be callable")
|
||||||
|
|
||||||
text = self.processor.apply_chat_template(
|
text = self.processor.apply_chat_template(
|
||||||
conversation,
|
conversation,
|
||||||
chat_template=self.chat_template,
|
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=add_generation_prompt,
|
**chat_template_kwargs,
|
||||||
**self.chat_template_kwargs,
|
|
||||||
)
|
)
|
||||||
batch = self.processor(
|
batch = self.processor(
|
||||||
text=text,
|
text=text,
|
||||||
@@ -104,9 +127,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
conversation,
|
conversation,
|
||||||
add_generation_prompt=add_generation_prompt,
|
**chat_template_kwargs,
|
||||||
chat_template=self.chat_template,
|
|
||||||
**self.chat_template_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_offsets_for_train_detail(
|
def get_offsets_for_train_detail(
|
||||||
@@ -376,7 +397,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
and not self.prompter.message_field_training_detail # type: ignore
|
and not self.prompter.message_field_training_detail # type: ignore
|
||||||
):
|
):
|
||||||
turns = self.get_conversation_thread(prompt)
|
turns = self.get_conversation_thread(prompt)
|
||||||
images = self.get_images(prompt)
|
images = self._get_images(prompt)
|
||||||
prompt_ids = self.prompter.build_prompt( # type: ignore
|
prompt_ids = self.prompter.build_prompt( # type: ignore
|
||||||
turns[:-1],
|
turns[:-1],
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
@@ -405,7 +426,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
turns = self.get_conversation_thread(prompt)
|
turns = self.get_conversation_thread(prompt)
|
||||||
input_ids = self.prompter.build_prompt(turns) # type: ignore
|
tools = self._get_tools(prompt)
|
||||||
|
input_ids = self.prompter.build_prompt(turns, tools=tools) # type: ignore
|
||||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||||
|
|
||||||
last_eos_idx = -1
|
last_eos_idx = -1
|
||||||
@@ -444,7 +466,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
|
turn_start_idx, turn_end_idx = self.find_turn(
|
||||||
|
turns=turns, turn_idx=index, tools=tools
|
||||||
|
)
|
||||||
|
|
||||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||||
|
|
||||||
@@ -546,7 +570,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
return i
|
return i
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def find_turn(self, turns: list[dict], turn_idx: int):
|
def find_turn(
|
||||||
|
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Locate the starting and ending indices of the specified turn in a conversation.
|
Locate the starting and ending indices of the specified turn in a conversation.
|
||||||
"""
|
"""
|
||||||
@@ -577,10 +603,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
turns_with_content = turns[: turn_idx + 1]
|
turns_with_content = turns[: turn_idx + 1]
|
||||||
|
|
||||||
# Generate the conversation up to the turn, with final turn replaced with dummy content
|
# Generate the conversation up to the turn, with final turn replaced with dummy content
|
||||||
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
|
dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore
|
||||||
|
|
||||||
# Generate the conversation up to the turn, with final turn included
|
# Generate the conversation up to the turn, with final turn included
|
||||||
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
|
full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore
|
||||||
|
|
||||||
if not full_ids or not dummy_ids:
|
if not full_ids or not dummy_ids:
|
||||||
LOG.warning(f"Empty template generated for turn {turn_idx}")
|
LOG.warning(f"Empty template generated for turn {turn_idx}")
|
||||||
@@ -633,9 +659,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
turns = []
|
turns = []
|
||||||
|
|
||||||
possible_sys_turn = self.transform_message(
|
messages = self._get_messages(prompt)
|
||||||
prompt[self.prompter.field_messages][0]
|
|
||||||
)
|
possible_sys_turn = self.transform_message(messages[0])
|
||||||
|
|
||||||
if (
|
if (
|
||||||
possible_sys_turn["role"] != "system"
|
possible_sys_turn["role"] != "system"
|
||||||
and self.prompter.field_system in prompt
|
and self.prompter.field_system in prompt
|
||||||
@@ -643,7 +670,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
|
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
|
||||||
turns.append(turn)
|
turns.append(turn)
|
||||||
|
|
||||||
for message in prompt[self.prompter.field_messages]:
|
for message in messages:
|
||||||
transformed_message = self.transform_message(message)
|
transformed_message = self.transform_message(message)
|
||||||
|
|
||||||
turn = {
|
turn = {
|
||||||
@@ -661,7 +688,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
def transform_message(self, message):
|
def transform_message(self, message: dict) -> dict:
|
||||||
# Build the initial transformed message from the mappings
|
# Build the initial transformed message from the mappings
|
||||||
transformed_message = {}
|
transformed_message = {}
|
||||||
for key, value in self.prompter.message_property_mappings.items():
|
for key, value in self.prompter.message_property_mappings.items():
|
||||||
@@ -738,9 +765,36 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
return transformed_message
|
return transformed_message
|
||||||
|
|
||||||
def get_images(self, prompt):
|
def _get_images(self, prompt):
|
||||||
return prompt.get(self.images, None)
|
return prompt.get(self.images, None)
|
||||||
|
|
||||||
|
def _get_tools(self, prompt) -> list[dict] | None:
|
||||||
|
"""Get tools from prompt if available."""
|
||||||
|
tools = prompt.get(self.prompter.field_tools, None)
|
||||||
|
if tools is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(tools, list):
|
||||||
|
return tools
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Unknown tools format. Please convert it into a list[dict].\n"
|
||||||
|
f"Current format: {type(tools)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_messages(self, prompt):
|
||||||
|
messages = prompt.get(self.prompter.field_messages, None)
|
||||||
|
if messages is None:
|
||||||
|
raise ValueError("Messages is null. Please check `field_messages`.")
|
||||||
|
|
||||||
|
if isinstance(messages, list):
|
||||||
|
return messages
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Unknown messages format. Please convert it into a list[dict].\n"
|
||||||
|
f"Current format: {type(messages)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StrategyLoader:
|
class StrategyLoader:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class SFTDataset(BaseModel):
|
|||||||
field_human: str | None = None
|
field_human: str | None = None
|
||||||
field_model: str | None = None
|
field_model: str | None = None
|
||||||
field_messages: str | None = None
|
field_messages: str | None = None
|
||||||
|
field_tools: str | None = None
|
||||||
# deprecated, use message_property_mappings
|
# deprecated, use message_property_mappings
|
||||||
message_field_role: str | None = None
|
message_field_role: str | None = None
|
||||||
# deprecated, use message_property_mappings
|
# deprecated, use message_property_mappings
|
||||||
|
|||||||
@@ -1280,3 +1280,162 @@ class TestChatTemplateConfigurations:
|
|||||||
assert (
|
assert (
|
||||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||||
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
|
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatTemplateToolCalling:
|
||||||
|
"""
|
||||||
|
Test class for tool calling functionality with chat templates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_tool_calling_with_llama4_template(
|
||||||
|
self,
|
||||||
|
llama3_tokenizer,
|
||||||
|
):
|
||||||
|
LOG.info("Testing tool calling with llama3 tokenizer and llama4 chat template")
|
||||||
|
|
||||||
|
# Create tool calling dataset
|
||||||
|
tool_calling_dataset = [
|
||||||
|
{
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "xml_escape",
|
||||||
|
"description": 'Replaces any "<", ">", or "&" characters in the input string with their corresponding XML entities.',
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"s": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The input string to be XML-escaped.",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["s"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "multiples",
|
||||||
|
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"number": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The number to find multiples of.",
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The upper limit for the multiples.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["number", "limit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Can you help me find multiples of 5 that are less than 20?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "multiples",
|
||||||
|
"arguments": {
|
||||||
|
"number": 5,
|
||||||
|
"limit": 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "name": "multiples", "content": "5,10,15"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The multiples of 5 less than 20 are: 5, 10, and 15.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Setup tokenizer with llama4 chat template
|
||||||
|
tokenizer = deepcopy(llama3_tokenizer)
|
||||||
|
|
||||||
|
# Add EOS token to the tokenizer
|
||||||
|
eot_token = "<|eot_id|>"
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]})
|
||||||
|
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=get_chat_template("llama4"),
|
||||||
|
message_property_mappings={"role": "role", "content": "content"},
|
||||||
|
field_messages="messages",
|
||||||
|
field_tools="tools",
|
||||||
|
),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
eot_tokens=[eot_token],
|
||||||
|
)
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(tool_calling_dataset[0])
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
labels = res["labels"]
|
||||||
|
|
||||||
|
# Verify that the input_ids contain expected tokens
|
||||||
|
assert len(input_ids) > 0, "Input IDs should not be empty"
|
||||||
|
assert len(labels) == len(input_ids), "Labels should match input_ids length"
|
||||||
|
|
||||||
|
# Decode the full conversation to verify structure
|
||||||
|
decoded_conversation = tokenizer.decode(input_ids)
|
||||||
|
|
||||||
|
# Verify tool calling structure is present in the decoded conversation
|
||||||
|
assert (
|
||||||
|
'"type": "function",' in decoded_conversation
|
||||||
|
), "Tool type function should be in conversation"
|
||||||
|
assert (
|
||||||
|
'"name": "multiples",' in decoded_conversation
|
||||||
|
), "Tool function name should be in conversation"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
'<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>'
|
||||||
|
in decoded_conversation
|
||||||
|
), "Assistant tool call should be in conversation"
|
||||||
|
assert (
|
||||||
|
"<|header_start|>ipython<|header_end|>" in decoded_conversation
|
||||||
|
), "IPython header should be in conversation"
|
||||||
|
assert (
|
||||||
|
'"5,10,15"' in decoded_conversation
|
||||||
|
), "Tool response should be in conversation"
|
||||||
|
|
||||||
|
# Get conversation turns to verify labeling
|
||||||
|
turns = strategy.get_conversation_thread(tool_calling_dataset[0])
|
||||||
|
tools = strategy._get_tools( # pylint: disable=protected-access
|
||||||
|
tool_calling_dataset[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that assistant responses are properly labeled
|
||||||
|
for i, turn in enumerate(tool_calling_dataset[0]["messages"]):
|
||||||
|
if turn["role"] == "assistant":
|
||||||
|
start_idx, end_idx = strategy.find_turn(
|
||||||
|
turns=turns, turn_idx=i, tools=tools
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
start_idx != -1 and end_idx != -1
|
||||||
|
), f"Assistant turn {i} should be found"
|
||||||
|
|
||||||
|
# Verify that assistant responses have proper labels
|
||||||
|
turn_labels = labels[start_idx:end_idx]
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID for label in turn_labels
|
||||||
|
), f"Assistant turn {i} should be unmasked"
|
||||||
|
|||||||
Reference in New Issue
Block a user