Feat: add tool calling support via tools column (#2774)
* feat: add tool_calling field support * fix: add tests
This commit is contained in:
@@ -1280,3 +1280,162 @@ class TestChatTemplateConfigurations:
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
|
||||
|
||||
|
||||
class TestChatTemplateToolCalling:
|
||||
"""
|
||||
Test class for tool calling functionality with chat templates.
|
||||
"""
|
||||
|
||||
def test_tool_calling_with_llama4_template(
|
||||
self,
|
||||
llama3_tokenizer,
|
||||
):
|
||||
LOG.info("Testing tool calling with llama3 tokenizer and llama4 chat template")
|
||||
|
||||
# Create tool calling dataset
|
||||
tool_calling_dataset = [
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "xml_escape",
|
||||
"description": 'Replaces any "<", ">", or "&" characters in the input string with their corresponding XML entities.',
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"s": {
|
||||
"type": "string",
|
||||
"description": "The input string to be XML-escaped.",
|
||||
}
|
||||
},
|
||||
"required": ["s"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiples",
|
||||
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"number": {
|
||||
"type": "integer",
|
||||
"description": "The number to find multiples of.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "The upper limit for the multiples.",
|
||||
},
|
||||
},
|
||||
"required": ["number", "limit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you help me find multiples of 5 that are less than 20?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiples",
|
||||
"arguments": {
|
||||
"number": 5,
|
||||
"limit": 20,
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "name": "multiples", "content": "5,10,15"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The multiples of 5 less than 20 are: 5, 10, and 15.",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Setup tokenizer with llama4 chat template
|
||||
tokenizer = deepcopy(llama3_tokenizer)
|
||||
|
||||
# Add EOS token to the tokenizer
|
||||
eot_token = "<|eot_id|>"
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]})
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template("llama4"),
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
field_messages="messages",
|
||||
field_tools="tools",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
eot_tokens=[eot_token],
|
||||
)
|
||||
|
||||
res = strategy.tokenize_prompt(tool_calling_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
|
||||
# Verify that the input_ids contain expected tokens
|
||||
assert len(input_ids) > 0, "Input IDs should not be empty"
|
||||
assert len(labels) == len(input_ids), "Labels should match input_ids length"
|
||||
|
||||
# Decode the full conversation to verify structure
|
||||
decoded_conversation = tokenizer.decode(input_ids)
|
||||
|
||||
# Verify tool calling structure is present in the decoded conversation
|
||||
assert (
|
||||
'"type": "function",' in decoded_conversation
|
||||
), "Tool type function should be in conversation"
|
||||
assert (
|
||||
'"name": "multiples",' in decoded_conversation
|
||||
), "Tool function name should be in conversation"
|
||||
|
||||
assert (
|
||||
'<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>'
|
||||
in decoded_conversation
|
||||
), "Assistant tool call should be in conversation"
|
||||
assert (
|
||||
"<|header_start|>ipython<|header_end|>" in decoded_conversation
|
||||
), "IPython header should be in conversation"
|
||||
assert (
|
||||
'"5,10,15"' in decoded_conversation
|
||||
), "Tool response should be in conversation"
|
||||
|
||||
# Get conversation turns to verify labeling
|
||||
turns = strategy.get_conversation_thread(tool_calling_dataset[0])
|
||||
tools = strategy._get_tools( # pylint: disable=protected-access
|
||||
tool_calling_dataset[0]
|
||||
)
|
||||
|
||||
# Check that assistant responses are properly labeled
|
||||
for i, turn in enumerate(tool_calling_dataset[0]["messages"]):
|
||||
if turn["role"] == "assistant":
|
||||
start_idx, end_idx = strategy.find_turn(
|
||||
turns=turns, turn_idx=i, tools=tools
|
||||
)
|
||||
|
||||
assert (
|
||||
start_idx != -1 and end_idx != -1
|
||||
), f"Assistant turn {i} should be found"
|
||||
|
||||
# Verify that assistant responses have proper labels
|
||||
turn_labels = labels[start_idx:end_idx]
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in turn_labels
|
||||
), f"Assistant turn {i} should be unmasked"
|
||||
|
||||
Reference in New Issue
Block a user