feat: support training with JSON string tool arguments (#3136)

* feat: support training with JSON string tool arguments; fix PyArrow data type inconsistent error

* feat: raise error for tool call arguments decode

* Add test_chat_templates_tool_call_string_arguments.py

Add test for string arguments

* fix: change to correct qwen3 tokenizer

* fix: update docs to clarify arguments json

* chore: lint

* fix: duplicate

* chore: revert

* feat: add error to faq

* fix: remove duplicate fixture

---------

Co-authored-by: caoqinping <caoqinping@lixiang.com>
Co-authored-by: gamersover-blog <1611885128@qq.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
陈华杰
2025-09-25 13:06:21 +08:00
committed by GitHub
parent 856ff12171
commit e8b962d47f
7 changed files with 252 additions and 20 deletions

View File

@@ -177,6 +177,15 @@ def fixture_devstral_1_1_tokenizer():
return tokenizer
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'

View File

@@ -6,7 +6,6 @@ import json
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.prompt_strategies.chat_template import StrategyLoader
from axolotl.utils.dict import DictDefault
@@ -23,15 +22,6 @@ def fixture_messages_w_tools():
return Dataset.from_list(rows)
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer
@pytest.fixture(name="qwen3_prompt_strategy")
def qwen3_chat_template_strategy(qwen3_tokenizer):
cfg = DictDefault(

View File

@@ -4,7 +4,6 @@ Tests for splitting reasoning/thinking from content into separate field
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.prompt_strategies.chat_template import (
load,
@@ -56,15 +55,6 @@ def messages_w_reasoning_fixture():
)
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer
class TestSplitThinking:
"""
test class to make sure datasets with reasoning content conforms to the chat_template strategy

View File

@@ -0,0 +1,214 @@
"""
Tests for handling json tool content
"""
import json
import pytest
from datasets import Dataset
from axolotl.prompt_strategies.chat_template import (
load,
)
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="qwen3_instruct_prompt_strategy")
def qwen3_instruct_chat_template_strategy(qwen3_tokenizer):
strategy = load(
qwen3_tokenizer,
DictDefault(
{
"train_on_inputs": False,
"sequence_len": 512,
}
),
DictDefault(
{
"chat_template": "qwen3",
"message_field_role": "role",
"message_field_content": "content",
"message_property_mappings": {
"role": "role",
"content": "content",
},
"roles": {
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
"field_messages": "messages",
}
),
)
return strategy
class TestQwen3IdenticalConversationArgs:
"""
Test Qwen3 tools is identical between JSON and dict
"""
@pytest.fixture(name="conversation_dict_args_dataset")
def fixture_conversation_dict_args_dataset(self):
"""
Provides a dataset with conversation where arguments is a dict.
"""
user_content = "What is the weather in Boston?"
function_name = "get_current_weather"
arguments_dict = {"location": "Boston, MA", "unit": "celsius"}
data = [
{
"messages": [
{"role": "user", "content": user_content},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": function_name,
"arguments": arguments_dict, # dict格式
}
}
],
},
],
}
]
return Dataset.from_list(data)
@pytest.fixture(name="conversation_str_args_dataset")
def fixture_conversation_str_args_dataset(self):
"""
Provides a dataset with conversation where arguments is a JSON string.
"""
user_content = "What is the weather in Boston?"
function_name = "get_current_weather"
arguments_dict = {"location": "Boston, MA", "unit": "celsius"}
arguments_str = json.dumps(arguments_dict)
data = [
{
"messages": [
{"role": "user", "content": user_content},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": function_name,
"arguments": arguments_str, # str格式
}
}
],
},
],
}
]
return Dataset.from_list(data)
@pytest.fixture(name="conversation_mixed_time_types_dataset")
def fixture_conversation_mixed_time_types_dataset(self):
"""
Provides a dataset where 'time' field has different types in different tool calls.
"""
data = [
{
"messages": [
{
"role": "user",
"content": "Get weather information at different times",
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "func1",
"arguments": json.dumps(
{"time": "2025-08-01"}
), # string type
}
},
{
"function": {
"name": "func2",
"arguments": json.dumps(
{"time": 1690876800}
), # number type
}
},
],
},
],
}
]
return Dataset.from_list(data)
def test_dict_and_str_args_produce_identical_output(
self,
conversation_dict_args_dataset,
conversation_str_args_dataset,
qwen3_instruct_prompt_strategy,
qwen3_tokenizer,
):
"""
Tests that after tokenization and decoding, the outputs for both
dict and string `arguments` are exactly the same.
"""
processed_dict_args = conversation_dict_args_dataset.map(
qwen3_instruct_prompt_strategy.tokenize_prompt,
batched=True,
remove_columns=["messages"],
)
processed_str_args = conversation_str_args_dataset.map(
qwen3_instruct_prompt_strategy.tokenize_prompt,
batched=True,
remove_columns=["messages"],
)
decoded_prompt_from_dict = qwen3_tokenizer.decode(
processed_dict_args[0]["input_ids"]
)
decoded_prompt_from_str = qwen3_tokenizer.decode(
processed_str_args[0]["input_ids"]
)
assert decoded_prompt_from_dict == decoded_prompt_from_str, (
f"Dict format output:\n{decoded_prompt_from_dict}\n"
f"String format output:\n{decoded_prompt_from_str}"
)
assert (
processed_dict_args[0]["input_ids"] == processed_str_args[0]["input_ids"]
), "The tokenized input_ids should be identical for dict and str arguments"
def test_str_args_with_mixed_time_types_no_error(
self,
conversation_mixed_time_types_dataset,
qwen3_instruct_prompt_strategy,
qwen3_tokenizer,
):
"""
Tests that when 'time' field has different types (string vs number)
in different tool calls, str format arguments don't cause errors.
"""
processed = conversation_mixed_time_types_dataset.map(
qwen3_instruct_prompt_strategy.tokenize_prompt,
batched=True,
remove_columns=["messages"],
)
assert len(processed) == 1
assert "input_ids" in processed[0]
assert len(processed[0]["input_ids"]) > 0
decoded = qwen3_tokenizer.decode(processed[0]["input_ids"])
assert "2025-08-01" in decoded, "String time value should be present"
assert "1690876800" in decoded, "Number time value should be present"