Feat: add tool calling support via tools column (#2774)

* feat: add tool_calling field support

* fix: add tests
This commit is contained in:
NanoCode012
2025-06-09 21:42:05 -07:00
committed by GitHub
parent 92afa4fa27
commit 83632f71d8
5 changed files with 327 additions and 27 deletions

View File

@@ -1280,3 +1280,162 @@ class TestChatTemplateConfigurations:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
class TestChatTemplateToolCalling:
"""
Test class for tool calling functionality with chat templates.
"""
def test_tool_calling_with_llama4_template(
self,
llama3_tokenizer,
):
LOG.info("Testing tool calling with llama3 tokenizer and llama4 chat template")
# Create tool calling dataset
tool_calling_dataset = [
{
"tools": [
{
"type": "function",
"function": {
"name": "xml_escape",
"description": 'Replaces any "<", ">", or "&" characters in the input string with their corresponding XML entities.',
"parameters": {
"type": "object",
"properties": {
"s": {
"type": "string",
"description": "The input string to be XML-escaped.",
}
},
"required": ["s"],
},
},
},
{
"type": "function",
"function": {
"name": "multiples",
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
"parameters": {
"type": "object",
"properties": {
"number": {
"type": "integer",
"description": "The number to find multiples of.",
},
"limit": {
"type": "integer",
"description": "The upper limit for the multiples.",
},
},
"required": ["number", "limit"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "Can you help me find multiples of 5 that are less than 20?",
},
{
"role": "assistant",
"tool_calls": [
{
"type": "function",
"function": {
"name": "multiples",
"arguments": {
"number": 5,
"limit": 20,
},
},
}
],
},
{"role": "tool", "name": "multiples", "content": "5,10,15"},
{
"role": "assistant",
"content": "The multiples of 5 less than 20 are: 5, 10, and 15.",
},
],
}
]
# Setup tokenizer with llama4 chat template
tokenizer = deepcopy(llama3_tokenizer)
# Add EOS token to the tokenizer
eot_token = "<|eot_id|>"
tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]})
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template("llama4"),
message_property_mappings={"role": "role", "content": "content"},
field_messages="messages",
field_tools="tools",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
eot_tokens=[eot_token],
)
res = strategy.tokenize_prompt(tool_calling_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# Verify that the input_ids contain expected tokens
assert len(input_ids) > 0, "Input IDs should not be empty"
assert len(labels) == len(input_ids), "Labels should match input_ids length"
# Decode the full conversation to verify structure
decoded_conversation = tokenizer.decode(input_ids)
# Verify tool calling structure is present in the decoded conversation
assert (
'"type": "function",' in decoded_conversation
), "Tool type function should be in conversation"
assert (
'"name": "multiples",' in decoded_conversation
), "Tool function name should be in conversation"
assert (
'<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>'
in decoded_conversation
), "Assistant tool call should be in conversation"
assert (
"<|header_start|>ipython<|header_end|>" in decoded_conversation
), "IPython header should be in conversation"
assert (
'"5,10,15"' in decoded_conversation
), "Tool response should be in conversation"
# Get conversation turns to verify labeling
turns = strategy.get_conversation_thread(tool_calling_dataset[0])
tools = strategy._get_tools( # pylint: disable=protected-access
tool_calling_dataset[0]
)
# Check that assistant responses are properly labeled
for i, turn in enumerate(tool_calling_dataset[0]["messages"]):
if turn["role"] == "assistant":
start_idx, end_idx = strategy.find_turn(
turns=turns, turn_idx=i, tools=tools
)
assert (
start_idx != -1 and end_idx != -1
), f"Assistant turn {i} should be found"
# Verify that assistant responses have proper labels
turn_labels = labels[start_idx:end_idx]
assert all(
label != IGNORE_TOKEN_ID for label in turn_labels
), f"Assistant turn {i} should be unmasked"