Fix: Prevents merging of tool arguments during preprocessing (#2909)
This commit is contained in:
@@ -379,6 +379,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
Public method that can handle either a single prompt or a batch of prompts.
|
Public method that can handle either a single prompt or a batch of prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _remove_none_values(obj):
|
||||||
|
"""
|
||||||
|
Remove null from a dictionary-like obj or list.
|
||||||
|
These can appear due to Dataset loading causing schema merge.
|
||||||
|
See https://github.com/axolotl-ai-cloud/axolotl/pull/2909
|
||||||
|
"""
|
||||||
|
if hasattr(obj, "items"):
|
||||||
|
return {
|
||||||
|
k: _remove_none_values(v) for k, v in obj.items() if v is not None
|
||||||
|
}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_remove_none_values(elem) for elem in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
prompt = _remove_none_values(prompt)
|
||||||
|
|
||||||
if not self.is_prompt_batched(prompt) or not self.supports_batched:
|
if not self.is_prompt_batched(prompt) or not self.supports_batched:
|
||||||
return self._tokenize_single_prompt(prompt)
|
return self._tokenize_single_prompt(prompt)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,75 @@
|
|||||||
|
"""
|
||||||
|
Tests for chat template prompt strategy with schema unification for none fields
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.chat_template import StrategyLoader
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="messages_w_tools")
|
||||||
|
def fixture_messages_w_tools():
|
||||||
|
jsons = """
|
||||||
|
{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||||
|
{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||||
|
{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||||
|
""".strip().split(
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
rows = [json.loads(row) for row in jsons]
|
||||||
|
return Dataset.from_list(rows)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="qwen3_tokenizer")
|
||||||
|
def qwen3_tokenizer_fixture(
|
||||||
|
download_qwen3_half_billion_model,
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="qwen3_prompt_strategy")
|
||||||
|
def qwen3_chat_template_strategy(qwen3_tokenizer):
|
||||||
|
cfg = DictDefault(
|
||||||
|
sequence_len=2048,
|
||||||
|
chat_template="qwen3",
|
||||||
|
eot_tokens=["<|im_end|>"],
|
||||||
|
)
|
||||||
|
ds_cfg = DictDefault(
|
||||||
|
type="chat_template",
|
||||||
|
)
|
||||||
|
load = StrategyLoader()
|
||||||
|
strat = load(qwen3_tokenizer, cfg, ds_cfg)
|
||||||
|
return strat
|
||||||
|
|
||||||
|
|
||||||
|
class TestSchemaUnification:
|
||||||
|
"""
|
||||||
|
Test class on handling null fields for tool calling
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_schema_unification_single_prompt(
|
||||||
|
self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
|
||||||
|
):
|
||||||
|
for row in messages_w_tools:
|
||||||
|
inputs = qwen3_prompt_strategy.tokenize_prompt(row)
|
||||||
|
decoded = qwen3_tokenizer.decode(inputs["input_ids"])
|
||||||
|
tool_call = decoded.split("<tool_call>")[-1].split("</tool_call>")[0]
|
||||||
|
assert '"message": null' not in tool_call
|
||||||
|
assert '"theta": null' not in tool_call
|
||||||
|
|
||||||
|
def test_schema_unification_batched(
|
||||||
|
self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
|
||||||
|
):
|
||||||
|
rows = messages_w_tools.map(qwen3_prompt_strategy.tokenize_prompt, batched=True)
|
||||||
|
for row in rows:
|
||||||
|
decoded = qwen3_tokenizer.decode(row["input_ids"])
|
||||||
|
tool_call = decoded.split("<tool_call>")[-1].split("</tool_call>")[0]
|
||||||
|
assert '"message": null' not in tool_call
|
||||||
|
assert '"theta": null' not in tool_call
|
||||||
Reference in New Issue
Block a user