Feat: add Magistral Small 2509 and native mistral3 tokenizer support (#3165)
* feat: update mistral common * feat: add mistral3processor * fix: loading * fix: cast pixel_values to fp32 * fix: image tensor conversion * feat: add FA2 support for pixtral based models * fix: update mistral small 3.1 to use native tokenizer * fix: install tips * fix: improve info on sample dataset files * chore: move mistral configs into subfolders * fix: remove unneeded patch * fix: indent * feat: add integration tests * chore: move * feat: add magistral 2509 docs and example * fix: convert tensor to bool * feat: expand tests * chore: move tests
This commit is contained in:
35
tests/monkeypatch/test_mistral_tokenizer_patch.py
Normal file
35
tests/monkeypatch/test_mistral_tokenizer_patch.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Integration tests for MistralCommonTokenizer patches."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestMistralTokenizerPatchIntegration:
|
||||
"""Test MistralCommonTokenizer patch integration."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_mistral_tokenizer_image_patch(self):
|
||||
"""Test that MistralCommonTokenizer image patch can be applied."""
|
||||
try:
|
||||
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
||||
except ImportError:
|
||||
pytest.skip("MistralCommonTokenizer not available")
|
||||
|
||||
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
|
||||
apply_mistral_tokenizer_image_patch,
|
||||
)
|
||||
|
||||
# Store original method
|
||||
original_apply_chat_template = MistralCommonTokenizer.apply_chat_template
|
||||
|
||||
# Apply patch
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
# Verify patch was applied
|
||||
assert (
|
||||
MistralCommonTokenizer.apply_chat_template != original_apply_chat_template
|
||||
), "apply_chat_template was not patched"
|
||||
|
||||
# Verify the method is still callable
|
||||
assert callable(MistralCommonTokenizer.apply_chat_template), (
|
||||
"Patched method is not callable"
|
||||
)
|
||||
77
tests/monkeypatch/test_pixtral_flash_attention_patch.py
Normal file
77
tests/monkeypatch/test_pixtral_flash_attention_patch.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Integration tests for Pixtral Flash Attention patches."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
class TestPixtralFlashAttentionPatchIntegration:
|
||||
"""Test Pixtral Flash Attention patch integration."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_pixtral_flash_attention_patch(self):
|
||||
"""Test that Pixtral Flash Attention patch can be applied and works correctly."""
|
||||
try:
|
||||
from transformers import modeling_flash_attention_utils
|
||||
except ImportError:
|
||||
pytest.skip("Flash Attention utils not available")
|
||||
|
||||
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
|
||||
apply_patch_is_packed_sequence,
|
||||
)
|
||||
|
||||
# Store original method
|
||||
original_is_packed_sequence = modeling_flash_attention_utils._is_packed_sequence
|
||||
|
||||
# Apply patch and get unpatch function
|
||||
unpatch_fn = apply_patch_is_packed_sequence()
|
||||
|
||||
# Verify patch was applied
|
||||
assert (
|
||||
modeling_flash_attention_utils._is_packed_sequence
|
||||
!= original_is_packed_sequence
|
||||
), "_is_packed_sequence was not patched"
|
||||
|
||||
# Test the patched function with 1D position_ids
|
||||
patched_fn = modeling_flash_attention_utils._is_packed_sequence
|
||||
|
||||
# Test 1D position_ids 1 sequence
|
||||
position_ids_1d = torch.tensor([0, 1, 2, 3])
|
||||
result = patched_fn(position_ids_1d, batch_size=1)
|
||||
assert isinstance(result, bool), "Function should return a boolean"
|
||||
assert result is False, "1D sequential position_ids should not be packed"
|
||||
|
||||
# Test 1D packed 2 sequences
|
||||
position_ids_1d_packed = torch.tensor([0, 1, 2, 0, 1, 2])
|
||||
result = patched_fn(position_ids_1d_packed, batch_size=1)
|
||||
assert isinstance(result, bool), "Function should return a boolean"
|
||||
assert result is True, "1D packed position_ids should be detected as packed"
|
||||
|
||||
# Test 2D packed 2 sequences
|
||||
position_ids_2d_packed = torch.tensor([[0, 1, 2, 3, 0, 1]])
|
||||
result = patched_fn(position_ids_2d_packed, batch_size=1)
|
||||
assert isinstance(result, bool), "Function should return a boolean"
|
||||
assert result is True, "2D packed position_ids should be detected as packed"
|
||||
|
||||
# Test 2D 1 sequence
|
||||
position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
result = patched_fn(position_ids_2d_normal, batch_size=1)
|
||||
assert isinstance(result, bool), "Function should return a boolean"
|
||||
assert result is False, "2D sequential position_ids should not be packed"
|
||||
|
||||
# Test 2D batch size 2
|
||||
position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]])
|
||||
result = patched_fn(position_ids_2d_normal, batch_size=2)
|
||||
assert isinstance(result, bool), "Function should return a boolean"
|
||||
assert result is False, "2D position_ids batch 2 should not be packed"
|
||||
|
||||
# Test None case
|
||||
result = patched_fn(None, batch_size=1)
|
||||
assert isinstance(result, bool), "Function should return a boolean"
|
||||
assert result is False, "None position_ids should return False"
|
||||
|
||||
# Test unpatch function
|
||||
unpatch_fn()
|
||||
assert (
|
||||
modeling_flash_attention_utils._is_packed_sequence
|
||||
== original_is_packed_sequence
|
||||
), "unpatch function did not restore original method"
|
||||
43
tests/monkeypatch/test_voxtral_modeling_patch.py
Normal file
43
tests/monkeypatch/test_voxtral_modeling_patch.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Integration tests for Voxtral modeling patches."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestVoxtralModelingPatchIntegration:
|
||||
"""Test Voxtral modeling patch integration."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_voxtral_conditional_generation_patch(self):
|
||||
"""Test that Voxtral conditional generation patch can be applied."""
|
||||
try:
|
||||
from transformers.models.voxtral.modeling_voxtral import (
|
||||
VoxtralForConditionalGeneration,
|
||||
)
|
||||
except ImportError:
|
||||
pytest.skip("VoxtralForConditionalGeneration not available")
|
||||
|
||||
from axolotl.monkeypatch.models.voxtral.modeling import (
|
||||
patch_voxtral_conditional_generation_forward,
|
||||
)
|
||||
|
||||
# Store original method
|
||||
original_forward = VoxtralForConditionalGeneration.forward
|
||||
|
||||
# Apply patch and get unpatch function
|
||||
unpatch_fn = patch_voxtral_conditional_generation_forward()
|
||||
|
||||
# Verify patch was applied
|
||||
assert VoxtralForConditionalGeneration.forward != original_forward, (
|
||||
"forward method was not patched"
|
||||
)
|
||||
|
||||
# Verify the method is still callable
|
||||
assert callable(VoxtralForConditionalGeneration.forward), (
|
||||
"Patched method is not callable"
|
||||
)
|
||||
|
||||
# Test unpatch function
|
||||
unpatch_fn()
|
||||
assert VoxtralForConditionalGeneration.forward == original_forward, (
|
||||
"unpatch function did not restore original method"
|
||||
)
|
||||
Reference in New Issue
Block a user