feat: support training with JSON string tool arguments (#3136)
* feat: support training with JSON string tool arguments; fix PyArrow data type inconsistent error * feat: raise error for tool call arguments decode * Add test_chat_templates_tool_call_string_arguments.py Add test for string arguments * fix: change to correct qwen3 tokenizer * fix: update docs to clarify arguments json * chore: lint * fix: duplicate * chore: revert * feat: add error to faq * fix: remove duplicate fixture --------- Co-authored-by: caoqinping <caoqinping@lixiang.com> Co-authored-by: gamersover-blog <1611885128@qq.com> Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -177,6 +177,15 @@ def fixture_devstral_1_1_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
|
||||
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
|
||||
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'
|
||||
|
||||
@@ -6,7 +6,6 @@ import json
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import StrategyLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -23,15 +22,6 @@ def fixture_messages_w_tools():
|
||||
return Dataset.from_list(rows)
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_prompt_strategy")
|
||||
def qwen3_chat_template_strategy(qwen3_tokenizer):
|
||||
cfg = DictDefault(
|
||||
|
||||
@@ -4,7 +4,6 @@ Tests for splitting reasoning/thinking from content into separate field
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
load,
|
||||
@@ -56,15 +55,6 @@ def messages_w_reasoning_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestSplitThinking:
|
||||
"""
|
||||
test class to make sure datasets with reasoning content conforms to the chat_template strategy
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Tests for handling json tool content
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
load,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_instruct_prompt_strategy")
|
||||
def qwen3_instruct_chat_template_strategy(qwen3_tokenizer):
|
||||
strategy = load(
|
||||
qwen3_tokenizer,
|
||||
DictDefault(
|
||||
{
|
||||
"train_on_inputs": False,
|
||||
"sequence_len": 512,
|
||||
}
|
||||
),
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "qwen3",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"message_property_mappings": {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
"roles": {
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
},
|
||||
"field_messages": "messages",
|
||||
}
|
||||
),
|
||||
)
|
||||
return strategy
|
||||
|
||||
|
||||
class TestQwen3IdenticalConversationArgs:
|
||||
"""
|
||||
Test Qwen3 tools is identical between JSON and dict
|
||||
"""
|
||||
|
||||
@pytest.fixture(name="conversation_dict_args_dataset")
|
||||
def fixture_conversation_dict_args_dataset(self):
|
||||
"""
|
||||
Provides a dataset with conversation where arguments is a dict.
|
||||
"""
|
||||
user_content = "What is the weather in Boston?"
|
||||
function_name = "get_current_weather"
|
||||
arguments_dict = {"location": "Boston, MA", "unit": "celsius"}
|
||||
|
||||
data = [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": user_content},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments_dict, # dict格式
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
return Dataset.from_list(data)
|
||||
|
||||
@pytest.fixture(name="conversation_str_args_dataset")
|
||||
def fixture_conversation_str_args_dataset(self):
|
||||
"""
|
||||
Provides a dataset with conversation where arguments is a JSON string.
|
||||
"""
|
||||
user_content = "What is the weather in Boston?"
|
||||
function_name = "get_current_weather"
|
||||
arguments_dict = {"location": "Boston, MA", "unit": "celsius"}
|
||||
arguments_str = json.dumps(arguments_dict)
|
||||
|
||||
data = [
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": user_content},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments_str, # str格式
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
return Dataset.from_list(data)
|
||||
|
||||
@pytest.fixture(name="conversation_mixed_time_types_dataset")
|
||||
def fixture_conversation_mixed_time_types_dataset(self):
|
||||
"""
|
||||
Provides a dataset where 'time' field has different types in different tool calls.
|
||||
"""
|
||||
data = [
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Get weather information at different times",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "func1",
|
||||
"arguments": json.dumps(
|
||||
{"time": "2025-08-01"}
|
||||
), # string type
|
||||
}
|
||||
},
|
||||
{
|
||||
"function": {
|
||||
"name": "func2",
|
||||
"arguments": json.dumps(
|
||||
{"time": 1690876800}
|
||||
), # number type
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
return Dataset.from_list(data)
|
||||
|
||||
def test_dict_and_str_args_produce_identical_output(
|
||||
self,
|
||||
conversation_dict_args_dataset,
|
||||
conversation_str_args_dataset,
|
||||
qwen3_instruct_prompt_strategy,
|
||||
qwen3_tokenizer,
|
||||
):
|
||||
"""
|
||||
Tests that after tokenization and decoding, the outputs for both
|
||||
dict and string `arguments` are exactly the same.
|
||||
"""
|
||||
processed_dict_args = conversation_dict_args_dataset.map(
|
||||
qwen3_instruct_prompt_strategy.tokenize_prompt,
|
||||
batched=True,
|
||||
remove_columns=["messages"],
|
||||
)
|
||||
|
||||
processed_str_args = conversation_str_args_dataset.map(
|
||||
qwen3_instruct_prompt_strategy.tokenize_prompt,
|
||||
batched=True,
|
||||
remove_columns=["messages"],
|
||||
)
|
||||
|
||||
decoded_prompt_from_dict = qwen3_tokenizer.decode(
|
||||
processed_dict_args[0]["input_ids"]
|
||||
)
|
||||
|
||||
decoded_prompt_from_str = qwen3_tokenizer.decode(
|
||||
processed_str_args[0]["input_ids"]
|
||||
)
|
||||
|
||||
assert decoded_prompt_from_dict == decoded_prompt_from_str, (
|
||||
f"Dict format output:\n{decoded_prompt_from_dict}\n"
|
||||
f"String format output:\n{decoded_prompt_from_str}"
|
||||
)
|
||||
|
||||
assert (
|
||||
processed_dict_args[0]["input_ids"] == processed_str_args[0]["input_ids"]
|
||||
), "The tokenized input_ids should be identical for dict and str arguments"
|
||||
|
||||
def test_str_args_with_mixed_time_types_no_error(
|
||||
self,
|
||||
conversation_mixed_time_types_dataset,
|
||||
qwen3_instruct_prompt_strategy,
|
||||
qwen3_tokenizer,
|
||||
):
|
||||
"""
|
||||
Tests that when 'time' field has different types (string vs number)
|
||||
in different tool calls, str format arguments don't cause errors.
|
||||
"""
|
||||
processed = conversation_mixed_time_types_dataset.map(
|
||||
qwen3_instruct_prompt_strategy.tokenize_prompt,
|
||||
batched=True,
|
||||
remove_columns=["messages"],
|
||||
)
|
||||
|
||||
assert len(processed) == 1
|
||||
assert "input_ids" in processed[0]
|
||||
assert len(processed[0]["input_ids"]) > 0
|
||||
|
||||
decoded = qwen3_tokenizer.decode(processed[0]["input_ids"])
|
||||
assert "2025-08-01" in decoded, "String time value should be present"
|
||||
assert "1690876800" in decoded, "Number time value should be present"
|
||||
Reference in New Issue
Block a user