diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index d839ce211..413404195 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -13,6 +13,7 @@ format: - [Pixtral](#sec-pixtral) - [Llava-1.5](#sec-llava-15) - [Mistral-Small-3.1](#sec-mistral-small-31) +- [Magistral-Small-2509](#sec-magistral-small-2509) - [Voxtral](#sec-voxtral) - [Gemma-3](#sec-gemma-3) - [Gemma-3n](#sec-gemma-3n) @@ -94,10 +95,22 @@ chat_template: llava ### Mistral-Small-3.1 {#sec-mistral-small-31} +::: {.callout-tip} +Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'` +::: + ```yaml base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503 +``` -chat_template: mistral_v7_tekken +### Magistral-Small-2509 {#sec-magistral-small-2509} + +::: {.callout-tip} +Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'` +::: + +```yaml +base_model: mistralai/Magistral-Small-2509 ``` ### Voxtral {#sec-voxtral} diff --git a/examples/gemma3n/README.md b/examples/gemma3n/README.md index 8c4e02a1d..ff3946c90 100644 --- a/examples/gemma3n/README.md +++ b/examples/gemma3n/README.md @@ -23,7 +23,15 @@ pip3 install timm==1.0.17 pip3 install librosa==0.11.0 ``` -3. Run the finetuning example: +3. Download sample dataset files + +```bash +# for text + vision + audio only +wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg +wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga +``` + +4. Run the finetuning example: ```bash # text only diff --git a/examples/magistral/README.md b/examples/magistral/README.md index f4f278208..a09138744 100644 --- a/examples/magistral/README.md +++ b/examples/magistral/README.md @@ -1,10 +1,10 @@ # Finetune Magistral Small with Axolotl -Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506) and [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. +Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506), [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)), and [2509](https://huggingface.co/mistralai/Magistral-Small-2509) (see [Vision](#vision)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. MistralAI has also released a proprietary medium-sized version called Magistral Medium. -Thanks to the team at MistralAI for giving us early access to prepare for this release. +Thanks to the team at MistralAI for giving us early access to prepare for these releases. ## Getting started @@ -36,29 +36,17 @@ Let us know how it goes. Happy finetuning! 🚀 ### Thinking -MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities. The model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages. +MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps. -Example format: +📚 **[See the Thinking fine-tuning guide →](./think/README.md)** -```json -{ - "messages": [ - {"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]}, - {"role": "user", "content": [{ "type": "text", "text": "..."}]}, - {"role": "assistant", "content": [{ "type": "thinking", "thinking": "..."}, { "type": "text", "text": "..." }]}, - ], -} -``` +### Vision -Example config: `./magistral-small-think-qlora.yaml`. +MistralAI has released their [2509](https://huggingface.co/mistralai/Magistral-Small-2509) model with vision capabilities. -The `thinking` section also supports an optional arg `closed: bool` (`True` default) which controls adding the closing `[/THINK]` tag. +📚 **[See the Vision fine-tuning guide →](./vision/README.md)** -Limitations: -- You cannot mix `content: str` with `content: list[dict]` as the `dataset.load_dataset` may complain about different types for `content` key. -- This mode does not work with custom `train_detail` and `training` at the moment. - -### TIPS +### Tips - We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`. - For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`. @@ -89,5 +77,5 @@ In addition, we do not support overriding tokens yet. ## Future Work -- Add parity to Preference Tuning, RL, Multi-modal, etc. +- Add parity to Preference Tuning, RL, etc. - Add parity to other tokenizer configs like overriding tokens. diff --git a/examples/magistral/think/README.md b/examples/magistral/think/README.md new file mode 100644 index 000000000..29950f59e --- /dev/null +++ b/examples/magistral/think/README.md @@ -0,0 +1,73 @@ +# Magistral Small Thinking Fine-tuning + +This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mistralai/Magistral-Small-2507) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections. + +## Prerequisites + +Before starting, ensure you have: +- Installed Axolotl (see [main README](../README.md)) + +## Getting Started + +Run the thinking model fine-tuning: + +```bash +axolotl train magistral-small-think-qlora.yaml +``` + +This config uses about 19.1 GiB VRAM. + +### Tips + +- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below. +- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent. + +## Dataset Format + +The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages. + +Example format: + +```json +{ + "messages": [ + { + "role": "system", + "content": [ + { "type": "text", "text": "{SYSTEM_PROMPT}"} + ] + }, + { + "role": "user", + "content": [ + { "type": "text", "text": "Solve this step by step: What is 15% of 240?"} + ] + }, + { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36." + }, + { + "type": "text", + "text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36." + } + ] + } + ] +} +``` + +### Advanced Options + +The `thinking` section supports an optional `closed` parameter: + +```json +{ + "type": "thinking", + "thinking": "Internal reasoning here...", + "closed": true // Default: true, controls adding the closing [/THINK] tag +} +``` diff --git a/examples/magistral/magistral-small-think-qlora.yaml b/examples/magistral/think/magistral-small-think-qlora.yaml similarity index 100% rename from examples/magistral/magistral-small-think-qlora.yaml rename to examples/magistral/think/magistral-small-think-qlora.yaml diff --git a/examples/magistral/vision/README.md b/examples/magistral/vision/README.md new file mode 100644 index 000000000..932a3631e --- /dev/null +++ b/examples/magistral/vision/README.md @@ -0,0 +1,60 @@ +# Magistral Small Vision Fine-tuning + +This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mistralai/Magistral-Small-2509) with vision capabilities using Axolotl. + +## Prerequisites + +Before starting, ensure you have: +- Installed Axolotl from source (see [main README](../README.md#getting-started)) + +## Getting started + +1. Install the required vision lib: + ```bash + pip install 'mistral-common[opencv]==1.8.5' + ``` + +2. Download the example dataset image: + ```bash + wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg + ``` + +3. Run the fine-tuning: + ```bash + axolotl train magistral-small-vision-24B-qlora.yml + ``` + +This config uses about 17GiB VRAM. + +WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look. + +### Tips + +Key differences from text-only model: +- `max_tokens: 131072` for inference +- Multi-modal dataset format required +- Sample packing not supported + +## Dataset Format + +The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format). + +One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now. + +Example: +```json +{ + "messages": [ + {"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]}, + {"role": "user", "content": [ + { "type": "text", "text": "What's in this image?"}, + {"type": "image", "path": "path/to/image.jpg" } + ]}, + {"role": "assistant", "content": [{ "type": "text", "text": "..." }]}, + ], +} +``` + +## Limitations + +- Sample Packing is not supported for multi-modality training currently. diff --git a/examples/magistral/vision/magistral-small-vision-24B-qlora.yml b/examples/magistral/vision/magistral-small-vision-24B-qlora.yml new file mode 100644 index 000000000..397db383e --- /dev/null +++ b/examples/magistral/vision/magistral-small-vision-24B-qlora.yml @@ -0,0 +1,64 @@ +base_model: mistralai/Magistral-Small-2509 +processor_type: AutoProcessor + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_4bit: true + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +# sample dataset below requires downloading image in advance +# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg +datasets: + - path: Nanobit/text-vision-2k-test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/bigstral-ds-zero3.yaml b/examples/mistral/bigstral/bigstral-ds-zero3.yaml similarity index 100% rename from examples/mistral/bigstral-ds-zero3.yaml rename to examples/mistral/bigstral/bigstral-ds-zero3.yaml diff --git a/examples/mistral/mistral-dpo-qlora.yml b/examples/mistral/dpo/mistral-dpo-qlora.yml similarity index 100% rename from examples/mistral/mistral-dpo-qlora.yml rename to examples/mistral/dpo/mistral-dpo-qlora.yml diff --git a/examples/mistral/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml similarity index 78% rename from examples/mistral/mistral-small-3.1-24B-lora.yml rename to examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml index 3e477645e..ec197f333 100644 --- a/examples/mistral/mistral-small-3.1-24B-lora.yml +++ b/examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml @@ -1,6 +1,9 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503 processor_type: AutoProcessor +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + load_in_8bit: true # these 3 lines are needed for now to handle vision chat templates w images @@ -8,12 +11,12 @@ skip_prepare_dataset: true remove_unused_columns: false sample_packing: false -chat_template: mistral_v7_tekken +# sample dataset below requires downloading image in advance +# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg datasets: - - path: HuggingFaceH4/llava-instruct-mix-vsft + - path: Nanobit/text-vision-2k-test type: chat_template - split: train[:1%] - field_messages: messages + dataset_prepared_path: last_run_prepared val_set_size: 0.01 output_dir: ./outputs/out @@ -48,8 +51,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -# flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet. -sdp_attention: true +flash_attention: true warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml b/examples/mistral/mixtral/mixtral-8x22b-qlora-fsdp.yml similarity index 100% rename from examples/mistral/mixtral-8x22b-qlora-fsdp.yml rename to examples/mistral/mixtral/mixtral-8x22b-qlora-fsdp.yml diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral/mixtral-qlora-fsdp.yml similarity index 100% rename from examples/mistral/mixtral-qlora-fsdp.yml rename to examples/mistral/mixtral/mixtral-qlora-fsdp.yml diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral/mixtral.yml similarity index 100% rename from examples/mistral/mixtral.yml rename to examples/mistral/mixtral/mixtral.yml diff --git a/examples/mistral/mixtral_22.yml b/examples/mistral/mixtral/mixtral_22.yml similarity index 100% rename from examples/mistral/mixtral_22.yml rename to examples/mistral/mixtral/mixtral_22.yml diff --git a/examples/mistral/lora-mps.yml b/examples/mistral/mps/lora-mps.yml similarity index 100% rename from examples/mistral/lora-mps.yml rename to examples/mistral/mps/lora-mps.yml diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/orpo/mistral-qlora-orpo.yml similarity index 100% rename from examples/mistral/mistral-qlora-orpo.yml rename to examples/mistral/orpo/mistral-qlora-orpo.yml diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml index fea2a60ff..0e6489914 100644 --- a/examples/pixtral/lora-12b.yml +++ b/examples/pixtral/lora-12b.yml @@ -45,8 +45,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -# flash_attention: # PixtralVisionModel does not support Flash Attention 2.0 yet -sdp_attention: true +flash_attention: true warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/voxtral/README.md b/examples/voxtral/README.md index 984af4ddb..b77691d72 100644 --- a/examples/voxtral/README.md +++ b/examples/voxtral/README.md @@ -27,7 +27,14 @@ pip3 install 'mistral_common[audio]==1.8.3' python scripts/cutcrossentropy_install.py | sh ``` -3. Run the finetuning example: +3. Download sample dataset files + +```bash +# for text + audio only +wget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga +``` + +4. Run the finetuning example: ```bash # text only diff --git a/requirements.txt b/requirements.txt index 44a3c0277..86013374f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -70,4 +70,4 @@ schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.6 axolotl-contribs-mit==0.0.5 -mistral-common==1.8.3 +mistral-common==1.8.5 diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index a5a630cb5..98eb07b0f 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -168,6 +168,13 @@ class PatchManager: patch_llama4_linearized_modeling() + if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type: + from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import ( + apply_mistral_tokenizer_image_patch, + ) + + apply_mistral_tokenizer_image_patch() + def _apply_fp8_patches(self): """Apply patches for FP8 support.""" if self.cfg.fp8: @@ -334,6 +341,13 @@ class PatchManager: replace_stablelm_attn_with_flash_attn(self.cfg.base_model) + if self.model_config.model_type in ("mistral3", "llava"): + from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import ( + apply_patch_is_packed_sequence, + ) + + apply_patch_is_packed_sequence() + def _patch_loss_llama(self): """Patch loss functions and other optimizations for LLaMA models.""" if not self.cfg.is_llama_derived_model: diff --git a/src/axolotl/loaders/processor.py b/src/axolotl/loaders/processor.py index 2e3ec8d7f..7580b2008 100644 --- a/src/axolotl/loaders/processor.py +++ b/src/axolotl/loaders/processor.py @@ -21,6 +21,13 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): if cfg.processor_type: processor_cls = getattr(transformers, cfg.processor_type) + if cfg.tokenizer_use_mistral_common: + from axolotl.utils.mistral import Mistral3Processor + + return Mistral3Processor( + tokenizer=tokenizer, + ) + processor = processor_cls.from_pretrained( cfg.processor_config, trust_remote_code=cfg.trust_remote_code or False, diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 37b66ac83..69455dd77 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -124,13 +124,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: def _load_mistral_common_tokenizer(cfg: DictDefault): """Load mistral-common tokenizer""" - from transformers import tokenization_mistral_common - from axolotl.utils.mistral import HFMistralTokenizer - # patch - tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer - # Load the HF-compatible wrapper around MistralTokenizer tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config) diff --git a/src/axolotl/monkeypatch/models/mistral3/__init__.py b/src/axolotl/monkeypatch/models/mistral3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/mistral3/mistral_common_tokenizer.py b/src/axolotl/monkeypatch/models/mistral3/mistral_common_tokenizer.py new file mode 100644 index 000000000..9e7259a05 --- /dev/null +++ b/src/axolotl/monkeypatch/models/mistral3/mistral_common_tokenizer.py @@ -0,0 +1,85 @@ +""" +Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template +""" + +import importlib +import inspect + +from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def apply_mistral_tokenizer_image_patch(): + """Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion.""" + from transformers.tokenization_mistral_common import MistralCommonTokenizer + + # Get original source + original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template) + original_source, _ = detab_code(original_source) + + # Define the replacement + original_tensor_conversion = ( + " pixel_values = torch.tensor(images)" + ) + + patched_tensor_conversion = """ if isinstance(images, list) and len(images) > 0 and isinstance(images[0], np.ndarray): + pixel_values = torch.tensor(np.array(images)) + else: + pixel_values = torch.tensor(images)""" + + # Apply the replacement + if original_tensor_conversion in original_source: + patched_source = original_source.replace( + original_tensor_conversion, patched_tensor_conversion + ) + patched_source = patched_source.replace( + "def apply_chat_template(", + "def patched_apply_chat_template(", + 1, + ) + + # Load necessary imports from the module + module_name = MistralCommonTokenizer.__module__ + module = importlib.import_module(module_name) + + # Detect what needs to be imported + items_to_import = [] + for item in dir(module): + if item in patched_source and not item.startswith("_"): + items_to_import.append(item) + + # Execute imports in global scope + if items_to_import: + exec( # nosec B102 + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + + # Also need standard imports that might be used + exec("import numpy as np", globals()) # nosec B102 + exec("import torch", globals()) # nosec B102 + exec("from typing import Union, Optional, List, Dict, Any, Callable", globals()) # nosec B102 + exec("from pathlib import Path", globals()) # nosec B102 + + # Import other dependencies that might be needed + try: + exec("from transformers.utils import is_torch_available", globals()) # nosec B102 + exec( + "from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType", + globals(), + ) # nosec B102 + exec("from transformers.utils import logging", globals()) # nosec B102 + exec("logger = logging.get_logger(__name__)", globals()) # nosec B102 + except ImportError as e: + LOG.warning(f"Could not import some dependencies: {e}") + + # Execute the patched source + exec(patched_source, globals()) # nosec B102 + + # Replace the method + MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template + LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch") + else: + LOG.warning("Could not find target code for MistralCommonTokenizer patching") diff --git a/src/axolotl/monkeypatch/models/pixtral/__init__.py b/src/axolotl/monkeypatch/models/pixtral/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/pixtral/modeling_flash_attention_utils.py b/src/axolotl/monkeypatch/models/pixtral/modeling_flash_attention_utils.py new file mode 100644 index 000000000..d2b482f19 --- /dev/null +++ b/src/axolotl/monkeypatch/models/pixtral/modeling_flash_attention_utils.py @@ -0,0 +1,42 @@ +"""Monkeypatch for FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid""" + +import torch + + +def apply_patch_is_packed_sequence(): + """Apply patch to FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid""" + from transformers import modeling_flash_attention_utils + + def fixed_is_packed_sequence(position_ids, batch_size): + """ + Check the position ids whether packed sequences are indicated or not + 1. Position ids exist + 2. Flattened sequences only are supported + 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences + """ + if position_ids is None: + return False + + if position_ids.ndim == 1: + position_ids = position_ids.unsqueeze(0) # [N] -> [1, N] + + increasing_position_sequences = ( + torch.arange(position_ids.shape[1], device=position_ids.device) + + position_ids.min() + ) + return ( + batch_size == 1 + and (increasing_position_sequences - position_ids).abs().sum().bool().item() + ) + + # Store original method + old_fn = modeling_flash_attention_utils._is_packed_sequence + + # Apply the patch + modeling_flash_attention_utils._is_packed_sequence = fixed_is_packed_sequence + + def unpatch(): + """Restore the original method""" + modeling_flash_attention_utils._is_packed_sequence = old_fn + + return unpatch diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index 4b06eb4c8..5e7c1456a 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -11,6 +11,7 @@ from transformers.image_utils import load_image from axolotl.utils.dict import remove_none_values from axolotl.utils.logging import get_logger +from axolotl.utils.mistral.mistral3_processor import Mistral3Processor LOG = get_logger(__name__) @@ -421,6 +422,36 @@ class SmolVLM2ProcessingStrategy(ProcessingStrategy): ] +class Mistral3ProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Mistral3""" + + def __init__( + self, + processor: Mistral3Processor, + chat_template: Optional[str] = None, + image_size: int | tuple[int, int] | None = None, + image_resize_algorithm: Resampling | None = None, + ): + super().__init__(processor, chat_template, image_size, image_resize_algorithm) + special_ids = ( + processor.tokenizer.tokenizer.instruct_tokenizer.image_encoder.special_ids + ) + + self.image_token = special_ids.img + self.image_break_token = special_ids.img_break + self.image_end_token = special_ids.img_end + + def process_labels(self, input_ids): + labels = input_ids.clone() + + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + labels[labels == self.image_token] = -100 + labels[labels == self.image_break_token] = -100 + labels[labels == self.image_end_token] = -100 + + return labels + + def get_processing_strategy( processor: ProcessorMixin, chat_template, @@ -463,6 +494,11 @@ def get_processing_strategy( **processing_kwargs, ) + if isinstance(processor, Mistral3Processor): + return Mistral3ProcessingStrategy( + **processing_kwargs, + ) + # llama3_2_vision, llama4, llava # mistral_v7_tekken, pixtral, lfm2vl return ProcessingStrategy( diff --git a/src/axolotl/utils/mistral/__init__.py b/src/axolotl/utils/mistral/__init__.py index eb1e2df89..eb51031ec 100644 --- a/src/axolotl/utils/mistral/__init__.py +++ b/src/axolotl/utils/mistral/__init__.py @@ -1,5 +1,6 @@ """Init for `axolotl.utils.mistral` module.""" +from axolotl.utils.mistral.mistral3_processor import Mistral3Processor from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer -__all__ = ["HFMistralTokenizer"] +__all__ = ["HFMistralTokenizer", "Mistral3Processor"] diff --git a/src/axolotl/utils/mistral/mistral3_processor.py b/src/axolotl/utils/mistral/mistral3_processor.py new file mode 100644 index 000000000..85479ca7b --- /dev/null +++ b/src/axolotl/utils/mistral/mistral3_processor.py @@ -0,0 +1,169 @@ +"""Processor for Mistral3 multimodal models with image support""" + +from typing import Any, Dict, Optional, Union + +import torch +from transformers import ProcessorMixin +from transformers.feature_extraction_utils import BatchFeature +from transformers.processing_utils import ProcessingKwargs +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer + + +class Mistral3ProcessorKwargs(ProcessingKwargs): + _defaults: Dict[str, Dict[str, Any]] = { + "text_kwargs": { + "padding": True, + }, + "common_kwargs": { + "return_tensors": "pt", + "return_dict": True, + "tokenize": True, + }, + } + + +class Mistral3Processor(ProcessorMixin): + """ + Processor for Mistral3 multimodal models that handles text and images. + Wraps HFMistralTokenizer and adds image processing capabilities. + """ + + attributes = ["tokenizer"] + tokenizer_class = "HFMistralTokenizer" + + def __init__(self, tokenizer: HFMistralTokenizer): + # Don't call super().__init__ to avoid the class validation issue + self.tokenizer = tokenizer + + @property + def chat_template(self) -> None: + """Chat template is not supported. Dummy method to satisfy HuggingFace API.""" + return None + + @property + def audio_tokenizer(self) -> None: + """Audio tokenizer is not supported. Dummy method to satisfy HuggingFace API.""" + return None + + def _merge_kwargs( + self, processor_kwargs_class: Any, **kwargs: Any + ) -> Dict[str, Dict[str, Any]]: + """Merge kwargs with defaults similar to ProcessorMixin""" + defaults = processor_kwargs_class._defaults + output_kwargs: Dict[str, Dict[str, Any]] = {} + + for kwarg_type, default_values in defaults.items(): + output_kwargs[kwarg_type] = {**default_values} + + # Update with provided kwargs + for key, value in kwargs.items(): + # Try to match key to appropriate kwarg type + if key in ["padding", "truncation", "max_length"]: + output_kwargs.setdefault("text_kwargs", {}).update({key: value}) + elif key in ["return_tensors", "return_dict", "tokenize"]: + output_kwargs.setdefault("common_kwargs", {}).update({key: value}) + else: + # Add to text_kwargs by default + output_kwargs.setdefault("text_kwargs", {}).update({key: value}) + + return output_kwargs + + def apply_chat_template( + self, + conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]], + **kwargs: Any, + ) -> Union[BatchFeature, str, list[str]]: + """ + Apply chat template with image support for Mistral3. + + Similar to VoxtralProcessor, this method extracts images from the conversation, + calls the tokenizer's apply_chat_template, then adds pixel_values and image_sizes + to the result. + """ + output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs) + text_kwargs = output_kwargs["text_kwargs"] + common_kwargs = output_kwargs["common_kwargs"] + + return_tensors = common_kwargs.pop("return_tensors", "pt") + if return_tensors != "pt": + raise ValueError( + f"{self.__class__.__name__} only supports `return_tensors='pt'`." + ) + + return_dict = common_kwargs.pop("return_dict", False) + tokenize = common_kwargs.pop("tokenize", False) + + # Determine if batched + if isinstance(conversation, (list, tuple)) and ( + isinstance(conversation[0], (list, tuple)) + or hasattr(conversation[0], "content") + ): + is_batched = True + conversations = conversation + else: + is_batched = False + conversations = [conversation] # type: ignore + + # Call tokenizer's apply_chat_template + tokenizer_kwargs = {**text_kwargs, **common_kwargs} + tokenizer_kwargs["return_tensors"] = return_tensors + tokenizer_kwargs["tokenize"] = tokenize + tokenizer_kwargs["return_dict"] = return_dict + + encoded_instruct_inputs = self.tokenizer.apply_chat_template( + conversations, + **tokenizer_kwargs, + ) + + if tokenize: + if return_dict: + # The tokenizer already handles pixel_values, we just need to add image_sizes + if hasattr(encoded_instruct_inputs, "items"): + data: Dict[str, Any] = dict(encoded_instruct_inputs) # type: ignore + elif hasattr(encoded_instruct_inputs, "data"): + data = encoded_instruct_inputs.data # type: ignore + else: + raise ValueError("Unknown data type") + + if "pixel_values" in data: + pixel_values = data["pixel_values"] + + # MistralTokenizer returns a Double, so we convert to fp32 + data["pixel_values"] = pixel_values.to(dtype=torch.float32) + + # Always batched: [B, C, H, W] -> image_sizes: [B, 2] + # Since tensor is homogeneous, all images have same H, W + batch_size = pixel_values.shape[0] + image_sizes = torch.tensor([pixel_values.shape[-2:]] * batch_size) + data["image_sizes"] = image_sizes + + return BatchFeature(data=data, tensor_type=return_tensors) + + if not is_batched: + return encoded_instruct_inputs[0] + + return encoded_instruct_inputs + + def __call__( + self, + text: Optional[ + Union[ + TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput] + ] + ], + **kwargs: Any, + ) -> BatchFeature: + """ + Forward text processing to the tokenizer. + This method does not support images - use apply_chat_template instead. + """ + output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs) + text_kwargs = output_kwargs["text_kwargs"] + common_kwargs = output_kwargs["common_kwargs"] + + out = self.tokenizer(text, **text_kwargs) + return BatchFeature( + data=out, tensor_type=common_kwargs.pop("return_tensors", None) + ) diff --git a/tests/monkeypatch/test_mistral_tokenizer_patch.py b/tests/monkeypatch/test_mistral_tokenizer_patch.py new file mode 100644 index 000000000..cb82c0890 --- /dev/null +++ b/tests/monkeypatch/test_mistral_tokenizer_patch.py @@ -0,0 +1,35 @@ +"""Integration tests for MistralCommonTokenizer patches.""" + +import pytest + + +class TestMistralTokenizerPatchIntegration: + """Test MistralCommonTokenizer patch integration.""" + + @pytest.mark.integration + def test_mistral_tokenizer_image_patch(self): + """Test that MistralCommonTokenizer image patch can be applied.""" + try: + from transformers.tokenization_mistral_common import MistralCommonTokenizer + except ImportError: + pytest.skip("MistralCommonTokenizer not available") + + from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import ( + apply_mistral_tokenizer_image_patch, + ) + + # Store original method + original_apply_chat_template = MistralCommonTokenizer.apply_chat_template + + # Apply patch + apply_mistral_tokenizer_image_patch() + + # Verify patch was applied + assert ( + MistralCommonTokenizer.apply_chat_template != original_apply_chat_template + ), "apply_chat_template was not patched" + + # Verify the method is still callable + assert callable(MistralCommonTokenizer.apply_chat_template), ( + "Patched method is not callable" + ) diff --git a/tests/monkeypatch/test_pixtral_flash_attention_patch.py b/tests/monkeypatch/test_pixtral_flash_attention_patch.py new file mode 100644 index 000000000..285fde41e --- /dev/null +++ b/tests/monkeypatch/test_pixtral_flash_attention_patch.py @@ -0,0 +1,77 @@ +"""Integration tests for Pixtral Flash Attention patches.""" + +import pytest +import torch + + +class TestPixtralFlashAttentionPatchIntegration: + """Test Pixtral Flash Attention patch integration.""" + + @pytest.mark.integration + def test_pixtral_flash_attention_patch(self): + """Test that Pixtral Flash Attention patch can be applied and works correctly.""" + try: + from transformers import modeling_flash_attention_utils + except ImportError: + pytest.skip("Flash Attention utils not available") + + from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import ( + apply_patch_is_packed_sequence, + ) + + # Store original method + original_is_packed_sequence = modeling_flash_attention_utils._is_packed_sequence + + # Apply patch and get unpatch function + unpatch_fn = apply_patch_is_packed_sequence() + + # Verify patch was applied + assert ( + modeling_flash_attention_utils._is_packed_sequence + != original_is_packed_sequence + ), "_is_packed_sequence was not patched" + + # Test the patched function with 1D position_ids + patched_fn = modeling_flash_attention_utils._is_packed_sequence + + # Test 1D position_ids 1 sequence + position_ids_1d = torch.tensor([0, 1, 2, 3]) + result = patched_fn(position_ids_1d, batch_size=1) + assert isinstance(result, bool), "Function should return a boolean" + assert result is False, "1D sequential position_ids should not be packed" + + # Test 1D packed 2 sequences + position_ids_1d_packed = torch.tensor([0, 1, 2, 0, 1, 2]) + result = patched_fn(position_ids_1d_packed, batch_size=1) + assert isinstance(result, bool), "Function should return a boolean" + assert result is True, "1D packed position_ids should be detected as packed" + + # Test 2D packed 2 sequences + position_ids_2d_packed = torch.tensor([[0, 1, 2, 3, 0, 1]]) + result = patched_fn(position_ids_2d_packed, batch_size=1) + assert isinstance(result, bool), "Function should return a boolean" + assert result is True, "2D packed position_ids should be detected as packed" + + # Test 2D 1 sequence + position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5]]) + result = patched_fn(position_ids_2d_normal, batch_size=1) + assert isinstance(result, bool), "Function should return a boolean" + assert result is False, "2D sequential position_ids should not be packed" + + # Test 2D batch size 2 + position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]]) + result = patched_fn(position_ids_2d_normal, batch_size=2) + assert isinstance(result, bool), "Function should return a boolean" + assert result is False, "2D position_ids batch 2 should not be packed" + + # Test None case + result = patched_fn(None, batch_size=1) + assert isinstance(result, bool), "Function should return a boolean" + assert result is False, "None position_ids should return False" + + # Test unpatch function + unpatch_fn() + assert ( + modeling_flash_attention_utils._is_packed_sequence + == original_is_packed_sequence + ), "unpatch function did not restore original method" diff --git a/tests/monkeypatch/test_voxtral_modeling_patch.py b/tests/monkeypatch/test_voxtral_modeling_patch.py new file mode 100644 index 000000000..878bbc185 --- /dev/null +++ b/tests/monkeypatch/test_voxtral_modeling_patch.py @@ -0,0 +1,43 @@ +"""Integration tests for Voxtral modeling patches.""" + +import pytest + + +class TestVoxtralModelingPatchIntegration: + """Test Voxtral modeling patch integration.""" + + @pytest.mark.integration + def test_voxtral_conditional_generation_patch(self): + """Test that Voxtral conditional generation patch can be applied.""" + try: + from transformers.models.voxtral.modeling_voxtral import ( + VoxtralForConditionalGeneration, + ) + except ImportError: + pytest.skip("VoxtralForConditionalGeneration not available") + + from axolotl.monkeypatch.models.voxtral.modeling import ( + patch_voxtral_conditional_generation_forward, + ) + + # Store original method + original_forward = VoxtralForConditionalGeneration.forward + + # Apply patch and get unpatch function + unpatch_fn = patch_voxtral_conditional_generation_forward() + + # Verify patch was applied + assert VoxtralForConditionalGeneration.forward != original_forward, ( + "forward method was not patched" + ) + + # Verify the method is still callable + assert callable(VoxtralForConditionalGeneration.forward), ( + "Patched method is not callable" + ) + + # Test unpatch function + unpatch_fn() + assert VoxtralForConditionalGeneration.forward == original_forward, ( + "unpatch function did not restore original method" + )