Feat: add Magistral Small 2509 and native mistral3 tokenizer support (#3165)

* feat: update mistral common

* feat: add mistral3processor

* fix: loading

* fix: cast pixel_values to fp32

* fix: image tensor conversion

* feat: add FA2 support for pixtral based models

* fix: update mistral small 3.1 to use native tokenizer

* fix: install tips

* fix: improve info on sample dataset files

* chore: move mistral configs into subfolders

* fix: remove unneeded patch

* fix: indent

* feat: add integration tests

* chore: move

* feat: add magistral 2509 docs and example

* fix: convert tensor to bool

* feat: expand tests

* chore: move tests
This commit is contained in:
NanoCode012
2025-09-18 15:42:20 +07:00
committed by GitHub
parent 4065bc14c6
commit 09959fac70
32 changed files with 757 additions and 39 deletions

View File

@@ -168,6 +168,13 @@ class PatchManager:
patch_llama4_linearized_modeling()
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
apply_mistral_tokenizer_image_patch,
)
apply_mistral_tokenizer_image_patch()
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:
@@ -334,6 +341,13 @@ class PatchManager:
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
if self.model_config.model_type in ("mistral3", "llava"):
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
apply_patch_is_packed_sequence,
)
apply_patch_is_packed_sequence()
def _patch_loss_llama(self):
"""Patch loss functions and other optimizations for LLaMA models."""
if not self.cfg.is_llama_derived_model:

View File

@@ -21,6 +21,13 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
if cfg.tokenizer_use_mistral_common:
from axolotl.utils.mistral import Mistral3Processor
return Mistral3Processor(
tokenizer=tokenizer,
)
processor = processor_cls.from_pretrained(
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,

View File

@@ -124,13 +124,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
def _load_mistral_common_tokenizer(cfg: DictDefault):
"""Load mistral-common tokenizer"""
from transformers import tokenization_mistral_common
from axolotl.utils.mistral import HFMistralTokenizer
# patch
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
# Load the HF-compatible wrapper around MistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)

View File

@@ -0,0 +1,85 @@
"""
Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template
"""
import importlib
import inspect
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def apply_mistral_tokenizer_image_patch():
"""Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion."""
from transformers.tokenization_mistral_common import MistralCommonTokenizer
# Get original source
original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template)
original_source, _ = detab_code(original_source)
# Define the replacement
original_tensor_conversion = (
" pixel_values = torch.tensor(images)"
)
patched_tensor_conversion = """ if isinstance(images, list) and len(images) > 0 and isinstance(images[0], np.ndarray):
pixel_values = torch.tensor(np.array(images))
else:
pixel_values = torch.tensor(images)"""
# Apply the replacement
if original_tensor_conversion in original_source:
patched_source = original_source.replace(
original_tensor_conversion, patched_tensor_conversion
)
patched_source = patched_source.replace(
"def apply_chat_template(",
"def patched_apply_chat_template(",
1,
)
# Load necessary imports from the module
module_name = MistralCommonTokenizer.__module__
module = importlib.import_module(module_name)
# Detect what needs to be imported
items_to_import = []
for item in dir(module):
if item in patched_source and not item.startswith("_"):
items_to_import.append(item)
# Execute imports in global scope
if items_to_import:
exec( # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
# Also need standard imports that might be used
exec("import numpy as np", globals()) # nosec B102
exec("import torch", globals()) # nosec B102
exec("from typing import Union, Optional, List, Dict, Any, Callable", globals()) # nosec B102
exec("from pathlib import Path", globals()) # nosec B102
# Import other dependencies that might be needed
try:
exec("from transformers.utils import is_torch_available", globals()) # nosec B102
exec(
"from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType",
globals(),
) # nosec B102
exec("from transformers.utils import logging", globals()) # nosec B102
exec("logger = logging.get_logger(__name__)", globals()) # nosec B102
except ImportError as e:
LOG.warning(f"Could not import some dependencies: {e}")
# Execute the patched source
exec(patched_source, globals()) # nosec B102
# Replace the method
MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template
LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch")
else:
LOG.warning("Could not find target code for MistralCommonTokenizer patching")

View File

@@ -0,0 +1,42 @@
"""Monkeypatch for FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid"""
import torch
def apply_patch_is_packed_sequence():
"""Apply patch to FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid"""
from transformers import modeling_flash_attention_utils
def fixed_is_packed_sequence(position_ids, batch_size):
"""
Check the position ids whether packed sequences are indicated or not
1. Position ids exist
2. Flattened sequences only are supported
3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
"""
if position_ids is None:
return False
if position_ids.ndim == 1:
position_ids = position_ids.unsqueeze(0) # [N] -> [1, N]
increasing_position_sequences = (
torch.arange(position_ids.shape[1], device=position_ids.device)
+ position_ids.min()
)
return (
batch_size == 1
and (increasing_position_sequences - position_ids).abs().sum().bool().item()
)
# Store original method
old_fn = modeling_flash_attention_utils._is_packed_sequence
# Apply the patch
modeling_flash_attention_utils._is_packed_sequence = fixed_is_packed_sequence
def unpatch():
"""Restore the original method"""
modeling_flash_attention_utils._is_packed_sequence = old_fn
return unpatch

View File

@@ -11,6 +11,7 @@ from transformers.image_utils import load_image
from axolotl.utils.dict import remove_none_values
from axolotl.utils.logging import get_logger
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
LOG = get_logger(__name__)
@@ -421,6 +422,36 @@ class SmolVLM2ProcessingStrategy(ProcessingStrategy):
]
class Mistral3ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Mistral3"""
def __init__(
self,
processor: Mistral3Processor,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
special_ids = (
processor.tokenizer.tokenizer.instruct_tokenizer.image_encoder.special_ids
)
self.image_token = special_ids.img
self.image_break_token = special_ids.img_break
self.image_end_token = special_ids.img_end
def process_labels(self, input_ids):
labels = input_ids.clone()
labels[labels == self.processor.tokenizer.pad_token_id] = -100
labels[labels == self.image_token] = -100
labels[labels == self.image_break_token] = -100
labels[labels == self.image_end_token] = -100
return labels
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -463,6 +494,11 @@ def get_processing_strategy(
**processing_kwargs,
)
if isinstance(processor, Mistral3Processor):
return Mistral3ProcessingStrategy(
**processing_kwargs,
)
# llama3_2_vision, llama4, llava
# mistral_v7_tekken, pixtral, lfm2vl
return ProcessingStrategy(

View File

@@ -1,5 +1,6 @@
"""Init for `axolotl.utils.mistral` module."""
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
__all__ = ["HFMistralTokenizer"]
__all__ = ["HFMistralTokenizer", "Mistral3Processor"]

View File

@@ -0,0 +1,169 @@
"""Processor for Mistral3 multimodal models with image support"""
from typing import Any, Dict, Optional, Union
import torch
from transformers import ProcessorMixin
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessingKwargs
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
class Mistral3ProcessorKwargs(ProcessingKwargs):
_defaults: Dict[str, Dict[str, Any]] = {
"text_kwargs": {
"padding": True,
},
"common_kwargs": {
"return_tensors": "pt",
"return_dict": True,
"tokenize": True,
},
}
class Mistral3Processor(ProcessorMixin):
"""
Processor for Mistral3 multimodal models that handles text and images.
Wraps HFMistralTokenizer and adds image processing capabilities.
"""
attributes = ["tokenizer"]
tokenizer_class = "HFMistralTokenizer"
def __init__(self, tokenizer: HFMistralTokenizer):
# Don't call super().__init__ to avoid the class validation issue
self.tokenizer = tokenizer
@property
def chat_template(self) -> None:
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
return None
@property
def audio_tokenizer(self) -> None:
"""Audio tokenizer is not supported. Dummy method to satisfy HuggingFace API."""
return None
def _merge_kwargs(
self, processor_kwargs_class: Any, **kwargs: Any
) -> Dict[str, Dict[str, Any]]:
"""Merge kwargs with defaults similar to ProcessorMixin"""
defaults = processor_kwargs_class._defaults
output_kwargs: Dict[str, Dict[str, Any]] = {}
for kwarg_type, default_values in defaults.items():
output_kwargs[kwarg_type] = {**default_values}
# Update with provided kwargs
for key, value in kwargs.items():
# Try to match key to appropriate kwarg type
if key in ["padding", "truncation", "max_length"]:
output_kwargs.setdefault("text_kwargs", {}).update({key: value})
elif key in ["return_tensors", "return_dict", "tokenize"]:
output_kwargs.setdefault("common_kwargs", {}).update({key: value})
else:
# Add to text_kwargs by default
output_kwargs.setdefault("text_kwargs", {}).update({key: value})
return output_kwargs
def apply_chat_template(
self,
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
**kwargs: Any,
) -> Union[BatchFeature, str, list[str]]:
"""
Apply chat template with image support for Mistral3.
Similar to VoxtralProcessor, this method extracts images from the conversation,
calls the tokenizer's apply_chat_template, then adds pixel_values and image_sizes
to the result.
"""
output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)
text_kwargs = output_kwargs["text_kwargs"]
common_kwargs = output_kwargs["common_kwargs"]
return_tensors = common_kwargs.pop("return_tensors", "pt")
if return_tensors != "pt":
raise ValueError(
f"{self.__class__.__name__} only supports `return_tensors='pt'`."
)
return_dict = common_kwargs.pop("return_dict", False)
tokenize = common_kwargs.pop("tokenize", False)
# Determine if batched
if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple))
or hasattr(conversation[0], "content")
):
is_batched = True
conversations = conversation
else:
is_batched = False
conversations = [conversation] # type: ignore
# Call tokenizer's apply_chat_template
tokenizer_kwargs = {**text_kwargs, **common_kwargs}
tokenizer_kwargs["return_tensors"] = return_tensors
tokenizer_kwargs["tokenize"] = tokenize
tokenizer_kwargs["return_dict"] = return_dict
encoded_instruct_inputs = self.tokenizer.apply_chat_template(
conversations,
**tokenizer_kwargs,
)
if tokenize:
if return_dict:
# The tokenizer already handles pixel_values, we just need to add image_sizes
if hasattr(encoded_instruct_inputs, "items"):
data: Dict[str, Any] = dict(encoded_instruct_inputs) # type: ignore
elif hasattr(encoded_instruct_inputs, "data"):
data = encoded_instruct_inputs.data # type: ignore
else:
raise ValueError("Unknown data type")
if "pixel_values" in data:
pixel_values = data["pixel_values"]
# MistralTokenizer returns a Double, so we convert to fp32
data["pixel_values"] = pixel_values.to(dtype=torch.float32)
# Always batched: [B, C, H, W] -> image_sizes: [B, 2]
# Since tensor is homogeneous, all images have same H, W
batch_size = pixel_values.shape[0]
image_sizes = torch.tensor([pixel_values.shape[-2:]] * batch_size)
data["image_sizes"] = image_sizes
return BatchFeature(data=data, tensor_type=return_tensors)
if not is_batched:
return encoded_instruct_inputs[0]
return encoded_instruct_inputs
def __call__(
self,
text: Optional[
Union[
TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
]
],
**kwargs: Any,
) -> BatchFeature:
"""
Forward text processing to the tokenizer.
This method does not support images - use apply_chat_template instead.
"""
output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)
text_kwargs = output_kwargs["text_kwargs"]
common_kwargs = output_kwargs["common_kwargs"]
out = self.tokenizer(text, **text_kwargs)
return BatchFeature(
data=out, tensor_type=common_kwargs.pop("return_tensors", None)
)