feat: add Mistral Small 4 (#3502)

* feat: add mistral small 4

* fix: update mistral common

* fix: deepcopy when passing in tokenizer

* feat: add doc on reasoning and thinking section

* fix: don't use custom tokenizer and quantize experts

* chore: update docs and configs

* chore: update doc to follow official name

* feat: update cce to include mistral4

* chore: move

* fix: naming

* fix: test mock breaking get_text_config check

* fix: enable CCE and add expert block targetting to configs

* chore: docs

* fix: use act checkpointing

* chore: doc

* chore: docs

* chore: docs
This commit is contained in:
NanoCode012
2026-03-17 09:39:05 +07:00
committed by GitHub
parent 7da5f94379
commit a098df527b
20 changed files with 417 additions and 14 deletions

View File

@@ -30,7 +30,7 @@
## 🎉 Latest Updates
- 2026/03:
- New model support has been added in Axolotl for [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- New model support has been added in Axolotl for [[Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
- 2026/02:
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.

View File

@@ -13,6 +13,7 @@ format:
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Mistral-Small-4](#sec-mistral-small-4)
- [Magistral-Small-2509](#sec-magistral-small-2509)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3)
@@ -108,6 +109,12 @@ Please make sure to install vision lib via `pip install 'mistral-common[opencv]=
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
```
### Mistral-Small-4 {#sec-mistral-small-4}
```yaml
base_model: mistralai/Mistral-Small-4-119B-2603
```
### Magistral-Small-2509 {#sec-magistral-small-2509}
::: {.callout-tip}

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe\""
]
},
{

View File

@@ -0,0 +1,85 @@
# Finetune Mistral Small 4 with Axolotl
Mistral Small 4 is a 119B parameter (6.5B active) multimodal MoE model from MistralAI that unifies instruct, reasoning, and coding capabilities into a single model. It is available on HuggingFace at [Mistral-Small-4-119B-2603](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603).
Thanks to the team at MistralAI for giving us early access to prepare for this release.
## Getting started
Note: Training this model requires weights in BF16 which we will link to later.
Users interested in training can convert / descale the existing FP8 weights.
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
3. Install transformers from main
```bash
pip install git+https://github.com/huggingface/transformers.git
```
4. Run one of the example configs:
```bash
# text-only
axolotl train examples/mistral4/qlora-text.yml # no experts ~69 GiB, experts ~93 GiB
axolotl train examples/mistral4/fft-text.yml
# text + vision
# run: wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
axolotl train examples/mistral4/qlora-vision.yml # no experts ~68 GiB
axolotl train examples/mistral4/fft-vision.yml
```
Note: FFT configs provided as reference. Please adjust hyperparameters as needed.
## Reasoning Effort
The chat template supports a `reasoning_effort` variable to control the model's reasoning depth:
- `"none"` — instruct mode (default)
- `"high"` — reasoning mode with explicit thinking steps
Pass it via `chat_template_kwargs` under your dataset config:
```yaml
datasets:
- path: your/dataset
type: chat_template
chat_template_kwargs:
reasoning_effort: high
```
## Thinking Support
The chat template supports a `thinking` content type in assistant messages for training on reasoning traces (rendered as `[THINK]...[/THINK]` blocks).
To use thinking datasets, add the `thinking` mapping via `message_property_mappings`:
```yaml
datasets:
- path: your/thinking-dataset
type: chat_template
message_property_mappings:
role: role
content: content
thinking: thinking
chat_template_kwargs:
reasoning_effort: high
```
See the [Magistral thinking guide](../magistral/think/README.md) for dataset format details.
## Tips
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
## Related Resources
- [MistralAI Mistral Small 4 Blog](https://mistral.ai/news/mistral-small-4)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,58 @@
base_model: mistralai/Mistral-Small-4-119B-2603
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_sonicmoe: true
# only train language model layers, freeze vision tower
unfrozen_parameters:
- model.language_model.*
- lm_head
- embed_tokens
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -0,0 +1,57 @@
base_model: mistralai/Mistral-Small-4-119B-2603
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_sonicmoe: true
# vision requirements
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 2048
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -0,0 +1,58 @@
base_model: mistralai/Mistral-Small-4-119B-2603
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# uncomment to train on expert layers
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# lora_mlp_kernel: false
# lora_qkv_kernel: false
# lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,63 @@
base_model: mistralai/Mistral-Small-4-119B-2603
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
# vision chat template requirements
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
sequence_len: 2048
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# uncomment to train on expert layers
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# lora_mlp_kernel: false
# lora_qkv_kernel: false
# lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -75,4 +75,4 @@ axolotl-contribs-mit==0.0.6
# telemetry
posthog==6.7.11
mistral-common==1.8.8
mistral-common==1.10.0

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"'
)

View File

@@ -16,6 +16,7 @@ MOE_ARCH_BLOCK = {
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"deepseek_v3": "DeepseekV3MoE",
"mistral4": "Mistral4MoE",
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE",

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"
```
## Usage
@@ -73,8 +73,10 @@ plugins:
- ministral3
- mistral
- mistral3
- mistral4
- mixtral
- mllama
- nemotron_h
- olmo
- olmo2
- olmo3

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`'
)

View File

@@ -25,6 +25,8 @@ SPARSE_MOE_BLOCK = {
"olmoe": "OlmoeSparseMoeBlock",
"mixtral": "MixtralSparseMoeBlock",
"minimax": "MiniMaxSparseMoeBlock",
# softmax -> topk routing (with group-based expert selection)
"mistral4": "Mistral4MoE",
# sigmoid -> topk routing (with group-based expert selection)
"glm_moe_dsa": "GlmMoeDsaMoE",
"deepseek_v3": "DeepseekV3MoE",

View File

@@ -61,9 +61,11 @@ class KernelsPlugin(BasePlugin):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(cfg.model_config_type)
self._kernelize_model(moe_model_type)
elif cfg.use_sonicmoe:
if not importlib.util.find_spec("sonicmoe"):
raise RuntimeError(
@@ -75,11 +77,9 @@ class KernelsPlugin(BasePlugin):
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
LOG.info(
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
)
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
patch_sonicmoe(
cfg.model_config_type,
moe_model_type,
torch_compile=bool(getattr(cfg, "torch_compile", False)),
)

View File

@@ -5,6 +5,7 @@ Different MoE architectures use different routing strategies:
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
- mistral4: softmax -> group selection -> topk (with renormalization and scaling)
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
@@ -45,6 +46,8 @@ def get_model_moe_config(model_type: str):
"minimax",
):
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in ("mistral4",):
return softmax_group_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in (
"glm_moe_dsa",
"deepseek_v3",
@@ -126,6 +129,62 @@ def softmax_topk_routing(
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def softmax_group_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
gate = moe_block.gate
T, H = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
scores_for_choice = router_probs
# Group selection: pick top groups, mask the rest
if n_group > 1:
group_scores = (
scores_for_choice.view(-1, n_group, E // n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
)
group_idx = torch.topk(
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
topk_weights = router_probs.gather(1, topk_indices)
# Renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob:
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
topk_weights = topk_weights * routed_scaling_factor
# Flatten for moe_general_routing_inputs
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def sigmoid_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

View File

@@ -829,8 +829,9 @@ class ModelLoader:
def _set_z3_leaf_modules(self):
from deepspeed.utils import set_z3_leaf_modules
if self.cfg.model_config_type in MOE_ARCH_BLOCK:
moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type]
moe_type = self.cfg.model_config_type_text or self.cfg.model_config_type
if moe_type in MOE_ARCH_BLOCK:
moe_blocks = MOE_ARCH_BLOCK[moe_type]
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
set_z3_leaf_modules(
self.model,

View File

@@ -55,12 +55,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
)
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
processor_kwargs["tokenizer"] = tokenizer
processor = processor_cls.from_pretrained(
cfg.processor_config,
**processor_kwargs,
)
processor.tokenizer = tokenizer
# Attempt to load image size from processor if available
if (

View File

@@ -57,6 +57,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"olmo3",
"ministral",
"ministral3",
"mistral4",
"afmoe",
]

View File

@@ -195,6 +195,15 @@ def normalize_config(cfg):
cfg.model_config_type = model_config.model_type
# Resolve inner text backbone type for VLM wrappers (e.g. mistral3 -> mistral4)
if callable(getattr(model_config, "get_text_config", None)):
text_config = model_config.get_text_config()
if (
hasattr(text_config, "model_type")
and text_config.model_type != model_config.model_type
):
cfg.model_config_type_text = text_config.model_type
# figure out if the model is llama
cfg.is_llama_derived_model = (
(