From e8b962d47f89041fd6ca6e84c7ece38b8baa34a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=8D=8E=E6=9D=B0?= Date: Thu, 25 Sep 2025 13:06:21 +0800 Subject: [PATCH] feat: support training with JSON string tool arguments (#3136) * feat: support training with JSON string tool arguments; fix PyArrow data type inconsistent error * feat: raise error for tool call arguments decode * Add test_chat_templates_tool_call_string_arguments.py Add test for string arguments * fix: change to correct qwen3 tokenizer * fix: update docs to clarify arguments json * chore: lint * fix: duplicate * chore: revert * feat: add error to faq * fix: remove duplicate fixture --------- Co-authored-by: caoqinping Co-authored-by: gamersover-blog <1611885128@qq.com> Co-authored-by: NanoCode012 --- docs/dataset-formats/conversation.qmd | 8 + docs/faq.qmd | 4 + .../prompt_strategies/chat_template.py | 17 ++ tests/prompt_strategies/conftest.py | 9 + ...est_chat_template_ds_schema_unification.py | 10 - .../test_chat_templates_thinking.py | 10 - ...at_templates_tool_call_string_arguments.py | 214 ++++++++++++++++++ 7 files changed, 252 insertions(+), 20 deletions(-) create mode 100644 tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index d53c68598..870a2b67d 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -212,6 +212,14 @@ Instead of passing `tools` via the system prompt, an alternative method would be Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step). ::: +::: {.callout-warning} +If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues. + +``` +"arguments": "{\"...\": \"...\"}" +``` +::: + Example config for Llama4: ```yaml chat_template: llama4 diff --git a/docs/faq.qmd b/docs/faq.qmd index 08d439af7..ffc29d35d 100644 --- a/docs/faq.qmd +++ b/docs/faq.qmd @@ -140,3 +140,7 @@ description: Frequently asked questions **Q: `ValueError("Backward pass should have cleared tracker of all tensors")` > A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML. + +**Q: `Error parsing tool_calls arguments as JSON.` + +> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details. diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index cb3e3dfb1..f4dcbd7cd 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -2,6 +2,7 @@ HF Chat Templates prompt strategy """ +import json from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Set, Union @@ -794,6 +795,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): if val is not None: transformed_message[key] = val + if "tool_calls" in transformed_message and transformed_message["tool_calls"]: + for tool_call in transformed_message["tool_calls"]: + if "function" in tool_call and "arguments" in tool_call["function"]: + args = tool_call["function"]["arguments"] + if isinstance(args, str): + try: + tool_call["function"]["arguments"] = json.loads(args) + except json.JSONDecodeError as e: + LOG.error( + f"Error parsing tool_calls arguments as JSON. " + f"Function: {tool_call.get('function', {}).get('name', 'unknown')}, " + f"Arguments string: {args!r}, " + f"Error: {e}" + ) + raise + return transformed_message def _get_images(self, prompt): diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py index 12c4bcd93..0af7b3e93 100644 --- a/tests/prompt_strategies/conftest.py +++ b/tests/prompt_strategies/conftest.py @@ -177,6 +177,15 @@ def fixture_devstral_1_1_tokenizer(): return tokenizer +@pytest.fixture(name="qwen3_tokenizer") +def qwen3_tokenizer_fixture( + download_qwen3_half_billion_model, +): # pylint: disable=unused-argument,redefined-outer-name + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") + + return tokenizer + + @pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja") def fixture_mistralv03_chat_template_jinja_w_system() -> str: return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n' diff --git a/tests/prompt_strategies/test_chat_template_ds_schema_unification.py b/tests/prompt_strategies/test_chat_template_ds_schema_unification.py index e8d35e974..4f4e32208 100644 --- a/tests/prompt_strategies/test_chat_template_ds_schema_unification.py +++ b/tests/prompt_strategies/test_chat_template_ds_schema_unification.py @@ -6,7 +6,6 @@ import json import pytest from datasets import Dataset -from transformers import AutoTokenizer from axolotl.prompt_strategies.chat_template import StrategyLoader from axolotl.utils.dict import DictDefault @@ -23,15 +22,6 @@ def fixture_messages_w_tools(): return Dataset.from_list(rows) -@pytest.fixture(name="qwen3_tokenizer") -def qwen3_tokenizer_fixture( - download_qwen3_half_billion_model, -): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") - - return tokenizer - - @pytest.fixture(name="qwen3_prompt_strategy") def qwen3_chat_template_strategy(qwen3_tokenizer): cfg = DictDefault( diff --git a/tests/prompt_strategies/test_chat_templates_thinking.py b/tests/prompt_strategies/test_chat_templates_thinking.py index 5475666a5..054012e00 100644 --- a/tests/prompt_strategies/test_chat_templates_thinking.py +++ b/tests/prompt_strategies/test_chat_templates_thinking.py @@ -4,7 +4,6 @@ Tests for splitting reasoning/thinking from content into separate field import pytest from datasets import Dataset -from transformers import AutoTokenizer from axolotl.prompt_strategies.chat_template import ( load, @@ -56,15 +55,6 @@ def messages_w_reasoning_fixture(): ) -@pytest.fixture(name="qwen3_tokenizer") -def qwen3_tokenizer_fixture( - download_qwen3_half_billion_model, -): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") - - return tokenizer - - class TestSplitThinking: """ test class to make sure datasets with reasoning content conforms to the chat_template strategy diff --git a/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py b/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py new file mode 100644 index 000000000..7de21b940 --- /dev/null +++ b/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py @@ -0,0 +1,214 @@ +""" +Tests for handling json tool content +""" + +import json + +import pytest +from datasets import Dataset + +from axolotl.prompt_strategies.chat_template import ( + load, +) +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="qwen3_instruct_prompt_strategy") +def qwen3_instruct_chat_template_strategy(qwen3_tokenizer): + strategy = load( + qwen3_tokenizer, + DictDefault( + { + "train_on_inputs": False, + "sequence_len": 512, + } + ), + DictDefault( + { + "chat_template": "qwen3", + "message_field_role": "role", + "message_field_content": "content", + "message_property_mappings": { + "role": "role", + "content": "content", + }, + "roles": { + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + "field_messages": "messages", + } + ), + ) + return strategy + + +class TestQwen3IdenticalConversationArgs: + """ + Test Qwen3 tools is identical between JSON and dict + """ + + @pytest.fixture(name="conversation_dict_args_dataset") + def fixture_conversation_dict_args_dataset(self): + """ + Provides a dataset with conversation where arguments is a dict. + """ + user_content = "What is the weather in Boston?" + function_name = "get_current_weather" + arguments_dict = {"location": "Boston, MA", "unit": "celsius"} + + data = [ + { + "messages": [ + {"role": "user", "content": user_content}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": function_name, + "arguments": arguments_dict, # dict格式 + } + } + ], + }, + ], + } + ] + return Dataset.from_list(data) + + @pytest.fixture(name="conversation_str_args_dataset") + def fixture_conversation_str_args_dataset(self): + """ + Provides a dataset with conversation where arguments is a JSON string. + """ + user_content = "What is the weather in Boston?" + function_name = "get_current_weather" + arguments_dict = {"location": "Boston, MA", "unit": "celsius"} + arguments_str = json.dumps(arguments_dict) + + data = [ + { + "messages": [ + {"role": "user", "content": user_content}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": function_name, + "arguments": arguments_str, # str格式 + } + } + ], + }, + ], + } + ] + return Dataset.from_list(data) + + @pytest.fixture(name="conversation_mixed_time_types_dataset") + def fixture_conversation_mixed_time_types_dataset(self): + """ + Provides a dataset where 'time' field has different types in different tool calls. + """ + data = [ + { + "messages": [ + { + "role": "user", + "content": "Get weather information at different times", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "func1", + "arguments": json.dumps( + {"time": "2025-08-01"} + ), # string type + } + }, + { + "function": { + "name": "func2", + "arguments": json.dumps( + {"time": 1690876800} + ), # number type + } + }, + ], + }, + ], + } + ] + return Dataset.from_list(data) + + def test_dict_and_str_args_produce_identical_output( + self, + conversation_dict_args_dataset, + conversation_str_args_dataset, + qwen3_instruct_prompt_strategy, + qwen3_tokenizer, + ): + """ + Tests that after tokenization and decoding, the outputs for both + dict and string `arguments` are exactly the same. + """ + processed_dict_args = conversation_dict_args_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages"], + ) + + processed_str_args = conversation_str_args_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages"], + ) + + decoded_prompt_from_dict = qwen3_tokenizer.decode( + processed_dict_args[0]["input_ids"] + ) + + decoded_prompt_from_str = qwen3_tokenizer.decode( + processed_str_args[0]["input_ids"] + ) + + assert decoded_prompt_from_dict == decoded_prompt_from_str, ( + f"Dict format output:\n{decoded_prompt_from_dict}\n" + f"String format output:\n{decoded_prompt_from_str}" + ) + + assert ( + processed_dict_args[0]["input_ids"] == processed_str_args[0]["input_ids"] + ), "The tokenized input_ids should be identical for dict and str arguments" + + def test_str_args_with_mixed_time_types_no_error( + self, + conversation_mixed_time_types_dataset, + qwen3_instruct_prompt_strategy, + qwen3_tokenizer, + ): + """ + Tests that when 'time' field has different types (string vs number) + in different tool calls, str format arguments don't cause errors. + """ + processed = conversation_mixed_time_types_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages"], + ) + + assert len(processed) == 1 + assert "input_ids" in processed[0] + assert len(processed[0]["input_ids"]) > 0 + + decoded = qwen3_tokenizer.decode(processed[0]["input_ids"]) + assert "2025-08-01" in decoded, "String time value should be present" + assert "1690876800" in decoded, "Number time value should be present"