Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
f94ec0434c include tool in default message_property_mappings 2026-03-03 09:03:25 -05:00
96 changed files with 244 additions and 4410 deletions

View File

@@ -70,11 +70,6 @@ You can skip certain CI checks by including specific keywords in your commit mes
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
Use the pre-commit linter to ensure that your code is formatted consistently.
```bash
pre-commit run --all-files
```
### Commit Messages
Write clear and concise commit messages that briefly describe the changes made in each commit. Use the imperative mood and start with a capitalized verb, e.g., "Add new feature" or "Fix bug in function".

View File

@@ -51,14 +51,6 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -181,14 +173,6 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""

View File

@@ -35,6 +35,12 @@ jobs:
pytorch: 2.8.0
axolotl_extras: fbgemm-gpu
num_gpus: 2
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
@@ -49,13 +55,6 @@ jobs:
axolotl_extras:
# axolotl_extras: fbgemm-gpu
num_gpus: 2
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.10.0
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:

View File

@@ -18,27 +18,15 @@ jobs:
env:
SKIP: no-commit-to-branch
prime-cdn-s3-cache:
name: Prefetch S3 once to prime the CDN cache
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
timeout-minutes: 10
steps:
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [prime-cdn-s3-cache]
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
python_version: ["3.11"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
timeout-minutes: 20
steps:
@@ -114,23 +102,16 @@ jobs:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.8.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.10.0
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
nightly_build: "true"
steps:
- name: Checkout
@@ -151,7 +132,6 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal

View File

@@ -46,32 +46,21 @@ jobs:
env:
SKIP: no-commit-to-branch
prime-cdn-s3-cache:
name: Prefetch S3 once to prime the CDN cache
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
timeout-minutes: 10
steps:
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
pytest:
name: PyTest
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
needs: [prime-cdn-s3-cache]
# needs: [preload-cache]
strategy:
fail-fast: false
matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
python_version: ["3.11", "3.12"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
timeout-minutes: 20
steps:
@@ -157,18 +146,17 @@ jobs:
name: PyTest from Source Dist
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
needs: [prime-cdn-s3-cache]
strategy:
fail-fast: false
matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
timeout-minutes: 30
python_version: ["3.11", "3.12"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
timeout-minutes: 20
steps:
- name: cleanup node
@@ -338,12 +326,6 @@ jobs:
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.10.0
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
@@ -387,9 +369,9 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:

View File

@@ -11,7 +11,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.4
rev: v0.14.10
hooks:
- id: ruff
args: [--fix]
@@ -26,7 +26,7 @@ repos:
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.9.4
rev: 1.9.2
hooks:
- id: bandit
args: [

View File

@@ -29,23 +29,8 @@
## 🎉 Latest Updates
- 2026/03:
- New model support has been added in Axolotl for [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
- 2026/02:
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
- Axolotl now has support for [SageAttention](https://github.com/axolotl-ai-cloud/axolotl/pull/2823) and [GDPO](https://github.com/axolotl-ai-cloud/axolotl/pull/3353) (Generalized DPO).
- 2026/01:
- New integration for [EAFT](https://github.com/axolotl-ai-cloud/axolotl/pull/3366) (Entropy-Aware Focal Training), weights loss by entropy of the top-k logit distribution, and [Scalable Softmax](https://github.com/axolotl-ai-cloud/axolotl/pull/3338), improves long context in attention.
- 2025/12:
- Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
- [Distributed Muon Optimizer](https://github.com/axolotl-ai-cloud/axolotl/pull/3264) support has been added for FSDP2 pretraining.
- 2025/12: Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
<details>
<summary>Expand older updates</summary>
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
- 2025/07:
@@ -54,10 +39,15 @@
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
<details>
<summary>Expand older updates</summary>
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
@@ -72,10 +62,10 @@ Axolotl is a free and open-source tool designed to streamline post-training and
Features:
- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [SageAttention](https://github.com/thu-ml/SageAttention), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.

View File

@@ -1 +1 @@
0.15.0
0.15.0.dev0

View File

@@ -331,7 +331,6 @@ website:
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/nd_parallelism.qmd
- docs/expert_quantization.qmd
- section: "Troubleshooting"
contents:

View File

@@ -33,7 +33,6 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
RUN uv pip install packaging==26.0 setuptools==75.8.0
RUN uv pip install torchvision
RUN uv pip uninstall causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -33,7 +33,6 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
RUN pip uninstall -y causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -3,12 +3,6 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
# curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
hf download "NousResearch/Meta-Llama-3-8B"
hf download "NousResearch/Meta-Llama-3-8B-Instruct"
hf download "microsoft/Phi-4-reasoning"
hf download "microsoft/Phi-3.5-mini-instruct"
# Run unit tests with initial coverage report
pytest -v --durations=10 -n8 \
--ignore=tests/e2e/ \

View File

@@ -22,7 +22,6 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN pip uninstall -y causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \

View File

@@ -22,7 +22,6 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN uv pip uninstall causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \

View File

@@ -1,67 +0,0 @@
---
title: "MoE Expert Quantization"
description: "Reduce VRAM usage when training MoE model adapters by quantizing expert weights on load"
---
Transformers v5 changed MoE expert layers from `nn.Linear` to fused `nn.Parameter` (3D+ tensors).
This means `bitsandbytes` can no longer quantize them during model loading, resulting in all expert
weights being loaded in full bf16 precision and causing massive VRAM usage.
`quantize_moe_experts` solves this by quantizing expert weights during model loading.
It intercepts the weight loading process, quantizes each expert tensor on the fly, and
immediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory.
For example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.
## Usage
Enable expert quantization in your Axolotl config:
```yaml
quantize_moe_experts: true
```
This works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.
### Expert LoRA targeting
You can optionally apply LoRA adapters directly to expert weights using `lora_target_parameters`:
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router
```
::: {.callout-note}
`lora_dropout` must be `0` when using `lora_target_parameters`.
:::
## Requirements
- Requires (`adapter: lora` and `load_in_8bit: true`) or (`adapter: qlora` and `load_in_4bit: true`)
- CUDA GPUs only (not tested with ROCm or other backends)
- FSDP2 compatible for distributed training
## Limitations
- `lora_target_linear` is not compatible with `quantize_moe_experts`. See [Expert LoRA targeting](#expert-lora-targeting) instead.
- `cpu_ram_efficient_loading` hangs / takes long time with FSDP2 + QLoRA.
- Total model parameter count may display incorrectly (trainable param count is correct).
- FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this.
- FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks.
- Model loading takes longer due to on-demand quantization, even on consecutive runs.
- DeepSpeed has not been tested.
## Implementation details
The quantization is applied by patching transformers to intercept weight loading.
When a 3D+ CUDA tensor with "expert" in its name is detected:
- **4-bit mode:** Uses bitsandbytes NF4 parametrization (configurable via `bnb_4bit_quant_type`).
- **8-bit mode:** Uses a custom row-wise int8 parametrization with bitsandbytes dequantization.
The original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to
transformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.
For full implementation details, see [PR #3439](https://github.com/axolotl-ai-cloud/axolotl/pull/3439).

View File

@@ -66,15 +66,6 @@ Provides efficient Triton kernels to improve training speed and reduce memory us
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
### Expert Kernels
Optimized kernel implementations for Mixture of Experts (MoE) model training.
- **ScatterMoE**: Triton-based MoE kernels with fused LoRA support.
- **SonicMoE**: CUTLASS-based MoE kernels for NVIDIA Hopper and Blackwell GPUs.
- **Learn more:** [Custom Integrations - Kernels Integration](custom_integrations.qmd#kernels-integration)
## Long Context Models
Techniques to train models on sequences longer than their original context window.
@@ -140,10 +131,3 @@ Simulates quantization effects during training, helping the model adapt and pote
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
### MoE Expert Quantization
Quantizes MoE expert weights on load to reduce VRAM when training MoE models with adapters. Required for Transformers v5+ MoE models where experts use fused `nn.Parameter` tensors.
- **Config:** `quantize_moe_experts: true`
- **Learn more:** [MoE Expert Quantization](expert_quantization.qmd)

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572\""
]
},
{

View File

@@ -1,72 +0,0 @@
# Finetune Z.ai's GLM-4.5-Air with Axolotl
[GLM-4.5-Air](https://huggingface.co/zai-org/GLM-4.5-Air) is a MoE model by Z.ai.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# QLoRA (1x80GB @ ~63.4GiB/GPU)
axolotl train examples/glm45/glm-45-air-qlora.yaml
```
### Dataset
In addition to the standard OpenAI Messages format, GLM-4.5 supports an extra parameter for thinking in the assistant section.
```json
{
"role": "assistant",
"reasoning_content": "...", // or have </think>...</think> in `content`
"content": "..."
}
```
Make sure you set the below extra attributes if needed:
```yaml
datasets:
- path: ...
type: chat_template
message_property_mappings:
role: role
content: content
# tool_calls: tool_calls # uncomment if using tools
# reasoning_content: reasoning_content # uncomment if have reasoning
# Uncomment if training on tool role (you would rarely if ever need this)
# eot_tokens:
# - <|observation|>
```
### Tips
- The role name for tools in this template is `tool`.
- You will see this Axolotl WARNING — this is expected as the template does not use EOS:
```
EOS token '<|endoftext|>' not found in chat_template. Please check if your template/EOS token is correct.
```
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config.
- **LoRA kernels**: Incompatible with this model. Must be explicitly disabled (`lora_*_kernel: false`).
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [GLM-4.5-Air on HuggingFace](https://huggingface.co/zai-org/GLM-4.5-Air)
- [GLM-4.5 Blog](https://z.ai/blog/glm-4.5)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,64 +0,0 @@
base_model: zai-org/GLM-4.5-Air
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
quantize_moe_experts: true # important
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 16
lora_alpha: 8
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,65 +0,0 @@
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# QLoRA
# - no target experts (1x48GB @ ~24GiB/GPU)
# - target experts (1x48GB @ ~34GiB/GPU)
axolotl train examples/glm47-flash/qlora.yaml
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
axolotl train examples/glm47-flash/qlora_fsdp.yaml
```
```bash
# LoRA
# - no target experts (1x48GB @ ~35GiB/GPU)
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
axolotl train examples/glm47-flash/lora.yaml
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
axolotl train examples/glm47-flash/lora_fsdp.yaml
```
### MoE Expert Quantization & Expert LoRA
This model quantize expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
## Limitations
- **lora_target_linear**: Incompatible for this model.
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
### TIPS
- For inference, the official Z.ai team recommends these default settings (most tasks):
- `temperature: 1.0`
- `top_p: 0.95`
- `max_new_tokens: 131072`
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,65 +0,0 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -1,75 +0,0 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,65 +0,0 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -1,75 +0,0 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,65 +0,0 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/dataset_prepared
sequence_len: 2048
flash_attention: true
qat:
activation_dtype: mxfp4
weight_dtype: mxfp4
group_size: 32
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_checkpointing: true
activation_offloading: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -6,13 +6,30 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Install Qwen3-Next transformers commit
```bash
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
```
3. Install FLA for improved performance
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
```
4. Run the finetuning example:
@@ -21,7 +38,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
```
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
This config uses about 45.62 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀

View File

@@ -9,8 +9,6 @@ plugins:
load_in_8bit: false
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
@@ -27,7 +25,7 @@ sample_packing: true
lora_r: 16
lora_alpha: 8
lora_dropout: 0
lora_dropout: 0.05
lora_target_modules:
- linear_attn.in_proj_ba
- linear_attn.in_proj_qkvz
@@ -36,19 +34,12 @@ lora_target_modules:
- shared_expert.down_proj
- shared_expert.gate_proj
- shared_expert_gate
- mlp.gate
- q_proj
- v_proj
- k_proj
- o_proj
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:

View File

@@ -1,71 +0,0 @@
base_model: Qwen/Qwen3.5-122B-A10B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
#lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,72 +0,0 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Note: Qwen3.5 is an early-fusion VLM (image+text). This config fine-tunes
# the text-only path. For multimodal (image+text) fine-tuning, add image
# columns to your dataset following axolotl's multimodal dataset format.
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
# Uncomment below to also target the linear attention projections.
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
# - linear_attn.in_proj_qkv
# - linear_attn.in_proj_z
# - linear_attn.out_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,70 +0,0 @@
base_model: Qwen/Qwen3.5-35B-A3B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
#lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,72 +0,0 @@
base_model: Qwen/Qwen3.5-7B
processor_type: AutoProcessor
# Qwen3.5-7B and above are early-fusion VLMs (Qwen3_5ForConditionalGeneration).
# Vision and text tokens are processed together by the same transformer layers.
# Note: Qwen3.5-2B is a text-only model — the smallest VLM is Qwen3.5-7B.
# These 3 lines are required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
# Targets the language model attention and MLP layers.
# Qwen3.5 is early-fusion: all layers (including those seeing vision tokens) share
# the same transformer stack, so standard attention targets work for both modalities.
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
# Uncomment to also target the linear attention (GatedDeltaNet) projections:
# - linear_attn.in_proj_qkv
# - linear_attn.in_proj_z
# - linear_attn.out_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -1,61 +0,0 @@
# Finetune Qwen3.5 with Axolotl
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-68452f3bc6e4b7cfb4e1c803) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. Models from 7B onwards are early-fusion vision-language models (`Qwen3_5ForConditionalGeneration`), meaning vision and text tokens are processed through the same transformer stack. The 2B variant is text-only.
Available configs:
| Config | Model | Type |
|---|---|---|
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only path |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only path |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only path |
| `7b-lora-vision.yaml` | Qwen3.5-7B | Vision+text (multimodal) |
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
4. Run a finetuning example:
```bash
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
axolotl train examples/qwen3.5/27b-qlora.yaml
# MoE 35B-A3B text-only (QLoRA)
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
# MoE 122B-A10B text-only (QLoRA)
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
# 7B vision+text (LoRA, multimodal dataset)
axolotl train examples/qwen3.5/7b-lora-vision.yaml
```
### TIPS
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
- You can run a full finetuning by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `7b-lora-vision.yaml`.
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
## Optimization Guides
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)
## Related Resources
- [Qwen3.5 Blog](https://qwenlm.github.io/blog/qwen3.5/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -8,15 +8,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
2. Run the finetuning example:
```bash
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
```
This config uses about 24.9 GiB VRAM (w/o CCE).
This config uses about 24.9 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
@@ -31,6 +29,10 @@ Let us know how it goes. Happy finetuning! 🚀
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
## Related Resources
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)

View File

@@ -1,4 +1,5 @@
base_model: arcee-ai/Trinity-Nano-Preview
trust_remote_code: true
revision_of_model: 2ee94b0
# Automatically upload checkpoint and final model to HF

View File

@@ -12,16 +12,13 @@ packaging==26.0
huggingface_hub>=1.1.7
peft>=0.18.1
tokenizers>=0.22.1
transformers==5.3.0
accelerate==1.13.0
transformers==5.2.0
accelerate==1.12.0
datasets==4.5.0
deepspeed>=0.18.6,<0.19.0
trl==0.29.0
hf_xet==1.3.2
kernels==0.12.2
fla-core==0.4.1
flash-linear-attention==0.4.1
deepspeed>=0.18.3
trl==0.28.0
hf_xet==1.2.0
kernels==0.12.1
trackio>=0.16.1
typing-extensions>=4.15.0

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"'
)

View File

@@ -27,16 +27,9 @@ def parse_requirements(extras_require_map):
xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if platform.machine() == "aarch64":
# skip on ARM64
skip_packages = [
"torchao",
"fla-core",
"flash-linear-attention",
]
# skip torchao on ARM64
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
req for req in _install_requires if "torchao" not in req
]
if "Darwin" in platform.system():
# skip packages not compatible with OSX

View File

@@ -6,6 +6,5 @@ from axolotl.logging_config import configure_logging
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
os.environ.setdefault("TRL_EXPERIMENTAL_SILENCE", "1")
configure_logging()

View File

@@ -71,7 +71,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
merge_lora=True,
load_in_8bit=False,
load_in_4bit=False,
quantize_moe_experts=False,
flash_attention=False,
context_parallel_size=None,
deepspeed=None,

View File

@@ -12,14 +12,10 @@ MOE_ARCH_BLOCK = {
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"deepseek_v3": "DeepseekV3MoE",
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE",
"glm4_moe": "Glm4MoeDecoderLayer",
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
}

View File

@@ -120,6 +120,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
else:
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:

View File

@@ -26,7 +26,7 @@ from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from trl.experimental.utils import pad_to_length
from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (
@@ -720,16 +720,12 @@ class AxolotlTrainer(
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
# fix for Context Parallel save: CP eval invalidates tensor storage
# pointers, so clone to CPU to get fresh valid storage for safetensors
if (
state_dict is not None
and self.axolotl_cfg
and self.axolotl_cfg.context_parallel_size
and self.axolotl_cfg.context_parallel_size > 1
):
# fix for Context Parallel save
if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
state_dict = {
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
@@ -765,11 +761,7 @@ class AxolotlTrainer(
metadata={"format": "pt"},
)
else:
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
is_main_process=self.accelerator.is_main_process,
)
self.model.save_pretrained(output_dir, state_dict=state_dict)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)

View File

@@ -25,13 +25,17 @@ class DPOStrategy:
# Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None:
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
if cfg.dpo_norm_loss is not None:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
if cfg.dpo_use_liger_kernel is not None:
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
return training_args_kwargs

View File

@@ -103,10 +103,10 @@ class AxolotlDPOTrainer(
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: list[str] = self.loss_type # type: ignore[has-type]
loss_type: str = self.loss_type # type: ignore[has-type]
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = ["ipo"]
self.loss_type = "ipo"
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type
return res

View File

@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
return optimizer_grouped_parameters
def create_optimizer(self, model=None):
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer(model=model)
return super().create_optimizer()
opt_model = self.model if model is None else model
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if (
not self.optimizer

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"
```
## Usage
@@ -88,9 +88,9 @@ plugins:
- qwen2_vl
- qwen3
- qwen3_5
- qwen3_5_text
- qwen3_5_moe
- qwen3_5_moe_text
- qwen3_5_moe_vl
- qwen3_5_vl
- qwen3_moe
- qwen3_next
- qwen3_vl

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e8ad129"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"`'
)

View File

@@ -10,7 +10,7 @@ class ExpertsInterface(GeneralInterface):
}
```
In our custom integration, we add support for **ScatterMoE** and **SonicMoE**, which are more efficient and faster than `grouped_mm`.
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
## Usage
@@ -21,57 +21,23 @@ plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
# Choose one (mutually exclusive):
use_scattermoe: true
# OR
use_sonicmoe: true
```
**Important:** Setting `experts_implementation` is incompatible with custom kernel options.
### SonicMoE installation
**Prerequisites:**
- NVIDIA Hopper (H100, H200) or Blackwell (B200, GB200) GPU
- CUDA 12.9+ (13.0+ for B300)
- PyTorch 2.7+ (2.9.1 recommended)
- For B300: Triton 3.6.0
```bash
pip install --ignore-requires-python --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@116e2df0a41874f77fa0ad269ce7df3f0cfcb956" && pip install nvidia-cutlass-dsl==4.4.0 quack-kernels==0.2.5
```
See the [SonicMoE installation guide](https://github.com/Dao-AILab/sonic-moe?tab=readme-ov-file#-installation) for the latest prerequisite details.
**Note:** Blackwell support is in upstream beta. On Blackwell GPUs, Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels.
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
## How It Works
The `KernelsPlugin` runs before model loading and:
### ScatterMoE
1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
### SonicMoE
1. Resolves the model's MoE block class(es) from `constants.py`.
2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format.
3. Supports both softmax->topk and sigmoid->topk routing strategies.
Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.
#### Supported Models
See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.).
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment.
SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures.
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
## Note on MegaBlocks

View File

@@ -6,18 +6,7 @@ LOG = get_logger(__name__)
class KernelsArgs(BaseModel):
use_scattermoe: bool | None = None
use_sonicmoe: bool | None = None
@model_validator(mode="before")
@classmethod
def check_mutually_exclusive(cls, data):
if data.get("use_scattermoe") and data.get("use_sonicmoe"):
raise ValueError(
"Cannot use both ScatterMoE and SonicMoE simultaneously. "
"Please set only one of `use_scattermoe` or `use_sonicmoe` to true."
)
return data
use_scattermoe: bool | None = True
@model_validator(mode="before")
@classmethod
@@ -47,11 +36,11 @@ class KernelsArgs(BaseModel):
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel(cls, data):
if data.get("use_scattermoe") is True or data.get("use_sonicmoe") is True:
def disable_mlp_kernel_scattermoe(cls, data):
if data.get("use_scattermoe") is True:
if data.get("lora_mlp_kernel") is True:
LOG.warning(
"Disabling lora_mlp_kernel when using custom MoE kernels due to compatibility issues."
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
)
data["lora_mlp_kernel"] = False
data["mlp_kernel"] = False

View File

@@ -1,68 +0,0 @@
"""
Supported MoE block mappings for kernel integrations.
Maps model_type to the SparseMoeBlock class name(s) in transformers.
Used by both ScatterMoE and SonicMoE kernel paths.
Values can be a single class name (str) or a list of class names for models
with multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker).
"""
import importlib
SPARSE_MOE_BLOCK = {
# softmax -> topk routing
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
"qwen3_next": "Qwen3NextSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
"qwen3_omni_moe": [
"Qwen3OmniMoeThinkerTextSparseMoeBlock",
"Qwen3OmniMoeTalkerTextSparseMoeBlock",
],
"olmoe": "OlmoeSparseMoeBlock",
"mixtral": "MixtralSparseMoeBlock",
"minimax": "MiniMaxSparseMoeBlock",
# sigmoid -> topk routing (with group-based expert selection)
"glm_moe_dsa": "GlmMoeDsaMoE",
"deepseek_v3": "DeepseekV3MoE",
"glm4_moe": "Glm4MoeMoE",
"glm4_moe_lite": "Glm4MoeLiteMoE",
"glm4v_moe": "Glm4vMoeTextMoE",
# sigmoid -> topk routing (no group selection)
"minimax_m2": "MiniMaxM2SparseMoeBlock",
# Models below need custom routing (not yet implemented):
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
# "hunyuan_v1_moe": "HunYuanMoEV1Moe", # softmax->topk, gate.wg (not gate.weight), scatter routing
# "gpt_oss": "GptOssMLP", # topk->softmax, transposed layout [E,H,2*I], custom GLU, expert biases
}
def resolve_moe_block_classes(model_type: str):
"""Resolve all MoE block classes from transformers for the given model type.
Returns a list of classes (one for most models, multiple for models with
distinct MoE block types like qwen3_omni_moe).
"""
entry = SPARSE_MOE_BLOCK.get(model_type)
if entry is None:
raise ValueError(
f"Unsupported MoE model type '{model_type}'. "
f"Supported types: {list(SPARSE_MOE_BLOCK.keys())}"
)
cls_names = entry if isinstance(entry, list) else [entry]
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = importlib.import_module(module_path)
classes = []
for cls_name in cls_names:
moe_cls = getattr(module, cls_name, None)
if moe_cls is None:
raise ValueError(f"Could not find class '{cls_name}' in '{module_path}'")
classes.append(moe_cls)
return classes

View File

@@ -1,59 +1,14 @@
import importlib
import os
from pathlib import Path
import torch
from kernels import (
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _check_sonicmoe_gpu_compat():
"""Validate GPU compute capability for SonicMoE and configure env.
Supported: Hopper (sm_90), Blackwell (sm_100 - sm_103).
B300 (sm_103) additionally requires Triton 3.6.0.
"""
if not torch.cuda.is_available():
return
cc = torch.cuda.get_device_capability()
if cc < (9, 0):
raise RuntimeError(
f"SonicMoE requires Hopper (sm_90) or Blackwell (sm_100+) GPU, "
f"but detected sm_{cc[0]}{cc[1]}."
)
if cc > (10, 3):
raise RuntimeError(
f"SonicMoE does not yet support sm_{cc[0]}{cc[1]}. "
f"Supported: Hopper (sm_90) and Blackwell (sm_100 - sm_103)."
)
# Blackwell (sm_100+): enable QuACK GEMM kernels
if cc >= (10, 0):
os.environ.setdefault("USE_QUACK_GEMM", "1")
LOG.info(
f"Blackwell GPU (sm_{cc[0]}{cc[1]}) detected, enabling USE_QUACK_GEMM=1"
)
# B300 (sm_103): requires Triton 3.6.0
if cc == (10, 3):
triton_spec = importlib.util.find_spec("triton")
if triton_spec is None:
raise RuntimeError(
"B300 (sm_103) requires Triton 3.6.0, but Triton is not installed."
)
import triton
triton_version = tuple(int(x) for x in triton.__version__.split(".")[:2])
if triton_version != (3, 6):
raise RuntimeError(
f"B300 (sm_103) requires Triton 3.6.x, but found {triton.__version__}."
)
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
class KernelsPlugin(BasePlugin):
@@ -64,32 +19,8 @@ class KernelsPlugin(BasePlugin):
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(cfg.model_config_type)
elif cfg.use_sonicmoe:
if not importlib.util.find_spec("sonicmoe"):
raise RuntimeError(
"SonicMoE is not installed. See installation instructions at "
"https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/kernels/README.md#sonicmoe-installation"
)
_check_sonicmoe_gpu_compat()
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
LOG.info(
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
)
patch_sonicmoe(
cfg.model_config_type,
torch_compile=bool(getattr(cfg, "torch_compile", False)),
)
def _register_kernels(self):
from kernels import (
LocalLayerRepository,
Mode,
register_kernel_mapping,
)
plugin_root = Path(__file__).parent
register_kernel_mapping(
{
@@ -111,11 +42,25 @@ class KernelsPlugin(BasePlugin):
)
def _kernelize_model(self, model_type: str):
from kernels import replace_kernel_forward_from_hub
if model_type == "olmoe":
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
for model_moe_cls in resolve_moe_block_classes(model_type):
replace_kernel_forward_from_hub(
model_moe_cls, "HFScatterMoEParallelExperts"
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
)
else:
try:
model_moe_cls = get_model_moe_block(model_type)
replace_kernel_forward_from_hub(
model_moe_cls, "HFScatterMoEParallelExperts"
)
except Exception as err:
raise ValueError(f"Unsupported model type: {model_type}") from err
def get_model_moe_block(model_type: str):
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"])
model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock")
return model_cls

View File

@@ -1,3 +0,0 @@
from .patch import patch_sonicmoe
__all__ = ["patch_sonicmoe"]

View File

@@ -1,213 +0,0 @@
"""
SonicMoE patching for SparseMoeBlock forward pass.
Monkeypatches the SparseMoeBlock class for a given model type to use
SonicMoE's optimized kernels. Two forward paths are supported:
1. **General routing path** (routing_fn is not None):
Uses a custom routing function + ``moe_general_routing_inputs``.
Suitable for models with non-standard routing (softmax->topk, sigmoid->topk).
2. **Fused topk->softmax path** (routing_fn is None):
Uses ``moe_TC_softmax_topk_layer`` which fuses routing + expert computation.
Suitable for models with simple topk->softmax routing.
Weight format conversion (interleave/deinterleave) is handled by the
WeightConverter system, so the forward assumes weights are already in
interleaved format.
Shared experts are handled generically: if the block has a ``shared_expert``
or ``shared_experts`` attribute, its output is computed alongside the routed
experts and added to the final output. An optional ``shared_expert_gate``
applies sigmoid gating to the shared expert contribution.
"""
import torch
import torch.nn.functional as F
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patch_sonicmoe(model_type: str, torch_compile: bool = False):
"""Main entry point: patch SparseMoeBlock for SonicMoE support.
Args:
model_type: The HuggingFace model type (e.g. "qwen3_moe").
torch_compile: If True, wrap routing functions with torch.compile
for kernel fusion (fuses softmax+topk+renorm into fewer launches).
"""
from .routing import get_model_moe_config
from .weight_converter import register_sonicmoe_weight_converter
routing_fn, activation, router_attr = get_model_moe_config(model_type)
if torch_compile and routing_fn is not None:
routing_fn = _try_compile_routing(routing_fn)
for moe_cls in resolve_moe_block_classes(model_type):
_patch_forward(moe_cls, routing_fn, activation, router_attr)
register_sonicmoe_weight_converter(model_type)
def _try_compile_routing(routing_fn):
"""Attempt to torch.compile the routing function, fall back to eager on failure."""
try:
compiled_fn = torch.compile(routing_fn, mode="reduce-overhead", dynamic=False)
LOG.info(f"torch.compile enabled for routing function: {routing_fn.__name__}")
return compiled_fn
except Exception as exc: # pylint: disable=broad-except
LOG.warning(
f"torch.compile failed for routing function {routing_fn.__name__}, "
f"falling back to eager: {exc}"
)
return routing_fn
def _patch_forward(moe_cls, routing_fn, activation, router_attr):
"""Monkeypatch the SparseMoeBlock class with a SonicMoE forward.
The patched forward handles shared experts generically: if
``self.shared_expert`` or ``self.shared_experts`` exists, it is computed
and added to the routed output. If ``self.shared_expert_gate`` also exists,
it applies sigmoid gating to the shared expert contribution (as in qwen2_moe).
Args:
moe_cls: The SparseMoeBlock class to patch.
routing_fn: Routing function (e.g. softmax_topk_routing), or None
for the fused moe_TC_softmax_topk_layer path.
activation: SonicMoE ActivationType enum value.
router_attr: Name of the router module attribute on the MoE block.
"""
if hasattr(moe_cls, "_original_forward"):
LOG.info(f"{moe_cls.__name__}.forward already patched with SonicMoE, skipping")
return
original_forward = moe_cls.forward
if routing_fn is not None:
_make_general_forward(moe_cls, routing_fn, activation)
else:
_make_fused_forward(moe_cls, activation, router_attr)
moe_cls._original_forward = original_forward
LOG.info(f"Patched {moe_cls.__name__}.forward with SonicMoE implementation")
def _make_general_forward(moe_cls, routing_fn, activation):
"""Create forward using routing_fn + moe_general_routing_inputs."""
def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
from sonicmoe import moe_general_routing_inputs
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)
# Shared expert (computed early, matching original model ordering)
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
# Routing
router_scores, token_indices, expert_indices, _router_logits = routing_fn(
hidden_states_flat, self
)
# Permute weights to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
down_weight = self.experts.down_proj.permute(1, 2, 0)
E = gate_up_weight.shape[-1]
output, _ = moe_general_routing_inputs(
hidden_states_flat,
router_scores,
token_indices,
expert_indices,
gate_up_weight,
None, # b1 (no gate/up bias)
down_weight,
None, # b2 (no down bias)
E,
torch.cuda.current_stream().cuda_stream,
activation,
False, # is_inference_mode
)
# Add shared expert contribution if present
if shared_expert_output is not None:
if hasattr(self, "shared_expert_gate"):
shared_expert_output = (
F.sigmoid(self.shared_expert_gate(hidden_states_flat))
* shared_expert_output
)
output = output + shared_expert_output
return output.view(batch_size, sequence_length, hidden_dim)
moe_cls.forward = sonicmoe_forward
def _make_fused_forward(moe_cls, activation, router_attr):
"""Create forward using moe_TC_softmax_topk_layer (topk -> softmax)."""
def sonicmoe_fused_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
from sonicmoe import moe_TC_softmax_topk_layer
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)
# Shared expert (computed early, matching original model ordering)
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
router = getattr(self, router_attr)
# Permute weights to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
down_weight = self.experts.down_proj.permute(1, 2, 0)
output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer(
hidden_states_flat,
router.weight,
gate_up_weight,
None, # b1 (no gate/up bias)
down_weight,
None, # b2 (no down bias)
router.top_k,
torch.cuda.current_stream().cuda_stream,
activation,
False, # is_inference_mode
)
# Add shared expert contribution if present
if shared_expert_output is not None:
if hasattr(self, "shared_expert_gate"):
shared_expert_output = (
F.sigmoid(self.shared_expert_gate(hidden_states_flat))
* shared_expert_output
)
output = output + shared_expert_output
return output.view(batch_size, sequence_length, hidden_dim)
moe_cls.forward = sonicmoe_fused_forward
def _compute_shared_expert(moe_block, hidden_states_flat):
"""Compute shared expert output if the block has one.
Handles singular (qwen2_moe: ``shared_expert``), plural
(glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP
(hunyuan_v1_moe: ``shared_mlp``) attribute names.
"""
shared_expert = (
getattr(moe_block, "shared_expert", None)
or getattr(moe_block, "shared_experts", None)
or getattr(moe_block, "shared_mlp", None)
)
if shared_expert is not None:
return shared_expert(hidden_states_flat)
return None

View File

@@ -1,219 +0,0 @@
"""
Routing functions for SonicMoE integration.
Different MoE architectures use different routing strategies:
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
"""
import torch
import torch.nn.functional as F
def get_model_moe_config(model_type: str):
"""Returns (routing_fn, activation, router_attr) for a given model type.
Args:
model_type: HuggingFace model type string.
Returns:
routing_fn: Callable or None. None signals the fused
moe_TC_softmax_topk_layer path (topk -> softmax models).
activation: SonicMoE ActivationType enum value.
router_attr: Name of the router module attribute on the MoE block
(e.g. "gate" or "router").
The activation type cannot be derived from config.hidden_act because
e.g. qwen3_moe reports "silu" but architecturally uses SwiGLU
(act_fn(gate) * up pattern). So we specify it per model type.
"""
from sonicmoe.enums import ActivationType
if model_type in (
"qwen2_moe",
"qwen3_moe",
"qwen3_5_moe",
"qwen3_next",
"qwen3_vl_moe",
"qwen3_omni_moe",
"olmoe",
"mixtral",
"minimax",
):
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in (
"glm_moe_dsa",
"deepseek_v3",
"glm4_moe",
"glm4_moe_lite",
"glm4v_moe",
"minimax_m2",
):
return sigmoid_topk_routing, ActivationType.SWIGLU, "gate"
# elif model_type in ("ernie4_5_moe",):
# # Softmax→topk with e_score_correction_bias applied between softmax and topk.
# return ..., ActivationType.SWIGLU, "gate"
# elif model_type in ("deepseek_v2",):
# # Softmax→topk with group_limited_greedy. Different attr names: num_group
# # (not n_group), gate is nn.Linear (not a router class).
# return ..., ActivationType.SWIGLU, "gate"
# elif model_type in ("hunyuan_v1_moe",):
# # Softmax→topk but gate structure differs: gate.wg (not gate.weight),
# # top_k on block not gate, creates scatter routing matrix.
# return ..., ActivationType.SWIGLU, "gate"
# Fused topk -> softmax path (routing_fn=None):
# elif model_type in ("gpt_oss",):
# # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer
# # ignores (it only takes router_w, not bias). Also has transposed
# # weight layout [E, H, 2*I] and custom GLU activation.
# return None, ActivationType.SWIGLU, "router"
else:
raise ValueError(f"SonicMoE: unsupported model type '{model_type}'")
def softmax_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Qwen3/Qwen2-style routing: softmax -> topk -> optional renorm.
Args:
hidden_states: [T, H] flattened token representations
moe_block: MoE block module (accesses moe_block.gate.*)
Returns:
router_scores: [T*K] flattened scores (float32)
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, H = hidden_states.shape
K = gate.top_k
# Compute router logits and softmax over all experts
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Select top-k experts per token
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
# Renormalize if configured (default True for models without the attribute,
# e.g. Mixtral/MiniMax which always normalize)
if getattr(gate, "norm_topk_prob", True):
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
# no-op: matches transformers which casts to softmax output dtype (float32).
# top_values = top_values.to(router_probs.dtype)
# Flatten for moe_general_routing_inputs.
# Token indices are naturally sorted ascending from the [T, K] layout:
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
# Expert sorting is handled internally by general_routing_router_metadata.
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = top_values.reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def sigmoid_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sigmoid-based routing: sigmoid -> optional group selection -> topk.
Supports two variants:
- **Group selection** (glm_moe_dsa, deepseek_v3, etc.): n_group > 1,
bias on gate, group-based masking before topk.
- **No group selection** (minimax_m2): n_group == 1 (or absent),
bias on moe_block, straight topk from all experts.
Final routing weights come from the original sigmoid scores (not
bias-corrected), with optional renormalization and scaling.
Args:
hidden_states: [T, H] flattened token representations
moe_block: MoE block module (accesses moe_block.gate.* and
optional moe_block.n_group, .topk_group, .top_k, .norm_topk_prob,
.routed_scaling_factor, .n_routed_experts)
Returns:
router_scores: [T*K] flattened scores (float32)
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, H = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
# Compute router logits and sigmoid probabilities
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
router_probs = router_logits.sigmoid() # [T, E]
# Bias-corrected scores for expert selection (not used for final weights).
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block.
e_score_correction_bias = getattr(gate, "e_score_correction_bias", None)
if e_score_correction_bias is None:
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
if e_score_correction_bias is None:
raise AttributeError(
f"sigmoid_topk_routing requires e_score_correction_bias on "
f"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it"
)
scores_for_choice = router_probs + e_score_correction_bias
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
if n_group > 1:
group_scores = (
scores_for_choice.view(-1, n_group, E // n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [T, n_group]
group_idx = torch.topk(
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
# Final topk from (possibly masked) scores
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
# Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices)
# Optional renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob:
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
topk_weights = topk_weights * routed_scaling_factor
# Flatten for moe_general_routing_inputs.
# Token indices are naturally sorted ascending from the [T, K] layout.
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits

View File

@@ -1,181 +0,0 @@
"""
Custom WeightConverter operations for SonicMoE weight format conversion.
SonicMoE requires gate_up_proj weights in interleaved format:
- Standard (concatenated): [E, 2*I, H] where first I rows are gate, last I rows are up
- SonicMoE (interleaved): [E, 2*I, H] where rows alternate [g0, u0, g1, u1, ...]
These ConversionOps integrate with transformers' WeightConverter system so that
weights are transparently converted during loading and reverted during saving.
"""
from typing import Any
import torch
from einops import rearrange
from transformers.core_model_loading import ConversionOps
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def interleave_gate_up(tensor: torch.Tensor) -> torch.Tensor:
"""[gate..., up...] -> [g0, u0, g1, u1, ...] along the 2*I dimension."""
return rearrange(tensor, "... (two out) h -> ... (out two) h", two=2)
def deinterleave_gate_up(tensor: torch.Tensor) -> torch.Tensor:
"""[g0, u0, g1, u1, ...] -> [gate..., up...] along the 2*I dimension."""
return rearrange(tensor, "... (out two) h -> ... (two out) h", two=2)
class ConcatenatedToInterleaved(ConversionOps):
"""Convert concatenated gate/up projections to interleaved format.
Input: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H]
Output: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...]
This operation is applied along ``dim`` (default 1, the 2*I dimension).
"""
def __init__(self, dim: int = 1):
self.dim = dim
@torch.no_grad()
def convert(
self,
input_dict: dict[str, Any],
source_patterns: list[str],
target_patterns: list[str],
**kwargs,
) -> dict[str, torch.Tensor]:
target_pattern = self._get_target_pattern(
input_dict, source_patterns, target_patterns
)
tensors = next(iter(input_dict.values()))
tensor = tensors[0] if isinstance(tensors, list) else tensors
interleaved = interleave_gate_up(tensor)
return {target_pattern: interleaved}
def _get_target_pattern(
self,
input_dict: dict[str, Any],
source_patterns: list[str],
target_patterns: list[str],
) -> str:
# Follow the same logic as Transpose.get_target_pattern
if len(input_dict) != 1:
raise ValueError("Undefined Operation encountered!")
if len(target_patterns) > 1:
if len(source_patterns) == 1:
return source_patterns[0]
raise ValueError("Undefined Operation encountered!")
return target_patterns[0]
@property
def reverse_op(self) -> ConversionOps:
return InterleavedToConcatenated(self.dim)
class InterleavedToConcatenated(ConversionOps):
"""Convert interleaved gate/up projections back to concatenated format.
Input: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...]
Output: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H]
This is the reverse of ``ConcatenatedToInterleaved``.
"""
def __init__(self, dim: int = 1):
self.dim = dim
@torch.no_grad()
def convert(
self,
input_dict: dict[str, Any],
source_patterns: list[str],
target_patterns: list[str],
**kwargs,
) -> dict[str, torch.Tensor]:
target_pattern = self._get_target_pattern(
input_dict, source_patterns, target_patterns
)
tensors = next(iter(input_dict.values()))
tensor = tensors[0] if isinstance(tensors, list) else tensors
concatenated = deinterleave_gate_up(tensor)
return {target_pattern: concatenated}
def _get_target_pattern(
self,
input_dict: dict[str, Any],
source_patterns: list[str],
target_patterns: list[str],
) -> str:
if len(input_dict) != 1:
raise ValueError("Undefined Operation encountered!")
if len(target_patterns) > 1:
if len(source_patterns) == 1:
return source_patterns[0]
raise ValueError("Undefined Operation encountered!")
return target_patterns[0]
@property
def reverse_op(self) -> ConversionOps:
return ConcatenatedToInterleaved(self.dim)
def register_sonicmoe_weight_converter(model_type: str):
"""Override the conversion mapping to add interleave step for gate_up_proj.
Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj
converter chain. For example, qwen3_moe's chain becomes:
MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1)
The reverse is auto-generated for saving:
InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0)
"""
from transformers.conversion_mapping import (
get_checkpoint_conversion_mapping,
register_checkpoint_conversion_mapping,
)
existing = get_checkpoint_conversion_mapping(model_type)
if existing is None:
LOG.warning(
f"No conversion mapping found for model type '{model_type}'. "
"SonicMoE weight interleaving will not be applied during checkpoint loading."
)
return
# Find the gate_up_proj converter and append ConcatenatedToInterleaved
patched = False
for converter in existing:
if hasattr(converter, "operations") and any(
"gate_up_proj" in pat for pat in converter.target_patterns
):
# Guard against double registration (e.g. plugin reloaded)
if any(
isinstance(op, ConcatenatedToInterleaved) for op in converter.operations
):
LOG.info(
f"SonicMoE weight converter already registered for '{model_type}'"
)
return
converter.operations.append(ConcatenatedToInterleaved(dim=1))
patched = True
break
if not patched:
LOG.warning(
f"Could not find gate_up_proj converter for model type '{model_type}'. "
"SonicMoE weight interleaving will not be applied during checkpoint loading."
)
return
register_checkpoint_conversion_mapping(model_type, existing, overwrite=True)
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")

View File

@@ -8,6 +8,9 @@ import sys
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .models.base import patch_lce_forward
from .utils import patch_with_compile_disable
LOG = get_logger(__name__)
@@ -20,18 +23,10 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
# shim: liger-kernel 0.7.0 imports ORPOTrainer from old trl path
import trl.trainer
from trl.experimental.orpo import ORPOTrainer
trl.trainer.ORPOTrainer = ORPOTrainer
if cfg.torch_compile:
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
import liger_kernel.ops.fused_linear_cross_entropy
from .utils import patch_with_compile_disable
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_forward",
@@ -40,7 +35,6 @@ class LigerPlugin(BasePlugin):
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_backward",
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
@@ -198,8 +192,6 @@ class LigerPlugin(BasePlugin):
)
elif cfg.liger_fused_linear_cross_entropy:
try:
from .models.base import patch_lce_forward
patch_lce_forward(cfg.model_config_type)
LOG.warning_once(
f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}"

View File

@@ -34,7 +34,7 @@ def setup_quantized_meta_for_peft(model: torch.nn.Module):
return self
for param in model.parameters():
if isinstance(param, Params4bit) and param.quant_state is not None:
if isinstance(param, Params4bit):
param.quant_state._orig_to = param.quant_state.to
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)

View File

@@ -172,10 +172,7 @@ class ModelLoader:
# Build the model
PLUGIN_MANAGER.pre_model_load(self.cfg)
self.patch_manager.apply_post_plugin_pre_model_load_patches()
skip_move_to_device = self._build_model()
self.patch_manager.apply_post_model_build_patches(self.model)
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
# Post-build model configuration
@@ -674,8 +671,8 @@ class ModelLoader:
del self.model_kwargs["device_map"]
transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = lambda: (
True
transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = (
lambda: True
)
return hf_ds_cfg
@@ -863,10 +860,6 @@ class ModelLoader:
# Make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
if getattr(self.model, "_moe_experts_quantized", False):
# Parametrized expert tensors dequantize on access — would OOM.
skip_prepare_model_for_kbit_training = True
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]

View File

@@ -118,7 +118,6 @@ class PatchManager:
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_moe_expert_quantization_patch()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
@@ -136,10 +135,6 @@ class PatchManager:
patch_prepare_context_parallel_inputs()
def apply_post_model_build_patches(self, model: PreTrainedModel):
"""Apply patches right after model build, before post-load setup."""
self._finalize_moe_expert_quantization(model)
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
@@ -166,13 +161,6 @@ class PatchManager:
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.fsdp_config:
from axolotl.monkeypatch.accelerate.fsdp2 import (
patch_initialize_missing_keys_for_fsdp,
)
patch_initialize_missing_keys_for_fsdp()
if self.cfg.context_parallel_size > 1 or (
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
):
@@ -182,14 +170,9 @@ class PatchManager:
patch_parallelism_config()
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import (
patch_accelerate_fsdp2,
patch_tied_keys_for_meta_device,
)
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
patch_accelerate_fsdp2()
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
patch_tied_keys_for_meta_device()
if self.cfg.rl:
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
@@ -246,31 +229,6 @@ class PatchManager:
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "qwen3_5" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_modeling_packing,
)
patch_qwen3_5_modeling_packing()
if self.cfg.model_config_type == "qwen3_5_moe" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_moe_modeling_packing,
)
patch_qwen3_5_moe_modeling_packing()
if (
self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"]
and self.cfg.is_multimodal
and self.cfg.flash_attention
):
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_vlm_flash_attention,
)
patch_qwen3_5_vlm_flash_attention()
if self.cfg.model_config_type == "kimi_linear":
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_model,
@@ -394,54 +352,15 @@ class PatchManager:
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and (self.cfg.load_in_4bit or self.cfg.load_in_8bit)
and self.cfg.adapter == "qlora"
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_dtype_attrs_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
apply_linear8bitlt_save_patch,
)
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
apply_init_dtype_attrs_patch()
if self.cfg.load_in_8bit:
apply_linear8bitlt_save_patch()
def _apply_moe_expert_quantization_patch(self):
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
if not self.cfg.quantize_moe_experts:
return
from axolotl.monkeypatch.moe_quant import (
patch_moe_quantization_on_load,
patch_peft_target_parameters_matching,
)
patch_moe_quantization_on_load(self.cfg)
patch_peft_target_parameters_matching()
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
"""Log quantization results and set model flag for downstream use."""
import torch
model._moe_experts_quantized = False
if self.cfg.quantize_moe_experts:
from axolotl.monkeypatch.moe_quant import get_moe_quantized_count
count = get_moe_quantized_count()
if count > 0:
import gc
model._moe_experts_quantized = True
LOG.info(
"Quantized %d MoE expert parameter(s) to %s during model loading",
count,
"4-bit" if self.cfg.load_in_4bit else "8-bit",
)
gc.collect()
torch.cuda.empty_cache()
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:

View File

@@ -201,7 +201,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # nosec B105
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Mistral's official FA implementation requires left padding

View File

@@ -111,7 +111,6 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
**kwargs,
):
if state_dict is None:
state_dict = self.state_dict()

View File

@@ -150,17 +150,13 @@ def get_state_dict(self, model, unwrap=True):
)
elif self.is_fsdp2:
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
from torch.distributed.tensor import DTensor
state_dict = {}
sharded_state_dict = model.state_dict()
for param_name, param in sharded_state_dict.items():
if param.is_cpu:
param = param.to(torch.device("cuda"))
if isinstance(param, DTensor):
param = param.full_tensor()
param = param.full_tensor()
if torch.distributed.get_rank() == 0:
state_dict[param_name] = param.cpu()
torch.distributed.barrier()
@@ -186,56 +182,10 @@ def get_state_dict(self, model, unwrap=True):
return state_dict
def patch_peft_param_wrapper_for_fsdp2():
"""Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds
delta_weight to the base weight W inside _LoraParameterProxy.forward().
Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a
regular Tensor (or vice versa), causing a RuntimeError on mixed types.
This patch promotes the non-DTensor operand to match the DTensor's spec
using DTensor.from_local(), which is free for Replicate placement (just
metadata wrapping, no communication).
"""
from peft.tuners.lora.layer import _LoraParameterProxy
if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False):
return
_original_forward = _LoraParameterProxy.forward
# NOTE: Replaces (not wraps) forward; assumes original is just `W + self.delta_weight`.
def _patched_forward(self, W):
from torch.distributed.tensor import DTensor
delta = self.delta_weight
w_is_dt = isinstance(W, DTensor)
d_is_dt = isinstance(delta, DTensor)
with torch.nn.utils.parametrize.cached():
if w_is_dt == d_is_dt:
return W + delta
if w_is_dt:
return W + DTensor.from_local(delta, W.device_mesh, W.placements)
return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta
_LoraParameterProxy.forward = _patched_forward
_LoraParameterProxy._axolotl_fsdp2_patched = True
LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility")
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2."""
from peft.tuners.lora.layer import ParamWrapper
from torch.distributed.fsdp import fully_shard
# Skip ParamWrapper — its lora_A/B must not be independently sharded.
# The parent decoder layer's FSDP wrapper handles unsharding them.
# TODO: review if we even need to shard them separately in first place.
if isinstance(module, ParamWrapper):
return False
log_bias_dtype_mismatch = False
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
@@ -252,20 +202,12 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
if module.lora_B:
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_A:
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
if module.lora_embedding_B:
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
if module.lora_magnitude_vector:
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
# lora_embedding_A/B are ParameterDicts containing nn.Parameter (Tensors),
# not nn.Module. fully_shard() only accepts nn.Module, so we cannot shard
# individual embedding Parameters. Instead, shard the entire LoraLayer module. fully_shard() can be used hierarchically because it does not
# override groups already assigned by fully_shard(), so modules
# where fully_shard() was already called are not affected [see https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html]
if module.lora_embedding_A or module.lora_embedding_B:
from torch.distributed.fsdp import FSDPModule
if not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs)
return log_bias_dtype_mismatch
@@ -385,14 +327,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
is_peft_model = isinstance(model, PeftModel)
# Patch PEFT's _LoraParameterProxy for DTensor compatibility if any
# ParamWrapper modules exist (used for target_parameters / 3D expert params).
if is_peft_model:
from peft.tuners.lora.layer import ParamWrapper
if any(isinstance(m, ParamWrapper) for m in model.modules()):
patch_peft_param_wrapper_for_fsdp2()
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False
if auto_wrap_policy is not None:
@@ -442,83 +376,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
return model
def patch_tied_keys_for_meta_device():
"""Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.
Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly
grouped as "tied". Skipping them is safe since they have no real storage.
"""
from collections import defaultdict
from transformers import PreTrainedModel
def _patched_adjust_tied_keys_with_tied_pointers(self, missing_keys):
param_pointers = defaultdict(list)
for param_name, param_value in self.state_dict().items():
if param_value.is_meta:
continue
param_pointers[param_value.data_ptr()].append(param_name)
tied_param_names = [
names
for names in param_pointers.values()
if len(names) > 1
and not any(name in self.all_tied_weights_keys.keys() for name in names)
and not all(name in missing_keys for name in names)
]
tied_weights_keys_by_pointers = {
param_name: group[0]
for group in tied_param_names
for param_name in group[1:]
}
self.all_tied_weights_keys.update(tied_weights_keys_by_pointers)
PreTrainedModel._adjust_tied_keys_with_tied_pointers = (
_patched_adjust_tied_keys_with_tied_pointers
)
def patch_initialize_missing_keys_for_fsdp():
"""Patch _initialize_missing_keys to skip re-initialization on FSDP non-rank-0.
When using cpu_ram_efficient_loading, non-rank-0 processes load weights on
meta device and move them to CPU as empty tensors. Without this patch,
initialize_weights() re-initializes ALL parameters (via guarded init
functions), which is slow and uses extra RAM per process.
The fix marks all params/buffers with _is_hf_initialized=True before calling
the original method, so guarded init functions (init.normal_, init.zeros_,
etc.) become no-ops on non-rank-0 processes. The real weights arrive later
via FSDP broadcast from rank 0.
Upstream fix: https://github.com/huggingface/transformers/pull/44473
Remove this patch once transformers includes the fix in a stable release.
"""
from transformers import PreTrainedModel
from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
if getattr(PreTrainedModel._initialize_missing_keys, "_axolotl_patched", False):
return
_original_initialize_missing_keys = PreTrainedModel._initialize_missing_keys
def _patched_initialize_missing_keys(self, is_quantized: bool) -> None:
if is_fsdp_enabled() and not is_local_dist_rank_0():
for key in self.state_dict():
try:
param_or_buffer = self.get_parameter_or_buffer(key)
param_or_buffer._is_hf_initialized = True
except AttributeError:
pass # may happen when handling pre-quantized weights
self._is_hf_initialized = True
_original_initialize_missing_keys(self, is_quantized)
PreTrainedModel._initialize_missing_keys = _patched_initialize_missing_keys
PreTrainedModel._initialize_missing_keys._axolotl_patched = True
def patch_accelerate_fsdp2():
import accelerate

View File

@@ -1,10 +1,9 @@
"""
Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2
and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
our LoRA / QLoRA Triton kernels to work with FSDP2.
This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam
to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization
metadata through the FSDP2 shard/unshard cycle.
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
Params4bit parameters.
"""
import importlib
@@ -18,8 +17,6 @@ LOG = get_logger(__name__)
def apply_init_sharded_param_patch():
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
if getattr(apply_init_sharded_param_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
@@ -44,20 +41,9 @@ def apply_init_sharded_param_patch():
bnb_quantized=param.bnb_quantized,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
elif isinstance(param, bnb.nn.modules.Int8Params):
self.sharded_param = bnb.nn.modules.Int8Params(
data=sharded_param,
requires_grad=param.requires_grad,
has_fp16_weights=param.has_fp16_weights,
CB=None,
SCB=param.SCB,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
else:
self.sharded_param = nn.Parameter(
self.to_sharded_dtensor(sharded_param),
requires_grad=param.requires_grad,
)"""
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""
# Apply the replacement
if original_param_creation in original_source:
@@ -87,7 +73,6 @@ def apply_init_sharded_param_patch():
# Replace the method
FSDPParam._init_sharded_param = patched_init_sharded_param
apply_init_sharded_param_patch._axolotl_patched = True
LOG.info("Successfully applied FSDP _init_sharded_param patch")
else:
LOG.warning("Could not find target code for _init_sharded_param patching")
@@ -95,8 +80,6 @@ def apply_init_sharded_param_patch():
def apply_init_unsharded_param_patch():
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
if getattr(apply_init_unsharded_param_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
@@ -122,14 +105,6 @@ def apply_init_unsharded_param_patch():
module=local_tensor.module,
bnb_quantized=local_tensor.bnb_quantized,
)
elif isinstance(local_tensor, bnb.nn.modules.Int8Params):
self._unsharded_param = bnb.nn.modules.Int8Params(
data=unsharded_param,
requires_grad=self.sharded_param.requires_grad,
has_fp16_weights=local_tensor.has_fp16_weights,
CB=unsharded_param,
SCB=local_tensor.SCB,
)
else:
self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
@@ -163,74 +138,6 @@ def apply_init_unsharded_param_patch():
# Replace the method
FSDPParam.init_unsharded_param = patched_init_unsharded_param
apply_init_unsharded_param_patch._axolotl_patched = True
LOG.info("Successfully applied FSDP init_unsharded_param patch")
else:
LOG.warning("Could not find target code for patching")
def apply_linear8bitlt_save_patch():
"""Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.
After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params.
BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor
doesn't proxy custom attribute access to its _local_tensor. This patch
temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.
"""
if getattr(apply_linear8bitlt_save_patch, "_axolotl_patched", False):
return
import bitsandbytes as bnb
from torch.distributed.tensor import DTensor
original_save = bnb.nn.Linear8bitLt._save_to_state_dict
def _patched_save_to_state_dict(self, destination, prefix, keep_vars):
# Use _parameters dict directly to bypass nn.Module.__setattr__ type check.
weight = self._parameters["weight"]
unwrapped = False
if isinstance(weight, DTensor) and hasattr(weight, "_local_tensor"):
self._parameters["weight"] = weight._local_tensor
unwrapped = True
try:
original_save(self, destination, prefix, keep_vars)
finally:
if unwrapped:
self._parameters["weight"] = weight
bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict
apply_linear8bitlt_save_patch._axolotl_patched = True
LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility")
def apply_init_dtype_attrs_patch():
"""Prevent FSDP2 mixed precision from casting non-float quantized params.
When mixed precision is enabled (e.g., bf16), FSDP2's init_dtype_attrs sets
param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts
the sharded param to param_dtype. For non-float params (uint8 packed 4-bit,
int8 quantized) without FSDP2 extensions, this destroys the quantized data.
Params4bit handles this via fsdp_pre/post_all_gather extensions, but our
parametrize-based expert quantization uses plain nn.Parameter(uint8/int8)
without extensions.
"""
if getattr(apply_init_dtype_attrs_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
original_init_dtype_attrs = FSDPParam.init_dtype_attrs
def patched_init_dtype_attrs(self, mp_policy):
original_init_dtype_attrs(self, mp_policy)
# Skip casting non-float quantized params (uint8/int8) without FSDP2
# extensions — the parametrization chain handles dequantization.
if self.param_dtype is not None and not self.sharded_param.is_floating_point():
local = self.sharded_param
if hasattr(local, "_local_tensor"):
local = local._local_tensor
if not hasattr(local, "fsdp_pre_all_gather"):
self.param_dtype = None
FSDPParam.init_dtype_attrs = patched_init_dtype_attrs
apply_init_dtype_attrs_patch._axolotl_patched = True
LOG.info("Patched FSDPParam.init_dtype_attrs for non-float quantized params")

View File

@@ -1,291 +0,0 @@
"""Monkeypatch for Qwen3_5 and Qwen3_5Moe models to pass position_ids to linear attention."""
import importlib
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
from fla.modules.convolution import (
causal_conv1d as fla_causal_conv1d, # FLA >= 0.4.1
)
except ImportError:
try:
from fla.modules.conv import causal_conv1d as fla_causal_conv1d # FLA < 0.4.1
except ImportError:
fla_causal_conv1d = None
def get_cu_seqlens(position_ids):
"""
Compute cumulative sequence lengths from position_ids for FLA varlen kernels.
Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids.
https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316
Qwen3.5 uses MRoPE: position_ids arrive as [axes, B, T]. All axes carry the
same temporal positions, so axis 0 is used to recover the [B, T] layout.
See: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
"""
if position_ids.ndim == 3:
position_ids = position_ids[0]
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
position_ids = position_ids.view(-1)
indices_q = (position_ids == 0).nonzero().view(-1)
return torch.cat(
(
indices_q.to(**tensor_kwargs),
torch.tensor(position_ids.size(), **tensor_kwargs),
)
)
def _inject_fla_kernels(module) -> None:
"""Inject FLA kernels into a modeling module, bypassing is_flash_linear_attention_available."""
try:
from fla.modules import FusedRMSNormGated
from fla.ops.gated_delta_rule import (
chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule,
)
module.FusedRMSNormGated = FusedRMSNormGated
module.chunk_gated_delta_rule = chunk_gated_delta_rule
module.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule
module.is_fast_path_available = True
except ImportError:
module.chunk_gated_delta_rule = None
module.fused_recurrent_gated_delta_rule = None
module.FusedRMSNormGated = None
def _patched_decoder_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values=None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.FloatTensor:
"""Decoder layer forward that passes position_ids through to linear attention."""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states=hidden_states,
cache_params=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
position_ids=position_ids,
)
elif self.layer_type == "full_attention":
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple): # MoE returns (hidden_states, router_logits)
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states
return hidden_states
def _make_qwen3_5_gated_delta_forward(apply_mask_fn):
"""Factory for patched Qwen3_5/Qwen3_5Moe GatedDeltaNet forward with packing support."""
def patched_forward(
self,
hidden_states: torch.Tensor,
cache_params=None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
):
hidden_states = apply_mask_fn(hidden_states, attention_mask)
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and seq_len == 1
and cache_position is not None
)
cu_seqlens = None
if not use_precomputed_states and position_ids is not None:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
recurrent_state = cache_params.recurrent_states[self.layer_idx]
# mixed_qkv stays [B, T, D]; only transposed inside paths that require [B, D, T]
mixed_qkv = self.in_proj_qkv(hidden_states) # [B, T, D]
z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)
if use_precomputed_states:
mixed_qkv = self.causal_conv1d_update(
mixed_qkv.transpose(1, 2),
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
).transpose(1, 2)
else:
if cache_params is not None:
mixed_qkv_t = mixed_qkv.transpose(1, 2)
cache_params.conv_states[self.layer_idx] = F.pad(
mixed_qkv_t,
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
)
if fla_causal_conv1d is not None and cu_seqlens is not None:
# FLA varlen kernel for packed sequences; input must be contiguous [B, T, D]
mixed_qkv, _ = fla_causal_conv1d(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
cu_seqlens=cu_seqlens,
)
else:
if cu_seqlens is not None and fla_causal_conv1d is None:
raise RuntimeError(
"Packed sequences require fla.modules.convolution.causal_conv1d "
"(cu_seqlens support). Install flash-linear-attention or disable packing."
)
mixed_qkv = F.silu(
self.conv1d(mixed_qkv.transpose(1, 2))[:, :, :seq_len]
).transpose(1, 2)
query, key, value = torch.split(
mixed_qkv,
[self.key_dim, self.key_dim, self.value_dim],
dim=-1,
)
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if not use_precomputed_states:
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g.to(dtype=query.dtype),
beta=beta,
initial_state=None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
# torch_chunk_gated_delta_rule fallback does not accept cu_seqlens
**({"cu_seqlens": cu_seqlens} if cu_seqlens is not None else {}),
)
else:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g.to(dtype=query.dtype),
beta=beta,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
return self.out_proj(core_attn_out)
return patched_forward
def _apply_packing_patches(model_type: str, cls_prefix: str, forward_factory) -> None:
module_name = f"transformers.models.{model_type}.modeling_{model_type}"
try:
module = importlib.import_module(module_name)
except ImportError:
LOG.warning(f"{model_type} not found in transformers, skipping packing patches")
return
_inject_fla_kernels(module)
getattr(module, f"{cls_prefix}DecoderLayer").forward = _patched_decoder_forward
gated_cls = getattr(module, f"{cls_prefix}GatedDeltaNet")
gated_cls.forward = forward_factory(module.apply_mask_to_padding_states)
LOG.info(
f"Applied {cls_prefix} packing patch "
f"(fla_causal_conv1d={'available' if fla_causal_conv1d else 'unavailable'})"
)
def patch_qwen3_5_modeling_packing():
_apply_packing_patches("qwen3_5", "Qwen3_5", _make_qwen3_5_gated_delta_forward)
def patch_qwen3_5_moe_modeling_packing():
_apply_packing_patches(
"qwen3_5_moe", "Qwen3_5Moe", _make_qwen3_5_gated_delta_forward
)
def patch_qwen3_5_vlm_flash_attention():
"""
Patch _is_packed_sequence to handle Qwen3.5's 3-D MRoPE position_ids.
transformers passes position_ids as [axes, B, T] to decoder layers, but
_is_packed_sequence only handles 2-D tensors and mis-classifies the 3-D
shape as a packed-sequence indicator, causing CUDA errors in the varlen path.
"""
try:
import transformers.modeling_flash_attention_utils as fa_utils
_original = fa_utils._is_packed_sequence
def _patched(position_ids, batch_size):
if position_ids is not None and position_ids.ndim != 2:
return False
return _original(position_ids, batch_size)
fa_utils._is_packed_sequence = _patched
LOG.info("Applied Qwen3.5 VLM flash-attention patch (3-D MRoPE position_ids)")
except Exception as exc: # pragma: no cover
LOG.warning(f"Failed to apply Qwen3.5 VLM flash-attention patch: {exc}")

View File

@@ -9,11 +9,6 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
except ImportError:
fla_causal_conv1d = None
def get_cu_seqlens(position_ids):
"""
@@ -142,11 +137,6 @@ def patch_qwen3_next_gateddelta_layer():
and cache_position is not None
)
# Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule
cu_seqlens = None
if not use_precomputed_states and position_ids is not None:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
# getting projected states from cache if it exists
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
@@ -161,11 +151,12 @@ def patch_qwen3_next_gateddelta_layer():
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D]
mixed_qkv = torch.cat((query, key, value), dim=-1)
mixed_qkv = mixed_qkv.transpose(1, 2)
if use_precomputed_states:
# Inference single-token path: causal_conv1d_update expects [B, D, T]
mixed_qkv = mixed_qkv.transpose(1, 2)
# 2. Convolution sequence transformation
# NOTE: the conv state is updated in `causal_conv1d_update`
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
@@ -173,41 +164,24 @@ def patch_qwen3_next_gateddelta_layer():
self.conv1d.bias,
self.activation,
)
mixed_qkv = mixed_qkv.transpose(1, 2)
else:
if cache_params is not None:
# Cache state expects [B, D, T] for the inference update path
mixed_qkv_t = mixed_qkv.transpose(1, 2)
conv_state = F.pad(
mixed_qkv_t,
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
)
cache_params.conv_states[self.layer_idx] = conv_state
if fla_causal_conv1d is not None:
# FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support
mixed_qkv, _ = fla_causal_conv1d(
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
cu_seqlens=cu_seqlens,
seq_idx=None,
)
else:
# PyTorch fallback (no cu_seqlens support)
if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1:
raise RuntimeError(
"Packed sequences require fla.modules.convolution.causal_conv1d "
"(cu_seqlens support). Install flash-linear-attention or disable packing."
)
LOG.warning_once(
"FLA causal_conv1d not available. Falling back to PyTorch conv1d."
)
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = mixed_qkv.transpose(1, 2)
# mixed_qkv is [B, T, D] in all paths
mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
mixed_qkv,
[
@@ -229,6 +203,7 @@ def patch_qwen3_next_gateddelta_layer():
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if not use_precomputed_states:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,

View File

@@ -1,188 +0,0 @@
"""
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
skips (only targets nn.Linear). This module patches weight loading to quantize them
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
"""
import bitsandbytes as bnb
import torch
import torch.nn.utils.parametrize as P
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Module-level state for the loading-time quantization patch.
_moe_load_state = {
"count": 0,
"mode": "4bit",
"quant_type": "nf4",
"compress_statistics": True,
"patched": False,
}
class Bnb8bitParametrization(torch.nn.Module):
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
def __init__(self, row_stats: torch.Tensor):
super().__init__()
self.register_buffer("row_stats", row_stats)
@torch.no_grad()
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
orig_shape = quantized_param.shape
if quantized_param.ndim > 2:
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
result = bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
return result.reshape(orig_shape)
def _enable_parametrization_cache(module, inputs):
P._cache_enabled += 1
def _disable_parametrization_cache(module, inputs, output):
P._cache_enabled -= 1
if not P._cache_enabled:
P._cache = {}
def replace_parameter_8bit(module, param_name):
"""Replace a module parameter with an 8-bit quantized version using parametrization."""
original_param = getattr(module, param_name)
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
original_param.data.to(torch.float16)
)
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
del original_param
P.register_parametrization(
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
)
# Cache dequantized values during forward to avoid redundant dequantization.
if not getattr(module, "_axolotl_8bit_hooks_registered", False):
module.register_forward_pre_hook(_enable_parametrization_cache)
module.register_forward_hook(_disable_parametrization_cache)
module._axolotl_8bit_hooks_registered = True
def patch_moe_quantization_on_load(cfg):
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
"""
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
_moe_load_state["mode"] = mode
_moe_load_state["count"] = 0
if _moe_load_state["patched"]:
LOG.debug("MoE loading-time quantization patch already active")
return
import transformers.core_model_loading
import transformers.modeling_utils
if mode == "4bit":
from bitsandbytes.nn.parametrize import replace_parameter_4bit
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
if compress_statistics is None:
compress_statistics = True
_moe_load_state["quant_type"] = quant_type
_moe_load_state["compress_statistics"] = compress_statistics
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
# size for all params, defeating our on-load quantization VRAM savings.
def _noop_warmup(*args, **kwargs):
pass
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
original_set_param = transformers.core_model_loading.set_param_for_module
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3 and param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
if "expert" not in target_name.lower():
LOG.debug(
"Skipping non-expert 3D param: %s (shape=%s)",
target_name,
list(param_value.shape),
)
return
if _moe_load_state["mode"] == "4bit":
replace_parameter_4bit(
mod,
pname,
compress_statistics=_moe_load_state["compress_statistics"],
quant_type=_moe_load_state["quant_type"],
)
else:
replace_parameter_8bit(mod, pname)
_moe_load_state["count"] += 1
# Release the bf16 tensor so CUDA memory is freed immediately.
param_value.data = torch.empty(0, device="cpu")
torch.cuda.empty_cache()
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
_moe_load_state["patched"] = True
def get_moe_quantized_count():
"""Return the number of expert parameters quantized during loading."""
return _moe_load_state["count"]
def patch_peft_target_parameters_matching():
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
return
from peft.tuners.tuners_utils import BaseTuner
original_inject = BaseTuner._inject_parameters
def _patched_inject_parameters(
self, peft_config, model, adapter_name, low_cpu_mem_usage
):
# Patch target_parameters to use full paths for parametrized modules
original_targets = list(peft_config.target_parameters)
expanded = set(original_targets)
for module_name, module in model.named_modules():
if not hasattr(module, "parametrizations"):
continue
for target in original_targets:
mod_path, _, param_name = target.rpartition(".")
if (
module_name == mod_path or module_name.endswith("." + mod_path)
) and hasattr(module, param_name):
expanded.add(f"{module_name}.{param_name}")
peft_config.target_parameters = sorted(expanded)
try:
return original_inject(
self, peft_config, model, adapter_name, low_cpu_mem_usage
)
finally:
peft_config.target_parameters = original_targets
BaseTuner._inject_parameters = _patched_inject_parameters
patch_peft_target_parameters_matching._axolotl_patched = True
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")

View File

@@ -22,8 +22,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"qwen3",
"qwen3_moe",
"qwen3_next",
"qwen3_5",
"qwen3_5_moe",
"falcon",
"phi",
"phi3",
@@ -39,7 +37,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"deepseek_v3",
"glm",
"glm4",
"glm4_moe",
"smollm3",
"granite",
"granitemoe",

View File

@@ -258,32 +258,6 @@ class Qwen2VLProcessingStrategy(ProcessingStrategy):
)
class Qwen3_5ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Qwen3.5 (early-fusion VLM)"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<|image_pad|>" # nosec
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
)
self.video_token = "<|video_pad|>" # nosec
self.video_token_id = processor.tokenizer.convert_tokens_to_ids(
self.video_token
)
def process_labels(self, input_ids):
labels = super().process_labels(input_ids)
labels[labels == self.video_token_id] = -100
return labels
class Gemma3ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Gemma3"""
@@ -588,10 +562,6 @@ def get_processing_strategy(
return Qwen2VLProcessingStrategy(
**processing_kwargs,
)
if chat_template_type in ["qwen3_5", "qwen3_5_moe"]:
return Qwen3_5ProcessingStrategy(
**processing_kwargs,
)
if chat_template_type == "gemma3":
return Gemma3ProcessingStrategy(
**processing_kwargs,

View File

@@ -48,9 +48,9 @@ class ChatTemplatePrompter(Prompter):
):
# check if message_property_mappings is None or empty dict
if message_property_mappings is None or (not message_property_mappings):
default_message_property_mappings_keys = ["role", "content", "tool"]
message_property_mappings = {
"role": "role",
"content": "content",
prop: prop for prop in default_message_property_mappings_keys
}
if template_thinking_key and field_thinking:
message_property_mappings[template_thinking_key] = field_thinking

View File

@@ -86,21 +86,9 @@ def setup_model_and_tokenizer(
if model.generation_config is not None:
model.generation_config.do_sample = True
model_properties = model.config.to_dict()
try:
model_properties["num_parameters"] = model.num_parameters()
except Exception: # pylint: disable=broad-exception-caught
model_properties["num_parameters"] = sum(p.numel() for p in model.parameters())
# if the num_parameters is less than 2B, let's round to nearest 100M, else round to nearest 1B
if model_properties["num_parameters"] < 2e9:
model_properties["num_parameters_est"] = (
f"{round(model_properties['num_parameters'] / 1e8) * 100}M"
)
else:
model_properties["num_parameters_est"] = (
f"{round(model_properties['num_parameters'] / 1e9)}B"
)
TELEMETRY_MANAGER.send_event(event_type="model-load", properties=model_properties)
TELEMETRY_MANAGER.send_event(
event_type="model-load", properties=model.config.to_dict()
)
if peft_config:
TELEMETRY_MANAGER.send_event(
event_type="peft-config-load", properties=peft_config.to_dict()

View File

@@ -1,123 +0,0 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{#- Determine the real last index: use provided value or default to messages length - 1 #}
{%- if real_last_index is defined and real_last_index is not none %}
{%- set ns.real_last_index = real_last_index %}
{%- else %}
{%- set ns.real_last_index = messages|length - 1 %}
{%- endif %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if message['content'] is string %}
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- else %}
{%- if ns.multi_step_tool and message.role == "user" %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' }}
{%- if message['content'] is string %}
{{- message.content }}
{%- else %}
{%- for content in message['content'] %}
{%- if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
{%- elif content['type'] == 'video' or 'video' in content %}
{{- '<|vision_start|><|video_pad|><|vision_end|>' }}
{%- elif 'text' in content %}
{{- content['text'] }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "assistant" %}
{%- if message['content'] is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- for item in message['content'] %}
{%- if 'text' in item %}
{%- set content = content + item['text'] %}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is defined and message.reasoning_content is not none %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{%- if loop.index0 == ns.real_last_index or (loop.index0 != ns.real_last_index and reasoning_content) %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{%- if tool_call.arguments is string %}
{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}
{{- '}\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- message.content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- else %}
{{- '<think>\n\n' }}
{%- endif %}
{%- endif %}

View File

@@ -6,10 +6,7 @@ from typing import Optional
import torch
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import (
is_torch_greater_or_equal,
is_torch_npu_available,
)
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.integrations.base import PluginManager
from axolotl.integrations.config import merge_input_args
@@ -84,15 +81,8 @@ def resolve_dtype(cfg):
cfg.fp16 = True
cfg.bf16 = False
else:
if cfg.tf32:
torch.set_float32_matmul_precision("high")
if is_torch_greater_or_equal("2.9.0"):
torch.backends.fp32_precision = "tf32"
torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.fp32_precision = "tf32"
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
torch.backends.cudnn.allow_tf32 = cfg.tf32 or False
if cfg.bf16:
cfg.fp16 = False
@@ -129,12 +119,7 @@ def normalize_config(cfg):
if cfg.world_size != 1:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
effective_world_size = (
cfg.world_size
// (cfg.context_parallel_size or 1)
// (cfg.tensor_parallel_size or 1)
)
cfg.batch_size = cfg.batch_size * effective_world_size
cfg.batch_size = cfg.batch_size * cfg.world_size
if not cfg.use_ray:
# delay resolving dtype until on worker node when launching with ray

View File

@@ -189,7 +189,7 @@ def _get_remote_filesystem(
try:
import gcsfs
storage_options = {"token": None} # type: ignore # nosec B105
storage_options = {"token": None} # type: ignore
return gcsfs.GCSFileSystem(**storage_options), storage_options
except ImportError as exc:
raise ImportError(

View File

@@ -5,7 +5,6 @@ Utilities for quantization including QAT and PTQ using torchao.
import torch
from packaging import version
from torchao.core.config import AOBaseConfig
from torchao.prototype.qat import MXFakeQuantizeConfig
from torchao.quantization import quantize_
from torchao.quantization.qat import (
QATConfig,
@@ -41,13 +40,6 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"):
except:
pass
try:
from torchao.prototype.qat import MXFakeQuantizeConfig
quantization_config_to_str[MXFakeQuantizeConfig] = "mxfp4"
except ImportError:
pass
def get_quantization_config(
weight_dtype: TorchAOQuantDType,
@@ -117,19 +109,6 @@ def get_quantization_config(
if group_size is not None and group_size != 16:
raise ValueError("NVFP4 quantization must use a group_size of 16")
return NVFP4InferenceConfig()
if weight_dtype == TorchAOQuantDType.mxfp4:
from torchao.prototype.qat import MXFakeQuantizeConfig
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)
block_size = group_size if group_size is not None else 32
if block_size != 32:
raise ValueError(
"MXFP4 quantization must use a block_size (group_size) of 32"
)
return MXFakeQuantizeConfig(dtype=torch.float4_e2m1fn_x2, block_size=block_size)
raise ValueError(
f"Invalid activation/weight dtype combination: {activation_dtype}/{weight_dtype}"
)
@@ -200,13 +179,7 @@ def prepare_model_for_qat(
activation_dtype=activation_dtype,
group_size=group_size,
)
if isinstance(base_config, MXFakeQuantizeConfig):
qat_config = QATConfig(
activation_config=base_config,
weight_config=base_config,
)
else:
qat_config = QATConfig(base_config)
qat_config = QATConfig(base_config)
quantize_(model, qat_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
@@ -215,12 +188,7 @@ def prepare_model_for_qat(
activation_dtype=None,
group_size=group_size,
)
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
embedding_qat_config = QATConfig(
weight_config=embedding_base_config,
)
else:
embedding_qat_config = QATConfig(embedding_base_config)
embedding_qat_config = QATConfig(embedding_base_config)
quantize_(
model,
embedding_qat_config,

View File

@@ -173,6 +173,7 @@ class AxolotlInputConfig(
"description": "Whether to perform weighting in DPO trainer"
},
)
dpo_use_logits_to_keep: bool | None = None
dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None
@@ -182,6 +183,7 @@ class AxolotlInputConfig(
)
dpo_padding_free: bool | None = None
dpo_generate_during_eval: bool | None = None
datasets: (
Annotated[
@@ -627,17 +629,6 @@ class AxolotlInputConfig(
},
)
quantize_moe_experts: bool = Field(
default=False,
json_schema_extra={
"description": "Quantize MoE expert weights on load to reduce VRAM. "
"Requires adapter (lora/qlora) with load_in_4bit or load_in_8bit. "
"Requires CUDA (not compatible with ROCm or other backends). "
"Note: total parameter count may be reported incorrectly when enabled "
"(trainable param count is correct)."
},
)
scaling_softmax: bool | None = Field(
default=None,
json_schema_extra={
@@ -1298,31 +1289,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_quantize_moe_experts(cls, data):
if data.get("quantize_moe_experts"):
if data.get("lora_target_linear"):
raise ValueError(
"lora_target_linear is not compatible with quantize_moe_experts. "
"Use lora_target_parameters to target expert weights instead."
)
if data.get("adapter") not in ("lora", "qlora"):
raise ValueError("quantize_moe_experts requires adapter: lora or qlora")
if not (data.get("load_in_4bit") or data.get("load_in_8bit")):
raise ValueError(
"quantize_moe_experts requires load_in_4bit or load_in_8bit"
)
if (
data.get("capabilities")
and data["capabilities"].get("compute_capability")
and not data["capabilities"]["compute_capability"].startswith("sm_")
):
raise ValueError(
"quantize_moe_experts requires CUDA (not compatible with ROCm or other backends)"
)
return data
@model_validator(mode="before")
@classmethod
def check_auto_enable_lora_kernels(cls, data):

View File

@@ -19,8 +19,6 @@ class DeprecatedParameters(BaseModel):
evaluation_strategy: str | None = None
eval_table_size: int | None = None
eval_max_new_tokens: int | None = None
dpo_use_logits_to_keep: bool | None = None
dpo_generate_during_eval: bool | None = None
@field_validator("max_packed_sequence_len")
@classmethod
@@ -80,26 +78,6 @@ class DeprecatedParameters(BaseModel):
)
return eval_max_new_tokens
@field_validator("dpo_use_logits_to_keep")
@classmethod
def validate_dpo_use_logits_to_keep(cls, dpo_use_logits_to_keep):
if dpo_use_logits_to_keep is not None:
raise DeprecationWarning(
"`dpo_use_logits_to_keep` is no longer supported, "
"it has been removed in TRL >= 0.29.0"
)
return dpo_use_logits_to_keep
@field_validator("dpo_generate_during_eval")
@classmethod
def validate_dpo_generate_during_eval(cls, dpo_generate_during_eval):
if dpo_generate_during_eval is not None:
raise DeprecationWarning(
"`dpo_generate_during_eval` is no longer supported, "
"it has been removed in TRL >= 0.29.0"
)
return dpo_generate_during_eval
class RemappedParameters(BaseModel):
"""Parameters that have been remapped to other names"""

View File

@@ -10,7 +10,6 @@ class TorchAOQuantDType(Enum):
int8 = torch.int8
float8_e4m3fn = torch.float8_e4m3fn
nvfp4 = "nvfp4"
mxfp4 = "mxfp4"
def from_string(str):
if str == "int4":
@@ -21,8 +20,6 @@ class TorchAOQuantDType(Enum):
return TorchAOQuantDType.float8_e4m3fn
if str == "nvfp4":
return TorchAOQuantDType.nvfp4
if str == "mxfp4":
return TorchAOQuantDType.mxfp4
class RLType(str, Enum):
@@ -59,7 +56,6 @@ class ChatTemplate(str, Enum):
jinja = "jinja"
qwen_25 = "qwen_25"
qwen3 = "qwen3"
qwen3_5 = "qwen3_5"
falcon_h1 = "falcon_h1"
tokenizer_default = "tokenizer_default"
exaone = "exaone"

View File

@@ -209,19 +209,6 @@ class LoraConfig(BaseModel):
data["lora_dropout"] = 0.0
return data
@model_validator(mode="after")
def validate_lora_target_parameters_dropout(self):
if (
self.lora_target_parameters
and self.lora_dropout
and self.lora_dropout != 0.0
):
raise ValueError(
"lora_dropout must be 0 when lora_target_parameters is set. "
"PEFT's ParamWrapper does not support lora_dropout != 0."
)
return self
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""

View File

@@ -20,9 +20,6 @@ def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
return TorchAOQuantDType.float8_e4m3fn
if v == "nvfp4":
return TorchAOQuantDType.nvfp4
if v == "mxfp4":
return TorchAOQuantDType.mxfp4
raise ValueError(
f"Invalid dtype: '{v}'. Must be one of: {[e.name for e in TorchAOQuantDType] + ['fp8', 'float8']}"
)

View File

@@ -986,6 +986,23 @@ class OptimizationValidationMixin:
return self
@model_validator(mode="after")
def lr_groups_ao_optimizer(self):
if (
self.loraplus_lr_ratio is not None
or self.embedding_lr_scale is not None
or self.embedding_lr is not None
or self.lr_groups is not None
) and self.optimizer.value in ["adamw_torch_8bit", "adamw_torch_4bit"]:
# TODO(wing): remove this once ao>0.12.0
# requires https://github.com/pytorch/ao/pull/2606 in an ao release
raise ValueError(
"lr groups (`loraplus_lr_ratio`, `embedding_lr_scale`, `embedding_lr`, `lr_groups`) are not "
"supported with ao low-bit optimizers until ao>0.12.0. "
"Please refer to https://github.com/pytorch/ao/pull/2606."
)
return self
@model_validator(mode="before")
@classmethod
def check_tensor_parallel_size_update_ds_json(cls, data):

View File

@@ -457,6 +457,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1
)
* cfg.num_epochs
* cfg.context_parallel_size
* cfg.tensor_parallel_size
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
@@ -495,7 +497,14 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
LOG.debug(f"data_loader_len: {data_loader_len}")
# FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
total_num_steps = int(
math.floor(
data_loader_len
* cfg.num_epochs
* cfg.context_parallel_size
* cfg.tensor_parallel_size
)
)
if cfg.dataloader_drop_last:
# drop the last batch for each epoch
total_num_steps -= int(math.ceil(cfg.num_epochs))
@@ -516,7 +525,13 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
math.ceil(
len(train_dataset)
* cfg.num_epochs
* cfg.context_parallel_size
* cfg.tensor_parallel_size
/ cfg.batch_size
)
)
LOG.debug(f"total_num_steps: {total_num_steps}")
return total_num_steps

View File

@@ -2,7 +2,7 @@
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
@@ -94,6 +94,7 @@ def fixture_dpo_cfg(base_cfg):
{
"rl": RLType.DPO,
"dpo_use_weighting": True,
"dpo_use_logits_to_keep": True,
"dpo_label_smoothing": 0.1,
"beta": 0.1, # DPO beta
}
@@ -147,16 +148,9 @@ def fixture_grpo_cfg(base_cfg):
),
# Must be evenly divisible by num_generations
"micro_batch_size": 4,
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"split": "train[:1%]",
}
],
}
)
return DictDefault(cfg)
return cfg
@pytest.fixture(name="ipo_cfg")
@@ -340,7 +334,6 @@ def rand_reward_func(prompts, completions) -> list[float]:
try:
builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
builder.train_dataset = MagicMock()
self._test_common_training_arguments(training_arguments, rl=grpo_cfg.rl)
# GRPO specific
@@ -370,7 +363,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
self._test_common_training_arguments(training_arguments, rl=ipo_cfg.rl)
# IPO specific
assert training_arguments.beta == 0.1
assert training_arguments.loss_type == ["ipo"]
assert training_arguments.loss_type == "ipo"
assert training_arguments.label_smoothing == 0
def test_simpo_training_arguments(self, simpo_cfg, model, tokenizer):
@@ -536,11 +529,13 @@ class TestHFCausalTrainerBuilder:
"cfg_string",
[
"sft_cfg",
# "rm_cfg", # TODO fix for num_labels = 2 vs 1
"rm_cfg",
"prm_cfg",
],
)
def test_builder_w_rm_trainers(self, request, cfg_string, model, tokenizer):
def test_custom_optimizer_cls_and_kwargs(
self, request, cfg_string, model, tokenizer
):
cfg = request.getfixturevalue(cfg_string)
builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
cfg["optimizer"] = "muon"

View File

@@ -1,288 +0,0 @@
"""
End-to-end gradient and convergence tests for SonicMoE integration.
Requires:
- H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90)
- sonicmoe package installed
- transformers with Qwen3MoE support
Usage:
pytest tests/e2e/integrations/test_sonicmoe.py -v -s
"""
import importlib.util
import math
import pytest
import torch
_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None
_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
pytestmark = [
pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"),
pytest.mark.skipif(
not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)"
),
pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"),
]
def _create_tiny_qwen3_config():
"""Create a minimal Qwen3MoE config for fast testing."""
from transformers import AutoConfig
config = AutoConfig.for_model("qwen3_moe")
config.hidden_size = 512
config.intermediate_size = 1024
config.moe_intermediate_size = 64
config.num_attention_heads = 16
config.num_key_value_heads = 2
config.head_dim = 32
config.num_hidden_layers = 2
config.num_experts = 8
config.num_experts_per_tok = 2
config.vocab_size = 1000
config.max_position_embeddings = 128
config.norm_topk_prob = True
config.torch_dtype = torch.bfloat16
return config
def _interleave_gate_up_weights(model):
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
interleave_gate_up,
)
with torch.no_grad():
for name, param in model.named_parameters():
if "gate_up_proj" in name:
param.copy_(interleave_gate_up(param))
def _unpatch_sonicmoe():
"""Restore original forward on the MoE block class if it was patched."""
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
for moe_cls in resolve_moe_block_classes("qwen3_moe"):
if hasattr(moe_cls, "_original_forward"):
moe_cls.forward = moe_cls._original_forward
del moe_cls._original_forward
class TestSonicMoEForwardCorrectness:
"""Verify SonicMoE-patched model produces same output as original."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_forward_output_matches(self):
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
# Original model
model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
with torch.no_grad():
out_orig = model_orig(input_ids)
# Patched model (same weights, interleaved for SonicMoE)
model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
model_patched.load_state_dict(model_orig.state_dict())
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model_patched)
with torch.no_grad():
out_patched = model_patched(input_ids)
max_diff = (out_orig.logits - out_patched.logits).abs().max().item()
assert torch.allclose(
out_orig.logits, out_patched.logits, atol=1e-1, rtol=1e-1
), f"Output mismatch: max diff={max_diff:.6f}"
class TestSonicMoEGradientCorrectness:
"""Compare gradients between original HuggingFace and SonicMoE-patched forward."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_gradients_match(self):
"""Verify all parameter gradients match between original and patched."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
deinterleave_gate_up,
)
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
# ---------- Original model ----------
model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
out_orig = model_orig(input_ids, labels=input_ids)
out_orig.loss.backward()
grads_orig = {
n: p.grad.float().clone()
for n, p in model_orig.named_parameters()
if p.grad is not None
}
loss_orig = out_orig.loss.item()
# ---------- SonicMoE-patched model (same weights, interleaved) ----------
model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
model_patched.load_state_dict(model_orig.state_dict())
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model_patched)
out_patched = model_patched(input_ids, labels=input_ids)
out_patched.loss.backward()
grads_patched = {}
for n, p in model_patched.named_parameters():
if p.grad is None:
continue
g = p.grad.float().clone()
# gate_up_proj grads are in interleaved layout, de-interleave to match orig
if "gate_up_proj" in n:
g = deinterleave_gate_up(g)
grads_patched[n] = g
loss_patched = out_patched.loss.item()
# ---------- Compare ----------
assert abs(loss_orig - loss_patched) < 0.5, (
f"Loss mismatch: orig={loss_orig:.4f}, patched={loss_patched:.4f}"
)
# All parameters with gradients in original should have them in patched
missing = set(grads_orig.keys()) - set(grads_patched.keys())
assert not missing, f"Missing gradients in patched model: {missing}"
# Compare gradient values
# bf16 with different GEMM impls (cuBLAS vs CUTLASS) can diverge,
# so use generous tolerance: flag only if both rel >10% AND abs >1e-2
mismatches = []
for name in grads_orig:
if name not in grads_patched:
continue
g_orig = grads_orig[name]
g_patched = grads_patched[name]
max_diff = (g_orig - g_patched).abs().max().item()
rel_diff = max_diff / (g_orig.abs().max().item() + 1e-8)
if rel_diff > 0.1 and max_diff > 1e-2:
mismatches.append(
f" {name}: max_abs_diff={max_diff:.6f}, rel_diff={rel_diff:.4f}"
)
assert not mismatches, (
"Gradient mismatches (rel_diff > 10% and abs_diff > 1e-2):\n"
+ "\n".join(mismatches)
)
def test_router_weights_receive_gradients(self):
"""Verify that router (gate) weights get non-zero gradients."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
out = model(input_ids, labels=input_ids)
out.loss.backward()
gate_grads_found = False
for name, param in model.named_parameters():
if "gate" in name and "weight" in name:
gate_grads_found = True
assert param.grad is not None, f"No gradient for router: {name}"
assert param.grad.abs().max() > 0, f"Zero gradient for router: {name}"
assert gate_grads_found, "No gate.weight parameters found in model"
class TestSonicMoETrainingConvergence:
"""Verify loss decreases during training with SonicMoE."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_loss_decreases(self):
"""Run 30 training steps, verify loss decreases and no NaN/Inf."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
losses = []
for step in range(30):
out = model(input_ids, labels=input_ids)
loss = out.loss
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
losses.append(loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
assert losses[-1] < losses[0], (
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
)
def test_expert_weights_update(self):
"""Verify expert weights change during training (not frozen)."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
# Snapshot expert weights before training
expert_weights_before = {}
for name, param in model.named_parameters():
if "experts" in name:
expert_weights_before[name] = param.data.clone()
assert expert_weights_before, "No expert parameters found"
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
for _ in range(5):
out = model(input_ids, labels=input_ids)
out.loss.backward()
optimizer.step()
optimizer.zero_grad()
# Check that expert weights changed
changed = 0
for name, param in model.named_parameters():
if name in expert_weights_before:
if not torch.equal(param.data, expert_weights_before[name]):
changed += 1
assert changed > 0, "No expert weights changed after 5 training steps"

View File

@@ -8,8 +8,6 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import TorchAOQuantDType
from axolotl.utils.schemas.quantization import QATConfig, validate_ao_dtype
from .utils import check_model_output_exists, check_tensorboard
@@ -132,32 +130,3 @@ class TestQATLlama:
loss_threshold,
"Train Loss (%s) is too high",
)
class TestMXFP4Schema:
"""Test MXFP4 schema validation"""
def test_validate_mxfp4_dtype(self):
result = validate_ao_dtype("mxfp4")
assert result == TorchAOQuantDType.mxfp4
def test_qat_config_with_mxfp4(self):
"""Test QATConfig accepts mxfp4 weight_dtype"""
config = QATConfig(
weight_dtype="mxfp4",
group_size=32,
quantize_embedding=False,
)
assert config.weight_dtype == TorchAOQuantDType.mxfp4
assert config.group_size == 32
def test_qat_config_mxfp4_invalid_group_size(self):
"""Test that invalid group_size raises appropriate error during quantization"""
# Note: Schema validation doesn't check group_size compatibility,
# that happens in get_quantization_config
config = QATConfig(
weight_dtype="mxfp4",
group_size=16, # Invalid for mxfp4, but schema allows it
)
assert config.group_size == 16 # Schema accepts it
# Actual validation happens at runtime in get_quantization_config

View File

@@ -5,7 +5,6 @@ Tests for axolotl.utils.quantization
import pytest
import torch
from torch import nn
from torchao.prototype.qat import MXFakeQuantizeConfig
from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
@@ -118,21 +117,6 @@ class TestQuantization:
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
assert isinstance(config, expected_type)
@require_torch_2_8_0
@requires_sm_ge_100
def test_get_ptq_config_mxfp4(self):
config = get_quantization_config(TorchAOQuantDType.mxfp4, None, 32)
assert isinstance(config, MXFakeQuantizeConfig)
assert config.block_size == 32
@require_torch_2_8_0
@requires_sm_ge_100
def test_get_ptq_config_mxfp4_invalid_group_size(self):
with pytest.raises(
ValueError, match="MXFP4 quantization must use a block_size"
):
get_quantization_config(TorchAOQuantDType.mxfp4, None, 16)
@requires_cuda_ge_8_9
@require_torch_2_8_0
def test_get_ptq_config_int4_weight_only(self):
@@ -278,35 +262,6 @@ class TestQuantization:
else:
assert child.activation_fake_quantizer is None
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding",
[
(TorchAOQuantDType.mxfp4, None, 32, False),
(TorchAOQuantDType.mxfp4, None, 32, True),
],
)
@require_torch_2_8_0
@requires_sm_ge_100
def test_prepare_model_for_qat_mxfp4(
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
):
prepare_model_for_qat(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer")
@require_torch_2_8_0
@requires_cuda_ge_8_9
def test_convert_qat_model(self, model):

View File

@@ -180,7 +180,6 @@ def check_tensorboard(
lt_val: float,
assertion_err: str,
rtol: float = 0.02,
gt_zero: bool = True,
) -> None:
"""
helper function to parse and check tensorboard logs
@@ -195,8 +194,6 @@ def check_tensorboard(
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
else:
assert df.value.values[-1] < lt_val, assertion_err
if gt_zero:
assert df.value.values[-1] > 1e-5, "Expected loss to be greater than zero"
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:

View File

@@ -6,7 +6,7 @@
Unit tests for scattermoe-lora code-review fixes.
Tests cover:
- KernelsArgs validator: disable_mlp_kernel
- KernelsArgs validator: disable_mlp_kernel_scattermoe
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
- ParallelExperts: scaling=0.0 not treated as falsy
- single2scatter: non-aligned K/N dimensions
@@ -20,12 +20,12 @@ import pytest
import torch
# ============================================================================
# 1. KernelsArgs: disable_mlp_kernel validator
# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator
# ============================================================================
class TestKernelsArgsValidator:
"""Test that disable_mlp_kernel sets both flags correctly.
"""Test that disable_mlp_kernel_scattermoe sets both flags correctly.
These tests call the validator classmethod directly on raw dicts,
since lora_mlp_kernel / mlp_kernel are not declared model fields.
@@ -40,7 +40,7 @@ class TestKernelsArgsValidator:
"use_scattermoe": True,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel(data)
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
assert result["lora_mlp_kernel"] is False
assert result["mlp_kernel"] is False
@@ -52,7 +52,7 @@ class TestKernelsArgsValidator:
"use_kernels": True,
"use_scattermoe": True,
}
result = KernelsArgs.disable_mlp_kernel(data)
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
assert result["mlp_kernel"] is False
# lora_mlp_kernel was not in data, should not be added
assert "lora_mlp_kernel" not in result
@@ -66,7 +66,7 @@ class TestKernelsArgsValidator:
"use_scattermoe": True,
"lora_mlp_kernel": False,
}
result = KernelsArgs.disable_mlp_kernel(data)
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
assert result["lora_mlp_kernel"] is False
def test_no_change_when_scattermoe_disabled(self):
@@ -78,7 +78,7 @@ class TestKernelsArgsValidator:
"use_scattermoe": False,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel(data)
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
assert result["lora_mlp_kernel"] is True

View File

@@ -1,428 +0,0 @@
"""Unit tests for the SonicMoE integration."""
from types import SimpleNamespace
import pytest
import torch
from axolotl.integrations.kernels.args import KernelsArgs
from axolotl.integrations.kernels.sonicmoe.routing import (
sigmoid_topk_routing,
softmax_topk_routing,
)
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
ConcatenatedToInterleaved,
InterleavedToConcatenated,
register_sonicmoe_weight_converter,
)
class TestKernelsArgs:
def test_mutual_exclusivity_raises(self):
with pytest.raises(ValueError, match="Cannot use both"):
KernelsArgs.model_validate({"use_scattermoe": True, "use_sonicmoe": True})
def test_sonicmoe_only(self):
result = KernelsArgs.model_validate({"use_sonicmoe": True})
assert result.use_sonicmoe is True
assert result.use_scattermoe is None
def test_scattermoe_only(self):
result = KernelsArgs.model_validate({"use_scattermoe": True})
assert result.use_scattermoe is True
assert result.use_sonicmoe is None
def test_neither_set(self):
result = KernelsArgs.model_validate({})
assert result.use_scattermoe is None
assert result.use_sonicmoe is None
def test_disables_mlp_kernel_when_sonicmoe(self):
data = {"use_sonicmoe": True, "lora_mlp_kernel": True}
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is False
assert result["mlp_kernel"] is False
class TestConcatenatedToInterleaved:
@pytest.fixture
def sample_tensor(self):
"""Create a test tensor [E=2, 2*I=4, H=3] with distinct gate/up values."""
E, I, H = 2, 2, 3 # noqa: E741
gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H)
up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H)
return torch.cat([gate, up], dim=1)
def test_interleave_rows_alternate(self, sample_tensor):
op = ConcatenatedToInterleaved(dim=1)
result = op.convert(
{"test": sample_tensor},
source_patterns=["test"],
target_patterns=["test"],
)
interleaved = result["test"]
# For expert 0: even rows should be gate, odd rows should be up
E, two_I, H = sample_tensor.shape
I = two_I // 2 # noqa: E741
gate_orig = sample_tensor[:, :I, :]
up_orig = sample_tensor[:, I:, :]
assert torch.equal(interleaved[:, 0::2, :], gate_orig)
assert torch.equal(interleaved[:, 1::2, :], up_orig)
def test_interleave_handles_list_input(self, sample_tensor):
op = ConcatenatedToInterleaved(dim=1)
result = op.convert(
{"test": [sample_tensor]},
source_patterns=["test"],
target_patterns=["test"],
)
assert result["test"].shape == sample_tensor.shape
def test_reverse_op_type(self):
op = ConcatenatedToInterleaved(dim=1)
assert isinstance(op.reverse_op, InterleavedToConcatenated)
assert op.reverse_op.dim == 1
class TestInterleavedToConcatenated:
@pytest.fixture
def interleaved_tensor(self):
"""Create an interleaved tensor [E=2, 2*I=4, H=3]."""
E, I, H = 2, 2, 3 # noqa: E741
gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H)
up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H)
interleaved = torch.empty(E, 2 * I, H)
interleaved[:, 0::2, :] = gate
interleaved[:, 1::2, :] = up
return interleaved
def test_deinterleave_gate_up_separated(self, interleaved_tensor):
op = InterleavedToConcatenated(dim=1)
result = op.convert(
{"test": interleaved_tensor},
source_patterns=["test"],
target_patterns=["test"],
)
concatenated = result["test"]
E, two_I, H = concatenated.shape
I = two_I // 2 # noqa: E741
# First half should be gate (even rows from interleaved)
assert torch.equal(concatenated[:, :I, :], interleaved_tensor[:, 0::2, :])
# Second half should be up (odd rows from interleaved)
assert torch.equal(concatenated[:, I:, :], interleaved_tensor[:, 1::2, :])
def test_reverse_op_type(self):
op = InterleavedToConcatenated(dim=1)
assert isinstance(op.reverse_op, ConcatenatedToInterleaved)
assert op.reverse_op.dim == 1
class TestRoundTrip:
@pytest.fixture
def concat_tensor(self):
E, I, H = 4, 8, 16 # noqa: E741
gate = torch.randn(E, I, H)
up = torch.randn(E, I, H)
return torch.cat([gate, up], dim=1)
def test_interleave_then_deinterleave_is_identity(self, concat_tensor):
fwd = ConcatenatedToInterleaved(dim=1)
rev = InterleavedToConcatenated(dim=1)
interleaved = fwd.convert(
{"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"]
)["k"]
recovered = rev.convert(
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
)["k"]
assert torch.equal(concat_tensor, recovered)
def test_reverse_op_chain_is_identity(self, concat_tensor):
"""Verify that op.reverse_op produces an exact inverse."""
op = ConcatenatedToInterleaved(dim=1)
rev = op.reverse_op
interleaved = op.convert(
{"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"]
)["k"]
recovered = rev.convert(
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
)["k"]
assert torch.equal(concat_tensor, recovered)
def test_various_shapes(self):
"""Test with different expert counts and dimensions."""
fwd = ConcatenatedToInterleaved(dim=1)
rev = InterleavedToConcatenated(dim=1)
for E, I, H in [(1, 4, 8), (8, 16, 32), (16, 128, 256)]: # noqa: E741
concat = torch.randn(E, 2 * I, H)
interleaved = fwd.convert(
{"k": concat}, source_patterns=["k"], target_patterns=["k"]
)["k"]
recovered = rev.convert(
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
)["k"]
assert torch.equal(concat, recovered), (
f"Failed for shape ({E}, {2 * I}, {H})"
)
class TestWeightConverterRegistration:
def test_register_appends_interleave_op(self):
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
register_sonicmoe_weight_converter("qwen3_moe")
modified = get_checkpoint_conversion_mapping("qwen3_moe")
# Find the gate_up_proj converter
gate_up_converter = None
for conv in modified:
if hasattr(conv, "operations") and any(
"gate_up_proj" in pat for pat in conv.target_patterns
):
gate_up_converter = conv
break
assert gate_up_converter is not None
assert isinstance(gate_up_converter.operations[-1], ConcatenatedToInterleaved)
def test_double_registration_is_idempotent(self):
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
register_sonicmoe_weight_converter("qwen3_moe")
register_sonicmoe_weight_converter("qwen3_moe")
modified = get_checkpoint_conversion_mapping("qwen3_moe")
for conv in modified:
if hasattr(conv, "operations") and any(
"gate_up_proj" in pat for pat in conv.target_patterns
):
interleave_count = sum(
isinstance(op, ConcatenatedToInterleaved) for op in conv.operations
)
assert interleave_count == 1, (
f"Expected 1 ConcatenatedToInterleaved op, got {interleave_count}"
)
break
def test_register_unsupported_model_type_warns(self):
# A model type with no conversion mapping should warn but not raise
register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
def _make_qwen_moe_block(T=8, H=16, E=4, K=2):
"""Create a mock qwen-style MoE block for routing tests."""
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
return SimpleNamespace(gate=gate), T, H, E, K
def _make_glm_moe_block(T=8, H=16, E=16, K=4, n_group=2, topk_group=1):
"""Create a mock GLM5-style MoE block for routing tests."""
gate = SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=n_group,
topk_group=topk_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
return moe_block, T, H, E, K
def _make_minimax_m2_moe_block(T=8, H=16, E=16, K=4):
"""Create a mock minimax_m2-style MoE block for routing tests.
minimax_m2 uses sigmoid->topk WITHOUT group selection:
- e_score_correction_bias is on the moe_block (not on gate)
- No n_group / topk_group attributes
- Always normalizes (norm_topk_prob defaults to True)
- No routed_scaling_factor (defaults to 1.0)
"""
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
e_score_correction_bias=torch.zeros(E),
)
return moe_block, T, H, E, K
class TestSoftmaxTopkRouting:
def test_output_shapes(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = softmax_topk_routing(hidden, moe_block)
assert scores.shape == (T * K,)
assert token_idx.shape == (T * K,)
assert expert_idx.shape == (T * K,)
assert logits.shape == (T, E)
def test_scores_are_float32(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
_, token_idx, _, _ = softmax_topk_routing(hidden, moe_block)
# Token indices must be sorted ascending (SonicMoE requirement)
diffs = token_idx[1:] - token_idx[:-1]
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
_, _, expert_idx, _ = softmax_topk_routing(hidden, moe_block)
assert (expert_idx >= 0).all()
assert (expert_idx < E).all()
def test_renormalized_scores_sum_to_one(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
per_token_sums = scores.reshape(T, K).sum(dim=-1)
assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5)
class TestSigmoidTopkRouting:
def test_output_shapes(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block)
assert scores.shape == (T * K,)
assert token_idx.shape == (T * K,)
assert expert_idx.shape == (T * K,)
assert logits.shape == (T, E)
def test_scores_are_float32(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
_, token_idx, _, _ = sigmoid_topk_routing(hidden, moe_block)
diffs = token_idx[1:] - token_idx[:-1]
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
assert (expert_idx >= 0).all()
assert (expert_idx < E).all()
def test_scores_are_nonnegative(self):
"""Sigmoid outputs are in [0, 1], so scores should be non-negative."""
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
assert (scores >= 0).all()
def test_scaling_factor_applied(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
# Get scores with scaling_factor=1.0
scores_1x, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
# Get scores with scaling_factor=2.0
moe_block.routed_scaling_factor = 2.0
scores_2x, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
assert torch.allclose(scores_2x, scores_1x * 2.0, atol=1e-5)
def test_group_selection_restricts_experts(self):
"""With n_group=4 and topk_group=1, only 1/4 of experts should be selectable."""
moe_block, T, H, E, K = _make_glm_moe_block(E=16, K=2, n_group=4, topk_group=1)
hidden = torch.randn(T, H)
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
# Each token's experts should all fall within a single group (size E//n_group=4)
expert_idx_2d = expert_idx.reshape(T, K)
for t in range(T):
experts = expert_idx_2d[t]
groups = experts // (E // moe_block.n_group)
# All selected experts should be from the same group
assert (groups == groups[0]).all()
class TestMiniMaxM2SigmoidRouting:
"""Tests for minimax_m2 routing: sigmoid->topk without group selection."""
def test_output_shapes(self):
"""Validates getattr defaults work: n_group=1, E from gate.weight.shape[0]."""
moe_block, T, H, E, K = _make_minimax_m2_moe_block()
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block)
assert scores.shape == (T * K,)
assert token_idx.shape == (T * K,)
assert expert_idx.shape == (T * K,)
assert logits.shape == (T, E)
def test_bias_on_block_not_gate(self):
"""Verify that e_score_correction_bias on the block (not gate) is used."""
T, H, E, K = 8, 16, 8, 2
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
)
# Large positive bias on expert 0 should make it selected more often
bias = torch.zeros(E)
bias[0] = 100.0
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
e_score_correction_bias=bias,
)
hidden = torch.randn(T, H)
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
# Expert 0 should appear for every token due to the large bias
expert_idx_2d = expert_idx.reshape(T, K)
for t in range(T):
assert 0 in expert_idx_2d[t]

View File

@@ -1,158 +0,0 @@
"""
Gradient correctness tests for SonicMoE routing functions (CPU-only).
Uses torch.autograd.gradcheck with float32 inputs to match the production
code path where routing happens in float32.
"""
import torch
from axolotl.integrations.kernels.sonicmoe.routing import (
sigmoid_topk_routing,
softmax_topk_routing,
)
_GC_EPS = 1e-3
_GC_ATOL = 1e-3
_GC_RTOL = 1e-3
def _make_softmax_moe_block(weight):
gate = torch.nn.Module()
gate.weight = weight
gate.top_k = 2
gate.norm_topk_prob = True
moe_block = torch.nn.Module()
moe_block.gate = gate
return moe_block
def _make_sigmoid_moe_block(weight, bias):
gate = torch.nn.Module()
gate.weight = weight
gate.e_score_correction_bias = bias
moe_block = torch.nn.Module()
moe_block.gate = gate
moe_block.top_k = 2
moe_block.n_routed_experts = weight.shape[0]
moe_block.n_group = 1
moe_block.norm_topk_prob = True
moe_block.routed_scaling_factor = 1.0
return moe_block
class TestSoftmaxTopkRoutingGradcheck:
"""Numerical gradient verification for softmax_topk_routing."""
def test_gradcheck_wrt_gate_weight(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_hidden_states(self):
T, H, E = 4, 8, 4
weight = torch.randn(E, H, dtype=torch.float32)
moe_block = _make_softmax_moe_block(weight)
def fn(hidden):
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_router_logits(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
_, _, _, router_logits = softmax_topk_routing(hidden, moe_block)
return router_logits
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_no_norm_variant(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
moe_block.gate.norm_topk_prob = False
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
class TestSigmoidTopkRoutingGradcheck:
"""Numerical gradient verification for sigmoid_topk_routing."""
def test_gradcheck_wrt_gate_weight(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
bias = torch.zeros(E, dtype=torch.float32)
def fn(weight):
moe_block = _make_sigmoid_moe_block(weight, bias)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_hidden_states(self):
T, H, E = 4, 8, 4
weight = torch.randn(E, H, dtype=torch.float32)
bias = torch.zeros(E, dtype=torch.float32)
moe_block = _make_sigmoid_moe_block(weight, bias)
def fn(hidden):
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_bias(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
weight = torch.randn(E, H, dtype=torch.float32)
def fn(bias):
moe_block = _make_sigmoid_moe_block(weight, bias)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
bias = torch.zeros(E, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(fn, (bias,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL)

View File

@@ -18,7 +18,6 @@ Unit tests for SwanLab Integration Plugin.
Tests conflict detection, configuration validation, and multi-logger warnings.
"""
import importlib.util
import logging
import os
import time
@@ -26,11 +25,12 @@ from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from transformers.utils.import_utils import _is_package_available
from axolotl.integrations.swanlab.args import SwanLabConfig
from axolotl.integrations.swanlab.plugins import SwanLabPlugin
SWANLAB_INSTALLED = importlib.util.find_spec("swanlab") is not None
SWANLAB_INSTALLED = _is_package_available("swanlab")
@pytest.mark.skipif(not SWANLAB_INSTALLED, reason="swanlab package not installed")

View File

@@ -52,8 +52,8 @@ def mock_torch():
mock_torch.cuda.device_count.return_value = 2
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
mock_torch.cuda.memory_allocated.side_effect = lambda device: (
(device + 1) * 1024 * 1024 * 1024
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 1) * 1024 * 1024 * 1024
)
yield mock_torch
@@ -292,8 +292,8 @@ class TestRuntimeMetricsTracker:
mock_memory_info = mock_process.memory_info.return_value
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
mock_torch.cuda.memory_allocated.side_effect = lambda device: (
(device + 0.5) * 1024 * 1024 * 1024
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 0.5) * 1024 * 1024 * 1024
)
# Update memory metrics again
@@ -307,8 +307,8 @@ class TestRuntimeMetricsTracker:
# Change mocked memory values to be higher
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
mock_torch.cuda.memory_allocated.side_effect = lambda device: (
(device + 2) * 1024 * 1024 * 1024
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 2) * 1024 * 1024 * 1024
)
# Update memory metrics again

View File

@@ -1,56 +0,0 @@
"""Tests for batch_size calculation with context parallelism."""
import sys
import types
import pytest
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="cp_base_cfg")
def fixture_cp_base_cfg(min_base_cfg):
return (
DictDefault(
micro_batch_size=2,
gradient_accumulation_steps=4,
sequence_len=2048,
num_epochs=1,
flash_attention=True,
)
| min_base_cfg
)
class TestContextParallelBatchSize:
"""Verify batch_size scales by effective dp world_size when using context parallelism."""
@pytest.mark.parametrize(
"world_size, context_parallel_size, expected_batch_size",
[
(4, 1, 32), # no CP: 2*4*4 = 32
(4, 2, 16), # CP=2: 2*4*(4//2) = 16
(4, 4, 8), # CP=4: 2*4*(4//4) = 8
(2, 2, 8), # CP=ws: 2*4*(2//2) = 8 (no scaling)
],
)
def test_batch_size_with_context_parallelism(
self,
cp_base_cfg,
monkeypatch,
world_size,
context_parallel_size,
expected_batch_size,
):
monkeypatch.setenv("WORLD_SIZE", str(world_size))
# Mock ring_flash_attn since it's not installable on CPU,
# but required by schema validation when context_parallel_size > 1.
if "ring_flash_attn" not in sys.modules:
monkeypatch.setitem(
sys.modules, "ring_flash_attn", types.ModuleType("ring_flash_attn")
)
cp_base_cfg["context_parallel_size"] = context_parallel_size
cfg = validate_config(cp_base_cfg)
normalize_config(cfg)
assert cfg.batch_size == expected_batch_size

View File

@@ -1,55 +0,0 @@
"""Tests for batch_size calculation with tensor parallelism."""
from unittest.mock import patch
import addict
import pytest
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="tp_base_cfg")
def fixture_tp_base_cfg(min_base_cfg):
return (
DictDefault(
micro_batch_size=2,
gradient_accumulation_steps=4,
sequence_len=2048,
num_epochs=1,
)
| min_base_cfg
)
class TestTensorParallelBatchSize:
"""Verify batch_size scales by effective dp world_size when using tensor parallelism."""
@pytest.mark.parametrize(
"world_size, tensor_parallel_size, expected_batch_size",
[
(4, 1, 32), # no TP: 2*4*4 = 32
(4, 2, 16), # TP=2: 2*4*(4//2) = 16
(4, 4, 8), # TP=4: 2*4*(4//4) = 8
(2, 2, 8), # TP=ws: 2*4*(2//2) = 8 (no scaling)
],
)
def test_batch_size_with_tensor_parallelism(
self,
tp_base_cfg,
monkeypatch,
world_size,
tensor_parallel_size,
expected_batch_size,
):
monkeypatch.setenv("WORLD_SIZE", str(world_size))
tp_base_cfg["tensor_parallel_size"] = tensor_parallel_size
cfg = validate_config(tp_base_cfg)
# Mock load_model_config to avoid downloading the model and to bypass
# the tie_word_embeddings validation that blocks TP > 1.
with patch(
"axolotl.utils.config.load_model_config",
return_value=addict.Dict({"model_type": "llama"}),
):
normalize_config(cfg)
assert cfg.batch_size == expected_batch_size

View File

@@ -84,8 +84,7 @@ class TestTokenizers:
}
)
tokenizer = load_tokenizer(cfg)
assert "LlamaTokenizer" in tokenizer.__class__.__name__
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1792]
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
assert len(tokenizer) == 32001
# ensure reloading the tokenizer again from cfg results in same vocab length

View File

@@ -1,156 +0,0 @@
"""Tests for MoE expert quantization config validation and PEFT patch idempotency."""
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture()
def gpu_caps():
return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1}
@pytest.fixture()
def env_caps():
return {"torch_version": "2.7.0"}
class TestQuantizeMoeExpertsValidation:
"""Test suite for quantize_moe_experts config validator."""
def test_requires_adapter(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts without adapter should fail."""
cfg = (
DictDefault(
quantize_moe_experts=True,
)
| min_base_cfg
)
with pytest.raises(ValueError, match="requires adapter"):
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
def test_requires_quantization(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts without load_in_4bit/8bit should fail."""
cfg = (
DictDefault(
quantize_moe_experts=True,
adapter="lora",
)
| min_base_cfg
)
with pytest.raises(ValueError, match="requires load_in_4bit or load_in_8bit"):
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
def test_valid_qlora_4bit(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts with qlora + 4bit should pass."""
cfg = (
DictDefault(
quantize_moe_experts=True,
adapter="qlora",
load_in_4bit=True,
)
| min_base_cfg
)
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
assert result["quantize_moe_experts"] is True
def test_valid_lora_8bit(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts with lora + 8bit should pass."""
cfg = (
DictDefault(
quantize_moe_experts=True,
adapter="lora",
load_in_8bit=True,
)
| min_base_cfg
)
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
assert result["quantize_moe_experts"] is True
def test_false_skips_validation(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts=false should not check adapter/quantization."""
cfg = (
DictDefault(
quantize_moe_experts=False,
)
| min_base_cfg
)
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
assert result["quantize_moe_experts"] is False
def test_rejects_lora_target_linear(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts with lora_target_linear should fail."""
cfg = (
DictDefault(
quantize_moe_experts=True,
adapter="qlora",
load_in_4bit=True,
lora_target_linear=True,
)
| min_base_cfg
)
with pytest.raises(ValueError, match="lora_target_linear is not compatible"):
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
def test_default_is_false(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts should default to false."""
cfg = DictDefault({}) | min_base_cfg
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
assert result["quantize_moe_experts"] is False
class TestLoraTargetParametersDropout:
"""Test that lora_dropout must be 0 when lora_target_parameters is set."""
def test_rejects_nonzero_dropout(self, min_base_cfg):
"""lora_dropout > 0 with lora_target_parameters should fail."""
cfg = (
DictDefault(
adapter="lora",
lora_target_parameters=["mlp.experts.gate_up_proj"],
lora_dropout=0.1,
load_in_8bit=True,
)
| min_base_cfg
)
with pytest.raises(ValueError, match="lora_dropout must be 0"):
validate_config(cfg)
def test_zero_dropout_passes(self, min_base_cfg):
"""lora_dropout=0 with lora_target_parameters should pass."""
cfg = (
DictDefault(
adapter="lora",
lora_target_parameters=["mlp.experts.gate_up_proj"],
lora_dropout=0.0,
load_in_8bit=True,
)
| min_base_cfg
)
result = validate_config(cfg)
assert result["lora_dropout"] == 0.0
class TestPeftPatchIdempotency:
"""Test that patch_peft_target_parameters_matching is idempotent."""
def test_double_call_does_not_stack_wrappers(self):
"""Calling patch twice should not double-wrap _inject_parameters."""
from peft.tuners.tuners_utils import BaseTuner
from axolotl.monkeypatch.moe_quant import (
patch_peft_target_parameters_matching,
)
original = BaseTuner._inject_parameters
try:
patch_peft_target_parameters_matching()
first_patched = BaseTuner._inject_parameters
patch_peft_target_parameters_matching()
second_patched = BaseTuner._inject_parameters
# Should be same function, not double-wrapped
assert first_patched is second_patched
finally:
BaseTuner._inject_parameters = original
patch_peft_target_parameters_matching._axolotl_patched = False