Compare commits

..

68 Commits

Author SHA1 Message Date
Dan Saunders
8564961423 fix compile 2025-09-19 13:59:57 -04:00
Dan Saunders
ce21da9177 fix compile 2025-09-19 13:55:54 -04:00
Dan Saunders
b5dc58373f fix compile 2025-09-19 13:52:42 -04:00
Dan Saunders
7327144344 compile 2025-09-19 13:41:12 -04:00
Dan Saunders
fb11f696e9 bench sweep 2025-09-19 13:24:40 -04:00
Dan Saunders
105c817b0b default fix 2025-09-19 16:59:20 +00:00
Dan Saunders
64345e7707 recurse fix 2025-09-19 12:58:58 -04:00
Dan Saunders
0f8b921399 contig 2025-09-19 12:47:53 -04:00
Dan Saunders
336616d659 defaults 2025-09-19 16:45:39 +00:00
Dan Saunders
d2f1e23bcd fix 2025-09-19 12:45:18 -04:00
Dan Saunders
42aadc5069 bench fix 2025-09-19 12:34:08 -04:00
Dan Saunders
1e7302d30a bench fix 2025-09-19 12:20:35 -04:00
Dan Saunders
63544ce709 fix 2025-09-19 11:34:27 -04:00
Dan Saunders
3bfed0aac8 shared expert detection 2025-09-19 11:24:26 -04:00
Dan Saunders
bfc848f81d bits and pieces 2025-09-19 02:12:57 +00:00
Dan Saunders
abe1cad6bc another bench 2025-09-18 13:45:19 -04:00
Dan Saunders
354389caef torchtitan bench 2025-09-18 13:29:20 -04:00
Dan Saunders
efcd032fce yet another refactor 2025-09-18 13:03:28 -04:00
Dan Saunders
7500641601 yet another refactor 2025-09-18 12:47:15 -04:00
Dan Saunders
0295df5bca precompute fuse 2025-09-18 12:10:46 -04:00
Dan Saunders
b39ef54833 combine mult 2025-09-18 12:08:03 -04:00
Dan Saunders
ad4cd39bcd remove contig 2025-09-18 11:55:15 -04:00
Dan Saunders
5c197275ad inplace 2025-09-18 11:51:17 -04:00
Dan Saunders
19c91e3675 refactor 2025-09-18 11:44:21 -04:00
Dan Saunders
2a176e4923 fix 2025-09-18 11:29:33 -04:00
Dan Saunders
7d867de9b2 refactor 2025-09-18 11:23:15 -04:00
Dan Saunders
01b6792c2e refactor 2025-09-18 11:20:08 -04:00
Dan Saunders
bbf1f14ca4 dtype issues 2025-09-17 23:52:18 +00:00
Dan Saunders
c6878beb7d simplify 2025-09-17 19:15:34 -04:00
Dan Saunders
e62979d11d fix 2025-09-17 18:53:07 -04:00
Dan Saunders
d57b9c67c2 log 2025-09-17 18:52:27 -04:00
Dan Saunders
eaaf16aa00 cumulative offsets 2025-09-17 18:45:15 -04:00
Dan Saunders
f3b953e222 fix? 2025-09-17 18:42:10 -04:00
Dan Saunders
7935dc0911 dtype fix 2025-09-17 18:36:22 -04:00
Dan Saunders
d2b49b2670 error msg 2025-09-17 18:29:30 -04:00
Dan Saunders
b5cb345ca4 fix test 2025-09-17 18:24:00 -04:00
Dan Saunders
03d4c2683e fix perf degradation 2025-09-17 18:20:37 -04:00
Dan Saunders
fd87eed501 minify 2025-09-17 16:42:35 -04:00
Dan Saunders
129db67705 fix 2025-09-17 16:24:29 -04:00
Dan Saunders
38b890a36b fix 2025-09-17 16:16:41 -04:00
Dan Saunders
180920c7bf simplify 2025-09-17 19:49:18 +00:00
Dan Saunders
d024048d74 logs + fix 2025-09-17 14:50:49 -04:00
Dan Saunders
98dc945838 fix 2025-09-17 14:42:53 -04:00
Dan Saunders
108600cd69 update config 2025-09-17 14:36:24 -04:00
Dan Saunders
0e9387c395 fix 2025-09-17 14:35:36 -04:00
Dan Saunders
db61e0d4ff fix 2025-09-17 14:26:25 -04:00
Dan Saunders
51e565f60a logs 2025-09-17 14:15:51 -04:00
Dan Saunders
c774dd0409 refactor + fix 2025-09-17 14:01:39 -04:00
Dan Saunders
7289e0cb55 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
8d483c11f7 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
9c1829cf57 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
135b09d1de logs, qwen2 support 2025-09-17 13:44:26 -04:00
Dan Saunders
de4344a56e patch 2025-09-17 13:44:26 -04:00
Dan Saunders
7d572b58d1 just grouped_mm for now 2025-09-17 13:44:26 -04:00
Dan Saunders
773d7e4291 update 2025-09-17 13:44:26 -04:00
Dan Saunders
fef47a5b7c hardening 2025-09-17 13:44:26 -04:00
Dan Saunders
f6ed8ddc01 fix 2025-09-17 13:44:26 -04:00
Dan Saunders
556d6448fe fix 2025-09-17 13:44:26 -04:00
Dan Saunders
5c2229721d diag 2025-09-17 13:44:26 -04:00
Dan Saunders
d7de6b0e96 grouped_mm 2025-09-17 13:44:26 -04:00
Dan Saunders
3c6648678f numerics 2025-09-17 13:44:26 -04:00
Dan Saunders
5b19a1ea9c improve 2025-09-17 13:44:26 -04:00
Dan Saunders
cfefad1eea fix 2025-09-17 13:44:26 -04:00
Dan Saunders
125e7b5fe6 fast path 2025-09-17 13:44:26 -04:00
Dan Saunders
479b6144df tflops 2025-09-17 13:44:26 -04:00
Dan Saunders
68da65cba2 update 2025-09-17 13:44:26 -04:00
Dan Saunders
0d689bb421 cache, example 2025-09-17 13:44:26 -04:00
Dan Saunders
43ada1278a moe kernels init scaffold 2025-09-17 13:44:26 -04:00
88 changed files with 8797 additions and 2009 deletions

View File

@@ -285,6 +285,7 @@ website:
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/moe_backends.md
- docs/nd_parallelism.qmd
- section: "Troubleshooting"

18
docs/moe_backends.md Normal file
View File

@@ -0,0 +1,18 @@
MoE Backends in Axolotl
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
- Set `moe_backend: auto|torch_grouped|naive`
Behavior
- auto (default): prefers PyTorch 2.8+ grouped GEMM; otherwise naive.
- torch_grouped: targets PyTorch 2.8+ grouped GEMM (H100/SM90+ recommended).
- naive: keeps the reference per-expert loop.
Notes
- Current implementation wires the backend selector and routes Mixtral MoE through it. Torch grouped uses cuBLASLt grouped GEMM when available; otherwise, the code falls back to the naive per-expert loop.
- No changes to training scripts are required; selection happens inside the model forward.
Example
moe_backend: torch_grouped
accelerate launch -m axolotl.cli.train path/to/config.yaml

View File

@@ -13,7 +13,6 @@ format:
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Magistral-Small-2509](#sec-magistral-small-2509)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
@@ -42,6 +41,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
# (optional) if doing lora, only finetune the Language model,
# leave the vision model and vision tower frozen
@@ -94,22 +94,10 @@ chat_template: llava
### Mistral-Small-3.1 {#sec-mistral-small-31}
::: {.callout-tip}
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
:::
```yaml
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
```
### Magistral-Small-2509 {#sec-magistral-small-2509}
::: {.callout-tip}
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
:::
```yaml
base_model: mistralai/Magistral-Small-2509
chat_template: mistral_v7_tekken
```
### Voxtral {#sec-voxtral}

View File

@@ -1,110 +0,0 @@
# Finetune Swiss-AI's Apertus with Axolotl
[Apertus](https://huggingface.co/collections/swiss-ai/apertus-llm-68b699e65415c231ace3b059) is a family of opensource models trained by Swiss-ai.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Apertus is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. (Optional, highly recommended) Install XIELU CUDA
```bash
## Recommended for reduced VRAM and faster speeds
# Point to CUDA toolkit directory
# For those using our Docker image, use the below path.
export CUDA_HOME=/usr/local/cuda
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
3. Run the finetuning example:
```bash
axolotl train examples/apertus/apertus-8b-qlora.yaml
```
This config uses about 8.7 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### Tips
- For inference, the official Apertus team recommends `top_p=0.9` and `temperature=0.8`.
- You can instead use full paremter fine-tuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
### XIELU Installation Issues
#### `ModuleNotFoundError: No module named 'torch'`
Please check these one by one:
- Running in correct environment
- Env has PyTorch installed
- CUDA toolkit is at `CUDA_HOME`
If those didn't help, please try the below solutions:
1. Pass env for CMAKE and try install again:
```bash
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
2. Git clone the repo and manually hardcode python path:
```bash
git clone https://github.com/nickjbrowning/XIELU
cd xielu
git checkout 59d6031
cd xielu
nano CMakeLists.txt # or vi depending on your preference
```
```diff
execute_process(
- COMMAND ${Python_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
+ COMMAND /root/miniconda3/envs/py3.11/bin/python -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT
OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT
ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR
)
```
```bash
pip3 install . --no-build-isolation --no-deps
```
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Apertus Tech Report](https://github.com/swiss-ai/apertus-tech-report/blob/main/Apertus_Tech_Report.pdf)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,64 +0,0 @@
base_model: swiss-ai/Apertus-8B-Instruct-2509
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -19,9 +19,6 @@ cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:

View File

@@ -9,6 +9,10 @@ strict: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\""
]
},
{

View File

@@ -9,6 +9,10 @@ strict: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -9,6 +9,10 @@ strict: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -18,7 +18,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out

View File

@@ -23,15 +23,7 @@ pip3 install timm==1.0.17
pip3 install librosa==0.11.0
```
3. Download sample dataset files
```bash
# for text + vision + audio only
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg
wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga
```
4. Run the finetuning example:
3. Run the finetuning example:
```bash
# text only

View File

@@ -12,6 +12,15 @@ chat_template: llama3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
roles:
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -46,6 +46,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -45,6 +45,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -1,10 +1,10 @@
# Finetune Magistral Small with Axolotl
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506), [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)), and [2509](https://huggingface.co/mistralai/Magistral-Small-2509) (see [Vision](#vision)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at [2506](https://huggingface.co/mistralai/Magistral-Small-2506) and [2507](https://huggingface.co/mistralai/Magistral-Small-2507) (see [Thinking](#thinking)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
Thanks to the team at MistralAI for giving us early access to prepare for these releases.
Thanks to the team at MistralAI for giving us early access to prepare for this release.
## Getting started
@@ -36,17 +36,29 @@ Let us know how it goes. Happy finetuning! 🚀
### Thinking
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.
MistralAI has released their [2507](https://huggingface.co/mistralai/Magistral-Small-2507) model with thinking capabilities. The model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
📚 **[See the Thinking fine-tuning guide →](./think/README.md)**
Example format:
### Vision
```json
{
"messages": [
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
{"role": "user", "content": [{ "type": "text", "text": "..."}]},
{"role": "assistant", "content": [{ "type": "thinking", "thinking": "..."}, { "type": "text", "text": "..." }]},
],
}
```
MistralAI has released their [2509](https://huggingface.co/mistralai/Magistral-Small-2509) model with vision capabilities.
Example config: `./magistral-small-think-qlora.yaml`.
📚 **[See the Vision fine-tuning guide →](./vision/README.md)**
The `thinking` section also supports an optional arg `closed: bool` (`True` default) which controls adding the closing `[/THINK]` tag.
### Tips
Limitations:
- You cannot mix `content: str` with `content: list[dict]` as the `dataset.load_dataset` may complain about different types for `content` key.
- This mode does not work with custom `train_detail` and `training` at the moment.
### TIPS
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
@@ -77,5 +89,5 @@ In addition, we do not support overriding tokens yet.
## Future Work
- Add parity to Preference Tuning, RL, etc.
- Add parity to Preference Tuning, RL, Multi-modal, etc.
- Add parity to other tokenizer configs like overriding tokens.

View File

@@ -1,73 +0,0 @@
# Magistral Small Thinking Fine-tuning
This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mistralai/Magistral-Small-2507) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl (see [main README](../README.md))
## Getting Started
Run the thinking model fine-tuning:
```bash
axolotl train magistral-small-think-qlora.yaml
```
This config uses about 19.1 GiB VRAM.
### Tips
- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.
- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.
## Dataset Format
The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
Example format:
```json
{
"messages": [
{
"role": "system",
"content": [
{ "type": "text", "text": "{SYSTEM_PROMPT}"}
]
},
{
"role": "user",
"content": [
{ "type": "text", "text": "Solve this step by step: What is 15% of 240?"}
]
},
{
"role": "assistant",
"content": [
{
"type": "thinking",
"thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36."
},
{
"type": "text",
"text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36."
}
]
}
]
}
```
### Advanced Options
The `thinking` section supports an optional `closed` parameter:
```json
{
"type": "thinking",
"thinking": "Internal reasoning here...",
"closed": true // Default: true, controls adding the closing [/THINK] tag
}
```

View File

@@ -1,60 +0,0 @@
# Magistral Small Vision Fine-tuning
This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mistralai/Magistral-Small-2509) with vision capabilities using Axolotl.
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl from source (see [main README](../README.md#getting-started))
## Getting started
1. Install the required vision lib:
```bash
pip install 'mistral-common[opencv]==1.8.5'
```
2. Download the example dataset image:
```bash
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
```
3. Run the fine-tuning:
```bash
axolotl train magistral-small-vision-24B-qlora.yml
```
This config uses about 17GiB VRAM.
WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
### Tips
Key differences from text-only model:
- `max_tokens: 131072` for inference
- Multi-modal dataset format required
- Sample packing not supported
## Dataset Format
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
Example:
```json
{
"messages": [
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
{"role": "user", "content": [
{ "type": "text", "text": "What's in this image?"},
{"type": "image", "path": "path/to/image.jpg" }
]},
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
],
}
```
## Limitations
- Sample Packing is not supported for multi-modality training currently.

View File

@@ -1,64 +0,0 @@
base_model: mistralai/Magistral-Small-2509
processor_type: AutoProcessor
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# sample dataset below requires downloading image in advance
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,9 +1,6 @@
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
processor_type: AutoProcessor
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
load_in_8bit: true
# these 3 lines are needed for now to handle vision chat templates w images
@@ -11,12 +8,12 @@ skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# sample dataset below requires downloading image in advance
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
chat_template: mistral_v7_tekken
datasets:
- path: Nanobit/text-vision-2k-test
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
@@ -51,7 +48,8 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
# flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -0,0 +1,53 @@
base_model: Qwen/Qwen1.5-MoE-A2.7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
# Keep VRAM low
load_in_8bit: false
load_in_4bit: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/qwen2-moe-qlora-10gb
# Train small to fit 10GB
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 5
flash_attention: true
warmup_ratio: 0.03
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
model_config:
output_router_logits: true
special_tokens:

View File

@@ -12,6 +12,15 @@ chat_template: phi_3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
roles:
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -45,7 +45,8 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
# flash_attention: # PixtralVisionModel does not support Flash Attention 2.0 yet
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -11,7 +11,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -11,7 +11,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -1,64 +0,0 @@
# Finetune Qwen3-Next with Axolotl
[Qwen3-Next](https://huggingface.co/collections/Qwen/qwen3-next-68c25fd6838e585db8eeea9d) represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Install Qwen3-Next transformers commit
```bash
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
```
3. Install FLA for improved performance
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
```
4. Run the finetuning example:
```bash
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
```
This config uses about 41.7 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. See [Multi-GPU](#optimization-guides) section below.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Qwen3-Next Blog](https://qwenlm.github.io/blog/qwen3_next/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,60 +0,0 @@
base_model: Qwen/Qwen3-Next-80B-A3B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 16
lora_alpha: 8
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -27,14 +27,7 @@ pip3 install 'mistral_common[audio]==1.8.3'
python scripts/cutcrossentropy_install.py | sh
```
3. Download sample dataset files
```bash
# for text + audio only
wget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga
```
4. Run the finetuning example:
3. Run the finetuning example:
```bash
# text only

View File

@@ -70,4 +70,4 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.5
mistral-common==1.8.5
mistral-common==1.8.3

209
scripts/bench_moe.py Normal file
View File

@@ -0,0 +1,209 @@
#!/usr/bin/env python
"""Benchmark Hugging Face Qwen2 MoE block with and without grouped_mm."""
from __future__ import annotations
import argparse
import sys
import time
import weakref
from pathlib import Path
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover
tg = None
def bench(run, *, iters: int, warmup: int, sync: bool = True) -> float:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for _ in range(warmup):
run()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times = []
for _ in range(iters):
if sync and device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
run()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def load_hf_block(
hidden: int,
inter: int,
experts: int,
top_k: int,
*,
device: torch.device,
dtype: torch.dtype,
):
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
cfg = Qwen2MoeConfig(
hidden_size=hidden,
moe_intermediate_size=inter,
shared_expert_intermediate_size=inter,
num_experts=experts,
num_experts_per_tok=top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block.state_dict())
return block, block_grouped
def main() -> None:
p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark")
p.add_argument("--bsz", type=int, default=8)
p.add_argument("--seq", type=int, default=1024)
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--experts", type=int, default=32)
p.add_argument("--top_k", type=int, default=4)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
p.add_argument("--profile", action="store_true")
p.add_argument(
"--compile",
action="store_true",
help="Torch.compile both paths before benchmarking",
)
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
block_naive, block_grouped = load_hf_block(
args.hidden,
args.inter,
args.experts,
args.top_k,
device=device,
dtype=dtype,
)
tokens = args.bsz * args.seq
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
f"experts={args.experts} top_k={args.top_k}"
)
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
# Optional torch.compile
run_grouped_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile naive failed ({exc}); using eager")
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp, block.gate, block.experts, block.top_k
)
return y
try:
run_grouped_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile grouped failed ({exc}); using eager")
run_grouped_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y
def run_grouped(block=block_grouped, data=x, impl=run_grouped_impl):
if impl is not None:
return impl(data)
if tg is None or not tg.available():
return torch.empty(0)
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(data, block.gate, block.experts, block.top_k)
return y if y is not None else torch.empty(0)
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
)
with torch.no_grad():
y_ref = run_naive()
if tg is None or not tg.available():
print("torch_grouped\tN/A (unavailable)")
return
y_grouped = run_grouped()
if y_grouped.numel() == 0:
print("torch_grouped\tN/A (op not callable)")
return
t_grouped = bench(run_grouped, iters=args.iters, warmup=args.warmup)
tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped
print(
f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
diff = (y_ref.float() - y_grouped.float()).abs()
print(
"torch_grouped_check: "
f"max_abs={diff.max().item():.3e} mean_abs={diff.mean().item():.3e} "
f"rel_l2={(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item():.3e}"
)
if args.profile:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
run_naive()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
run_grouped()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
if __name__ == "__main__":
main()

311
scripts/bench_moe_sweep.py Normal file
View File

@@ -0,0 +1,311 @@
#!/usr/bin/env python
"""Sweep grouped_mm vs naive performance for Qwen2 MoE block."""
from __future__ import annotations
import argparse
import csv
import sys
import time
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import List
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover
tg = None
def _parse_list(arg: str) -> List[int]:
return [int(v) for v in arg.split(",") if v]
def _bench(run, *, iters: int, warmup: int, device: torch.device) -> float:
for _ in range(warmup):
run()
if device.type == "cuda":
torch.cuda.synchronize()
times: List[float] = []
for _ in range(iters):
if device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
run()
if device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def _load_block(
hidden: int,
inter: int,
experts: int,
top_k: int,
*,
device: torch.device,
dtype: torch.dtype,
):
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
cfg = Qwen2MoeConfig(
hidden_size=hidden,
moe_intermediate_size=inter,
shared_expert_intermediate_size=inter,
num_experts=experts,
num_experts_per_tok=top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block.state_dict())
return block, block_grouped
@dataclass
class Result:
bsz: int
seq: int
hidden: int
inter: int
experts: int
top_k: int
dtype: str
naive_ms: float
grouped_ms: float
speedup: float
naive_tflops: float
grouped_tflops: float
max_abs: float
mean_abs: float
rel_l2: float
def main() -> None:
p = argparse.ArgumentParser(description="Grouped MoE sweep")
p.add_argument("--batch-sizes", default="4,8,16")
p.add_argument("--seq-lens", default="512,1024,2048")
p.add_argument("--hidden", default="2048,4096")
p.add_argument("--inter", default="5632,8192,14336")
p.add_argument("--experts", default="8,16,32")
p.add_argument("--top-k", default="1,2,4")
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=25)
p.add_argument("--warmup", type=int, default=5)
p.add_argument("--csv", type=Path, default=None)
p.add_argument("--compile", action="store_true")
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
if tg is None or not tg.available():
print("torch_grouped unavailable; sweep aborted")
return
bs_list = _parse_list(args.batch_sizes)
seq_list = _parse_list(args.seq_lens)
hidden_list = _parse_list(args.hidden)
inter_list = _parse_list(args.inter)
expert_list = _parse_list(args.experts)
topk_list = _parse_list(args.top_k)
results: List[Result] = []
print(
"bsz\tseq\thidden\tinter\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
)
for bsz in bs_list:
for seq in seq_list:
tokens = bsz * seq
for hidden in hidden_list:
for inter in inter_list:
for experts in expert_list:
for top_k in topk_list:
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
block_naive, block_grouped = _load_block(
hidden,
inter,
experts,
top_k,
device=device,
dtype=dtype,
)
x = torch.randn(
bsz, seq, hidden, device=device, dtype=dtype
)
compiled_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile naive failed ({exc}); using eager"
)
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = (
weakref.ref(block)
) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp,
block.gate,
block.experts,
block.top_k,
)
return y
try:
compiled_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile grouped failed ({exc}); using eager"
)
compiled_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y
def run_grouped(
block=block_grouped, data=x, impl=compiled_impl
):
if impl is not None:
return impl(data)
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
data,
block.gate,
block.experts,
block.top_k,
)
return y
naive_ms = _bench(
run_naive,
iters=args.iters,
warmup=args.warmup,
device=device,
)
y_naive = run_naive()
grouped_ms = _bench(
run_grouped,
iters=args.iters,
warmup=args.warmup,
device=device,
)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
res = Result(
bsz,
seq,
hidden,
inter,
experts,
top_k,
args.dtype,
naive_ms,
grouped_ms,
naive_ms / grouped_ms,
_estimate_flops(tokens, hidden, inter, top_k)
/ ((naive_ms / 1000.0) * 1e12),
_estimate_flops(tokens, hidden, inter, top_k)
/ ((grouped_ms / 1000.0) * 1e12),
diff.max().item(),
diff.mean().item(),
(
(
diff.pow(2).sum()
/ (y_naive.float().pow(2).sum() + 1e-12)
)
.sqrt()
.item()
),
)
results.append(res)
print(
f"{bsz}\t{seq}\t{hidden}\t{inter}\t{experts}\t{top_k}\t{res.naive_ms:.2f}\t"
f"{res.grouped_ms:.2f}\t{res.speedup:.2f}\t{res.naive_tflops:.2f}\t"
f"{res.grouped_tflops:.2f}\t{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
)
if args.csv:
fieldnames = [
"bsz",
"seq",
"hidden",
"inter",
"experts",
"top_k",
"dtype",
"naive_ms",
"grouped_ms",
"speedup",
"naive_tflops",
"grouped_tflops",
"max_abs",
"mean_abs",
"rel_l2",
]
with args.csv.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for r in results:
writer.writerow(
{
"bsz": r.bsz,
"seq": r.seq,
"hidden": r.hidden,
"inter": r.inter,
"experts": r.experts,
"top_k": r.top_k,
"dtype": r.dtype,
"naive_ms": f"{r.naive_ms:.4f}",
"grouped_ms": f"{r.grouped_ms:.4f}",
"speedup": f"{r.speedup:.4f}",
"naive_tflops": f"{r.naive_tflops:.4f}",
"grouped_tflops": f"{r.grouped_tflops:.4f}",
"max_abs": f"{r.max_abs:.6e}",
"mean_abs": f"{r.mean_abs:.6e}",
"rel_l2": f"{r.rel_l2:.6e}",
}
)
if __name__ == "__main__":
import weakref
main()

View File

@@ -0,0 +1,205 @@
#!/usr/bin/env python
"""Benchmark Torchtitan MoE grouped vs naive expert execution."""
from __future__ import annotations
import argparse
import sys
import time
from pathlib import Path
import torch
# Ensure torchtitan is importable when running from the axolotl tree
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
if str(_TITAN_PATH) not in sys.path:
sys.path.append(str(_TITAN_PATH))
from torchtitan.models.moe import MoE, MoEArgs
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Torchtitan MoE microbenchmark")
p.add_argument("--bsz", type=int, default=8)
p.add_argument("--seq", type=int, default=1024)
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--experts", type=int, default=8)
p.add_argument("--top_k", type=int, default=2)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
p.add_argument("--init-std", type=float, default=0.02)
p.add_argument(
"--score-before",
action="store_true",
help="Apply routing scores before expert computation (default: after)",
)
p.add_argument(
"--score-func",
choices=["softmax", "sigmoid"],
default="softmax",
)
p.add_argument(
"--route-norm",
action="store_true",
help="Enable Torchtitan router normalization when using sigmoid scores.",
)
return p.parse_args()
def _map_dtype(arg: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[arg]
def _estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
# Two up projections + one down projection per expert/token combination.
return 6.0 * tokens * top_k * hidden * inter
def _prepare_module(
moe: MoE,
*,
device: torch.device,
dtype: torch.dtype,
) -> MoE:
moe = moe.to(device=device)
for param in moe.parameters():
param.data = param.data.to(dtype)
if param.grad is not None:
param.grad = None
buffers = dict(moe.named_buffers())
for name, buf in buffers.items():
if name == "tokens_per_expert":
moe._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
elif name == "expert_bias" and buf is not None:
moe._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
else:
moe._buffers[name] = buf.to(device=device, dtype=dtype)
moe.eval()
return moe
@torch.inference_mode()
def _forward_fn(module: MoE, x: torch.Tensor) -> torch.Tensor:
return module(x)
def _bench(fn, *, iters: int, warmup: int, sync: bool = True) -> float:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
for _ in range(warmup):
fn()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times = []
for _ in range(iters):
if sync and device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
fn()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def main() -> None:
args = _parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = _map_dtype(args.dtype)
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
moe_args_grouped = MoEArgs(
num_experts=args.experts,
num_shared_experts=0,
score_func=args.score_func,
route_norm=args.route_norm,
top_k=args.top_k,
use_grouped_mm=True,
score_before_experts=args.score_before,
load_balance_coeff=None,
)
moe_grouped = MoE(moe_args_grouped, dim=args.hidden, hidden_dim=args.inter)
moe_grouped.init_weights(args.init_std, buffer_device=device)
moe_args_naive = MoEArgs(
num_experts=args.experts,
num_shared_experts=0,
score_func=args.score_func,
route_norm=args.route_norm,
top_k=args.top_k,
use_grouped_mm=False,
score_before_experts=args.score_before,
load_balance_coeff=None,
)
moe_naive = MoE(moe_args_naive, dim=args.hidden, hidden_dim=args.inter)
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
tokens = args.bsz * args.seq
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} "
f"inter={args.inter} experts={args.experts} top_k={args.top_k}"
)
def run_naive():
return _forward_fn(moe_naive, x)
def run_grouped():
return _forward_fn(moe_grouped, x)
if hasattr(moe_naive, "tokens_per_expert"):
moe_naive.tokens_per_expert.zero_()
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
t_naive = _bench(run_naive, iters=args.iters, warmup=args.warmup)
flops = _estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
tflops_naive = flops / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t"
f"{tflops_naive:.2f} TFLOP/s"
)
y_naive = run_naive()
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
t_grouped = _bench(run_grouped, iters=args.iters, warmup=args.warmup)
tflops_grouped = flops / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped if t_grouped > 0 else float("nan")
print(
f"grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
print(
f"grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,328 @@
#!/usr/bin/env python
"""Sweep Torchtitan MoE grouped vs naive configurations and report performance."""
from __future__ import annotations
import argparse
import csv
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List
import torch
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
if str(_TITAN_PATH) not in sys.path:
sys.path.append(str(_TITAN_PATH))
from torchtitan.models.moe import MoE, MoEArgs
def _parse_int_list(value: str) -> List[int]:
return [int(v) for v in value.split(",") if v]
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep")
p.add_argument(
"--batch-sizes", default="4,8,16", help="Comma separated batch sizes"
)
p.add_argument(
"--seq-lens", default="1024,2048", help="Comma separated sequence lengths"
)
p.add_argument(
"--experts", default="8,16,32,64", help="Comma separated expert counts"
)
p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices")
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=25)
p.add_argument("--warmup", type=int, default=5)
p.add_argument("--init-std", type=float, default=0.02)
p.add_argument("--score-before", action="store_true")
p.add_argument("--score-func", choices=["softmax", "sigmoid"], default="softmax")
p.add_argument("--route-norm", action="store_true")
p.add_argument("--csv", type=Path, default=None, help="Optional CSV output path")
return p.parse_args()
def _map_dtype(arg: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[arg]
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def _prepare_module(module: MoE, *, device: torch.device, dtype: torch.dtype) -> MoE:
module = module.to(device=device)
for param in module.parameters():
param.data = param.data.to(dtype)
if param.grad is not None:
param.grad = None
for name, buf in module.named_buffers():
if name == "tokens_per_expert":
module._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
elif name == "expert_bias" and buf is not None:
module._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
else:
module._buffers[name] = buf.to(device=device, dtype=dtype)
module.eval()
return module
@torch.inference_mode()
def _forward(module: MoE, x: torch.Tensor) -> torch.Tensor:
return module(x)
def _bench(callable_, *, iters: int, warmup: int, device: torch.device) -> float:
for _ in range(warmup):
callable_()
if device.type == "cuda":
torch.cuda.synchronize()
timings: List[float] = []
for _ in range(iters):
if device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
callable_()
if device.type == "cuda":
torch.cuda.synchronize()
timings.append((time.perf_counter() - start) * 1000.0)
return sum(timings) / len(timings)
@dataclass
class SweepResult:
bsz: int
seq: int
experts: int
top_k: int
dtype: str
naive_ms: float
grouped_ms: float
speedup: float
naive_tflops: float
grouped_tflops: float
max_abs: float
mean_abs: float
rel_l2: float
def _run_case(
*,
bsz: int,
seq: int,
experts: int,
top_k: int,
hidden: int,
inter: int,
dtype: torch.dtype,
device: torch.device,
iters: int,
warmup: int,
init_std: float,
score_before: bool,
score_func: str,
route_norm: bool,
) -> SweepResult:
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
moe_args_grouped = MoEArgs(
num_experts=experts,
num_shared_experts=0,
score_func=score_func,
route_norm=route_norm,
top_k=top_k,
use_grouped_mm=True,
score_before_experts=score_before,
load_balance_coeff=None,
)
moe_grouped = MoE(moe_args_grouped, dim=hidden, hidden_dim=inter)
moe_grouped.init_weights(init_std, buffer_device=device)
moe_args_naive = MoEArgs(
num_experts=experts,
num_shared_experts=0,
score_func=score_func,
route_norm=route_norm,
top_k=top_k,
use_grouped_mm=False,
score_before_experts=score_before,
load_balance_coeff=None,
)
moe_naive = MoE(moe_args_naive, dim=hidden, hidden_dim=inter)
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
x = torch.randn(bsz, seq, hidden, device=device, dtype=dtype)
def run_naive():
if hasattr(moe_naive, "tokens_per_expert"):
moe_naive.tokens_per_expert.zero_()
return _forward(moe_naive, x)
def run_grouped():
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
return _forward(moe_grouped, x)
naive_ms = _bench(run_naive, iters=iters, warmup=warmup, device=device)
y_naive = run_naive()
grouped_ms = _bench(run_grouped, iters=iters, warmup=warmup, device=device)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
tokens = bsz * seq
flops = _estimate_flops(tokens, hidden, inter, top_k)
naive_tflops = flops / ((naive_ms / 1000.0) * 1e12)
grouped_tflops = flops / ((grouped_ms / 1000.0) * 1e12)
speedup = naive_ms / grouped_ms if grouped_ms > 0 else float("nan")
return SweepResult(
bsz=bsz,
seq=seq,
experts=experts,
top_k=top_k,
dtype=str(dtype),
naive_ms=naive_ms,
grouped_ms=grouped_ms,
speedup=speedup,
naive_tflops=naive_tflops,
grouped_tflops=grouped_tflops,
max_abs=max_abs,
mean_abs=mean_abs,
rel_l2=rel_l2,
)
def _print_header(
hidden: int, inter: int, dtype: torch.dtype, device: torch.device
) -> None:
print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}")
print(
"bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
)
def _print_result(res: SweepResult) -> None:
print(
f"{res.bsz}\t{res.seq}\t{res.experts}\t{res.top_k}\t"
f"{res.naive_ms:.2f}\t{res.grouped_ms:.2f}\t{res.speedup:.2f}\t"
f"{res.naive_tflops:.2f}\t{res.grouped_tflops:.2f}\t"
f"{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
)
def _write_csv(path: Path, results: Iterable[SweepResult]) -> None:
fieldnames = [
"batch_size",
"seq_len",
"experts",
"top_k",
"dtype",
"naive_ms",
"grouped_ms",
"speedup",
"naive_tflops",
"grouped_tflops",
"max_abs",
"mean_abs",
"rel_l2",
]
with path.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for r in results:
writer.writerow(
{
"batch_size": r.bsz,
"seq_len": r.seq,
"experts": r.experts,
"top_k": r.top_k,
"dtype": r.dtype,
"naive_ms": f"{r.naive_ms:.4f}",
"grouped_ms": f"{r.grouped_ms:.4f}",
"speedup": f"{r.speedup:.4f}",
"naive_tflops": f"{r.naive_tflops:.4f}",
"grouped_tflops": f"{r.grouped_tflops:.4f}",
"max_abs": f"{r.max_abs:.6e}",
"mean_abs": f"{r.mean_abs:.6e}",
"rel_l2": f"{r.rel_l2:.6e}",
}
)
def main() -> None:
args = _parse_args()
dtype = _map_dtype(args.dtype)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch_sizes = _parse_int_list(args.batch_sizes)
seq_lens = _parse_int_list(args.seq_lens)
experts_list = _parse_int_list(args.experts)
top_ks = _parse_int_list(args.top_ks)
results: List[SweepResult] = []
_print_header(args.hidden, args.inter, dtype, device)
for bsz in batch_sizes:
for seq in seq_lens:
for experts in experts_list:
for top_k in top_ks:
try:
res = _run_case(
bsz=bsz,
seq=seq,
experts=experts,
top_k=top_k,
hidden=args.hidden,
inter=args.inter,
dtype=dtype,
device=device,
iters=args.iters,
warmup=args.warmup,
init_std=args.init_std,
score_before=args.score_before,
score_func=args.score_func,
route_norm=args.route_norm,
)
except RuntimeError as err:
print(
f"{bsz}\t{seq}\t{experts}\t{top_k}\tERROR: {err}",
file=sys.stderr,
)
continue
results.append(res)
_print_result(res)
if args.csv and results:
_write_csv(args.csv, results)
print(f"Wrote {len(results)} rows to {args.csv}")
if __name__ == "__main__":
main()

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
)

View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python
"""Inspect Qwen2 MoE expert implementations for grouped-mm debugging."""
from __future__ import annotations
import sys
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
sys.path.extend(
[
str(ROOT / "transformers" / "src"),
str(ROOT / "src"),
]
)
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from axolotl.kernels.moe.torch_grouped import _iter_expert_impls
def main() -> None:
cfg = Qwen2MoeConfig(
hidden_size=4096,
moe_intermediate_size=14336,
shared_expert_intermediate_size=14336,
num_experts=32,
num_experts_per_tok=4,
)
block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16)
experts = block.experts
experts._ax_parent_block = block
impls = _iter_expert_impls(experts)
print(f"impl count: {len(impls)}")
for idx, impl in enumerate(impls[:8]):
has_gate = hasattr(impl, "gate_proj")
has_up = hasattr(impl, "up_proj")
print(
f"impl[{idx}] type={impl.__class__.__name__} has_gate={has_gate} has_up={has_up}"
)
if has_gate:
print(f" gate shape {tuple(impl.gate_proj.weight.shape)}")
print(f" up shape {tuple(impl.up_proj.weight.shape)}")
print(f" down shape {tuple(impl.down_proj.weight.shape)}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python
"""
Probe PyTorch for grouped GEMM operator names and namespaces.
Run: python scripts/probe_torch_grouped_ops.py
"""
import sys
def main():
try:
import torch
except Exception as e:
print("Failed to import torch:", e)
sys.exit(1)
print("torch version:", torch.__version__)
namespaces = [n for n in dir(torch.ops) if not n.startswith("_")]
print("ops namespaces:", namespaces)
found_any = False
for ns in namespaces:
obj = getattr(torch.ops, ns, None)
ops = []
if obj is not None:
try:
ops = dir(obj)
except Exception as e:
print(f"warning: failed to list ops for namespace {ns}: {e}")
cands = [
o
for o in ops
if ("group" in o.lower())
or ("mm_grouped" in o.lower())
or ("matmul_grouped" in o.lower())
or ("grouped" in o.lower())
]
if cands:
found_any = True
print(f"namespace {ns} candidates:", cands)
if not found_any:
print("No grouped GEMM candidates found. PyTorch >= 2.8 is recommended.")
if __name__ == "__main__":
main()

View File

@@ -124,6 +124,7 @@ extras_require = {
"ring-flash-attn": [
"flash-attn==2.8.3",
"ring-flash-attn>=0.1.7",
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.17.5",

View File

@@ -120,11 +120,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
else:
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:
@@ -134,16 +129,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
# Handle when max_prompt_length == max_length from defaults
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
@@ -155,6 +144,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0
)
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.GRPO:
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))

View File

@@ -8,7 +8,7 @@ from typing import Any, Mapping
def chat_message_transform_builder(
train_on_inputs=False,
conversations_field: str = "messages",
conversations_field: str = "conversations",
message_field_role: str | list[str] | None = None, # commonly "role"
message_field_content: str | list[str] | None = None, # commonly "content"
message_field_training: str | list[str] | None = None, # commonly "weight"
@@ -20,13 +20,13 @@ def chat_message_transform_builder(
If True, the transform will train on the inputs. If False, the transform will train on the targets.
Defaults to False.
conversations_field (str, optional):
The field name of the conversations. Defaults to "messages".
The field name of the conversations. Defaults to "conversations".
message_field_role (str | list[str], optional):
The field name of the role.
The field name of the role. Defaults to "role".
message_field_content (str | list[str], optional):
The field name of the message content.
The field name of the message content. Defaults to "content".
message_field_training (str | list[str], optional):
The field name of the train/weight.
The field name of the train/weight. Defaults to "weight".
Returns:
Callable:

View File

@@ -27,6 +27,7 @@ class DPOStrategy:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
```
## Usage
@@ -65,7 +65,6 @@ plugins:
- qwen2_5_vl
- qwen3
- qwen3_moe
- qwen3_next
- smollm3
- seed_oss
- voxtral

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
)

View File

@@ -0,0 +1,3 @@
from .backends import MOEBackend, get_moe_backend_name
__all__ = ["get_moe_backend_name", "MOEBackend"]

View File

@@ -0,0 +1,47 @@
import warnings
from enum import Enum
class MOEBackend(str, Enum):
AUTO = "auto"
TORCH_GROUPED = "torch_grouped"
NAIVE = "naive"
def _probe_torch_grouped() -> bool:
try:
import torch # noqa: F401
# Prefer a simple version check; exact APIs may vary across 2.8+.
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
return ver >= (2, 8)
except Exception:
return False
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
"""
Resolve the desired MoE backend using, in order of precedence:
- explicit preferred argument (e.g., from config)
- auto detection
"""
choice = (preferred or "auto").lower()
try:
selected = MOEBackend(choice)
except ValueError:
warnings.warn(
f"Unknown moe backend '{choice}', falling back to auto", stacklevel=2
)
selected = MOEBackend.AUTO
if selected == MOEBackend.AUTO:
if _probe_torch_grouped():
return MOEBackend.TORCH_GROUPED
return MOEBackend.NAIVE
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
warnings.warn(
"torch_grouped requested but torch>=2.8 not detected; falling back to naive",
stacklevel=2,
)
return MOEBackend.NAIVE
return selected

View File

@@ -0,0 +1,371 @@
"""Minimal grouped GEMM fast path for MoE experts using PyTorch _grouped_mm."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
_LOGGER = logging.getLogger("axolotl.moe.grouped")
def available() -> bool:
try:
major, minor = map(int, torch.__version__.split("+")[0].split(".")[:2])
if (major, minor) < (2, 8):
return False
if not torch.cuda.is_available():
return False
sm, _ = torch.cuda.get_device_capability()
if sm < 9:
return False
return hasattr(torch.ops, "_grouped_mm")
except Exception:
return False
def _iter_expert_impls(
experts_module, visited: Optional[set[int]] = None
) -> List[torch.nn.Module]:
if visited is None:
visited = set()
module_id = id(experts_module)
if module_id in visited:
return []
visited.add(module_id)
impls: List[torch.nn.Module] = []
for exp in experts_module:
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
if hasattr(candidate, "gate_proj") and hasattr(candidate, "up_proj"):
impls.append(candidate)
continue
nested = getattr(candidate, "experts", None)
if nested is not None:
impls.extend(_iter_expert_impls(nested, visited))
continue
raise RuntimeError(
"torch_grouped: unable to resolve expert implementation for module"
)
return impls
@dataclass
class _GroupedWeightStorage:
pattern: str
gate: torch.Tensor
up: torch.Tensor
down: torch.Tensor
fused_gate_up: torch.Tensor
dtype: torch.dtype
device: torch.device
def _allocate_fused_gate_up(
num_experts: int,
gate_shape: torch.Size,
up_shape: torch.Size,
*,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if gate_shape[1] != up_shape[1]:
raise RuntimeError(
"torch_grouped: gate and up projections must share the hidden dimension"
)
fused = torch.empty(
(num_experts, gate_shape[0] + up_shape[0], gate_shape[1]),
device=device,
dtype=dtype,
)
gate_view = fused[:, : gate_shape[0]]
up_view = fused[:, gate_shape[0] : gate_shape[0] + up_shape[0]]
return fused, gate_view, up_view
def _ensure_grouped_weights(
experts_module, expert_impls: List[torch.nn.Module], sample_mod: torch.nn.Module
) -> _GroupedWeightStorage:
storage: Optional[_GroupedWeightStorage] = getattr(
experts_module, "_ax_grouped_storage", None
)
def _store(new_storage: _GroupedWeightStorage) -> _GroupedWeightStorage:
experts_module._ax_grouped_storage = new_storage
return new_storage
# Identify expert parameter layout
if (
hasattr(sample_mod, "w1")
and hasattr(sample_mod, "w3")
and hasattr(sample_mod, "w2")
):
pattern = "swi_glu"
num_experts = len(expert_impls)
w1_shape = sample_mod.w1.weight.shape
w3_shape = sample_mod.w3.weight.shape
w2_shape = sample_mod.w2.weight.shape
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == sample_mod.w1.weight.dtype
and storage.device == sample_mod.w1.weight.device
and storage.gate.shape[1:] == w1_shape
):
return storage
fused, gate, up = _allocate_fused_gate_up(
num_experts,
w1_shape,
w3_shape,
device=sample_mod.w1.weight.device,
dtype=sample_mod.w1.weight.dtype,
)
down = torch.empty(
(num_experts, *w2_shape),
device=sample_mod.w2.weight.device,
dtype=sample_mod.w2.weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate[idx].copy_(mod.w1.weight.detach())
up[idx].copy_(mod.w3.weight.detach())
down[idx].copy_(mod.w2.weight.detach())
mod.w1.weight.detach_()
mod.w1.weight.set_(gate[idx])
mod.w3.weight.detach_()
mod.w3.weight.set_(up[idx])
mod.w2.weight.detach_()
mod.w2.weight.set_(down[idx])
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=fused,
dtype=gate.dtype,
device=gate.device,
)
)
if hasattr(sample_mod, "gate_up_proj") and hasattr(sample_mod, "down_proj"):
pattern = "fused_gate_up"
gate_weight = sample_mod.gate_up_proj.weight
down_weight = sample_mod.down_proj.weight
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == gate_weight.dtype
and storage.device == gate_weight.device
and storage.gate.shape[1:]
== (gate_weight.shape[0] // 2, gate_weight.shape[1])
):
return storage
num_experts = len(expert_impls)
gate_full = torch.empty(
(num_experts, *gate_weight.shape),
device=gate_weight.device,
dtype=gate_weight.dtype,
)
down = torch.empty(
(num_experts, *down_weight.shape),
device=down_weight.device,
dtype=down_weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate_full[idx].copy_(mod.gate_up_proj.weight.detach())
down[idx].copy_(mod.down_proj.weight.detach())
mod.gate_up_proj.weight.detach_()
mod.gate_up_proj.weight.set_(gate_full[idx])
mod.down_proj.weight.detach_()
mod.down_proj.weight.set_(down[idx])
inter = gate_weight.shape[0] // 2
gate = gate_full[:, :inter]
up = gate_full[:, inter:]
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=gate_full,
dtype=gate.dtype,
device=gate.device,
)
)
if (
hasattr(sample_mod, "up_proj")
and hasattr(sample_mod, "gate_proj")
and hasattr(sample_mod, "down_proj")
):
pattern = "dual_proj"
up_weight = sample_mod.up_proj.weight
gate_weight = sample_mod.gate_proj.weight
down_weight = sample_mod.down_proj.weight
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == sample_mod.up_proj.weight.dtype
and storage.device == sample_mod.up_proj.weight.device
and storage.gate.shape[1:] == gate_weight.shape
):
return storage
num_experts = len(expert_impls)
fused, gate, up = _allocate_fused_gate_up(
num_experts,
gate_weight.shape,
up_weight.shape,
device=gate_weight.device,
dtype=gate_weight.dtype,
)
down = torch.empty(
(num_experts, *down_weight.shape),
device=down_weight.device,
dtype=down_weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate[idx].copy_(mod.gate_proj.weight.detach())
up[idx].copy_(mod.up_proj.weight.detach())
down[idx].copy_(mod.down_proj.weight.detach())
mod.up_proj.weight.detach_()
mod.up_proj.weight.set_(up[idx])
mod.gate_proj.weight.detach_()
mod.gate_proj.weight.set_(gate[idx])
mod.down_proj.weight.detach_()
mod.down_proj.weight.set_(down[idx])
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=fused,
dtype=gate.dtype,
device=gate.device,
)
)
raise RuntimeError(
"torch_grouped: unsupported expert module layout for grouped weights"
)
def moe_ffn_forward_grouped(
hidden_states: torch.Tensor,
gate_linear: torch.nn.Linear,
experts_module,
top_k: int,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if not available():
return None, None
bsz, seqlen, hdim = hidden_states.shape
tokens = bsz * seqlen
device = hidden_states.device
routing_dtype = gate_linear.weight.dtype
expert_dtype = hidden_states.dtype
if expert_dtype not in (torch.bfloat16, torch.float16):
_LOGGER.debug(
"torch_grouped: unsupported expert dtype %s; falling back to naive",
expert_dtype,
)
return None, None
parent_block = None
parent_ref = getattr(experts_module, "_ax_parent_block_ref", None)
if parent_ref is not None:
try:
parent_block = parent_ref()
except TypeError:
parent_block = None
expert_container = getattr(experts_module, "experts", experts_module)
expert_impls = _iter_expert_impls(expert_container)
sample_mod = expert_impls[0]
storage = _ensure_grouped_weights(expert_container, expert_impls, sample_mod)
w_gate = storage.gate
w_up = storage.up
w2 = storage.down
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
router_logits = gate_linear(x_flat.to(routing_dtype))
shared_out_flat: Optional[torch.Tensor] = None
shared_owner = parent_block if parent_block is not None else experts_module
if hasattr(shared_owner, "shared_expert"):
shared_expert = shared_owner.shared_expert
shared_out_flat = shared_expert(x_flat)
shared_out_flat = shared_out_flat.to(expert_dtype)
shared_gate = getattr(shared_owner, "shared_expert_gate", None)
if shared_gate is not None:
gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype))
gate_vals = torch.sigmoid(gate_input)
shared_out_flat.mul_(gate_vals.to(expert_dtype))
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
flat_idx = topk_idx.view(-1)
num_experts = len(expert_impls)
if flat_idx.numel() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
sorted_experts, perm = torch.sort(flat_idx)
assignments = torch.bincount(sorted_experts, minlength=num_experts)
if assignments.sum() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous()
scores_sorted = topk_weight.reshape(-1).index_select(0, perm)
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
routed_input = torch.gather(x_flat, 0, gather_index)
counts_i32 = assignments.to(device=device, dtype=torch.int32)
offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32)
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
routed_in = routed_input.to(mm_dtype)
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
w_up_t = w_up.transpose(-2, -1).to(mm_dtype)
w2_t = w2.transpose(-2, -1).to(mm_dtype)
routed_in = routed_in.contiguous()
w_gate_t = w_gate_t.contiguous()
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
torch.ops.aten.silu_(gate_out)
w_up_t = w_up_t.contiguous()
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
gate_out.mul_(up_out)
gate_out = gate_out.contiguous()
w2_t = w2_t.contiguous()
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype)
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
down_out.mul_(weights)
combined = torch.zeros_like(x_flat)
combined.scatter_add_(0, gather_index, down_out)
output = combined.view(bsz, seqlen, hdim)
if shared_out_flat is not None:
output = output + shared_out_flat.view(bsz, seqlen, hdim)
return output, router_logits

View File

@@ -12,6 +12,7 @@ import transformers
from transformers import PretrainedConfig, PreTrainedModel
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.moe_grouped import apply_grouped_to_moe_blocks
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
@@ -57,6 +58,8 @@ class PatchManager:
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
# Apply MoE grouped GEMM patches (cfg.moe_backend)
apply_grouped_to_moe_blocks(self.cfg)
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches()
self._apply_gradient_checkpointing_patches()
@@ -68,12 +71,11 @@ class PatchManager:
self._apply_self_attention_lora_patch()
self._apply_fsdp2_bnb_patches()
self._apply_patch_deepspeed_zero3()
self._apply_voxtral_patches()
self._apply_apertus_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_voxtral_patches()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
@@ -84,15 +86,6 @@ class PatchManager:
patch_evaluation_loop()
patch_maybe_log_save_evaluate()
if self.cfg.context_parallel_size > 1 and getattr(
self.cfg, "flash_attention", False
):
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
patch_prepare_context_parallel_inputs,
)
patch_prepare_context_parallel_inputs()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
@@ -178,20 +171,6 @@ class PatchManager:
patch_llama4_linearized_modeling()
if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_modeling_packing,
)
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
apply_mistral_tokenizer_image_patch,
)
apply_mistral_tokenizer_image_patch()
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:
@@ -293,6 +272,7 @@ class PatchManager:
self.cfg.model_config_type,
model_name=self.cfg.base_model,
has_remote_code=has_remote_code,
cfg=self.cfg,
)
if self.cfg.sample_packing:
@@ -358,13 +338,6 @@ class PatchManager:
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
if self.model_config.model_type in ("mistral3", "llava"):
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
apply_patch_is_packed_sequence,
)
apply_patch_is_packed_sequence()
def _patch_loss_llama(self):
"""Patch loss functions and other optimizations for LLaMA models."""
if not self.cfg.is_llama_derived_model:
@@ -510,12 +483,3 @@ class PatchManager:
apply_deepspeed_patches()
except ImportError as e:
LOG.warning(f"DeepSpeed patches not applied: {e}")
def _apply_apertus_patches(self):
"""Apply patches for Apertus model."""
if self.cfg.model_config_type == "apertus":
from axolotl.monkeypatch.models.apertus.activation import (
patch_apertus_xielu_activation,
)
patch_apertus_xielu_activation()

View File

@@ -21,13 +21,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
if cfg.tokenizer_use_mistral_common:
from axolotl.utils.mistral import Mistral3Processor
return Mistral3Processor(
tokenizer=tokenizer,
)
processor = processor_cls.from_pretrained(
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,

View File

@@ -124,8 +124,13 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
def _load_mistral_common_tokenizer(cfg: DictDefault):
"""Load mistral-common tokenizer"""
from transformers import tokenization_mistral_common
from axolotl.utils.mistral import HFMistralTokenizer
# patch
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
# Load the HF-compatible wrapper around MistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)

View File

@@ -5,9 +5,14 @@ Patches to support multipack for mixtral
import torch
def patch_mixtral_moe_forward_zero3() -> None:
def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
import warnings
import torch.nn.functional as F
from axolotl.kernels.moe import backends as _moe_backends
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
def mlp_forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
hidden_states
@@ -21,21 +26,32 @@ def patch_mixtral_moe_forward_zero3() -> None:
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if (
backend == MOEBackend.TORCH_GROUPED
and not _moe_backends._probe_torch_grouped()
):
warnings.warn(
"torch_grouped selected but not available; falling back to naive",
stacklevel=2,
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(
routing_weights, self.top_k, dim=-1, sorted=False
)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states)
hidden_states_rep = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states_rep)
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
sel = flat_topk_idx == i
if sel.any():
y[sel] = expert(hidden_states_rep[sel])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
@@ -46,4 +62,23 @@ def patch_mixtral_moe_forward_zero3() -> None:
)
MixtralBlockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward
# Wrap forward to support optional torch_grouped backend via config
from axolotl.kernels.moe import torch_grouped as _tg
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if backend == MOEBackend.TORCH_GROUPED and _tg.available():
def moe_forward_grouped(self, hidden_states: torch.Tensor) -> torch.Tensor:
bsz, seqlen, hdim = hidden_states.shape
y, router_logits = _tg.moe_ffn_forward_grouped(
hidden_states, self.gate, self.experts, self.top_k
)
if y is None:
return moe_forward(self, hidden_states)
return y, router_logits
MixtralSparseMoeBlock.forward = moe_forward_grouped
else:
MixtralSparseMoeBlock.forward = moe_forward

View File

@@ -1,52 +0,0 @@
"""Monkeypatch for Apertus to dtype mismatch in XIELU act"""
from torch import Tensor
def patch_apertus_xielu_activation():
try:
from transformers.activations import XIELUActivation
except ImportError as err:
raise ImportError(
"Cannot import XIELUActivation. "
"Please make sure to update your transformers version >= 4.56.1."
) from err
from transformers.activations import logger
# Store the original method
old_fn = XIELUActivation._xielu_cuda
def _xielu_cuda_fixed(self, x: Tensor) -> Tensor:
"""Firewall function to prevent torch.compile from seeing .item() calls"""
original_shape = x.shape
# CUDA kernel expects 3D tensors, reshape if needed
while x.dim() < 3:
x = x.unsqueeze(0)
if x.dim() > 3:
x = x.view(-1, 1, x.size(-1))
if original_shape != x.shape:
logger.warning_once(
"Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
original_shape,
x.shape,
)
result = self._xielu_cuda_obj.forward(
x,
self.alpha_p.to(x.dtype),
self.alpha_n.to(x.dtype),
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
self._beta_scalar,
self._eps_scalar,
self.with_vector_loads,
)
return result.view(original_shape)
# Apply the patch
XIELUActivation._xielu_cuda = _xielu_cuda_fixed
def unpatch():
"""Restore the original method"""
XIELUActivation._xielu_cuda = old_fn
return unpatch

View File

@@ -1,85 +0,0 @@
"""
Monkeypatch to fix inefficient tensor conversion in MistralCommonTokenizer.apply_chat_template
"""
import importlib
import inspect
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def apply_mistral_tokenizer_image_patch():
"""Apply patch to MistralCommonTokenizer.apply_chat_template to fix image tensor conversion."""
from transformers.tokenization_mistral_common import MistralCommonTokenizer
# Get original source
original_source = inspect.getsource(MistralCommonTokenizer.apply_chat_template)
original_source, _ = detab_code(original_source)
# Define the replacement
original_tensor_conversion = (
" pixel_values = torch.tensor(images)"
)
patched_tensor_conversion = """ if isinstance(images, list) and len(images) > 0 and isinstance(images[0], np.ndarray):
pixel_values = torch.tensor(np.array(images))
else:
pixel_values = torch.tensor(images)"""
# Apply the replacement
if original_tensor_conversion in original_source:
patched_source = original_source.replace(
original_tensor_conversion, patched_tensor_conversion
)
patched_source = patched_source.replace(
"def apply_chat_template(",
"def patched_apply_chat_template(",
1,
)
# Load necessary imports from the module
module_name = MistralCommonTokenizer.__module__
module = importlib.import_module(module_name)
# Detect what needs to be imported
items_to_import = []
for item in dir(module):
if item in patched_source and not item.startswith("_"):
items_to_import.append(item)
# Execute imports in global scope
if items_to_import:
exec( # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
# Also need standard imports that might be used
exec("import numpy as np", globals()) # nosec B102
exec("import torch", globals()) # nosec B102
exec("from typing import Union, Optional, List, Dict, Any, Callable", globals()) # nosec B102
exec("from pathlib import Path", globals()) # nosec B102
# Import other dependencies that might be needed
try:
exec("from transformers.utils import is_torch_available", globals()) # nosec B102
exec(
"from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType",
globals(),
) # nosec B102
exec("from transformers.utils import logging", globals()) # nosec B102
exec("logger = logging.get_logger(__name__)", globals()) # nosec B102
except ImportError as e:
LOG.warning(f"Could not import some dependencies: {e}")
# Execute the patched source
exec(patched_source, globals()) # nosec B102
# Replace the method
MistralCommonTokenizer.apply_chat_template = patched_apply_chat_template
LOG.info("Successfully applied MistralCommonTokenizer tensor conversion patch")
else:
LOG.warning("Could not find target code for MistralCommonTokenizer patching")

View File

@@ -1,42 +0,0 @@
"""Monkeypatch for FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid"""
import torch
def apply_patch_is_packed_sequence():
"""Apply patch to FA utils to accept 1D position_ids from Pixtral's position_ids_in_meshgrid"""
from transformers import modeling_flash_attention_utils
def fixed_is_packed_sequence(position_ids, batch_size):
"""
Check the position ids whether packed sequences are indicated or not
1. Position ids exist
2. Flattened sequences only are supported
3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
"""
if position_ids is None:
return False
if position_ids.ndim == 1:
position_ids = position_ids.unsqueeze(0) # [N] -> [1, N]
increasing_position_sequences = (
torch.arange(position_ids.shape[1], device=position_ids.device)
+ position_ids.min()
)
return (
batch_size == 1
and (increasing_position_sequences - position_ids).abs().sum().bool().item()
)
# Store original method
old_fn = modeling_flash_attention_utils._is_packed_sequence
# Apply the patch
modeling_flash_attention_utils._is_packed_sequence = fixed_is_packed_sequence
def unpatch():
"""Restore the original method"""
modeling_flash_attention_utils._is_packed_sequence = old_fn
return unpatch

View File

@@ -1 +0,0 @@
"""Qwen3_Next model monkeypatches."""

View File

@@ -1,317 +0,0 @@
"""Monkeypatch for Qwen3_Next model to pass position_ids to linear attention."""
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def get_cu_seqlens(position_ids):
"""
Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids.
https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316
"""
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
position_ids = position_ids.view(-1)
indices_q = (position_ids == 0).nonzero().view(-1)
cu_seq_lens_q = torch.cat(
(
indices_q.to(**tensor_kwargs),
torch.tensor(position_ids.size(), **tensor_kwargs),
)
)
return cu_seq_lens_q
def patch_qwen3_next_decoder_layer():
"""Patch Qwen3NextDecoderLayer to pass position_ids to linear attention."""
try:
from transformers.models.qwen3_next.modeling_qwen3_next import (
Qwen3NextDecoderLayer,
)
except ImportError:
LOG.warning("Qwen3Next model not found, skipping patch")
return
# Store original forward method
original_decoder_forward = Qwen3NextDecoderLayer.forward
def patched_decoder_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Token Mixer
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states=hidden_states,
cache_params=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
position_ids=position_ids,
)
elif self.layer_type == "full_attention":
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
# For the MoE layers, we need to unpack
if isinstance(hidden_states, Tuple):
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states
return hidden_states
# Apply the patches
Qwen3NextDecoderLayer.forward = patched_decoder_forward
def unpatch():
"""Restore the original forward method"""
Qwen3NextDecoderLayer.forward = original_decoder_forward
return unpatch
def patch_qwen3_next_gateddelta_layer():
"""Patch Qwen3NextGatedDeltaNet to parse cu_seqlens and pass to chunk_gated_delta_rule"""
try:
from transformers.models.qwen3_next.modeling_qwen3_next import (
Qwen3NextDynamicCache,
Qwen3NextGatedDeltaNet,
apply_mask_to_padding_states,
)
except ImportError:
LOG.warning("Qwen3Next model not found, skipping patch")
return
# Store original forward method
original_gated_delta_net_forward = Qwen3NextGatedDeltaNet.forward
def patched_gated_delta_net_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[Qwen3NextDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
):
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
# Set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and seq_len == 1
and cache_position is not None
)
# getting projected states from cache if it exists
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
recurrent_state = cache_params.recurrent_states[self.layer_idx]
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
projected_states_ba = self.in_proj_ba(hidden_states)
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba
)
query, key, value = (
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)
mixed_qkv = mixed_qkv.transpose(1, 2)
if use_precomputed_states:
# 2. Convolution sequence transformation
# NOTE: the conv state is updated in `causal_conv1d_update`
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
else:
if cache_params is not None:
conv_state = F.pad(
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
)
cache_params.conv_states[self.layer_idx] = conv_state
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
)
else:
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
mixed_qkv,
[
self.key_dim,
self.key_dim,
self.value_dim,
],
dim=-1,
)
query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)
key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)
value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)
beta = b.sigmoid()
# If the model is loaded in fp16, without the .float() here, A might be -inf
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if not use_precomputed_states:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
)
else:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
# Update cache
if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
z_shape_og = z.shape
# reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = core_attn_out.reshape(
core_attn_out.shape[0], core_attn_out.shape[1], -1
)
output = self.out_proj(core_attn_out)
return output
# Apply the patches
Qwen3NextGatedDeltaNet.forward = patched_gated_delta_net_forward
def unpatch():
"""Restore the original forward method"""
Qwen3NextGatedDeltaNet.forward = original_gated_delta_net_forward
return unpatch
def patch_qwen3_next_imports():
"""Patch Qwen3Next imports to use try/except instead of is_flash_linear_attention_available."""
try:
import transformers.models.qwen3_next.modeling_qwen3_next as qwen3_modeling
except ImportError:
LOG.warning("Qwen3Next model not found, skipping import patch")
return
# Save original values for unpatch
original_FusedRMSNormGated = getattr(qwen3_modeling, "FusedRMSNormGated", None)
original_chunk_gated_delta_rule = getattr(
qwen3_modeling, "chunk_gated_delta_rule", None
)
original_fused_recurrent_gated_delta_rule = getattr(
qwen3_modeling, "fused_recurrent_gated_delta_rule", None
)
original_is_fast_path_available = getattr(
qwen3_modeling, "is_fast_path_available", False
)
try:
from fla.modules import FusedRMSNormGated
from fla.ops.gated_delta_rule import (
chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule,
)
qwen3_modeling.FusedRMSNormGated = FusedRMSNormGated
qwen3_modeling.chunk_gated_delta_rule = chunk_gated_delta_rule
qwen3_modeling.fused_recurrent_gated_delta_rule = (
fused_recurrent_gated_delta_rule
)
# Force is_fast_path_available to be True
# fla has triton kernels for causal_conv1d
qwen3_modeling.is_fast_path_available = True
except ImportError:
qwen3_modeling.chunk_gated_delta_rule = None
qwen3_modeling.fused_recurrent_gated_delta_rule = None
qwen3_modeling.FusedRMSNormGated = None
def unpatch():
"""Restore the original import values"""
qwen3_modeling.FusedRMSNormGated = original_FusedRMSNormGated
qwen3_modeling.chunk_gated_delta_rule = original_chunk_gated_delta_rule
qwen3_modeling.fused_recurrent_gated_delta_rule = (
original_fused_recurrent_gated_delta_rule
)
qwen3_modeling.is_fast_path_available = original_is_fast_path_available
return unpatch
def patch_qwen3_next_modeling_packing():
"""Apply all Qwen3Next model patches."""
patch_qwen3_next_imports()
patch_qwen3_next_decoder_layer()
patch_qwen3_next_gateddelta_layer()
LOG.info("Applied Qwen3Next patch for packing")

View File

@@ -0,0 +1,133 @@
import logging
import weakref
from functools import wraps
import torch
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
_LOG = logging.getLogger("axolotl.moe.patch")
def _patch_block_forward(block_cls, grouped_fn):
"""Replace block_cls.forward with grouped_fn preserving signature."""
block_cls.forward = grouped_fn
def apply_grouped_to_moe_blocks(cfg=None) -> None:
"""
Attempt to patch all known MoE block classes to use the torch_grouped backend
when cfg.moe_backend resolves to 'torch_grouped' and the op is available.
Falls back to original forwards otherwise.
"""
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if backend != MOEBackend.TORCH_GROUPED:
_LOG.info(
f"moe_backend is '{backend}', not 'torch_grouped'; skipping grouped patches"
)
return
try:
from axolotl.kernels.moe import torch_grouped as _tg
except Exception:
_LOG.warning("torch_grouped backend import failed; skipping grouped patches")
return
if not _tg.available():
_LOG.warning(
"torch_grouped requested but unavailable (op smoke test failed); skipping grouped patches"
)
return
# Map of architecture key to (modeling module path, class name or list of class names)
model_mods = {
"mixtral": (
"transformers.models.mixtral.modeling_mixtral",
MOE_ARCH_BLOCK.get("mixtral"),
),
"qwen2_moe": (
"transformers.models.qwen2_moe.modeling_qwen2_moe",
MOE_ARCH_BLOCK.get("qwen2_moe"),
),
"qwen3_moe": (
"transformers.models.qwen3_moe.modeling_qwen3_moe",
MOE_ARCH_BLOCK.get("qwen3_moe"),
),
"jamba": (
"transformers.models.jamba.modeling_jamba",
MOE_ARCH_BLOCK.get("jamba"),
),
"deepseek_v2": (
"transformers.models.deepseek_v2.modeling_deepseek_v2",
MOE_ARCH_BLOCK.get("deepseek_v2"),
),
# Others may not follow standard paths; best-effort import
"dbrx": ("transformers.models.dbrx.modeling_dbrx", MOE_ARCH_BLOCK.get("dbrx")),
"jetmoe": (
"transformers.models.jetmoe.modeling_jetmoe",
MOE_ARCH_BLOCK.get("jetmoe"),
),
"gpt_oss": (
"transformers.models.gpt_oss.modeling_gpt_oss",
MOE_ARCH_BLOCK.get("gpt_oss"),
),
}
def make_grouped_forward(orig_forward):
@wraps(orig_forward)
def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs):
bsz, seqlen, hdim = hidden_states.shape
# expose parent block so grouped backend can access shared expert context
try:
self.experts._ax_parent_block_ref = weakref.ref(self)
except Exception:
pass
y, router_logits = _tg.moe_ffn_forward_grouped(
hidden_states, self.gate, self.experts, self.top_k
)
# One-time log per block instance indicating whether grouped engaged or fallback occurred
if not getattr(self, "_ax_grouped_wrapper_logged", False):
if y is None:
_LOG.warning(
"Grouped wrapper active but fell back to naive for %s",
self.__class__.__name__,
)
else:
_LOG.info(
f"Grouped wrapper engaged for {self.__class__.__name__} (top_k={self.top_k})"
)
self._ax_grouped_wrapper_logged = True
if y is None:
return orig_forward(self, hidden_states, *args, **kwargs)
return y, router_logits
return _grouped_forward
patched = 0
for key, (mod_path, cls_names) in model_mods.items():
if not cls_names:
continue
try:
import importlib
modeling = importlib.import_module(mod_path)
names = cls_names if isinstance(cls_names, list) else [cls_names]
for name in names:
if not hasattr(modeling, name):
continue
block_cls = getattr(modeling, name)
orig_forward = getattr(block_cls, "forward", None)
if orig_forward is None:
continue
_patch_block_forward(block_cls, make_grouped_forward(orig_forward))
patched += 1
_LOG.info(f"Patched MoE block for grouped GEMM: {mod_path}.{name}")
except Exception as e:
# Best effort; log and skip this entry
_LOG.warning(f"Skipping MoE patch for arch '{key}' ({mod_path}): {e}")
if patched == 0:
_LOG.warning(
"No MoE blocks patched for grouped GEMM; model may not use known MoE classes"
)
else:
_LOG.info(f"Grouped GEMM patches applied to {patched} MoE block class(es)")

View File

@@ -11,7 +11,6 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"apertus",
"mllama_text_model",
"llama",
"llama4",
@@ -21,7 +20,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"qwen2_moe",
"qwen3",
"qwen3_moe",
"qwen3_next",
"falcon",
"phi",
"phi3",
@@ -48,7 +46,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
def patch_for_multipack(model_type, model_name=None, has_remote_code=False, cfg=None):
if has_remote_code:
patch_remote(model_name)
elif hasattr(transformers, "modeling_flash_attention_utils"):
@@ -59,7 +57,7 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
patch_mixtral_moe_forward_zero3(cfg)
def patch_remote(model_name):

View File

@@ -13,10 +13,21 @@ from typing import Callable
import torch
import torch.distributed as dist
import transformers
import transformers.modeling_flash_attention_utils as flash_utils
import transformers.modeling_flash_attention_utils
from ring_flash_attn import ring_flash_attn_func
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
try:
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.utils.schemas.enums import RingAttnFunc
@@ -107,7 +118,7 @@ def create_flash_attn_forward_varlen_llama3(
# Handle sliding window
use_sliding_windows = (
_flash_windows_supported()
_flash_supports_window
and sliding_window is not None
and key_states.shape[1] > sliding_window
)
@@ -183,18 +194,3 @@ def substitute_hf_flash_attn(
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
def _flash_windows_supported() -> bool:
"""Return whether current transformers build advertises sliding-window support."""
support = getattr(flash_utils, "_flash_supports_window", None)
if support is None:
support = getattr(flash_utils, "_flash_supports_window_size", None)
if support is None:
return True
if callable(support):
return True
return bool(support)

View File

@@ -13,9 +13,18 @@ from typing import Optional
import torch
import torch.distributed as dist
import transformers.modeling_flash_attention_utils as flash_utils
from torch.distributed import DeviceMesh
try:
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RingAttnFunc
@@ -74,7 +83,7 @@ def create_ring_flash_attention_forward(
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_windows_supported()
_flash_supports_window
and sliding_window is not None
and key_states.shape[1] > sliding_window
)
@@ -216,19 +225,3 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
def _flash_windows_supported() -> bool:
"""Best-effort check for FlashAttention sliding-window support."""
support = getattr(flash_utils, "_flash_supports_window", None)
if support is None:
support = getattr(flash_utils, "_flash_supports_window_size", None)
if support is None:
return True
if callable(support):
# Signature differs across versions; assume support when callable.
return True
return bool(support)

View File

@@ -1,68 +0,0 @@
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
from __future__ import annotations
import importlib
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
PATCHED_GUARD = (
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
)
def patch_prepare_context_parallel_inputs() -> None:
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
return
try:
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
except OSError as exc: # pragma: no cover - occurs when source is unavailable
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
return
if GUARD_PATTERN not in original_source:
LOG.warning(
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
"skipping FlashAttention context parallelism patch"
)
return
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
patched_source, _ = detab_code(patched_source)
patched_source = patched_source.replace(
"def _prepare_context_parallel_inputs(",
"def axolotl_prepare_context_parallel_inputs(",
1,
)
module_name = Trainer.__module__
module = importlib.import_module(module_name)
# import symbols referenced in the method so exec can succeed
items_to_import = []
for item in dir(module):
if item in patched_source:
items_to_import.append(item)
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
exec(patched_source, globals())
Trainer._original_prepare_context_parallel_inputs = (
Trainer._prepare_context_parallel_inputs
)
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
LOG.debug(
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
)

View File

@@ -11,7 +11,6 @@ from transformers.image_utils import load_image
from axolotl.utils.dict import remove_none_values
from axolotl.utils.logging import get_logger
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
LOG = get_logger(__name__)
@@ -422,36 +421,6 @@ class SmolVLM2ProcessingStrategy(ProcessingStrategy):
]
class Mistral3ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Mistral3"""
def __init__(
self,
processor: Mistral3Processor,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
special_ids = (
processor.tokenizer.tokenizer.instruct_tokenizer.image_encoder.special_ids
)
self.image_token = special_ids.img
self.image_break_token = special_ids.img_break
self.image_end_token = special_ids.img_end
def process_labels(self, input_ids):
labels = input_ids.clone()
labels[labels == self.processor.tokenizer.pad_token_id] = -100
labels[labels == self.image_token] = -100
labels[labels == self.image_break_token] = -100
labels[labels == self.image_end_token] = -100
return labels
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -494,11 +463,6 @@ def get_processing_strategy(
**processing_kwargs,
)
if isinstance(processor, Mistral3Processor):
return Mistral3ProcessingStrategy(
**processing_kwargs,
)
# llama3_2_vision, llama4, llava
# mistral_v7_tekken, pixtral, lfm2vl
return ProcessingStrategy(

View File

@@ -179,11 +179,7 @@ def execute_training(
)
)
use_flash_cp = cfg.context_parallel_size > 1 and bool(
getattr(cfg, "flash_attention", False)
)
if use_flash_cp:
if cfg.context_parallel_size > 1:
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model)

View File

@@ -1,6 +1,5 @@
"""Init for `axolotl.utils.mistral` module."""
from axolotl.utils.mistral.mistral3_processor import Mistral3Processor
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
__all__ = ["HFMistralTokenizer", "Mistral3Processor"]
__all__ = ["HFMistralTokenizer"]

View File

@@ -1,169 +0,0 @@
"""Processor for Mistral3 multimodal models with image support"""
from typing import Any, Dict, Optional, Union
import torch
from transformers import ProcessorMixin
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessingKwargs
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from axolotl.utils.mistral.mistral_tokenizer import HFMistralTokenizer
class Mistral3ProcessorKwargs(ProcessingKwargs):
_defaults: Dict[str, Dict[str, Any]] = {
"text_kwargs": {
"padding": True,
},
"common_kwargs": {
"return_tensors": "pt",
"return_dict": True,
"tokenize": True,
},
}
class Mistral3Processor(ProcessorMixin):
"""
Processor for Mistral3 multimodal models that handles text and images.
Wraps HFMistralTokenizer and adds image processing capabilities.
"""
attributes = ["tokenizer"]
tokenizer_class = "HFMistralTokenizer"
def __init__(self, tokenizer: HFMistralTokenizer):
# Don't call super().__init__ to avoid the class validation issue
self.tokenizer = tokenizer
@property
def chat_template(self) -> None:
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
return None
@property
def audio_tokenizer(self) -> None:
"""Audio tokenizer is not supported. Dummy method to satisfy HuggingFace API."""
return None
def _merge_kwargs(
self, processor_kwargs_class: Any, **kwargs: Any
) -> Dict[str, Dict[str, Any]]:
"""Merge kwargs with defaults similar to ProcessorMixin"""
defaults = processor_kwargs_class._defaults
output_kwargs: Dict[str, Dict[str, Any]] = {}
for kwarg_type, default_values in defaults.items():
output_kwargs[kwarg_type] = {**default_values}
# Update with provided kwargs
for key, value in kwargs.items():
# Try to match key to appropriate kwarg type
if key in ["padding", "truncation", "max_length"]:
output_kwargs.setdefault("text_kwargs", {}).update({key: value})
elif key in ["return_tensors", "return_dict", "tokenize"]:
output_kwargs.setdefault("common_kwargs", {}).update({key: value})
else:
# Add to text_kwargs by default
output_kwargs.setdefault("text_kwargs", {}).update({key: value})
return output_kwargs
def apply_chat_template(
self,
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
**kwargs: Any,
) -> Union[BatchFeature, str, list[str]]:
"""
Apply chat template with image support for Mistral3.
Similar to VoxtralProcessor, this method extracts images from the conversation,
calls the tokenizer's apply_chat_template, then adds pixel_values and image_sizes
to the result.
"""
output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)
text_kwargs = output_kwargs["text_kwargs"]
common_kwargs = output_kwargs["common_kwargs"]
return_tensors = common_kwargs.pop("return_tensors", "pt")
if return_tensors != "pt":
raise ValueError(
f"{self.__class__.__name__} only supports `return_tensors='pt'`."
)
return_dict = common_kwargs.pop("return_dict", False)
tokenize = common_kwargs.pop("tokenize", False)
# Determine if batched
if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple))
or hasattr(conversation[0], "content")
):
is_batched = True
conversations = conversation
else:
is_batched = False
conversations = [conversation] # type: ignore
# Call tokenizer's apply_chat_template
tokenizer_kwargs = {**text_kwargs, **common_kwargs}
tokenizer_kwargs["return_tensors"] = return_tensors
tokenizer_kwargs["tokenize"] = tokenize
tokenizer_kwargs["return_dict"] = return_dict
encoded_instruct_inputs = self.tokenizer.apply_chat_template(
conversations,
**tokenizer_kwargs,
)
if tokenize:
if return_dict:
# The tokenizer already handles pixel_values, we just need to add image_sizes
if hasattr(encoded_instruct_inputs, "items"):
data: Dict[str, Any] = dict(encoded_instruct_inputs) # type: ignore
elif hasattr(encoded_instruct_inputs, "data"):
data = encoded_instruct_inputs.data # type: ignore
else:
raise ValueError("Unknown data type")
if "pixel_values" in data:
pixel_values = data["pixel_values"]
# MistralTokenizer returns a Double, so we convert to fp32
data["pixel_values"] = pixel_values.to(dtype=torch.float32)
# Always batched: [B, C, H, W] -> image_sizes: [B, 2]
# Since tensor is homogeneous, all images have same H, W
batch_size = pixel_values.shape[0]
image_sizes = torch.tensor([pixel_values.shape[-2:]] * batch_size)
data["image_sizes"] = image_sizes
return BatchFeature(data=data, tensor_type=return_tensors)
if not is_batched:
return encoded_instruct_inputs[0]
return encoded_instruct_inputs
def __call__(
self,
text: Optional[
Union[
TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]
]
],
**kwargs: Any,
) -> BatchFeature:
"""
Forward text processing to the tokenizer.
This method does not support images - use apply_chat_template instead.
"""
output_kwargs = self._merge_kwargs(Mistral3ProcessorKwargs, **kwargs)
text_kwargs = output_kwargs["text_kwargs"]
common_kwargs = output_kwargs["common_kwargs"]
out = self.tokenizer(text, **text_kwargs)
return BatchFeature(
data=out, tensor_type=common_kwargs.pop("return_tensors", None)
)

View File

@@ -132,6 +132,14 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(),
)
moe_backend: Literal["auto", "torch_grouped", "naive"] | None = Field(
default=None,
json_schema_extra={
"description": "Mixture-of-Experts backend to use: 'auto', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
},
)
# Value is constrained by the Literal type; no normalization needed.
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = Field(
@@ -436,8 +444,8 @@ class AxolotlInputConfig(
},
)
min_sample_len: int | None = None
max_prompt_len: int | None = Field(
default=None,
max_prompt_len: int = Field(
default=512,
json_schema_extra={"description": "maximum prompt length for RL training"},
)
sample_packing: bool | None = Field(

View File

@@ -1,6 +1,7 @@
"""Module with validation methods for config pydantic model."""
import json
import sys
import tempfile
from pathlib import Path
@@ -1313,40 +1314,50 @@ class ComplexValidationMixin:
if not self.context_parallel_size:
self.context_parallel_size = 1
elif self.context_parallel_size > 1:
use_flash_attention = getattr(self, "flash_attention", False)
use_sdp_attention = getattr(self, "sdp_attention", False)
if not (use_flash_attention or use_sdp_attention):
if not self.flash_attention:
raise ValueError(
"context_parallel_size > 1 requires either flash_attention: true "
"or sdp_attention: true"
"flash_attention: true must be set with context_parallel_size > 1"
)
if use_flash_attention:
if self.sample_packing and self.micro_batch_size > 1:
raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled "
"due to a `ring-flash-attn` requirement"
)
try:
import ring_flash_attn # noqa: F401 # Required after monkey-patching
except ImportError as exception:
raise ImportError(
"context_parallel_size > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"context_parallel_size={self.context_parallel_size}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
if self.sample_packing and self.micro_batch_size > 1:
raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled "
"due to a `ring-flash-attn` requirement"
)
try:
import transformers.modeling_flash_attention_utils
from transformers.utils import is_flash_attn_greater_or_equal
transformers.modeling_flash_attention_utils._flash_supports_window = (
True
)
sys.modules[
"transformers.modeling_flash_attention_utils"
]._flash_supports_window = True
sys.modules[
"transformers.modeling_flash_attention_utils"
]._flash_supports_window_size = True
sys.modules[
"transformers.modeling_flash_attention_utils"
].is_flash_attn_greater_or_equal = is_flash_attn_greater_or_equal
import ring_flash_attn # noqa: F401 # Required after monkey-patching
except ImportError as exception:
raise ImportError(
"context_parallel_size > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"context_parallel_size={self.context_parallel_size}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
return self
@model_validator(mode="after")

View File

@@ -23,8 +23,6 @@ class TestSequenceParallelism:
pad_to_sequence_len=True,
ring_attn_func=None,
threshold=2.0,
flash_attention=True,
sdp_attention=False,
):
"""Helper method to run sequence parallel tests with different configurations"""
cfg = DictDefault(
@@ -60,8 +58,7 @@ class TestSequenceParallelism:
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": flash_attention,
"sdp_attention": sdp_attention,
"flash_attention": True,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
@@ -135,16 +132,3 @@ class TestSequenceParallelism:
ring_attn_func=ring_attn_func,
threshold=threshold,
)
def test_sequence_parallel_training_sdpa(self, temp_dir):
"""Smoke test for SDPA-based context parallelism."""
self._run_sequence_parallel_test(
temp_dir,
sample_packing=False,
micro_batch_size=1,
pad_to_sequence_len=True,
ring_attn_func=None,
threshold=3.0,
flash_attention=False,
sdp_attention=True,
)

View File

@@ -1,74 +0,0 @@
"""Tests for PatchManager context parallel patch selection."""
import addict
from axolotl.loaders.patch_manager import PatchManager
from axolotl.utils.dict import DictDefault
def _stub_transformers_patches(monkeypatch):
"""Replace trainer loss patchers with no-ops for isolation."""
monkeypatch.setattr(
"axolotl.monkeypatch.transformers.trainer_loss_calc.patch_evaluation_loop",
lambda: None,
)
monkeypatch.setattr(
"axolotl.monkeypatch.transformers.trainer_loss_calc.patch_maybe_log_save_evaluate",
lambda: None,
)
def test_patch_manager_applies_flash_cp_patch(monkeypatch):
"""When flash attention is enabled, we patch Trainer for CP."""
_stub_transformers_patches(monkeypatch)
patch_calls = {"count": 0}
def stub_patch():
patch_calls["count"] += 1
monkeypatch.setattr(
"axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs",
stub_patch,
)
cfg = DictDefault(
{
"context_parallel_size": 2,
"flash_attention": True,
"sdp_attention": False,
}
)
manager = PatchManager(cfg, addict.Dict())
manager._apply_transformers_patches()
assert patch_calls["count"] == 1
def test_patch_manager_skips_flash_patch_for_sdpa(monkeypatch):
"""When only SDPA is requested, we should not patch Trainer."""
_stub_transformers_patches(monkeypatch)
patch_calls = {"count": 0}
def stub_patch():
patch_calls["count"] += 1
monkeypatch.setattr(
"axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs",
stub_patch,
)
cfg = DictDefault(
{
"context_parallel_size": 2,
"flash_attention": False,
"sdp_attention": True,
}
)
manager = PatchManager(cfg, addict.Dict())
manager._apply_transformers_patches()
assert patch_calls["count"] == 0

View File

@@ -1,35 +0,0 @@
"""Integration tests for MistralCommonTokenizer patches."""
import pytest
class TestMistralTokenizerPatchIntegration:
"""Test MistralCommonTokenizer patch integration."""
@pytest.mark.integration
def test_mistral_tokenizer_image_patch(self):
"""Test that MistralCommonTokenizer image patch can be applied."""
try:
from transformers.tokenization_mistral_common import MistralCommonTokenizer
except ImportError:
pytest.skip("MistralCommonTokenizer not available")
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
apply_mistral_tokenizer_image_patch,
)
# Store original method
original_apply_chat_template = MistralCommonTokenizer.apply_chat_template
# Apply patch
apply_mistral_tokenizer_image_patch()
# Verify patch was applied
assert (
MistralCommonTokenizer.apply_chat_template != original_apply_chat_template
), "apply_chat_template was not patched"
# Verify the method is still callable
assert callable(MistralCommonTokenizer.apply_chat_template), (
"Patched method is not callable"
)

View File

@@ -0,0 +1,258 @@
import sys
import types
import torch
import torch.nn as nn
from axolotl.kernels.moe import (
backends as moe_backends,
torch_grouped as torch_grouped_module,
)
from axolotl.monkeypatch import moe_grouped
class DummyExperts(nn.Module):
def __init__(self, layers):
super().__init__()
self.layers = nn.ModuleList(layers)
self.num_experts = len(layers)
def __getitem__(self, idx):
return self.layers[idx]
class DummyQwenMLP(nn.Module):
def __init__(self, idx: int, hidden: int, intermediate: int):
super().__init__()
self.gate_up_proj = nn.Linear(hidden, 2 * intermediate, bias=False)
self.down_proj = nn.Linear(intermediate, hidden, bias=False)
nn.init.constant_(self.gate_up_proj.weight, float(idx + 1))
nn.init.constant_(self.down_proj.weight, float((idx + 1) * 10))
class DummyQwenExpert(nn.Module):
def __init__(self, idx: int, hidden: int, intermediate: int):
super().__init__()
self.mlp = DummyQwenMLP(idx, hidden, intermediate)
def _make_transformers_stub(monkeypatch, block_cls):
# ensure we start from the original forward for each test
if block_cls is DummyMixtralBlock:
DummyMixtralBlock.forward = _DUMMY_MIXTRAL_ORIG_FORWARD
transformers_mod = types.ModuleType("transformers")
models_mod = types.ModuleType("transformers.models")
mixtral_mod = types.ModuleType("transformers.models.mixtral")
modeling_mixtral = types.ModuleType("transformers.models.mixtral.modeling_mixtral")
modeling_mixtral.MixtralSparseMoeBlock = block_cls
transformers_mod.models = models_mod
models_mod.mixtral = mixtral_mod
mixtral_mod.modeling_mixtral = modeling_mixtral
monkeypatch.setitem(sys.modules, "transformers", transformers_mod)
monkeypatch.setitem(sys.modules, "transformers.models", models_mod)
monkeypatch.setitem(sys.modules, "transformers.models.mixtral", mixtral_mod)
monkeypatch.setitem(
sys.modules,
"transformers.models.mixtral.modeling_mixtral",
modeling_mixtral,
)
def test_grouped_uses_per_expert_nested_modules(monkeypatch):
hidden = 4
intermediate = 2
num_experts = 2
experts = DummyExperts(
[DummyQwenExpert(i, hidden, intermediate) for i in range(num_experts)]
)
gate = nn.Linear(hidden, num_experts, bias=False)
nn.init.zeros_(gate.weight)
captured = []
def fake_grouped_mm(As, Bs, dtype):
captured.append([b.detach().clone() for b in Bs])
return [
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
for a, b in zip(As, Bs, strict=False)
]
monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm)
hidden_states = torch.randn(1, 2, hidden)
y, router_logits = torch_grouped_module.moe_ffn_forward_grouped(
hidden_states, gate, experts, top_k=2
)
assert y is not None
assert router_logits is not None
assert captured, "Grouped GEMM path should have been invoked"
first_call = captured[0]
expected0 = experts[0].mlp.gate_up_proj.weight.t()
expected1 = experts[1].mlp.gate_up_proj.weight.t()
assert torch.equal(first_call[0], expected0)
assert torch.equal(first_call[1], expected1)
assert not torch.equal(first_call[0], first_call[1])
def test_grouped_accepts_module_list_experts(monkeypatch):
hidden = 4
intermediate = 2
experts = nn.ModuleList(
[DummyQwenExpert(i, hidden, intermediate) for i in range(2)]
)
gate = nn.Linear(hidden, len(experts), bias=False)
nn.init.zeros_(gate.weight)
calls = {"count": 0}
def fake_grouped_mm(As, Bs, dtype):
calls["count"] += 1
return [
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
for a, b in zip(As, Bs, strict=False)
]
monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm)
hidden_states = torch.randn(1, 2, hidden)
y, router_logits = torch_grouped_module.moe_ffn_forward_grouped(
hidden_states, gate, experts, top_k=2
)
assert y is not None
assert router_logits is not None
assert calls["count"] > 0
class _DummyCfg:
moe_backend = "torch_grouped"
class DummyMixtralBlock(nn.Module):
def __init__(self):
super().__init__()
self.top_k = 1
self.gate = lambda x: x
self.experts = object()
self._calls = []
def forward(self, hidden_states: torch.Tensor, attention_mask=None):
self._calls.append((hidden_states, attention_mask))
tokens = hidden_states.shape[0] * hidden_states.shape[1]
router = torch.ones(
tokens, 2, device=hidden_states.device, dtype=hidden_states.dtype
)
return hidden_states + 5, router
_DUMMY_MIXTRAL_ORIG_FORWARD = DummyMixtralBlock.forward
def test_apply_grouped_forward_handles_args(monkeypatch):
_make_transformers_stub(monkeypatch, DummyMixtralBlock)
import axolotl.common.architectures as arch
original_map = arch.MOE_ARCH_BLOCK.copy()
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, "mixtral", "MixtralSparseMoeBlock")
for key in list(original_map.keys()):
if key != "mixtral":
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, key, None)
monkeypatch.setattr(
moe_grouped,
"get_moe_backend_name",
lambda preferred=None: moe_backends.MOEBackend.TORCH_GROUPED,
)
results = {}
def fake_grouped_forward(hidden_states, gate, experts, top_k):
results["called"] = True
router = torch.zeros(
hidden_states.shape[0] * hidden_states.shape[1],
2,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
return hidden_states + 1, router
monkeypatch.setattr(torch_grouped_module, "available", lambda: True)
monkeypatch.setattr(
torch_grouped_module,
"moe_ffn_forward_grouped",
fake_grouped_forward,
)
cfg = _DummyCfg()
moe_grouped.apply_grouped_to_moe_blocks(cfg)
block = DummyMixtralBlock()
hidden_states = torch.ones(1, 2, 3)
mask = torch.zeros(1, 2)
out, router = block.forward(hidden_states, attention_mask=mask)
assert results.get("called") is True
assert torch.equal(out, hidden_states + 1)
assert router.shape[0] == hidden_states.shape[0] * hidden_states.shape[1]
def test_apply_grouped_forward_fallback(monkeypatch):
_make_transformers_stub(monkeypatch, DummyMixtralBlock)
import axolotl.common.architectures as arch
original_map = arch.MOE_ARCH_BLOCK.copy()
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, "mixtral", "MixtralSparseMoeBlock")
for key in list(original_map.keys()):
if key != "mixtral":
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, key, None)
monkeypatch.setattr(
moe_grouped,
"get_moe_backend_name",
lambda preferred=None: moe_backends.MOEBackend.TORCH_GROUPED,
)
monkeypatch.setattr(torch_grouped_module, "available", lambda: True)
monkeypatch.setattr(
torch_grouped_module,
"moe_ffn_forward_grouped",
lambda *args, **kwargs: (None, None),
)
cfg = _DummyCfg()
moe_grouped.apply_grouped_to_moe_blocks(cfg)
block = DummyMixtralBlock()
hidden_states = torch.ones(1, 2, 3)
mask = torch.zeros(1, 2)
out, router = block.forward(hidden_states, attention_mask=mask)
assert torch.equal(out, hidden_states + 5)
assert router.shape[0] == hidden_states.shape[0] * hidden_states.shape[1]
assert block._calls, "Original forward should have been invoked"
call_hidden, call_mask = block._calls[-1]
assert torch.equal(call_hidden, hidden_states)
assert torch.equal(call_mask, mask)
def test_get_moe_backend_name_prefers_probe(monkeypatch):
monkeypatch.setattr(moe_backends, "_probe_torch_grouped", lambda: True)
assert moe_backends.get_moe_backend_name() == moe_backends.MOEBackend.TORCH_GROUPED
def test_get_moe_backend_name_falls_back(monkeypatch):
warnings_captured = []
def fake_warn(msg, *, stacklevel=None): # noqa: ARG001
warnings_captured.append(msg)
monkeypatch.setattr(moe_backends, "_probe_torch_grouped", lambda: False)
monkeypatch.setattr(moe_backends.warnings, "warn", fake_warn)
backend = moe_backends.get_moe_backend_name("torch_grouped")
assert backend == moe_backends.MOEBackend.NAIVE
assert warnings_captured, "Expected warning when torch_grouped unavailable"

View File

@@ -1,77 +0,0 @@
"""Integration tests for Pixtral Flash Attention patches."""
import pytest
import torch
class TestPixtralFlashAttentionPatchIntegration:
"""Test Pixtral Flash Attention patch integration."""
@pytest.mark.integration
def test_pixtral_flash_attention_patch(self):
"""Test that Pixtral Flash Attention patch can be applied and works correctly."""
try:
from transformers import modeling_flash_attention_utils
except ImportError:
pytest.skip("Flash Attention utils not available")
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
apply_patch_is_packed_sequence,
)
# Store original method
original_is_packed_sequence = modeling_flash_attention_utils._is_packed_sequence
# Apply patch and get unpatch function
unpatch_fn = apply_patch_is_packed_sequence()
# Verify patch was applied
assert (
modeling_flash_attention_utils._is_packed_sequence
!= original_is_packed_sequence
), "_is_packed_sequence was not patched"
# Test the patched function with 1D position_ids
patched_fn = modeling_flash_attention_utils._is_packed_sequence
# Test 1D position_ids 1 sequence
position_ids_1d = torch.tensor([0, 1, 2, 3])
result = patched_fn(position_ids_1d, batch_size=1)
assert isinstance(result, bool), "Function should return a boolean"
assert result is False, "1D sequential position_ids should not be packed"
# Test 1D packed 2 sequences
position_ids_1d_packed = torch.tensor([0, 1, 2, 0, 1, 2])
result = patched_fn(position_ids_1d_packed, batch_size=1)
assert isinstance(result, bool), "Function should return a boolean"
assert result is True, "1D packed position_ids should be detected as packed"
# Test 2D packed 2 sequences
position_ids_2d_packed = torch.tensor([[0, 1, 2, 3, 0, 1]])
result = patched_fn(position_ids_2d_packed, batch_size=1)
assert isinstance(result, bool), "Function should return a boolean"
assert result is True, "2D packed position_ids should be detected as packed"
# Test 2D 1 sequence
position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5]])
result = patched_fn(position_ids_2d_normal, batch_size=1)
assert isinstance(result, bool), "Function should return a boolean"
assert result is False, "2D sequential position_ids should not be packed"
# Test 2D batch size 2
position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]])
result = patched_fn(position_ids_2d_normal, batch_size=2)
assert isinstance(result, bool), "Function should return a boolean"
assert result is False, "2D position_ids batch 2 should not be packed"
# Test None case
result = patched_fn(None, batch_size=1)
assert isinstance(result, bool), "Function should return a boolean"
assert result is False, "None position_ids should return False"
# Test unpatch function
unpatch_fn()
assert (
modeling_flash_attention_utils._is_packed_sequence
== original_is_packed_sequence
), "unpatch function did not restore original method"

View File

@@ -1,111 +0,0 @@
"""Integration tests for Qwen3 Next modeling patches."""
import pytest
import torch
# Skip entire module if qwen3_next not available
qwen3_next = pytest.importorskip("transformers.models.qwen3_next.modeling_qwen3_next")
class TestQwen3NextModelingPatchIntegration:
"""Test Qwen3 Next modeling patch integration."""
@pytest.mark.integration
def test_qwen3_next_decoder_layer_patch(self):
"""Test that Qwen3Next decoder layer patch can be applied."""
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_decoder_layer,
)
# Store original method
original_forward = qwen3_next.Qwen3NextDecoderLayer.forward
# Apply patch and get unpatch function
unpatch_fn = patch_qwen3_next_decoder_layer()
# Verify patch was applied
assert qwen3_next.Qwen3NextDecoderLayer.forward != original_forward, (
"decoder layer forward method was not patched"
)
# Verify the method is still callable
assert callable(qwen3_next.Qwen3NextDecoderLayer.forward), (
"Patched method is not callable"
)
# Test unpatch function
if unpatch_fn:
unpatch_fn()
assert qwen3_next.Qwen3NextDecoderLayer.forward == original_forward, (
"unpatch function did not restore original method"
)
@pytest.mark.integration
def test_qwen3_next_gateddelta_layer_patch(self):
"""Test that Qwen3Next GatedDeltaNet patch can be applied."""
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_gateddelta_layer,
)
# Store original method
original_forward = qwen3_next.Qwen3NextGatedDeltaNet.forward
# Apply patch and get unpatch function
unpatch_fn = patch_qwen3_next_gateddelta_layer()
# Verify patch was applied
assert qwen3_next.Qwen3NextGatedDeltaNet.forward != original_forward, (
"GatedDeltaNet forward method was not patched"
)
# Verify the method is still callable
assert callable(qwen3_next.Qwen3NextGatedDeltaNet.forward), (
"Patched method is not callable"
)
# Test unpatch function
if unpatch_fn:
unpatch_fn()
assert qwen3_next.Qwen3NextGatedDeltaNet.forward == original_forward, (
"unpatch function did not restore original method"
)
@pytest.mark.integration
def test_qwen3_next_imports_patch(self):
"""Test that Qwen3Next imports patch can be applied without errors."""
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_imports,
)
# Apply patch - should not raise any exceptions even if modules unavailable
unpatch_fn = patch_qwen3_next_imports()
# Test that unpatch function is returned (or None if skipped)
assert unpatch_fn is None or callable(unpatch_fn), (
"patch_qwen3_next_imports should return None or callable unpatch function"
)
@pytest.mark.integration
def test_qwen3_next_modeling_packing_patch(self):
"""Test that all Qwen3Next modeling patches can be applied together."""
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_modeling_packing,
)
# This should not raise any exceptions
patch_qwen3_next_modeling_packing()
@pytest.mark.integration
def test_get_cu_seqlens_utility():
"""Test the get_cu_seqlens utility function."""
from axolotl.monkeypatch.models.qwen3_next.modeling import get_cu_seqlens
# Test with simple position_ids
position_ids = torch.tensor([[0, 1, 2, 0, 1]])
cu_seqlens = get_cu_seqlens(position_ids)
assert cu_seqlens.dtype == torch.int32, "Should be int32 dtype"
# Should return tensor with start positions and total length
expected = torch.tensor([0, 3, 5], dtype=torch.int32)
assert torch.equal(cu_seqlens, expected), f"Expected {expected}, got {cu_seqlens}"

View File

@@ -1,66 +0,0 @@
"""Tests for the HF Trainer context parallel patch."""
import pytest
from transformers import Trainer
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
GUARD_PATTERN,
PATCHED_GUARD,
patch_prepare_context_parallel_inputs,
)
@pytest.fixture
def restore_trainer_prepare_method():
"""Ensure Trainer._prepare_context_parallel_inputs is restored after a test."""
original_method = getattr(
Trainer,
"_original_prepare_context_parallel_inputs",
Trainer._prepare_context_parallel_inputs,
)
patched_attr_present = hasattr(
Trainer, "_axolotl_prepare_context_parallel_inputs_patched"
)
yield
Trainer._prepare_context_parallel_inputs = original_method
if patched_attr_present:
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
delattr(Trainer, "_original_prepare_context_parallel_inputs")
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"):
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
def test_patch_attention_guard(restore_trainer_prepare_method):
"""Patch should swap the guard to allow sdpa or flash attention."""
# Ensure we start from the unpatched method
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
Trainer._prepare_context_parallel_inputs = (
Trainer._original_prepare_context_parallel_inputs
)
delattr(Trainer, "_original_prepare_context_parallel_inputs")
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"):
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
patch_prepare_context_parallel_inputs()
patched_method = Trainer._prepare_context_parallel_inputs
assert patched_method is not None
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
source = Trainer._axolotl_prepare_context_parallel_inputs_source
assert GUARD_PATTERN not in source
assert PATCHED_GUARD in source
def test_patch_is_idempotent(restore_trainer_prepare_method):
"""Calling the patch twice should leave the same patched function in place."""
patch_prepare_context_parallel_inputs()
first_patched = Trainer._prepare_context_parallel_inputs
patch_prepare_context_parallel_inputs()
second_patched = Trainer._prepare_context_parallel_inputs
assert first_patched is second_patched

View File

@@ -1,43 +0,0 @@
"""Integration tests for Voxtral modeling patches."""
import pytest
class TestVoxtralModelingPatchIntegration:
"""Test Voxtral modeling patch integration."""
@pytest.mark.integration
def test_voxtral_conditional_generation_patch(self):
"""Test that Voxtral conditional generation patch can be applied."""
try:
from transformers.models.voxtral.modeling_voxtral import (
VoxtralForConditionalGeneration,
)
except ImportError:
pytest.skip("VoxtralForConditionalGeneration not available")
from axolotl.monkeypatch.models.voxtral.modeling import (
patch_voxtral_conditional_generation_forward,
)
# Store original method
original_forward = VoxtralForConditionalGeneration.forward
# Apply patch and get unpatch function
unpatch_fn = patch_voxtral_conditional_generation_forward()
# Verify patch was applied
assert VoxtralForConditionalGeneration.forward != original_forward, (
"forward method was not patched"
)
# Verify the method is still callable
assert callable(VoxtralForConditionalGeneration.forward), (
"Patched method is not callable"
)
# Test unpatch function
unpatch_fn()
assert VoxtralForConditionalGeneration.forward == original_forward, (
"unpatch function did not restore original method"
)

View File

@@ -1,111 +0,0 @@
"""Unit tests for choosing the correct context parallel implementation."""
from types import SimpleNamespace
from axolotl.train import execute_training
from axolotl.utils.dict import DictDefault
class DummyTrainer:
"""Minimal trainer stub to exercise execute_training."""
def __init__(self):
self.model = object()
self.ref_model = None
self.accelerator = SimpleNamespace(torch_device_mesh=None)
self.train_called = False
def train(self, resume_from_checkpoint=None): # pylint: disable=unused-argument
self.train_called = True
class DummyPluginManager:
"""Minimal plugin manager stub."""
@staticmethod
def post_train(cfg, model): # pylint: disable=unused-argument
return None
class DummyContext:
"""Test context manager that records entries/exits."""
def __init__(self, recorder, **kwargs):
recorder.append({"kwargs": kwargs})
self.recorder = recorder
def __enter__(self):
self.recorder[-1]["entered"] = True
return self
def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument
self.recorder[-1]["exited"] = True
return False
def _base_cfg(**overrides):
base = {
"context_parallel_size": 2,
"gradient_accumulation_steps": 1,
"ring_attn_func": None,
"heads_k_stride": None,
"rl": None,
"flash_optimum": False,
}
base.update(overrides)
return DictDefault(base)
def test_execute_training_uses_ring_when_flash(monkeypatch):
"""FlashAttention CP should engage the custom ring context manager."""
recorder: list[dict] = []
monkeypatch.setattr(
"axolotl.train.SequenceParallelContextManager",
lambda **kwargs: DummyContext(recorder, **kwargs),
)
monkeypatch.setattr(
"axolotl.train.PluginManager.get_instance",
lambda: DummyPluginManager(),
)
cfg = _base_cfg(flash_attention=True, sdp_attention=False)
trainer = DummyTrainer()
execute_training(cfg, trainer, resume_from_checkpoint=None)
assert trainer.train_called
assert len(recorder) == 1
assert recorder[0]["kwargs"]["context_parallel_size"] == 2
assert recorder[0].get("entered") is True
assert recorder[0].get("exited") is True
def test_execute_training_uses_transformers_cp_for_sdpa(monkeypatch):
"""SDPA CP should bypass the ring context manager."""
invoked = {"count": 0}
class NoOpContext:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument
return False
monkeypatch.setattr(
"axolotl.train.SequenceParallelContextManager",
lambda **kwargs: invoked.__setitem__("count", invoked["count"] + 1)
or NoOpContext(),
)
monkeypatch.setattr(
"axolotl.train.PluginManager.get_instance",
lambda: DummyPluginManager(),
)
cfg = _base_cfg(flash_attention=False, sdp_attention=True)
trainer = DummyTrainer()
execute_training(cfg, trainer, resume_from_checkpoint=None)
assert trainer.train_called
assert invoked["count"] == 0

6545
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff