diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index 870a2b67d..34fde45fb 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -218,6 +218,13 @@ If you have tool arguments with same name but different dtypes (like `"time": st ``` "arguments": "{\"...\": \"...\"}" ``` + +The same is applicable for tool parameters. + +``` +"parameters": "{\"...\": \"...\"}" +``` + ::: Example config for Llama4: diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index f4dcbd7cd..28155810f 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -823,6 +823,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return None if isinstance(tools, list): + # Process each tool to handle JSON string parameters + for tool in tools: + if isinstance(tool, dict) and "function" in tool: + function = tool["function"] + if "parameters" in function: + params = function["parameters"] + if isinstance(params, str): + try: + function["parameters"] = json.loads(params) + except json.JSONDecodeError as e: + LOG.error( + f"Error parsing tool parameters as JSON. " + f"Function: {function.get('name', 'unknown')}, " + f"Parameters string: {params!r}, " + f"Error: {e}" + ) + raise return tools raise ValueError( diff --git a/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py b/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py index 7de21b940..5866cc367 100644 --- a/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py +++ b/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py @@ -69,7 +69,7 @@ class TestQwen3IdenticalConversationArgs: { "function": { "name": function_name, - "arguments": arguments_dict, # dict格式 + "arguments": arguments_dict, # dict } } ], @@ -100,7 +100,7 @@ class TestQwen3IdenticalConversationArgs: { "function": { "name": function_name, - "arguments": arguments_str, # str格式 + "arguments": arguments_str, # str } } ], @@ -212,3 +212,294 @@ class TestQwen3IdenticalConversationArgs: decoded = qwen3_tokenizer.decode(processed[0]["input_ids"]) assert "2025-08-01" in decoded, "String time value should be present" assert "1690876800" in decoded, "Number time value should be present" + + +class TestQwen3IdenticalToolsParameters: + """ + Test Qwen3 tools parameters handling is identical between JSON string and dict + """ + + @pytest.fixture(name="tools_dict_params_dataset") + def fixture_tools_dict_params_dataset(self): + """ + Provides a dataset with tools where parameters is a dict. + """ + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + }, + } + ] + + data = [ + { + "tools": tools, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_weather", + "arguments": {"location": "Boston, MA"}, + }, + } + ], + }, + { + "role": "tool", + "name": "get_weather", + "content": "72°F and sunny", + }, + ], + } + ] + return Dataset.from_list(data) + + @pytest.fixture(name="tools_str_params_dataset") + def fixture_tools_str_params_dataset(self): + """ + Provides a dataset with tools where parameters is a JSON string. + """ + parameters_dict = { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + } + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information", + "parameters": json.dumps(parameters_dict), + }, + } + ] + + data = [ + { + "tools": tools, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_weather", + "arguments": {"location": "Boston, MA"}, + }, + } + ], + }, + { + "role": "tool", + "name": "get_weather", + "content": "72°F and sunny", + }, + ], + } + ] + return Dataset.from_list(data) + + @pytest.fixture(name="tools_mixed_type_params_dataset") + def fixture_tools_mixed_type_params_dataset(self): + """ + Provides a dataset where different tools have the same parameter name with different types. + This tests that JSON string format prevents casting issues. + """ + tools = [ + { + "type": "function", + "function": { + "name": "tool_with_string_arg", + "description": "Tool expecting string argument", + "parameters": json.dumps( + { + "type": "object", + "properties": { + "arg1": { + "type": "string", + "description": "A string parameter", + } + }, + "required": ["arg1"], + } + ), + }, + }, + { + "type": "function", + "function": { + "name": "tool_with_number_arg", + "description": "Tool expecting number argument", + "parameters": json.dumps( + { + "type": "object", + "properties": { + "arg1": { + "type": "number", + "description": "A numeric parameter", + } + }, + "required": ["arg1"], + } + ), + }, + }, + ] + + data = [ + { + "tools": tools, + "messages": [ + {"role": "user", "content": "Use both tools"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "tool_with_string_arg", + "arguments": json.dumps({"arg1": "hello"}), + }, + }, + { + "type": "function", + "function": { + "name": "tool_with_number_arg", + "arguments": json.dumps({"arg1": 42}), + }, + }, + ], + }, + ], + } + ] + return Dataset.from_list(data) + + def test_dict_and_str_params_produce_equivalent_output( + self, + tools_dict_params_dataset, + tools_str_params_dataset, + qwen3_instruct_prompt_strategy, + qwen3_tokenizer, + ): + """ + Tests that after tokenization and decoding, the outputs for both + dict and string `parameters` in tools are semantically equivalent. + """ + import re + + processed_dict_params = tools_dict_params_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages", "tools"], + ) + + processed_str_params = tools_str_params_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages", "tools"], + ) + + decoded_dict = qwen3_tokenizer.decode(processed_dict_params[0]["input_ids"]) + decoded_str = qwen3_tokenizer.decode(processed_str_params[0]["input_ids"]) + + # Extract the tool JSON from both outputs + tools_pattern = r"\n(.*?)\n" + + dict_tools_match = re.search(tools_pattern, decoded_dict, re.DOTALL) + str_tools_match = re.search(tools_pattern, decoded_str, re.DOTALL) + + assert dict_tools_match and str_tools_match, ( + "Could not find tools section in output" + ) + + # Parse the JSON and compare as objects (order-independent) + dict_tools_json = json.loads(dict_tools_match.group(1)) + str_tools_json = json.loads(str_tools_match.group(1)) + + # Deep comparison of the tool definitions + assert dict_tools_json == str_tools_json, ( + f"Tool definitions are not equivalent:\n" + f"Dict format: {json.dumps(dict_tools_json, indent=2)}\n" + f"String format: {json.dumps(str_tools_json, indent=2)}" + ) + + # Verify the rest of the structure is the same (excluding the tools JSON part) + # The tools JSON can have different order, so we remove it here. + dict_normalized = re.sub( + r".*?", + "TOOLS_PLACEHOLDER", + decoded_dict, + flags=re.DOTALL, + ) + str_normalized = re.sub( + r".*?", + "TOOLS_PLACEHOLDER", + decoded_str, + flags=re.DOTALL, + ) + + assert dict_normalized == str_normalized, ( + "The overall structure differs between dict and string parameter formats" + ) + + def test_str_params_with_mixed_types_no_error( + self, + tools_mixed_type_params_dataset, + qwen3_instruct_prompt_strategy, + qwen3_tokenizer, + ): + """ + Tests that when different tools have the same parameter name with different types, + JSON string format for parameters doesn't cause casting errors. + """ + processed = tools_mixed_type_params_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages", "tools"], + ) + + assert len(processed) == 1 + assert "input_ids" in processed[0] + assert len(processed[0]["input_ids"]) > 0 + + decoded = qwen3_tokenizer.decode(processed[0]["input_ids"]) + + # Check that both tools are present + assert "tool_with_string_arg" in decoded + assert "tool_with_number_arg" in decoded + + # Check that both argument values are present + assert "hello" in decoded + assert "42" in decoded