From 86ca1e27c083d0ab22063630aa9bea02085714c1 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 23 Feb 2026 23:39:13 +0700 Subject: [PATCH] fix: update MistralProcessor to be v5 compat (#3423) * fix: update MistralProcessor to be v5 compat * feat: add test for mistral3 processor * chore: comment --- .../utils/mistral/mistral3_processor.py | 12 +- tests/utils/test_mistral3_processor.py | 149 ++++++++++++++++++ 2 files changed, 150 insertions(+), 11 deletions(-) create mode 100644 tests/utils/test_mistral3_processor.py diff --git a/src/axolotl/utils/mistral/mistral3_processor.py b/src/axolotl/utils/mistral/mistral3_processor.py index 01e8f9f10..03a155e53 100644 --- a/src/axolotl/utils/mistral/mistral3_processor.py +++ b/src/axolotl/utils/mistral/mistral3_processor.py @@ -30,18 +30,8 @@ class Mistral3Processor(ProcessorMixin): Wraps HFMistralTokenizer and adds image processing capabilities. """ - # TODO(nano): This should be removed in transformers V5 - attributes = ["tokenizer"] - tokenizer_class = "HFMistralTokenizer" - def __init__(self, tokenizer: HFMistralTokenizer): - # Don't call super().__init__ to avoid the class validation issue - self.tokenizer = tokenizer - - @property - def chat_template(self) -> None: - """Chat template is not supported. Dummy method to satisfy HuggingFace API.""" - return None + super().__init__(tokenizer) @property def audio_tokenizer(self) -> None: diff --git a/tests/utils/test_mistral3_processor.py b/tests/utils/test_mistral3_processor.py new file mode 100644 index 000000000..ae2bc1faf --- /dev/null +++ b/tests/utils/test_mistral3_processor.py @@ -0,0 +1,149 @@ +"""Tests for Mistral3Processor with transformers v5 ProcessorMixin integration""" + +from unittest.mock import MagicMock + +import pytest +import torch +from transformers.feature_extraction_utils import BatchFeature + +from axolotl.utils.mistral.mistral3_processor import Mistral3Processor +from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer + + +@pytest.fixture() +def mock_tokenizer(): + """Create a mock HFMistralTokenizer that passes v5 ProcessorMixin isinstance checks.""" + return MagicMock(spec=HFMistralTokenizer) + + +@pytest.fixture() +def processor(mock_tokenizer): + return Mistral3Processor(tokenizer=mock_tokenizer) + + +class TestMistral3ProcessorInit: + def test_tokenizer_is_set(self, processor, mock_tokenizer): + assert processor.tokenizer is mock_tokenizer + + def test_chat_template_is_none(self, processor): + assert processor.chat_template is None + + def test_audio_tokenizer_is_none(self, processor): + assert processor.audio_tokenizer is None + + +class TestApplyChatTemplateTokenized: + """Test apply_chat_template with tokenize=True, return_dict=True""" + + @pytest.fixture() + def batched_conversations(self): + return [ + [ + {"role": "user", "content": "Describe this image."}, + {"role": "assistant", "content": "It is red."}, + ], + [ + {"role": "user", "content": "What is this?"}, + {"role": "assistant", "content": "A cat."}, + ], + ] + + def test_returns_batch_feature_with_pixel_values( + self, processor, mock_tokenizer, batched_conversations + ): + pixel_values = torch.randn(2, 3, 224, 224, dtype=torch.float64) + mock_tokenizer.apply_chat_template.return_value = { + "input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]), + "attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]]), + "pixel_values": pixel_values, + } + + result = processor.apply_chat_template( + batched_conversations, tokenize=True, return_dict=True + ) + + assert isinstance(result, BatchFeature) + assert "pixel_values" in result + assert "image_sizes" in result + assert result["pixel_values"].dtype == torch.float32 + assert result["image_sizes"].shape == (2, 2) + assert result["image_sizes"][0].tolist() == [224, 224] + + def test_returns_batch_feature_without_pixel_values( + self, processor, mock_tokenizer, batched_conversations + ): + mock_tokenizer.apply_chat_template.return_value = { + "input_ids": torch.tensor([[1, 2, 3], [4, 5, 6]]), + "attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]]), + } + + result = processor.apply_chat_template( + batched_conversations, tokenize=True, return_dict=True + ) + + assert isinstance(result, BatchFeature) + assert "input_ids" in result + assert "image_sizes" not in result + + +class TestApplyChatTemplateNotTokenized: + def test_single_conversation_returns_unwrapped(self, processor, mock_tokenizer): + """Single conversation (not batched) should return unwrapped result.""" + single_conversation = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + mock_tokenizer.apply_chat_template.return_value = [ + "[INST]Hello[/INST]Hi" + ] + + result = processor.apply_chat_template( + single_conversation, tokenize=False, return_dict=False + ) + + assert result == "[INST]Hello[/INST]Hi" + + def test_batched_conversations_returns_list(self, processor, mock_tokenizer): + batched = [ + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ], + [ + {"role": "user", "content": "Bye"}, + {"role": "assistant", "content": "Bye"}, + ], + ] + mock_tokenizer.apply_chat_template.return_value = ["text1", "text2"] + + result = processor.apply_chat_template( + batched, tokenize=False, return_dict=False + ) + + assert result == ["text1", "text2"] + + +class TestCall: + def test_delegates_to_tokenizer(self, processor, mock_tokenizer): + mock_tokenizer.return_value = { + "input_ids": [1, 2, 3], + "attention_mask": [1, 1, 1], + } + + result = processor("Hello world") + + mock_tokenizer.assert_called_once() + assert isinstance(result, BatchFeature) + + +class TestReturnTensorsValidation: + def test_rejects_non_pt_return_tensors(self, processor): + conversation = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + + with pytest.raises(ValueError, match=r"only supports.*return_tensors='pt'"): + processor.apply_chat_template( + conversation, tokenize=True, return_dict=True, return_tensors="np" + )