Files
axolotl/tests/prompt_strategies/test_chat_templates_mistral.py
Dan Saunders 79ddaebe9a Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
2025-08-23 23:37:33 -04:00

852 lines
31 KiB
Python

"""Test chat templates for mistral-common wrapper tokenizer"""
import unittest
from typing import TYPE_CHECKING
import pytest
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from axolotl.utils.mistral import HFMistralTokenizer
# fmt: off
@pytest.mark.parametrize(
("tokenizer_str", "assistant_toolcall_ids", "tool_result_ids"),
(
("magistral_tokenizer", (9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2), (7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8)),
("devstral_tokenizer", (9, 1091, 19227, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 61906, 2811, 16753, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 4179, 1429, 1327, 2811, 1429, 19881, 1049, 1050, 1051, 1052, 1053, 1034, 27028, 2), (7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8)),
("devstral_1_1_tokenizer", (9, 44627, 3684, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2,), (7, 1049, 1044, 1050, 8)),
)
)
# fmt: on
def test_mistral_chat_template(
tokenizer_str: str,
assistant_toolcall_ids: tuple[int, ...],
tool_result_ids: tuple[int, ...],
request: pytest.FixtureRequest,
):
"""Test chat template with the Magistral/Devstral tokenizer"""
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str)
# check bos, eos, pad, unk are accessible properties
assert tokenizer.bos_token_id == 1
assert tokenizer.eos_token_id == 2
assert tokenizer.pad_token_id == 11
assert tokenizer.unk_token_id == 0
assert tokenizer.pad_token == "<pad>"
assert tokenizer.eos_token == "</s>"
assert tokenizer.bos_token == "<s>"
assert tokenizer.unk_token == "<unk>"
strategy = MistralStrategy(
MistralPrompter(
tokenizer,
chat_template=None,
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=tokenizer,
train_on_inputs=False,
train_on_eos="turn",
sequence_len=512,
roles_to_train=["assistant"],
)
# test chat template masking without system prompt
res = strategy.tokenize_prompt(
{
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
]
}
)
assert res["input_ids"] == [
1, # bos
3, # [INST]
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
4, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
assert res["labels"] == [
-100, # bos
-100, # [INST]
-100, # Hello
-100, # ,
-100, # how
-100, # are
-100, # you
-100, # ?
-100, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
# test chat template masking with system prompt
res = strategy.tokenize_prompt(
{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
]
}
)
assert res["input_ids"] == [
1, # bos
17, # [SYSTEM_PROMPT]
4568, # You
1584, # are
1261, # a
20351, # helpful
27089, # assistant
1046, # .
18, # [/SYSTEM_PROMPT]
3, # [INST]
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
4, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
assert res["labels"] == [
-100, # bos
-100, # [SYSTEM_PROMPT]
-100, # You
-100, # are
-100, # a
-100, # helpful
-100, # assistant
-100, # .
-100, # [/SYSTEM_PROMPT]
-100, # [INST]
-100, # Hello
-100, # ,
-100, # how
-100, # are
-100, # you
-100, # ?
-100, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
# test chat template with tools
res = strategy.tokenize_prompt(
{
"tools": [
{
"type": "function",
"function": {
"name": "multiples",
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
"parameters": {
"type": "object",
"properties": {
"number": {
"type": "integer",
"description": "The number to find multiples of.",
},
"limit": {
"type": "integer",
"description": "The upper limit for the multiples.",
},
},
"required": ["number", "limit"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "Hey, can you give me a breakdown of how to throw an awesome themed party? Like, what themes work best, and how can I set everything up to really wow my guests? I want some ideas on decorations, food, and activities that will make the party unforgettable!",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call12345",
"type": "function",
"function": {
"name": "multiples",
"arguments": {
"number": 16,
"limit": 2,
},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call12345",
"name": "multiples",
"content": "1,2",
},
{"role": "assistant", "content": "The multiples of 16 is 1 and 2."},
],
}
)
# fmt: off
assert res["input_ids"] == [
1, # bos
5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt
3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user
*assistant_toolcall_ids, # assistant tool calling
*tool_result_ids, # tool result
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
2 # eos
]
assert res["labels"] == [
-100, # bos
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt
*assistant_toolcall_ids, # assistant tool calling
*([-100] * len(tool_result_ids)), # tool result
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
2 # eos
]
# fmt: on
# test chat template with tokenize=False
res = tokenizer.apply_chat_template(
[
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
],
tokenize=False,
)
assert res == "<s>[INST]Hello, how are you?[/INST]I'm doing great, thank you!</s>"
# test encode
res = tokenizer.encode("Hello, how are you?", add_special_tokens=True)
assert res == [
1, # bos
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
2, # eos
]
# test decode no skip special tokens
decoded_res = tokenizer.decode(res, skip_special_tokens=False)
assert decoded_res == "<s>Hello, how are you?</s>"
# test decode skip special tokens
decoded_res = tokenizer.decode(res, skip_special_tokens=True)
assert decoded_res == "Hello, how are you?"
# test encode no special tokens
res = tokenizer.encode("Hello, how are you?", add_special_tokens=False)
assert res == [
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
]
# test convert ids to tokens
res = tokenizer.convert_ids_to_tokens(res)
# spacing are needed as we are converting without decoding
assert res == ["Hello", ",", " how", " are", " you", "?"]
@pytest.mark.skip(reason="TODO, fix for new HF wrapper call")
def test_magistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer"):
"""Test the MistralTokenizer pad method"""
from axolotl.utils.collators.core import IGNORE_INDEX
magistral_pad_token_id = 11 # taken from tokenizer.pad_token_id
# Test padding with input_ids and labels only
features = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
{"input_ids": [7, 8], "labels": [9, 10]},
]
result = magistral_tokenizer.pad(features, padding=True, return_tensors="pt")
# Check that input_ids are padded correctly
assert result["input_ids"].shape == (2, 3)
assert result["input_ids"].tolist() == [[1, 2, 3], [7, 8, magistral_pad_token_id]]
# Check that labels are padded correctly
assert result["labels"].shape == (2, 3)
assert result["labels"].tolist() == [[4, 5, 6], [9, 10, IGNORE_INDEX]]
# Check that attention_mask and position_ids are NOT created
assert "attention_mask" not in result
assert "position_ids" not in result
# Test padding with attention_mask
features_with_attention = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "attention_mask": [1, 1, 1]},
{"input_ids": [7, 8], "labels": [9, 10], "attention_mask": [1, 1]},
]
result = magistral_tokenizer.pad(
features_with_attention, padding=True, return_tensors="pt"
)
# Check that attention_mask is padded correctly
assert result["attention_mask"].shape == (2, 3)
assert result["attention_mask"].tolist() == [[1, 1, 1], [1, 1, 0]]
# Test padding with position_ids
features_with_position = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "position_ids": [0, 1, 2]},
{"input_ids": [7, 8], "labels": [9, 10], "position_ids": [0, 1]},
]
result = magistral_tokenizer.pad(
features_with_position, padding=True, return_tensors="pt"
)
# Check that position_ids are padded correctly (continuing sequence)
assert result["position_ids"].shape == (2, 3)
assert result["position_ids"].tolist() == [[0, 1, 2], [0, 1, 2]]
# Test padding with all fields
features_all = [
{
"input_ids": [1, 2, 3],
"labels": [4, 5, 6],
"attention_mask": [1, 1, 1],
"position_ids": [0, 1, 2],
},
{
"input_ids": [7, 8],
"labels": [9, 10],
"attention_mask": [1, 1],
"position_ids": [0, 1],
},
]
result = magistral_tokenizer.pad(features_all, padding=True, return_tensors="pt")
# All fields should be present and correctly padded
assert "input_ids" in result
assert "labels" in result
assert "attention_mask" in result
assert "position_ids" in result
# Test padding with all sequences same length
features_same_length = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
{"input_ids": [7, 8, 9], "labels": [10, 11, 12]},
]
result = magistral_tokenizer.pad(
features_same_length, padding=True, return_tensors="pt"
)
# Check match when no padding is needed
assert result["input_ids"][0].tolist() == features_same_length[0]["input_ids"]
assert result["labels"][0].tolist() == features_same_length[0]["labels"]
assert result["input_ids"][1].tolist() == features_same_length[1]["input_ids"]
assert result["labels"][1].tolist() == features_same_length[1]["labels"]
# Test padding with max_length parameter
result = magistral_tokenizer.pad(
features, padding="max_length", max_length=5, return_tensors="pt"
)
# Should pad to max_length
assert result["input_ids"].shape == (2, 5)
assert result["labels"].shape == (2, 5)
# Test numpy return type
result = magistral_tokenizer.pad(features, padding=True, return_tensors="np")
# Should return numpy arrays
import numpy as np
assert isinstance(result["input_ids"], np.ndarray)
assert isinstance(result["labels"], np.ndarray)
# Test unsupported field rejection
features_unsupported = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "unsupported_field": [7, 8, 9]},
]
with pytest.raises(NotImplementedError, match="unsupported_field"):
magistral_tokenizer.pad(features_unsupported, padding=True, return_tensors="pt")
# Test token_type_ids rejection
features_token_type = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "token_type_ids": [0, 0, 0]},
]
with pytest.raises(ValueError, match="token_type_ids is not supported"):
magistral_tokenizer.pad(features_token_type, padding=True, return_tensors="pt")
def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"):
"""Test tool calling with the Magistral tokenizer"""
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
strategy = MistralStrategy(
MistralPrompter(
magistral_tokenizer,
chat_template=None,
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=magistral_tokenizer,
train_on_inputs=False,
train_on_eos="turn",
sequence_len=512,
roles_to_train=["assistant"],
)
# Test basic tool calling with single function
basic_tool_calling = {
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
"required": ["location"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "What's the weather like in San Francisco?",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call12345",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {
"location": "San Francisco, CA",
},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call12345",
"name": "get_weather",
"content": "Sunny, 72°F",
},
{
"role": "assistant",
"content": "The weather in San Francisco is sunny and 72°F.",
},
],
}
res = strategy.tokenize_prompt(basic_tool_calling)
# Basic validation
assert "input_ids" in res
assert "labels" in res
assert len(res["input_ids"]) > 0
assert len(res["labels"]) == len(res["input_ids"])
# Decode and verify structure
decoded = magistral_tokenizer.decode(res["input_ids"])
assert (
'<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS]'
in decoded
)
assert (
'[TOOL_CALLS]get_weather[CALL_ID]call12345[ARGS]{"location": "San Francisco, CA"}</s>'
in decoded
)
assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]Sunny, 72°F[/TOOL_RESULTS]" in decoded
assert "The weather in San Francisco is sunny and 72°F.</s>" in decoded
# Test multiple tool calls in sequence
multi_tool_calling = {
"tools": [
{
"type": "function",
"function": {
"name": "add_numbers",
"description": "Add two numbers together",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "First number"},
"b": {"type": "number", "description": "Second number"},
},
"required": ["a", "b"],
},
},
},
{
"type": "function",
"function": {
"name": "multiply_numbers",
"description": "Multiply two numbers",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "number", "description": "First number"},
"y": {"type": "number", "description": "Second number"},
},
"required": ["x", "y"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "Add 5 and 3, then multiply the result by 2",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call12345",
"type": "function",
"function": {
"name": "add_numbers",
"arguments": {"a": 5, "b": 3},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call12345",
"name": "add_numbers",
"content": "8",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call23456",
"type": "function",
"function": {
"name": "multiply_numbers",
"arguments": {"x": 8, "y": 2},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call23456",
"name": "multiply_numbers",
"content": "16",
},
{
"role": "assistant",
"content": "The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.",
},
],
}
res = strategy.tokenize_prompt(multi_tool_calling)
# Validation
assert len(res["input_ids"]) > 0
assert len(res["labels"]) == len(res["input_ids"])
decoded = magistral_tokenizer.decode(res["input_ids"])
assert (
'<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "add_numbers", "description": "Add two numbers together", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "First number"}, "b": {"type": "number", "description": "Second number"}}, "required": ["a", "b"]}}}, {"type": "function", "function": {"name": "multiply_numbers", "description": "Multiply two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "First number"}, "y": {"type": "number", "description": "Second number"}}, "required": ["x", "y"]}}}][/AVAILABLE_TOOLS]'
in decoded
)
assert (
'[TOOL_CALLS]add_numbers[CALL_ID]call12345[ARGS]{"a": 5, "b": 3}</s>' in decoded
)
assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]8[/TOOL_RESULTS]" in decoded
assert (
'[TOOL_CALLS]multiply_numbers[CALL_ID]call23456[ARGS]{"x": 8, "y": 2}</s>'
in decoded
)
assert "[TOOL_RESULTS]call23456[TOOL_CONTENT]16[/TOOL_RESULTS]" in decoded
assert (
"The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.</s>"
in decoded
)
# Test tool calling with system message
system_tool_calling = {
"tools": [
{
"type": "function",
"function": {
"name": "search_database",
"description": "Search for information in database",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
},
"required": ["query"],
},
},
},
],
"messages": [
{
"role": "system",
"content": "You are a helpful assistant with access to a database.",
},
{
"role": "user",
"content": "Find information about Python programming",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "search123",
"type": "function",
"function": {
"name": "search_database",
"arguments": {"query": "Python programming"},
},
}
],
},
{
"role": "tool",
"tool_call_id": "search123",
"name": "search_database",
"content": "Python is a high-level programming language known for its simplicity.",
},
{
"role": "assistant",
"content": "Based on the database search, Python is a high-level programming language known for its simplicity and readability.",
},
],
}
res = strategy.tokenize_prompt(system_tool_calling)
# Validation
assert len(res["input_ids"]) > 0
assert len(res["labels"]) == len(res["input_ids"])
decoded = magistral_tokenizer.decode(res["input_ids"])
assert (
'<s>[SYSTEM_PROMPT]You are a helpful assistant with access to a database.[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "search_database", "description": "Search for information in database", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}}, "required": ["query"]}}}][/AVAILABLE_TOOLS]'
in decoded
)
# Test error handling - missing tool response
incomplete_tool_calling = {
"tools": [
{
"type": "function",
"function": {
"name": "get_time",
"description": "Get current time",
"parameters": {"type": "object", "properties": {}},
},
},
],
"messages": [
{
"role": "user",
"content": "What time is it?",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "time12345",
"type": "function",
"function": {
"name": "get_time",
"arguments": {},
},
}
],
},
{
"role": "assistant",
"content": "The current time is 12:00 PM.",
},
],
}
from mistral_common.exceptions import InvalidMessageStructureException
try:
strategy.tokenize_prompt(incomplete_tool_calling)
except InvalidMessageStructureException as e:
assert "Not the same number of function calls and responses" in str(e)
@pytest.mark.skip(reason="TODO, fix for new HF wrapper call")
def test_magistral_tokenizer_call_method(
magistral_tokenizer: "HFMistralTokenizer", llama3_tokenizer: "PreTrainedTokenizer"
):
"""Test the __call__ method behavior matches HuggingFace standards"""
from copy import deepcopy
import numpy as np
import torch
hf_tokenizer = deepcopy(llama3_tokenizer)
hf_tokenizer.pad_token = hf_tokenizer.eos_token
test_text = "Hello, how are you?"
batch_texts = ["Hello world", "How are you?"]
# Test single string with return_tensors=None
hf_result: dict[str, list[int]] = hf_tokenizer(test_text, return_tensors=None)
mistral_result: dict[str, list[int]] = magistral_tokenizer(
test_text, return_tensors=None
)
assert isinstance(mistral_result, dict)
assert set(mistral_result.keys()) == {"input_ids", "attention_mask"}
assert isinstance(mistral_result["input_ids"], type(hf_result["input_ids"])) # list
assert isinstance(
mistral_result["attention_mask"], type(hf_result["attention_mask"])
)
assert len(mistral_result["input_ids"]) == len(mistral_result["attention_mask"])
assert np.all(mistral_result["attention_mask"])
assert len(np.array(mistral_result["input_ids"]).shape) == 1 # 1D array
# Test single string with return_tensors='pt'
hf_result_pt: dict[str, torch.Tensor] = hf_tokenizer(test_text, return_tensors="pt")
mistral_result_pt: dict[str, torch.Tensor] = magistral_tokenizer(
test_text, return_tensors="pt"
)
# Check structure and types
assert isinstance(mistral_result_pt["input_ids"], torch.Tensor)
assert isinstance(mistral_result_pt["attention_mask"], torch.Tensor)
# Check shapes match (don't compare token dimension)
assert len(hf_result_pt["input_ids"].shape) == len(
mistral_result_pt["input_ids"].shape
)
assert hf_result_pt["input_ids"].shape[0] == mistral_result_pt["input_ids"].shape[0]
assert (
mistral_result_pt["attention_mask"].shape
== mistral_result_pt["input_ids"].shape
)
assert torch.all(mistral_result_pt["attention_mask"] == 1)
# Test batch input with padding
hf_batch: dict[str, torch.Tensor] = hf_tokenizer(
batch_texts, return_tensors="pt", padding=True
)
mistral_batch: dict[str, torch.Tensor] = magistral_tokenizer(
batch_texts, return_tensors="pt", padding=True
)
# Check batch behavior
assert len(hf_batch["input_ids"].shape) == len(mistral_batch["input_ids"].shape)
assert hf_batch["input_ids"].shape[0] == mistral_batch["input_ids"].shape[0]
assert mistral_batch["attention_mask"].shape == mistral_batch["input_ids"].shape
assert torch.any(
mistral_batch["attention_mask"][0] == 0
) # padding in shorter sequence
assert torch.all(
mistral_batch["attention_mask"][1] == 1
) # no padding in longer sequence
# Test numpy tensors
mistral_result_np: dict[str, np.ndarray] = magistral_tokenizer(
test_text, return_tensors="np"
)
assert isinstance(mistral_result_np["input_ids"], np.ndarray)
assert isinstance(mistral_result_np["attention_mask"], np.ndarray)
# Test consistency with encode()
encoded: list[int] = magistral_tokenizer.encode(test_text, add_special_tokens=True)
called: dict[str, torch.Tensor] = magistral_tokenizer(
test_text, return_tensors="pt"
)
assert encoded == called["input_ids"][0].tolist()
# Test Error handling
with pytest.raises(ValueError, match="Unsupported kwargs"):
magistral_tokenizer(test_text, unsupported_param=True)
with pytest.raises(
ValueError, match="return_tensors='pt' or 'np' requires padding or truncation"
):
magistral_tokenizer(batch_texts, return_tensors="pt")
if __name__ == "__main__":
unittest.main()