From a098df527bd10ebaa9d54a14f793bd460485a739 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 17 Mar 2026 09:39:05 +0700 Subject: [PATCH] feat: add Mistral Small 4 (#3502) * feat: add mistral small 4 * fix: update mistral common * fix: deepcopy when passing in tokenizer * feat: add doc on reasoning and thinking section * fix: don't use custom tokenizer and quantize experts * chore: update docs and configs * chore: update doc to follow official name * feat: update cce to include mistral4 * chore: move * fix: naming * fix: test mock breaking get_text_config check * fix: enable CCE and add expert block targetting to configs * chore: docs * fix: use act checkpointing * chore: doc * chore: docs * chore: docs --- README.md | 2 +- docs/multimodal.qmd | 7 ++ .../colab-axolotl-example.ipynb | 2 +- examples/mistral4/README.md | 85 +++++++++++++++++++ examples/mistral4/fft-text.yml | 58 +++++++++++++ examples/mistral4/fft-vision.yml | 57 +++++++++++++ examples/mistral4/qlora-text.yml | 58 +++++++++++++ examples/mistral4/qlora-vision.yml | 63 ++++++++++++++ requirements.txt | 2 +- scripts/cutcrossentropy_install.py | 2 +- src/axolotl/common/architectures.py | 1 + .../integrations/cut_cross_entropy/README.md | 4 +- .../cut_cross_entropy/__init__.py | 2 +- src/axolotl/integrations/kernels/constants.py | 2 + src/axolotl/integrations/kernels/plugin.py | 10 +-- .../integrations/kernels/sonicmoe/routing.py | 59 +++++++++++++ src/axolotl/loaders/model.py | 5 +- src/axolotl/loaders/processor.py | 2 +- src/axolotl/monkeypatch/multipack.py | 1 + src/axolotl/utils/config/__init__.py | 9 ++ 20 files changed, 417 insertions(+), 14 deletions(-) create mode 100644 examples/mistral4/README.md create mode 100644 examples/mistral4/fft-text.yml create mode 100644 examples/mistral4/fft-vision.yml create mode 100644 examples/mistral4/qlora-text.yml create mode 100644 examples/mistral4/qlora-vision.yml diff --git a/README.md b/README.md index f10e08b42..a70dc8edf 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ ## 🎉 Latest Updates - 2026/03: - - New model support has been added in Axolotl for [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45). + - New model support has been added in Axolotl for [[Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45). - [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat). - 2026/02: - [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels. diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index 54793c6e3..e5753732d 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -13,6 +13,7 @@ format: - [Pixtral](#sec-pixtral) - [Llava-1.5](#sec-llava-15) - [Mistral-Small-3.1](#sec-mistral-small-31) +- [Mistral-Small-4](#sec-mistral-small-4) - [Magistral-Small-2509](#sec-magistral-small-2509) - [Voxtral](#sec-voxtral) - [Gemma-3](#sec-gemma-3) @@ -108,6 +109,12 @@ Please make sure to install vision lib via `pip install 'mistral-common[opencv]= base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503 ``` +### Mistral-Small-4 {#sec-mistral-small-4} + +```yaml +base_model: mistralai/Mistral-Small-4-119B-2603 +``` + ### Magistral-Small-2509 {#sec-magistral-small-2509} ::: {.callout-tip} diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 7be9800be..49a45cdc6 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "# This step can take ~5-10 minutes to install dependencies\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", - "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129\"" + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe\"" ] }, { diff --git a/examples/mistral4/README.md b/examples/mistral4/README.md new file mode 100644 index 000000000..651384791 --- /dev/null +++ b/examples/mistral4/README.md @@ -0,0 +1,85 @@ +# Finetune Mistral Small 4 with Axolotl + +Mistral Small 4 is a 119B parameter (6.5B active) multimodal MoE model from MistralAI that unifies instruct, reasoning, and coding capabilities into a single model. It is available on HuggingFace at [Mistral-Small-4-119B-2603](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603). + +Thanks to the team at MistralAI for giving us early access to prepare for this release. + +## Getting started + +Note: Training this model requires weights in BF16 which we will link to later. +Users interested in training can convert / descale the existing FP8 weights. + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage + +3. Install transformers from main + + ```bash + pip install git+https://github.com/huggingface/transformers.git + ``` + +4. Run one of the example configs: + + ```bash + # text-only + axolotl train examples/mistral4/qlora-text.yml # no experts ~69 GiB, experts ~93 GiB + axolotl train examples/mistral4/fft-text.yml + + # text + vision + # run: wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg + axolotl train examples/mistral4/qlora-vision.yml # no experts ~68 GiB + axolotl train examples/mistral4/fft-vision.yml + ``` + +Note: FFT configs provided as reference. Please adjust hyperparameters as needed. + +## Reasoning Effort + +The chat template supports a `reasoning_effort` variable to control the model's reasoning depth: + +- `"none"` — instruct mode (default) +- `"high"` — reasoning mode with explicit thinking steps + +Pass it via `chat_template_kwargs` under your dataset config: + +```yaml +datasets: + - path: your/dataset + type: chat_template + chat_template_kwargs: + reasoning_effort: high +``` + +## Thinking Support + +The chat template supports a `thinking` content type in assistant messages for training on reasoning traces (rendered as `[THINK]...[/THINK]` blocks). + +To use thinking datasets, add the `thinking` mapping via `message_property_mappings`: + +```yaml +datasets: + - path: your/thinking-dataset + type: chat_template + message_property_mappings: + role: role + content: content + thinking: thinking + chat_template_kwargs: + reasoning_effort: high +``` + +See the [Magistral thinking guide](../magistral/think/README.md) for dataset format details. + +## Tips + +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). +- The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format). + +## Related Resources + +- [MistralAI Mistral Small 4 Blog](https://mistral.ai/news/mistral-small-4) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/mistral4/fft-text.yml b/examples/mistral4/fft-text.yml new file mode 100644 index 000000000..e01d96dad --- /dev/null +++ b/examples/mistral4/fft-text.yml @@ -0,0 +1,58 @@ +base_model: mistralai/Mistral-Small-4-119B-2603 + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + - axolotl.integrations.kernels.KernelsPlugin +use_kernels: true +use_sonicmoe: true + +# only train language model layers, freeze vision tower +unfrozen_parameters: + - model.language_model.* + - lm_head + - embed_tokens + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./outputs/out + +sequence_len: 2048 +sample_packing: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: true +tf32: true + +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 + +fsdp_version: 2 +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: false + state_dict_type: FULL_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Mistral4DecoderLayer + reshard_after_forward: true + activation_checkpointing: true diff --git a/examples/mistral4/fft-vision.yml b/examples/mistral4/fft-vision.yml new file mode 100644 index 000000000..aa65dfa6d --- /dev/null +++ b/examples/mistral4/fft-vision.yml @@ -0,0 +1,57 @@ +base_model: mistralai/Mistral-Small-4-119B-2603 +processor_type: AutoProcessor + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + - axolotl.integrations.kernels.KernelsPlugin +use_kernels: true +use_sonicmoe: true + +# vision requirements +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +datasets: + - path: Nanobit/text-vision-2k-test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./outputs/out + +sequence_len: 2048 + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: true +tf32: true + +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 + +fsdp_version: 2 +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: false + state_dict_type: FULL_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Mistral4DecoderLayer + reshard_after_forward: true + activation_checkpointing: true diff --git a/examples/mistral4/qlora-text.yml b/examples/mistral4/qlora-text.yml new file mode 100644 index 000000000..ed38053f6 --- /dev/null +++ b/examples/mistral4/qlora-text.yml @@ -0,0 +1,58 @@ +base_model: mistralai/Mistral-Small-4-119B-2603 + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_4bit: true +quantize_moe_experts: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: qlora + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +# uncomment to train on expert layers +# lora_target_parameters: +# - mlp.experts.gate_up_proj +# - mlp.experts.down_proj +# lora_mlp_kernel: false +# lora_qkv_kernel: false +# lora_o_kernel: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 diff --git a/examples/mistral4/qlora-vision.yml b/examples/mistral4/qlora-vision.yml new file mode 100644 index 000000000..95b8138ce --- /dev/null +++ b/examples/mistral4/qlora-vision.yml @@ -0,0 +1,63 @@ +base_model: mistralai/Mistral-Small-4-119B-2603 +processor_type: AutoProcessor + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_4bit: true +quantize_moe_experts: true + +# vision chat template requirements +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +datasets: + - path: Nanobit/text-vision-2k-test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: qlora + +sequence_len: 2048 + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +# uncomment to train on expert layers +# lora_target_parameters: +# - mlp.experts.gate_up_proj +# - mlp.experts.down_proj +# lora_mlp_kernel: false +# lora_qkv_kernel: false +# lora_o_kernel: false + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 diff --git a/requirements.txt b/requirements.txt index c918d30aa..3fd75c3fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -75,4 +75,4 @@ axolotl-contribs-mit==0.0.6 # telemetry posthog==6.7.11 -mistral-common==1.8.8 +mistral-common==1.10.0 diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index d506fa87e..771a5adb2 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"' ) diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index 0e1de3017..181667cb9 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -16,6 +16,7 @@ MOE_ARCH_BLOCK = { "qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock", "deepseek_v2": "DeepseekV2MoE", "deepseek_v3": "DeepseekV3MoE", + "mistral4": "Mistral4MoE", "gpt_oss": "GptOssDecoderLayer", "lfm2_moe": "Lfm2MoeSparseMoeBlock", "afmoe": "AfmoeMoE", diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 9520dd48c..5a3a73d34 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe" ``` ## Usage @@ -73,8 +73,10 @@ plugins: - ministral3 - mistral - mistral3 +- mistral4 - mixtral - mllama +- nemotron_h - olmo - olmo2 - olmo3 diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index d8aa075b9..808aff662 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -35,7 +35,7 @@ LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install Axolotl's fork of cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`' ) diff --git a/src/axolotl/integrations/kernels/constants.py b/src/axolotl/integrations/kernels/constants.py index 529ed4ad6..a7d513b5e 100644 --- a/src/axolotl/integrations/kernels/constants.py +++ b/src/axolotl/integrations/kernels/constants.py @@ -25,6 +25,8 @@ SPARSE_MOE_BLOCK = { "olmoe": "OlmoeSparseMoeBlock", "mixtral": "MixtralSparseMoeBlock", "minimax": "MiniMaxSparseMoeBlock", + # softmax -> topk routing (with group-based expert selection) + "mistral4": "Mistral4MoE", # sigmoid -> topk routing (with group-based expert selection) "glm_moe_dsa": "GlmMoeDsaMoE", "deepseek_v3": "DeepseekV3MoE", diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index ad14dd148..f085e481c 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -61,9 +61,11 @@ class KernelsPlugin(BasePlugin): return "axolotl.integrations.kernels.KernelsArgs" def pre_model_load(self, cfg): + moe_model_type = cfg.model_config_type_text or cfg.model_config_type + if cfg.use_scattermoe: self._register_kernels() - self._kernelize_model(cfg.model_config_type) + self._kernelize_model(moe_model_type) elif cfg.use_sonicmoe: if not importlib.util.find_spec("sonicmoe"): raise RuntimeError( @@ -75,11 +77,9 @@ class KernelsPlugin(BasePlugin): from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe - LOG.info( - f"Applying SonicMoE patches for model type: {cfg.model_config_type}" - ) + LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}") patch_sonicmoe( - cfg.model_config_type, + moe_model_type, torch_compile=bool(getattr(cfg, "torch_compile", False)), ) diff --git a/src/axolotl/integrations/kernels/sonicmoe/routing.py b/src/axolotl/integrations/kernels/sonicmoe/routing.py index 3f93c1596..fe2d12092 100644 --- a/src/axolotl/integrations/kernels/sonicmoe/routing.py +++ b/src/axolotl/integrations/kernels/sonicmoe/routing.py @@ -5,6 +5,7 @@ Different MoE architectures use different routing strategies: - qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization) - gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None) - glm_moe_dsa: sigmoid -> topk (with group-based expert selection) +- mistral4: softmax -> group selection -> topk (with renormalization and scaling) Each model type maps to a (routing_fn, activation_type, router_attr) triple. When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used. @@ -45,6 +46,8 @@ def get_model_moe_config(model_type: str): "minimax", ): return softmax_topk_routing, ActivationType.SWIGLU, "gate" + elif model_type in ("mistral4",): + return softmax_group_topk_routing, ActivationType.SWIGLU, "gate" elif model_type in ( "glm_moe_dsa", "deepseek_v3", @@ -126,6 +129,62 @@ def softmax_topk_routing( return flat_scores, flat_token_idx, flat_expert_idx, router_logits +def softmax_group_topk_routing( + hidden_states: torch.Tensor, moe_block +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale.""" + gate = moe_block.gate + T, H = hidden_states.shape + K = moe_block.top_k + E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0]) + n_group = getattr(moe_block, "n_group", 1) + + router_logits = F.linear(hidden_states, gate.weight) # [T, E] + router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] + + scores_for_choice = router_probs + + # Group selection: pick top groups, mask the rest + if n_group > 1: + group_scores = ( + scores_for_choice.view(-1, n_group, E // n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk( + group_scores, k=moe_block.topk_group, dim=-1, sorted=False + )[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + + topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1] + topk_weights = router_probs.gather(1, topk_indices) + + # Renormalization + scaling + norm_topk_prob = getattr(moe_block, "norm_topk_prob", True) + if norm_topk_prob: + topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20) + routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0) + topk_weights = topk_weights * routed_scaling_factor + + # Flatten for moe_general_routing_inputs + token_indices = ( + torch.arange(T, device=hidden_states.device, dtype=torch.int32) + .unsqueeze(1) + .expand(T, K) + ) + + flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K] + flat_token_idx = token_indices.reshape(-1) # [T*K] + flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K] + + return flat_scores, flat_token_idx, flat_expert_idx, router_logits + + def sigmoid_topk_routing( hidden_states: torch.Tensor, moe_block ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 03c1f35bc..2662d0b86 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -829,8 +829,9 @@ class ModelLoader: def _set_z3_leaf_modules(self): from deepspeed.utils import set_z3_leaf_modules - if self.cfg.model_config_type in MOE_ARCH_BLOCK: - moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type] + moe_type = self.cfg.model_config_type_text or self.cfg.model_config_type + if moe_type in MOE_ARCH_BLOCK: + moe_blocks = MOE_ARCH_BLOCK[moe_type] moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks set_z3_leaf_modules( self.model, diff --git a/src/axolotl/loaders/processor.py b/src/axolotl/loaders/processor.py index e07e324d6..211c26060 100644 --- a/src/axolotl/loaders/processor.py +++ b/src/axolotl/loaders/processor.py @@ -55,12 +55,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): ) processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False - processor_kwargs["tokenizer"] = tokenizer processor = processor_cls.from_pretrained( cfg.processor_config, **processor_kwargs, ) + processor.tokenizer = tokenizer # Attempt to load image size from processor if available if ( diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index cad6039bd..9e48e73eb 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -57,6 +57,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "olmo3", "ministral", "ministral3", + "mistral4", "afmoe", ] diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index e8ca72aa1..b779abaa6 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -195,6 +195,15 @@ def normalize_config(cfg): cfg.model_config_type = model_config.model_type + # Resolve inner text backbone type for VLM wrappers (e.g. mistral3 -> mistral4) + if callable(getattr(model_config, "get_text_config", None)): + text_config = model_config.get_text_config() + if ( + hasattr(text_config, "model_type") + and text_config.model_type != model_config.model_type + ): + cfg.model_config_type_text = text_config.model_type + # figure out if the model is llama cfg.is_llama_derived_model = ( (