_get_tools in ChatTemplateStrategy : function "parameters" can be dict or string (#3238)

* When training of function calls, "tools" elements of a dataset can contain same parameter name but with different types. Datasets fails to load such training set. This fix allows "parameters" element of function call to be string( by running "json.dumps" in preparation of training data set). The _get_tools function will iterate over tool definitions, if "parameters" element is dict, it will keep that way, if it is a string, it will be converted to dict by invoking "json.loads" on string value.

* feat: add doc on tool parameters json loading

* feat: add tests for parameters json string

---------

Co-authored-by: ezlotnik <eduard_zlotnik@intuit.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Eduard Zl
2025-11-11 04:04:28 +02:00
committed by GitHub
parent 11eb36585a
commit b54f9c942b
3 changed files with 317 additions and 2 deletions

View File

@@ -218,6 +218,13 @@ If you have tool arguments with same name but different dtypes (like `"time": st
```
"arguments": "{\"...\": \"...\"}"
```
The same is applicable for tool parameters.
```
"parameters": "{\"...\": \"...\"}"
```
:::
Example config for Llama4:

View File

@@ -823,6 +823,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return None
if isinstance(tools, list):
# Process each tool to handle JSON string parameters
for tool in tools:
if isinstance(tool, dict) and "function" in tool:
function = tool["function"]
if "parameters" in function:
params = function["parameters"]
if isinstance(params, str):
try:
function["parameters"] = json.loads(params)
except json.JSONDecodeError as e:
LOG.error(
f"Error parsing tool parameters as JSON. "
f"Function: {function.get('name', 'unknown')}, "
f"Parameters string: {params!r}, "
f"Error: {e}"
)
raise
return tools
raise ValueError(

View File

@@ -69,7 +69,7 @@ class TestQwen3IdenticalConversationArgs:
{
"function": {
"name": function_name,
"arguments": arguments_dict, # dict格式
"arguments": arguments_dict, # dict
}
}
],
@@ -100,7 +100,7 @@ class TestQwen3IdenticalConversationArgs:
{
"function": {
"name": function_name,
"arguments": arguments_str, # str格式
"arguments": arguments_str, # str
}
}
],
@@ -212,3 +212,294 @@ class TestQwen3IdenticalConversationArgs:
decoded = qwen3_tokenizer.decode(processed[0]["input_ids"])
assert "2025-08-01" in decoded, "String time value should be present"
assert "1690876800" in decoded, "Number time value should be present"
class TestQwen3IdenticalToolsParameters:
"""
Test Qwen3 tools parameters handling is identical between JSON string and dict
"""
@pytest.fixture(name="tools_dict_params_dataset")
def fixture_tools_dict_params_dataset(self):
"""
Provides a dataset with tools where parameters is a dict.
"""
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
data = [
{
"tools": tools,
"messages": [
{"role": "user", "content": "What's the weather?"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "Boston, MA"},
},
}
],
},
{
"role": "tool",
"name": "get_weather",
"content": "72°F and sunny",
},
],
}
]
return Dataset.from_list(data)
@pytest.fixture(name="tools_str_params_dataset")
def fixture_tools_str_params_dataset(self):
"""
Provides a dataset with tools where parameters is a JSON string.
"""
parameters_dict = {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
}
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": json.dumps(parameters_dict),
},
}
]
data = [
{
"tools": tools,
"messages": [
{"role": "user", "content": "What's the weather?"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "Boston, MA"},
},
}
],
},
{
"role": "tool",
"name": "get_weather",
"content": "72°F and sunny",
},
],
}
]
return Dataset.from_list(data)
@pytest.fixture(name="tools_mixed_type_params_dataset")
def fixture_tools_mixed_type_params_dataset(self):
"""
Provides a dataset where different tools have the same parameter name with different types.
This tests that JSON string format prevents casting issues.
"""
tools = [
{
"type": "function",
"function": {
"name": "tool_with_string_arg",
"description": "Tool expecting string argument",
"parameters": json.dumps(
{
"type": "object",
"properties": {
"arg1": {
"type": "string",
"description": "A string parameter",
}
},
"required": ["arg1"],
}
),
},
},
{
"type": "function",
"function": {
"name": "tool_with_number_arg",
"description": "Tool expecting number argument",
"parameters": json.dumps(
{
"type": "object",
"properties": {
"arg1": {
"type": "number",
"description": "A numeric parameter",
}
},
"required": ["arg1"],
}
),
},
},
]
data = [
{
"tools": tools,
"messages": [
{"role": "user", "content": "Use both tools"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "tool_with_string_arg",
"arguments": json.dumps({"arg1": "hello"}),
},
},
{
"type": "function",
"function": {
"name": "tool_with_number_arg",
"arguments": json.dumps({"arg1": 42}),
},
},
],
},
],
}
]
return Dataset.from_list(data)
def test_dict_and_str_params_produce_equivalent_output(
self,
tools_dict_params_dataset,
tools_str_params_dataset,
qwen3_instruct_prompt_strategy,
qwen3_tokenizer,
):
"""
Tests that after tokenization and decoding, the outputs for both
dict and string `parameters` in tools are semantically equivalent.
"""
import re
processed_dict_params = tools_dict_params_dataset.map(
qwen3_instruct_prompt_strategy.tokenize_prompt,
batched=True,
remove_columns=["messages", "tools"],
)
processed_str_params = tools_str_params_dataset.map(
qwen3_instruct_prompt_strategy.tokenize_prompt,
batched=True,
remove_columns=["messages", "tools"],
)
decoded_dict = qwen3_tokenizer.decode(processed_dict_params[0]["input_ids"])
decoded_str = qwen3_tokenizer.decode(processed_str_params[0]["input_ids"])
# Extract the tool JSON from both outputs
tools_pattern = r"<tools>\n(.*?)\n</tools>"
dict_tools_match = re.search(tools_pattern, decoded_dict, re.DOTALL)
str_tools_match = re.search(tools_pattern, decoded_str, re.DOTALL)
assert dict_tools_match and str_tools_match, (
"Could not find tools section in output"
)
# Parse the JSON and compare as objects (order-independent)
dict_tools_json = json.loads(dict_tools_match.group(1))
str_tools_json = json.loads(str_tools_match.group(1))
# Deep comparison of the tool definitions
assert dict_tools_json == str_tools_json, (
f"Tool definitions are not equivalent:\n"
f"Dict format: {json.dumps(dict_tools_json, indent=2)}\n"
f"String format: {json.dumps(str_tools_json, indent=2)}"
)
# Verify the rest of the structure is the same (excluding the tools JSON part)
# The tools JSON can have different order, so we remove it here.
dict_normalized = re.sub(
r"<tools>.*?</tools>",
"<tools>TOOLS_PLACEHOLDER</tools>",
decoded_dict,
flags=re.DOTALL,
)
str_normalized = re.sub(
r"<tools>.*?</tools>",
"<tools>TOOLS_PLACEHOLDER</tools>",
decoded_str,
flags=re.DOTALL,
)
assert dict_normalized == str_normalized, (
"The overall structure differs between dict and string parameter formats"
)
def test_str_params_with_mixed_types_no_error(
self,
tools_mixed_type_params_dataset,
qwen3_instruct_prompt_strategy,
qwen3_tokenizer,
):
"""
Tests that when different tools have the same parameter name with different types,
JSON string format for parameters doesn't cause casting errors.
"""
processed = tools_mixed_type_params_dataset.map(
qwen3_instruct_prompt_strategy.tokenize_prompt,
batched=True,
remove_columns=["messages", "tools"],
)
assert len(processed) == 1
assert "input_ids" in processed[0]
assert len(processed[0]["input_ids"]) > 0
decoded = qwen3_tokenizer.decode(processed[0]["input_ids"])
# Check that both tools are present
assert "tool_with_string_arg" in decoded
assert "tool_with_number_arg" in decoded
# Check that both argument values are present
assert "hello" in decoded
assert "42" in decoded