Fix: Prevents merging of tool arguments during preprocessing (#2909)
This commit is contained in:
@@ -379,6 +379,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
Public method that can handle either a single prompt or a batch of prompts.
|
||||
"""
|
||||
|
||||
def _remove_none_values(obj):
|
||||
"""
|
||||
Remove null from a dictionary-like obj or list.
|
||||
These can appear due to Dataset loading causing schema merge.
|
||||
See https://github.com/axolotl-ai-cloud/axolotl/pull/2909
|
||||
"""
|
||||
if hasattr(obj, "items"):
|
||||
return {
|
||||
k: _remove_none_values(v) for k, v in obj.items() if v is not None
|
||||
}
|
||||
if isinstance(obj, list):
|
||||
return [_remove_none_values(elem) for elem in obj]
|
||||
return obj
|
||||
|
||||
prompt = _remove_none_values(prompt)
|
||||
|
||||
if not self.is_prompt_batched(prompt) or not self.supports_batched:
|
||||
return self._tokenize_single_prompt(prompt)
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Tests for chat template prompt strategy with schema unification for none fields
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import StrategyLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="messages_w_tools")
|
||||
def fixture_messages_w_tools():
|
||||
jsons = """
|
||||
{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
""".strip().split(
|
||||
"\n"
|
||||
)
|
||||
rows = [json.loads(row) for row in jsons]
|
||||
return Dataset.from_list(rows)
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
): # pylint: disable=unused-argument
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_prompt_strategy")
|
||||
def qwen3_chat_template_strategy(qwen3_tokenizer):
|
||||
cfg = DictDefault(
|
||||
sequence_len=2048,
|
||||
chat_template="qwen3",
|
||||
eot_tokens=["<|im_end|>"],
|
||||
)
|
||||
ds_cfg = DictDefault(
|
||||
type="chat_template",
|
||||
)
|
||||
load = StrategyLoader()
|
||||
strat = load(qwen3_tokenizer, cfg, ds_cfg)
|
||||
return strat
|
||||
|
||||
|
||||
class TestSchemaUnification:
|
||||
"""
|
||||
Test class on handling null fields for tool calling
|
||||
"""
|
||||
|
||||
def test_schema_unification_single_prompt(
|
||||
self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
|
||||
):
|
||||
for row in messages_w_tools:
|
||||
inputs = qwen3_prompt_strategy.tokenize_prompt(row)
|
||||
decoded = qwen3_tokenizer.decode(inputs["input_ids"])
|
||||
tool_call = decoded.split("<tool_call>")[-1].split("</tool_call>")[0]
|
||||
assert '"message": null' not in tool_call
|
||||
assert '"theta": null' not in tool_call
|
||||
|
||||
def test_schema_unification_batched(
|
||||
self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
|
||||
):
|
||||
rows = messages_w_tools.map(qwen3_prompt_strategy.tokenize_prompt, batched=True)
|
||||
for row in rows:
|
||||
decoded = qwen3_tokenizer.decode(row["input_ids"])
|
||||
tool_call = decoded.split("<tool_call>")[-1].split("</tool_call>")[0]
|
||||
assert '"message": null' not in tool_call
|
||||
assert '"theta": null' not in tool_call
|
||||
Reference in New Issue
Block a user