feat: add Mistral Small 4 (#3502)
* feat: add mistral small 4 * fix: update mistral common * fix: deepcopy when passing in tokenizer * feat: add doc on reasoning and thinking section * fix: don't use custom tokenizer and quantize experts * chore: update docs and configs * chore: update doc to follow official name * feat: update cce to include mistral4 * chore: move * fix: naming * fix: test mock breaking get_text_config check * fix: enable CCE and add expert block targetting to configs * chore: docs * fix: use act checkpointing * chore: doc * chore: docs * chore: docs
This commit is contained in:
@@ -30,7 +30,7 @@
|
|||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
- 2026/03:
|
- 2026/03:
|
||||||
- New model support has been added in Axolotl for [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
- New model support has been added in Axolotl for [[Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
||||||
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
|
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
|
||||||
- 2026/02:
|
- 2026/02:
|
||||||
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
|
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ format:
|
|||||||
- [Pixtral](#sec-pixtral)
|
- [Pixtral](#sec-pixtral)
|
||||||
- [Llava-1.5](#sec-llava-15)
|
- [Llava-1.5](#sec-llava-15)
|
||||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||||
|
- [Mistral-Small-4](#sec-mistral-small-4)
|
||||||
- [Magistral-Small-2509](#sec-magistral-small-2509)
|
- [Magistral-Small-2509](#sec-magistral-small-2509)
|
||||||
- [Voxtral](#sec-voxtral)
|
- [Voxtral](#sec-voxtral)
|
||||||
- [Gemma-3](#sec-gemma-3)
|
- [Gemma-3](#sec-gemma-3)
|
||||||
@@ -108,6 +109,12 @@ Please make sure to install vision lib via `pip install 'mistral-common[opencv]=
|
|||||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Mistral-Small-4 {#sec-mistral-small-4}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||||
|
```
|
||||||
|
|
||||||
### Magistral-Small-2509 {#sec-magistral-small-2509}
|
### Magistral-Small-2509 {#sec-magistral-small-2509}
|
||||||
|
|
||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
85
examples/mistral4/README.md
Normal file
85
examples/mistral4/README.md
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
# Finetune Mistral Small 4 with Axolotl
|
||||||
|
|
||||||
|
Mistral Small 4 is a 119B parameter (6.5B active) multimodal MoE model from MistralAI that unifies instruct, reasoning, and coding capabilities into a single model. It is available on HuggingFace at [Mistral-Small-4-119B-2603](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603).
|
||||||
|
|
||||||
|
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
Note: Training this model requires weights in BF16 which we will link to later.
|
||||||
|
Users interested in training can convert / descale the existing FP8 weights.
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||||
|
|
||||||
|
3. Install transformers from main
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install git+https://github.com/huggingface/transformers.git
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Run one of the example configs:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# text-only
|
||||||
|
axolotl train examples/mistral4/qlora-text.yml # no experts ~69 GiB, experts ~93 GiB
|
||||||
|
axolotl train examples/mistral4/fft-text.yml
|
||||||
|
|
||||||
|
# text + vision
|
||||||
|
# run: wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||||
|
axolotl train examples/mistral4/qlora-vision.yml # no experts ~68 GiB
|
||||||
|
axolotl train examples/mistral4/fft-vision.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: FFT configs provided as reference. Please adjust hyperparameters as needed.
|
||||||
|
|
||||||
|
## Reasoning Effort
|
||||||
|
|
||||||
|
The chat template supports a `reasoning_effort` variable to control the model's reasoning depth:
|
||||||
|
|
||||||
|
- `"none"` — instruct mode (default)
|
||||||
|
- `"high"` — reasoning mode with explicit thinking steps
|
||||||
|
|
||||||
|
Pass it via `chat_template_kwargs` under your dataset config:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
datasets:
|
||||||
|
- path: your/dataset
|
||||||
|
type: chat_template
|
||||||
|
chat_template_kwargs:
|
||||||
|
reasoning_effort: high
|
||||||
|
```
|
||||||
|
|
||||||
|
## Thinking Support
|
||||||
|
|
||||||
|
The chat template supports a `thinking` content type in assistant messages for training on reasoning traces (rendered as `[THINK]...[/THINK]` blocks).
|
||||||
|
|
||||||
|
To use thinking datasets, add the `thinking` mapping via `message_property_mappings`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
datasets:
|
||||||
|
- path: your/thinking-dataset
|
||||||
|
type: chat_template
|
||||||
|
message_property_mappings:
|
||||||
|
role: role
|
||||||
|
content: content
|
||||||
|
thinking: thinking
|
||||||
|
chat_template_kwargs:
|
||||||
|
reasoning_effort: high
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [Magistral thinking guide](../magistral/think/README.md) for dataset format details.
|
||||||
|
|
||||||
|
## Tips
|
||||||
|
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
- The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [MistralAI Mistral Small 4 Blog](https://mistral.ai/news/mistral-small-4)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
58
examples/mistral4/fft-text.yml
Normal file
58
examples/mistral4/fft-text.yml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_sonicmoe: true
|
||||||
|
|
||||||
|
# only train language model layers, freeze vision tower
|
||||||
|
unfrozen_parameters:
|
||||||
|
- model.language_model.*
|
||||||
|
- lm_head
|
||||||
|
- embed_tokens
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: false
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
57
examples/mistral4/fft-vision.yml
Normal file
57
examples/mistral4/fft-vision.yml
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_sonicmoe: true
|
||||||
|
|
||||||
|
# vision requirements
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: Nanobit/text-vision-2k-test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: false
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
58
examples/mistral4/qlora-text.yml
Normal file
58
examples/mistral4/qlora-text.yml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
# uncomment to train on expert layers
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
# lora_mlp_kernel: false
|
||||||
|
# lora_qkv_kernel: false
|
||||||
|
# lora_o_kernel: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
63
examples/mistral4/qlora-vision.yml
Normal file
63
examples/mistral4/qlora-vision.yml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
|
||||||
|
# vision chat template requirements
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: Nanobit/text-vision-2k-test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
# uncomment to train on expert layers
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
# lora_mlp_kernel: false
|
||||||
|
# lora_qkv_kernel: false
|
||||||
|
# lora_o_kernel: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
@@ -75,4 +75,4 @@ axolotl-contribs-mit==0.0.6
|
|||||||
# telemetry
|
# telemetry
|
||||||
posthog==6.7.11
|
posthog==6.7.11
|
||||||
|
|
||||||
mistral-common==1.8.8
|
mistral-common==1.10.0
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ MOE_ARCH_BLOCK = {
|
|||||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||||
"deepseek_v2": "DeepseekV2MoE",
|
"deepseek_v2": "DeepseekV2MoE",
|
||||||
"deepseek_v3": "DeepseekV3MoE",
|
"deepseek_v3": "DeepseekV3MoE",
|
||||||
|
"mistral4": "Mistral4MoE",
|
||||||
"gpt_oss": "GptOssDecoderLayer",
|
"gpt_oss": "GptOssDecoderLayer",
|
||||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||||
"afmoe": "AfmoeMoE",
|
"afmoe": "AfmoeMoE",
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -73,8 +73,10 @@ plugins:
|
|||||||
- ministral3
|
- ministral3
|
||||||
- mistral
|
- mistral
|
||||||
- mistral3
|
- mistral3
|
||||||
|
- mistral4
|
||||||
- mixtral
|
- mixtral
|
||||||
- mllama
|
- mllama
|
||||||
|
- nemotron_h
|
||||||
- olmo
|
- olmo
|
||||||
- olmo2
|
- olmo2
|
||||||
- olmo3
|
- olmo3
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ SPARSE_MOE_BLOCK = {
|
|||||||
"olmoe": "OlmoeSparseMoeBlock",
|
"olmoe": "OlmoeSparseMoeBlock",
|
||||||
"mixtral": "MixtralSparseMoeBlock",
|
"mixtral": "MixtralSparseMoeBlock",
|
||||||
"minimax": "MiniMaxSparseMoeBlock",
|
"minimax": "MiniMaxSparseMoeBlock",
|
||||||
|
# softmax -> topk routing (with group-based expert selection)
|
||||||
|
"mistral4": "Mistral4MoE",
|
||||||
# sigmoid -> topk routing (with group-based expert selection)
|
# sigmoid -> topk routing (with group-based expert selection)
|
||||||
"glm_moe_dsa": "GlmMoeDsaMoE",
|
"glm_moe_dsa": "GlmMoeDsaMoE",
|
||||||
"deepseek_v3": "DeepseekV3MoE",
|
"deepseek_v3": "DeepseekV3MoE",
|
||||||
|
|||||||
@@ -61,9 +61,11 @@ class KernelsPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.kernels.KernelsArgs"
|
return "axolotl.integrations.kernels.KernelsArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
|
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
|
||||||
|
|
||||||
if cfg.use_scattermoe:
|
if cfg.use_scattermoe:
|
||||||
self._register_kernels()
|
self._register_kernels()
|
||||||
self._kernelize_model(cfg.model_config_type)
|
self._kernelize_model(moe_model_type)
|
||||||
elif cfg.use_sonicmoe:
|
elif cfg.use_sonicmoe:
|
||||||
if not importlib.util.find_spec("sonicmoe"):
|
if not importlib.util.find_spec("sonicmoe"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -75,11 +77,9 @@ class KernelsPlugin(BasePlugin):
|
|||||||
|
|
||||||
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
|
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
|
||||||
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
|
|
||||||
)
|
|
||||||
patch_sonicmoe(
|
patch_sonicmoe(
|
||||||
cfg.model_config_type,
|
moe_model_type,
|
||||||
torch_compile=bool(getattr(cfg, "torch_compile", False)),
|
torch_compile=bool(getattr(cfg, "torch_compile", False)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Different MoE architectures use different routing strategies:
|
|||||||
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
|
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
|
||||||
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
|
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
|
||||||
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
|
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
|
||||||
|
- mistral4: softmax -> group selection -> topk (with renormalization and scaling)
|
||||||
|
|
||||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||||
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||||
@@ -45,6 +46,8 @@ def get_model_moe_config(model_type: str):
|
|||||||
"minimax",
|
"minimax",
|
||||||
):
|
):
|
||||||
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
|
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
|
||||||
|
elif model_type in ("mistral4",):
|
||||||
|
return softmax_group_topk_routing, ActivationType.SWIGLU, "gate"
|
||||||
elif model_type in (
|
elif model_type in (
|
||||||
"glm_moe_dsa",
|
"glm_moe_dsa",
|
||||||
"deepseek_v3",
|
"deepseek_v3",
|
||||||
@@ -126,6 +129,62 @@ def softmax_topk_routing(
|
|||||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_group_topk_routing(
|
||||||
|
hidden_states: torch.Tensor, moe_block
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
|
||||||
|
gate = moe_block.gate
|
||||||
|
T, H = hidden_states.shape
|
||||||
|
K = moe_block.top_k
|
||||||
|
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||||
|
n_group = getattr(moe_block, "n_group", 1)
|
||||||
|
|
||||||
|
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||||
|
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||||
|
|
||||||
|
scores_for_choice = router_probs
|
||||||
|
|
||||||
|
# Group selection: pick top groups, mask the rest
|
||||||
|
if n_group > 1:
|
||||||
|
group_scores = (
|
||||||
|
scores_for_choice.view(-1, n_group, E // n_group)
|
||||||
|
.topk(2, dim=-1)[0]
|
||||||
|
.sum(dim=-1)
|
||||||
|
)
|
||||||
|
group_idx = torch.topk(
|
||||||
|
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
|
||||||
|
)[1]
|
||||||
|
group_mask = torch.zeros_like(group_scores)
|
||||||
|
group_mask.scatter_(1, group_idx, 1)
|
||||||
|
score_mask = (
|
||||||
|
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
|
||||||
|
)
|
||||||
|
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||||
|
|
||||||
|
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
|
||||||
|
topk_weights = router_probs.gather(1, topk_indices)
|
||||||
|
|
||||||
|
# Renormalization + scaling
|
||||||
|
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
|
||||||
|
if norm_topk_prob:
|
||||||
|
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
|
||||||
|
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||||
|
topk_weights = topk_weights * routed_scaling_factor
|
||||||
|
|
||||||
|
# Flatten for moe_general_routing_inputs
|
||||||
|
token_indices = (
|
||||||
|
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(T, K)
|
||||||
|
)
|
||||||
|
|
||||||
|
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||||
|
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||||
|
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||||
|
|
||||||
|
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||||
|
|
||||||
|
|
||||||
def sigmoid_topk_routing(
|
def sigmoid_topk_routing(
|
||||||
hidden_states: torch.Tensor, moe_block
|
hidden_states: torch.Tensor, moe_block
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
|||||||
@@ -829,8 +829,9 @@ class ModelLoader:
|
|||||||
def _set_z3_leaf_modules(self):
|
def _set_z3_leaf_modules(self):
|
||||||
from deepspeed.utils import set_z3_leaf_modules
|
from deepspeed.utils import set_z3_leaf_modules
|
||||||
|
|
||||||
if self.cfg.model_config_type in MOE_ARCH_BLOCK:
|
moe_type = self.cfg.model_config_type_text or self.cfg.model_config_type
|
||||||
moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type]
|
if moe_type in MOE_ARCH_BLOCK:
|
||||||
|
moe_blocks = MOE_ARCH_BLOCK[moe_type]
|
||||||
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
|
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
|
||||||
set_z3_leaf_modules(
|
set_z3_leaf_modules(
|
||||||
self.model,
|
self.model,
|
||||||
|
|||||||
@@ -55,12 +55,12 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
|
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
|
||||||
processor_kwargs["tokenizer"] = tokenizer
|
|
||||||
|
|
||||||
processor = processor_cls.from_pretrained(
|
processor = processor_cls.from_pretrained(
|
||||||
cfg.processor_config,
|
cfg.processor_config,
|
||||||
**processor_kwargs,
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
processor.tokenizer = tokenizer
|
||||||
|
|
||||||
# Attempt to load image size from processor if available
|
# Attempt to load image size from processor if available
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"olmo3",
|
"olmo3",
|
||||||
"ministral",
|
"ministral",
|
||||||
"ministral3",
|
"ministral3",
|
||||||
|
"mistral4",
|
||||||
"afmoe",
|
"afmoe",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -195,6 +195,15 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
cfg.model_config_type = model_config.model_type
|
cfg.model_config_type = model_config.model_type
|
||||||
|
|
||||||
|
# Resolve inner text backbone type for VLM wrappers (e.g. mistral3 -> mistral4)
|
||||||
|
if callable(getattr(model_config, "get_text_config", None)):
|
||||||
|
text_config = model_config.get_text_config()
|
||||||
|
if (
|
||||||
|
hasattr(text_config, "model_type")
|
||||||
|
and text_config.model_type != model_config.model_type
|
||||||
|
):
|
||||||
|
cfg.model_config_type_text = text_config.model_type
|
||||||
|
|
||||||
# figure out if the model is llama
|
# figure out if the model is llama
|
||||||
cfg.is_llama_derived_model = (
|
cfg.is_llama_derived_model = (
|
||||||
(
|
(
|
||||||
|
|||||||
Reference in New Issue
Block a user