Feat: Add Magistral and mistral-common tokenizer support (#2780)

This commit is contained in:
NanoCode012
2025-06-12 16:18:33 -07:00
committed by GitHub
parent ace9287c96
commit eac4a61f55
15 changed files with 1213 additions and 14 deletions

View File

@@ -150,6 +150,14 @@ def fixture_gemma2_tokenizer():
return tokenizer
@pytest.fixture(name="magistral_tokenizer")
def fixture_magistral_tokenizer():
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Magistral-Small-2506")
return tokenizer
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'

View File

@@ -0,0 +1,290 @@
"""Test chat templates for mistral-common wrapper tokenizer"""
import unittest
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"):
# pylint: disable=duplicate-code
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
# check bos, eos, pad, unk are accessible properties
assert magistral_tokenizer.bos_token_id == 1
assert magistral_tokenizer.eos_token_id == 2
assert magistral_tokenizer.pad_token_id == 11
assert magistral_tokenizer.unk_token_id == 0
assert magistral_tokenizer.pad_token == "<pad>"
assert magistral_tokenizer.eos_token == "</s>"
assert magistral_tokenizer.bos_token == "<s>"
assert magistral_tokenizer.unk_token == "<unk>"
strategy = MistralStrategy(
MistralPrompter(
magistral_tokenizer,
chat_template=None,
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=magistral_tokenizer,
train_on_inputs=False,
train_on_eos="turn",
sequence_len=512,
roles_to_train=["assistant"],
)
# test chat template masking without system prompt
res = strategy.tokenize_prompt(
{
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
]
}
)
assert res["input_ids"] == [
1, # bos
3, # [INST]
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
4, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
assert res["labels"] == [
-100, # bos
-100, # [INST]
-100, # Hello
-100, # ,
-100, # how
-100, # are
-100, # you
-100, # ?
-100, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
# test chat template masking with system prompt
res = strategy.tokenize_prompt(
{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
]
}
)
assert res["input_ids"] == [
1, # bos
17, # [SYSTEM_PROMPT]
4568, # You
1584, # are
1261, # a
20351, # helpful
27089, # assistant
1046, # .
18, # [/SYSTEM_PROMPT]
3, # [INST]
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
4, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
assert res["labels"] == [
-100, # bos
-100, # [SYSTEM_PROMPT]
-100, # You
-100, # are
-100, # a
-100, # helpful
-100, # assistant
-100, # .
-100, # [/SYSTEM_PROMPT]
-100, # [INST]
-100, # Hello
-100, # ,
-100, # how
-100, # are
-100, # you
-100, # ?
-100, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
# test chat template with tools
res = strategy.tokenize_prompt(
{
"tools": [
{
"type": "function",
"function": {
"name": "multiples",
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
"parameters": {
"type": "object",
"properties": {
"number": {
"type": "integer",
"description": "The number to find multiples of.",
},
"limit": {
"type": "integer",
"description": "The upper limit for the multiples.",
},
},
"required": ["number", "limit"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "Hey, can you give me a breakdown of how to throw an awesome themed party? Like, what themes work best, and how can I set everything up to really wow my guests? I want some ideas on decorations, food, and activities that will make the party unforgettable!",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call12345",
"type": "function",
"function": {
"name": "multiples",
"arguments": {
"number": 16,
"limit": 2,
},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call12345",
"name": "multiples",
"content": "1,2",
},
{"role": "assistant", "content": "The multiples of 16 is 1 and 2."},
],
}
)
# fmt: off
assert res["input_ids"] == [
1, # bos
5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt
3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user
9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling
7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8, # tool result
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
2 # eos
]
assert res["labels"] == [
-100, # bos
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt
9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool result
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
2 # eos
]
# fmt: on
# test chat template with tokenize=False
res = magistral_tokenizer.apply_chat_template(
[
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
],
tokenize=False,
)
assert res == "<s>[INST]Hello, how are you?[/INST]I'm doing great, thank you!</s>"
# test encode
res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=True)
assert res == [
1, # bos
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
2, # eos
]
# test decode no skip special tokens
decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=False)
assert decoded_res == "<s>Hello, how are you?</s>"
# test decode skip special tokens
decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=True)
assert decoded_res == "Hello, how are you?"
# test encode no special tokens
res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=False)
assert res == [
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
]
# test convert ids to tokens
res = magistral_tokenizer.convert_ids_to_tokens(res)
# spacing are needed as we are converting without decoding
assert res == ["Hello", ",", " how", " are", " you", "?"]
if __name__ == "__main__":
unittest.main()