"""Tests for Mistral3Processor with transformers v5 ProcessorMixin integration"""
from unittest.mock import MagicMock
import pytest
import torch
from transformers.feature_extraction_utils import BatchFeature
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
@pytest.fixture()
def mock_tokenizer():
"""Create a mock HFMistralTokenizer that passes v5 ProcessorMixin isinstance checks."""
return MagicMock(spec=HFMistralTokenizer)
@pytest.fixture()
def processor(mock_tokenizer):
return Mistral3Processor(tokenizer=mock_tokenizer)
class TestMistral3ProcessorInit:
def test_tokenizer_is_set(self, processor, mock_tokenizer):
assert processor.tokenizer is mock_tokenizer
def test_chat_template_is_none(self, processor):
assert processor.chat_template is None
def test_audio_tokenizer_is_none(self, processor):
assert processor.audio_tokenizer is None
class TestApplyChatTemplateTokenized:
"""Test apply_chat_template with tokenize=True, return_dict=True"""
@pytest.fixture()
def batched_conversations(self):
return [
[
{"role": "user", "content": "Describe this image."},
{"role": "assistant", "content": "It is red."},
],
[
{"role": "user", "content": "What is this?"},
{"role": "assistant", "content": "A cat."},
],
]
def test_returns_batch_feature_with_pixel_values(
self, processor, mock_tokenizer, batched_conversations
):
pixel_values = torch.randn(2, 3, 224, 224, dtype=torch.float64)
mock_tokenizer.apply_chat_template.return_value = {
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]),
"attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]]),
"pixel_values": pixel_values,
}
result = processor.apply_chat_template(
batched_conversations, tokenize=True, return_dict=True
)
assert isinstance(result, BatchFeature)
assert "pixel_values" in result
assert "image_sizes" in result
assert result["pixel_values"].dtype == torch.float32
assert result["image_sizes"].shape == (2, 2)
assert result["image_sizes"][0].tolist() == [224, 224]
def test_returns_batch_feature_without_pixel_values(
self, processor, mock_tokenizer, batched_conversations
):
mock_tokenizer.apply_chat_template.return_value = {
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]),
"attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]]),
}
result = processor.apply_chat_template(
batched_conversations, tokenize=True, return_dict=True
)
assert isinstance(result, BatchFeature)
assert "input_ids" in result
assert "image_sizes" not in result
class TestApplyChatTemplateNotTokenized:
def test_single_conversation_returns_unwrapped(self, processor, mock_tokenizer):
"""Single conversation (not batched) should return unwrapped result."""
single_conversation = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
mock_tokenizer.apply_chat_template.return_value = [
"[INST]Hello[/INST]Hi"
]
result = processor.apply_chat_template(
single_conversation, tokenize=False, return_dict=False
)
assert result == "[INST]Hello[/INST]Hi"
def test_batched_conversations_returns_list(self, processor, mock_tokenizer):
batched = [
[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
],
[
{"role": "user", "content": "Bye"},
{"role": "assistant", "content": "Bye"},
],
]
mock_tokenizer.apply_chat_template.return_value = ["text1", "text2"]
result = processor.apply_chat_template(
batched, tokenize=False, return_dict=False
)
assert result == ["text1", "text2"]
class TestCall:
def test_delegates_to_tokenizer(self, processor, mock_tokenizer):
mock_tokenizer.return_value = {
"input_ids": [1, 2, 3],
"attention_mask": [1, 1, 1],
}
result = processor("Hello world")
mock_tokenizer.assert_called_once()
assert isinstance(result, BatchFeature)
class TestReturnTensorsValidation:
def test_rejects_non_pt_return_tensors(self, processor):
conversation = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
]
with pytest.raises(ValueError, match=r"only supports.*return_tensors='pt'"):
processor.apply_chat_template(
conversation, tokenize=True, return_dict=True, return_tensors="np"
)