From 83632f71d82422819d4bd0020523a52624b1cb6b Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 9 Jun 2025 21:42:05 -0700 Subject: [PATCH] Feat: add tool calling support via tools column (#2774) * feat: add tool_calling field support * fix: add tests --- docs/config.qmd | 4 + docs/dataset-formats/conversation.qmd | 98 ++++++++++- .../prompt_strategies/chat_template.py | 92 +++++++--- src/axolotl/utils/schemas/datasets.py | 1 + .../test_chat_templates_advanced.py | 159 ++++++++++++++++++ 5 files changed, 327 insertions(+), 27 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index 519065554..2ca236708 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -173,6 +173,10 @@ datasets: # Key containing the messages (default: "messages") field_messages: messages + # Key containing the tools (default: "tools") + # Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step). + field_tools: tools + # Key containing the system message (default: "system") # If the system message is not present in the dataset sample, it will be loaded from the field_system property. field_system: system diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index 87c2941e6..290841c08 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -52,7 +52,9 @@ We recommend checking the below examples for other usecases. ### Examples -1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message. +#### Training on last message + +(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message. ```yaml datasets: @@ -66,7 +68,9 @@ datasets: If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`. ::: -2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages. +#### Overriding default chat template + +Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages. ```yaml chat_template: gemma # this overwrites the tokenizer's chat_template @@ -76,7 +80,13 @@ datasets: roles_to_train: ["assistant"] # default value ``` -3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages. +::: {.callout-note} +If you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default). +::: + +#### Using default chat template with fallback + +Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages. ```yaml chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template @@ -85,7 +95,9 @@ datasets: type: chat_template ``` -4. Using a custom jinja template on OpenAI messages format, training on all assistant messages. +#### Custom Jinja template + +Using a custom jinja template on OpenAI messages format, training on all assistant messages. ```yaml # chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty @@ -100,7 +112,9 @@ datasets: Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `. ::: -5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn. +#### Using template with different token for EOT and EOS + +- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn. ```yaml eot_tokens: @@ -125,7 +139,7 @@ Using `eot_tokens` requires each token that exists in `chat_template` to be a si You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details. ::: -6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`. +- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`. ```yaml eot_tokens: @@ -145,7 +159,73 @@ If EOS token only appears at the end of a prompt, `train_on_eos: last` is equiva ::: -7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation +#### Using tool use + +Instead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it. + +```json +{ + "tools": [ + { + "type": "...", + "function": { + "name": "...", + "description": "...", + "parameters": { + "type": "...", + "properties": { + // ... + }, + "required": ["..."], + }, + }, + }, + ], + "messages": [ + // ... + { + "role": "assistant", // call the function via assistant + "tool_calls": [ + { + "type": "function", + "function": { + "name": "...", + "arguments": { + "...": "...", + } + } + } + ] + }, + { + "role": "tool", + "name": "...", + "content": "..." + }, + ], +} +``` + +::: {.callout-note} +Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step). +::: + +```yaml +chat_template: llama4 +datasets: + - path: ... + type: chat_template + # field_tools: tools # default is `tools` +``` + +::: {.callout-tip} +Look into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template. +::: + + +#### Using fine-grained control over token masking + +(Advanced) Using fine-grained control over tokens and turns to train in a conversation For a data sample that looks like: @@ -196,7 +276,9 @@ datasets: It is not necessary to set both `message_field_training` and `message_field_training_detail` at once. ::: -8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template. +#### Reasoning split + +(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template. ```yaml datasets: diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index a0fd8d911..1fee0f7f6 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -34,6 +34,7 @@ class ChatTemplatePrompter(Prompter): message_field_training_detail: str | None = None, field_messages: str = "messages", field_system: str = "system", + field_tools: str = "tools", roles: dict[str, list[str]] | None = None, chat_template_kwargs: dict[str, Any] | None = None, drop_system_message: bool = False, @@ -66,6 +67,7 @@ class ChatTemplatePrompter(Prompter): self.message_field_training_detail = message_field_training_detail self.field_messages = field_messages self.field_system = field_system + self.field_tools = field_tools self.tokenizer = tokenizer self.processor: ProcessorMixin | None = processor self.chat_template = chat_template @@ -77,17 +79,38 @@ class ChatTemplatePrompter(Prompter): def chat_template_msg_variables(self) -> Set[str]: return self._chat_template_msg_variables - def build_prompt(self, conversation, add_generation_prompt=False, images=None): + def build_prompt( + self, + conversation, + add_generation_prompt=False, + images=None, + tools=None, + ): + """ + Build a prompt from a conversation. + + Args: + conversation: A list of messages. + add_generation_prompt: Whether to add a generation prompt. + images: A list of images. (optional) + tools: A list of tools. (optional) + """ + chat_template_kwargs = { + "chat_template": self.chat_template, + "add_generation_prompt": add_generation_prompt, + } + + if tools: + chat_template_kwargs["tools"] = tools + if self.processor: if not callable(self.processor): raise TypeError("Processor must be callable") text = self.processor.apply_chat_template( conversation, - chat_template=self.chat_template, tokenize=False, - add_generation_prompt=add_generation_prompt, - **self.chat_template_kwargs, + **chat_template_kwargs, ) batch = self.processor( text=text, @@ -104,9 +127,7 @@ class ChatTemplatePrompter(Prompter): return self.tokenizer.apply_chat_template( conversation, - add_generation_prompt=add_generation_prompt, - chat_template=self.chat_template, - **self.chat_template_kwargs, + **chat_template_kwargs, ) def get_offsets_for_train_detail( @@ -376,7 +397,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): and not self.prompter.message_field_training_detail # type: ignore ): turns = self.get_conversation_thread(prompt) - images = self.get_images(prompt) + images = self._get_images(prompt) prompt_ids = self.prompter.build_prompt( # type: ignore turns[:-1], add_generation_prompt=True, @@ -405,7 +426,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return tokenized_prompt turns = self.get_conversation_thread(prompt) - input_ids = self.prompter.build_prompt(turns) # type: ignore + tools = self._get_tools(prompt) + input_ids = self.prompter.build_prompt(turns, tools=tools) # type: ignore labels = [IGNORE_TOKEN_ID] * len(input_ids) last_eos_idx = -1 @@ -444,7 +466,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): continue - turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index) + turn_start_idx, turn_end_idx = self.find_turn( + turns=turns, turn_idx=index, tools=tools + ) LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}") @@ -546,7 +570,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return i return -1 - def find_turn(self, turns: list[dict], turn_idx: int): + def find_turn( + self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None + ): """ Locate the starting and ending indices of the specified turn in a conversation. """ @@ -577,10 +603,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): turns_with_content = turns[: turn_idx + 1] # Generate the conversation up to the turn, with final turn replaced with dummy content - dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore + dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore # Generate the conversation up to the turn, with final turn included - full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore + full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore if not full_ids or not dummy_ids: LOG.warning(f"Empty template generated for turn {turn_idx}") @@ -633,9 +659,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): def get_conversation_thread(self, prompt): turns = [] - possible_sys_turn = self.transform_message( - prompt[self.prompter.field_messages][0] - ) + messages = self._get_messages(prompt) + + possible_sys_turn = self.transform_message(messages[0]) + if ( possible_sys_turn["role"] != "system" and self.prompter.field_system in prompt @@ -643,7 +670,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): turn = {"role": "system", "content": prompt[self.prompter.field_system]} turns.append(turn) - for message in prompt[self.prompter.field_messages]: + for message in messages: transformed_message = self.transform_message(message) turn = { @@ -661,7 +688,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return turns - def transform_message(self, message): + def transform_message(self, message: dict) -> dict: # Build the initial transformed message from the mappings transformed_message = {} for key, value in self.prompter.message_property_mappings.items(): @@ -738,9 +765,36 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return transformed_message - def get_images(self, prompt): + def _get_images(self, prompt): return prompt.get(self.images, None) + def _get_tools(self, prompt) -> list[dict] | None: + """Get tools from prompt if available.""" + tools = prompt.get(self.prompter.field_tools, None) + if tools is None: + return None + + if isinstance(tools, list): + return tools + + raise ValueError( + "Unknown tools format. Please convert it into a list[dict].\n" + f"Current format: {type(tools)}" + ) + + def _get_messages(self, prompt): + messages = prompt.get(self.prompter.field_messages, None) + if messages is None: + raise ValueError("Messages is null. Please check `field_messages`.") + + if isinstance(messages, list): + return messages + + raise ValueError( + "Unknown messages format. Please convert it into a list[dict].\n" + f"Current format: {type(messages)}" + ) + class StrategyLoader: """ diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index cc5d6daba..c71f9be77 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -43,6 +43,7 @@ class SFTDataset(BaseModel): field_human: str | None = None field_model: str | None = None field_messages: str | None = None + field_tools: str | None = None # deprecated, use message_property_mappings message_field_role: str | None = None # deprecated, use message_property_mappings diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index 7f011f954..fcf860f81 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -1280,3 +1280,162 @@ class TestChatTemplateConfigurations: assert ( labels[eos_idx] != IGNORE_TOKEN_ID ), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'" + + +class TestChatTemplateToolCalling: + """ + Test class for tool calling functionality with chat templates. + """ + + def test_tool_calling_with_llama4_template( + self, + llama3_tokenizer, + ): + LOG.info("Testing tool calling with llama3 tokenizer and llama4 chat template") + + # Create tool calling dataset + tool_calling_dataset = [ + { + "tools": [ + { + "type": "function", + "function": { + "name": "xml_escape", + "description": 'Replaces any "<", ">", or "&" characters in the input string with their corresponding XML entities.', + "parameters": { + "type": "object", + "properties": { + "s": { + "type": "string", + "description": "The input string to be XML-escaped.", + } + }, + "required": ["s"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "multiples", + "description": "Generates a list of all the multiples of a number that are less than a given limit.", + "parameters": { + "type": "object", + "properties": { + "number": { + "type": "integer", + "description": "The number to find multiples of.", + }, + "limit": { + "type": "integer", + "description": "The upper limit for the multiples.", + }, + }, + "required": ["number", "limit"], + }, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "Can you help me find multiples of 5 that are less than 20?", + }, + { + "role": "assistant", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "multiples", + "arguments": { + "number": 5, + "limit": 20, + }, + }, + } + ], + }, + {"role": "tool", "name": "multiples", "content": "5,10,15"}, + { + "role": "assistant", + "content": "The multiples of 5 less than 20 are: 5, 10, and 15.", + }, + ], + } + ] + + # Setup tokenizer with llama4 chat template + tokenizer = deepcopy(llama3_tokenizer) + + # Add EOS token to the tokenizer + eot_token = "<|eot_id|>" + tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]}) + + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + tokenizer, + chat_template=get_chat_template("llama4"), + message_property_mappings={"role": "role", "content": "content"}, + field_messages="messages", + field_tools="tools", + ), + tokenizer=tokenizer, + train_on_inputs=False, + sequence_len=512, + roles_to_train=["assistant"], + eot_tokens=[eot_token], + ) + + res = strategy.tokenize_prompt(tool_calling_dataset[0]) + input_ids = res["input_ids"] + labels = res["labels"] + + # Verify that the input_ids contain expected tokens + assert len(input_ids) > 0, "Input IDs should not be empty" + assert len(labels) == len(input_ids), "Labels should match input_ids length" + + # Decode the full conversation to verify structure + decoded_conversation = tokenizer.decode(input_ids) + + # Verify tool calling structure is present in the decoded conversation + assert ( + '"type": "function",' in decoded_conversation + ), "Tool type function should be in conversation" + assert ( + '"name": "multiples",' in decoded_conversation + ), "Tool function name should be in conversation" + + assert ( + '<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>' + in decoded_conversation + ), "Assistant tool call should be in conversation" + assert ( + "<|header_start|>ipython<|header_end|>" in decoded_conversation + ), "IPython header should be in conversation" + assert ( + '"5,10,15"' in decoded_conversation + ), "Tool response should be in conversation" + + # Get conversation turns to verify labeling + turns = strategy.get_conversation_thread(tool_calling_dataset[0]) + tools = strategy._get_tools( # pylint: disable=protected-access + tool_calling_dataset[0] + ) + + # Check that assistant responses are properly labeled + for i, turn in enumerate(tool_calling_dataset[0]["messages"]): + if turn["role"] == "assistant": + start_idx, end_idx = strategy.find_turn( + turns=turns, turn_idx=i, tools=tools + ) + + assert ( + start_idx != -1 and end_idx != -1 + ), f"Assistant turn {i} should be found" + + # Verify that assistant responses have proper labels + turn_labels = labels[start_idx:end_idx] + assert all( + label != IGNORE_TOKEN_ID for label in turn_labels + ), f"Assistant turn {i} should be unmasked"