diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index a9d26a650..ced8c8da6 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -379,6 +379,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): Public method that can handle either a single prompt or a batch of prompts. """ + def _remove_none_values(obj): + """ + Remove null from a dictionary-like obj or list. + These can appear due to Dataset loading causing schema merge. + See https://github.com/axolotl-ai-cloud/axolotl/pull/2909 + """ + if hasattr(obj, "items"): + return { + k: _remove_none_values(v) for k, v in obj.items() if v is not None + } + if isinstance(obj, list): + return [_remove_none_values(elem) for elem in obj] + return obj + + prompt = _remove_none_values(prompt) + if not self.is_prompt_batched(prompt) or not self.supports_batched: return self._tokenize_single_prompt(prompt) diff --git a/tests/prompt_strategies/test_chat_template_ds_schema_unification.py b/tests/prompt_strategies/test_chat_template_ds_schema_unification.py new file mode 100644 index 000000000..502efae4b --- /dev/null +++ b/tests/prompt_strategies/test_chat_template_ds_schema_unification.py @@ -0,0 +1,75 @@ +""" +Tests for chat template prompt strategy with schema unification for none fields +""" + +import json + +import pytest +from datasets import Dataset +from transformers import AutoTokenizer + +from axolotl.prompt_strategies.chat_template import StrategyLoader +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="messages_w_tools") +def fixture_messages_w_tools(): + jsons = """ +{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false} +{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false} +{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false} + """.strip().split( + "\n" + ) + rows = [json.loads(row) for row in jsons] + return Dataset.from_list(rows) + + +@pytest.fixture(name="qwen3_tokenizer") +def qwen3_tokenizer_fixture( + download_qwen3_half_billion_model, +): # pylint: disable=unused-argument + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") + + return tokenizer + + +@pytest.fixture(name="qwen3_prompt_strategy") +def qwen3_chat_template_strategy(qwen3_tokenizer): + cfg = DictDefault( + sequence_len=2048, + chat_template="qwen3", + eot_tokens=["<|im_end|>"], + ) + ds_cfg = DictDefault( + type="chat_template", + ) + load = StrategyLoader() + strat = load(qwen3_tokenizer, cfg, ds_cfg) + return strat + + +class TestSchemaUnification: + """ + Test class on handling null fields for tool calling + """ + + def test_schema_unification_single_prompt( + self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer + ): + for row in messages_w_tools: + inputs = qwen3_prompt_strategy.tokenize_prompt(row) + decoded = qwen3_tokenizer.decode(inputs["input_ids"]) + tool_call = decoded.split("")[-1].split("")[0] + assert '"message": null' not in tool_call + assert '"theta": null' not in tool_call + + def test_schema_unification_batched( + self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer + ): + rows = messages_w_tools.map(qwen3_prompt_strategy.tokenize_prompt, batched=True) + for row in rows: + decoded = qwen3_tokenizer.decode(row["input_ids"]) + tool_call = decoded.split("")[-1].split("")[0] + assert '"message": null' not in tool_call + assert '"theta": null' not in tool_call