Feat: add Magistral Small 2509 and native mistral3 tokenizer support (#3165)
* feat: update mistral common * feat: add mistral3processor * fix: loading * fix: cast pixel_values to fp32 * fix: image tensor conversion * feat: add FA2 support for pixtral based models * fix: update mistral small 3.1 to use native tokenizer * fix: install tips * fix: improve info on sample dataset files * chore: move mistral configs into subfolders * fix: remove unneeded patch * fix: indent * feat: add integration tests * chore: move * feat: add magistral 2509 docs and example * fix: convert tensor to bool * feat: expand tests * chore: move tests
This commit is contained in:
@@ -168,6 +168,13 @@ class PatchManager:
|
||||
|
||||
patch_llama4_linearized_modeling()
|
||||
|
||||
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
|
||||
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
|
||||
apply_mistral_tokenizer_image_patch,
|
||||
)
|
||||
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
def _apply_fp8_patches(self):
|
||||
"""Apply patches for FP8 support."""
|
||||
if self.cfg.fp8:
|
||||
@@ -334,6 +341,13 @@ class PatchManager:
|
||||
|
||||
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
|
||||
|
||||
if self.model_config.model_type in ("mistral3", "llava"):
|
||||
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
|
||||
apply_patch_is_packed_sequence,
|
||||
)
|
||||
|
||||
apply_patch_is_packed_sequence()
|
||||
|
||||
def _patch_loss_llama(self):
|
||||
"""Patch loss functions and other optimizations for LLaMA models."""
|
||||
if not self.cfg.is_llama_derived_model:
|
||||
|
||||
@@ -21,6 +21,13 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
if cfg.processor_type:
|
||||
processor_cls = getattr(transformers, cfg.processor_type)
|
||||
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
from axolotl.utils.mistral import Mistral3Processor
|
||||
|
||||
return Mistral3Processor(
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
processor = processor_cls.from_pretrained(
|
||||
cfg.processor_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
|
||||
@@ -124,13 +124,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
|
||||
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
||||
"""Load mistral-common tokenizer"""
|
||||
from transformers import tokenization_mistral_common
|
||||
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
# patch
|
||||
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
|
||||
|
||||
# Load the HF-compatible wrapper around MistralTokenizer
|
||||
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
|
||||
|
||||
|
||||
0
src/axolotl/monkeypatch/models/mistral3/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/mistral3/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def apply_mistral_tokenizer_image_patch():
|
||||
"""Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion."""
|
||||
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
||||
|
||||
# Get original source
|
||||
original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template)
|
||||
original_source, _ = detab_code(original_source)
|
||||
|
||||
# Define the replacement
|
||||
original_tensor_conversion = (
|
||||
" pixel_values = torch.tensor(images)"
|
||||
)
|
||||
|
||||
patched_tensor_conversion = """ if isinstance(images, list) and len(images) > 0 and isinstance(images[0], np.ndarray):
|
||||
pixel_values = torch.tensor(np.array(images))
|
||||
else:
|
||||
pixel_values = torch.tensor(images)"""
|
||||
|
||||
# Apply the replacement
|
||||
if original_tensor_conversion in original_source:
|
||||
patched_source = original_source.replace(
|
||||
original_tensor_conversion, patched_tensor_conversion
|
||||
)
|
||||
patched_source = patched_source.replace(
|
||||
"def apply_chat_template(",
|
||||
"def patched_apply_chat_template(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Load necessary imports from the module
|
||||
module_name = MistralCommonTokenizer.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Detect what needs to be imported
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in patched_source and not item.startswith("_"):
|
||||
items_to_import.append(item)
|
||||
|
||||
# Execute imports in global scope
|
||||
if items_to_import:
|
||||
exec( # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
|
||||
# Also need standard imports that might be used
|
||||
exec("import numpy as np", globals()) # nosec B102
|
||||
exec("import torch", globals()) # nosec B102
|
||||
exec("from typing import Union, Optional, List, Dict, Any, Callable", globals()) # nosec B102
|
||||
exec("from pathlib import Path", globals()) # nosec B102
|
||||
|
||||
# Import other dependencies that might be needed
|
||||
try:
|
||||
exec("from transformers.utils import is_torch_available", globals()) # nosec B102
|
||||
exec(
|
||||
"from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType",
|
||||
globals(),
|
||||
) # nosec B102
|
||||
exec("from transformers.utils import logging", globals()) # nosec B102
|
||||
exec("logger = logging.get_logger(__name__)", globals()) # nosec B102
|
||||
except ImportError as e:
|
||||
LOG.warning(f"Could not import some dependencies: {e}")
|
||||
|
||||
# Execute the patched source
|
||||
exec(patched_source, globals()) # nosec B102
|
||||
|
||||
# Replace the method
|
||||
MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template
|
||||
LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch")
|
||||
else:
|
||||
LOG.warning("Could not find target code for MistralCommonTokenizer patching")
|
||||
0
src/axolotl/monkeypatch/models/pixtral/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/pixtral/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Monkeypatch for FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def apply_patch_is_packed_sequence():
|
||||
"""Apply patch to FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid"""
|
||||
from transformers import modeling_flash_attention_utils
|
||||
|
||||
def fixed_is_packed_sequence(position_ids, batch_size):
|
||||
"""
|
||||
Check the position ids whether packed sequences are indicated or not
|
||||
1. Position ids exist
|
||||
2. Flattened sequences only are supported
|
||||
3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
|
||||
"""
|
||||
if position_ids is None:
|
||||
return False
|
||||
|
||||
if position_ids.ndim == 1:
|
||||
position_ids = position_ids.unsqueeze(0) # [N] -> [1, N]
|
||||
|
||||
increasing_position_sequences = (
|
||||
torch.arange(position_ids.shape[1], device=position_ids.device)
|
||||
+ position_ids.min()
|
||||
)
|
||||
return (
|
||||
batch_size == 1
|
||||
and (increasing_position_sequences - position_ids).abs().sum().bool().item()
|
||||
)
|
||||
|
||||
# Store original method
|
||||
old_fn = modeling_flash_attention_utils._is_packed_sequence
|
||||
|
||||
# Apply the patch
|
||||
modeling_flash_attention_utils._is_packed_sequence = fixed_is_packed_sequence
|
||||
|
||||
def unpatch():
|
||||
"""Restore the original method"""
|
||||
modeling_flash_attention_utils._is_packed_sequence = old_fn
|
||||
|
||||
return unpatch
|
||||
@@ -11,6 +11,7 @@ from transformers.image_utils import load_image
|
||||
|
||||
from axolotl.utils.dict import remove_none_values
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -421,6 +422,36 @@ class SmolVLM2ProcessingStrategy(ProcessingStrategy):
|
||||
]
|
||||
|
||||
|
||||
class Mistral3ProcessingStrategy(ProcessingStrategy):
|
||||
"""Processing Strategy class for Mistral3"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: Mistral3Processor,
|
||||
chat_template: Optional[str] = None,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||
special_ids = (
|
||||
processor.tokenizer.tokenizer.instruct_tokenizer.image_encoder.special_ids
|
||||
)
|
||||
|
||||
self.image_token = special_ids.img
|
||||
self.image_break_token = special_ids.img_break
|
||||
self.image_end_token = special_ids.img_end
|
||||
|
||||
def process_labels(self, input_ids):
|
||||
labels = input_ids.clone()
|
||||
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == self.image_token] = -100
|
||||
labels[labels == self.image_break_token] = -100
|
||||
labels[labels == self.image_end_token] = -100
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def get_processing_strategy(
|
||||
processor: ProcessorMixin,
|
||||
chat_template,
|
||||
@@ -463,6 +494,11 @@ def get_processing_strategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(processor, Mistral3Processor):
|
||||
return Mistral3ProcessingStrategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
# llama3_2_vision, llama4, llava
|
||||
# mistral_v7_tekken, pixtral, lfm2vl
|
||||
return ProcessingStrategy(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Init for `axolotl.utils.mistral` module."""
|
||||
|
||||
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
|
||||
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
__all__ = ["HFMistralTokenizer"]
|
||||
__all__ = ["HFMistralTokenizer", "Mistral3Processor"]
|
||||
|
||||
169
src/axolotl/utils/mistral/mistral3_processor.py
Normal file
169
src/axolotl/utils/mistral/mistral3_processor.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Processor for Mistral3 multimodal models with image support"""
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.processing_utils import ProcessingKwargs
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
|
||||
class Mistral3ProcessorKwargs(ProcessingKwargs):
|
||||
_defaults: Dict[str, Dict[str, Any]] = {
|
||||
"text_kwargs": {
|
||||
"padding": True,
|
||||
},
|
||||
"common_kwargs": {
|
||||
"return_tensors": "pt",
|
||||
"return_dict": True,
|
||||
"tokenize": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Mistral3Processor(ProcessorMixin):
|
||||
"""
|
||||
Processor for Mistral3 multimodal models that handles text and images.
|
||||
Wraps HFMistralTokenizer and adds image processing capabilities.
|
||||
"""
|
||||
|
||||
attributes = ["tokenizer"]
|
||||
tokenizer_class = "HFMistralTokenizer"
|
||||
|
||||
def __init__(self, tokenizer: HFMistralTokenizer):
|
||||
# Don't call super().__init__ to avoid the class validation issue
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@property
|
||||
def chat_template(self) -> None:
|
||||
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def audio_tokenizer(self) -> None:
|
||||
"""Audio tokenizer is not supported. Dummy method to satisfy HuggingFace API."""
|
||||
return None
|
||||
|
||||
def _merge_kwargs(
|
||||
self, processor_kwargs_class: Any, **kwargs: Any
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Merge kwargs with defaults similar to ProcessorMixin"""
|
||||
defaults = processor_kwargs_class._defaults
|
||||
output_kwargs: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for kwarg_type, default_values in defaults.items():
|
||||
output_kwargs[kwarg_type] = {**default_values}
|
||||
|
||||
# Update with provided kwargs
|
||||
for key, value in kwargs.items():
|
||||
# Try to match key to appropriate kwarg type
|
||||
if key in ["padding", "truncation", "max_length"]:
|
||||
output_kwargs.setdefault("text_kwargs", {}).update({key: value})
|
||||
elif key in ["return_tensors", "return_dict", "tokenize"]:
|
||||
output_kwargs.setdefault("common_kwargs", {}).update({key: value})
|
||||
else:
|
||||
# Add to text_kwargs by default
|
||||
output_kwargs.setdefault("text_kwargs", {}).update({key: value})
|
||||
|
||||
return output_kwargs
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
|
||||
**kwargs: Any,
|
||||
) -> Union[BatchFeature, str, list[str]]:
|
||||
"""
|
||||
Apply chat template with image support for Mistral3.
|
||||
|
||||
Similar to VoxtralProcessor, this method extracts images from the conversation,
|
||||
calls the tokenizer's apply_chat_template, then adds pixel_values and image_sizes
|
||||
to the result.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)
|
||||
text_kwargs = output_kwargs["text_kwargs"]
|
||||
common_kwargs = output_kwargs["common_kwargs"]
|
||||
|
||||
return_tensors = common_kwargs.pop("return_tensors", "pt")
|
||||
if return_tensors != "pt":
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} only supports `return_tensors='pt'`."
|
||||
)
|
||||
|
||||
return_dict = common_kwargs.pop("return_dict", False)
|
||||
tokenize = common_kwargs.pop("tokenize", False)
|
||||
|
||||
# Determine if batched
|
||||
if isinstance(conversation, (list, tuple)) and (
|
||||
isinstance(conversation[0], (list, tuple))
|
||||
or hasattr(conversation[0], "content")
|
||||
):
|
||||
is_batched = True
|
||||
conversations = conversation
|
||||
else:
|
||||
is_batched = False
|
||||
conversations = [conversation] # type: ignore
|
||||
|
||||
# Call tokenizer's apply_chat_template
|
||||
tokenizer_kwargs = {**text_kwargs, **common_kwargs}
|
||||
tokenizer_kwargs["return_tensors"] = return_tensors
|
||||
tokenizer_kwargs["tokenize"] = tokenize
|
||||
tokenizer_kwargs["return_dict"] = return_dict
|
||||
|
||||
encoded_instruct_inputs = self.tokenizer.apply_chat_template(
|
||||
conversations,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
if tokenize:
|
||||
if return_dict:
|
||||
# The tokenizer already handles pixel_values, we just need to add image_sizes
|
||||
if hasattr(encoded_instruct_inputs, "items"):
|
||||
data: Dict[str, Any] = dict(encoded_instruct_inputs) # type: ignore
|
||||
elif hasattr(encoded_instruct_inputs, "data"):
|
||||
data = encoded_instruct_inputs.data # type: ignore
|
||||
else:
|
||||
raise ValueError("Unknown data type")
|
||||
|
||||
if "pixel_values" in data:
|
||||
pixel_values = data["pixel_values"]
|
||||
|
||||
# MistralTokenizer returns a Double, so we convert to fp32
|
||||
data["pixel_values"] = pixel_values.to(dtype=torch.float32)
|
||||
|
||||
# Always batched: [B, C, H, W] -> image_sizes: [B, 2]
|
||||
# Since tensor is homogeneous, all images have same H, W
|
||||
batch_size = pixel_values.shape[0]
|
||||
image_sizes = torch.tensor([pixel_values.shape[-2:]] * batch_size)
|
||||
data["image_sizes"] = image_sizes
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
if not is_batched:
|
||||
return encoded_instruct_inputs[0]
|
||||
|
||||
return encoded_instruct_inputs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[
|
||||
Union[
|
||||
TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
|
||||
]
|
||||
],
|
||||
**kwargs: Any,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Forward text processing to the tokenizer.
|
||||
This method does not support images - use apply_chat_template instead.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)
|
||||
text_kwargs = output_kwargs["text_kwargs"]
|
||||
common_kwargs = output_kwargs["common_kwargs"]
|
||||
|
||||
out = self.tokenizer(text, **text_kwargs)
|
||||
return BatchFeature(
|
||||
data=out, tensor_type=common_kwargs.pop("return_tensors", None)
|
||||
)
|
||||
Reference in New Issue
Block a user