feat: add sonicmoe fused lora support (#3519)

* feat: add sonicmoe fused lora support

* fix: forgot to add file

* feat: add test

* feat: add lora support for other routes

* fix: add int8 lora support

* fix: add qwen35_moe interleave support

* fix: qwen3_5_moe loss

* chore: lint

* address some pr comments

* fix test imports

* add support matrix for moe kernels [skip ci]

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2026-04-02 19:53:48 +07:00
committed by GitHub
parent 16e32232fb
commit 842fa039dd
16 changed files with 1249 additions and 126 deletions

View File

@@ -52,26 +52,91 @@ The `KernelsPlugin` runs before model loading and:
### ScatterMoE
1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation via the HF `kernels` library.
### SonicMoE
1. Resolves the model's MoE block class(es) from `constants.py`.
2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format.
3. Supports both softmax->topk and sigmoid->topk routing strategies.
2. Patches the forward method with SonicMoE's optimized CUTLASS kernels and registers a weight converter for the interleaved gate/up projection format.
3. Supports pluggable routing strategies (see routing table below).
Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.
#### Supported Models
## Model Support Matrix
See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.).
All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel currently supports non-SwiGLU MoE architectures.
### Routing strategies
| Routing Strategy | Description | ScatterMoE | SonicMoE |
|---|---|:---:|:---:|
| softmax → topk | Softmax over experts, select top-K, optional renormalization | Yes | Yes |
| softmax → group selection → topk | Softmax, select top groups (sum of top-2 per group), topk from selected groups, renorm + scaling | No | Yes |
| sigmoid → topk (with groups) | Sigmoid + bias correction, group-based masking, topk from masked scores, weights from original sigmoid | Yes | Yes |
| sigmoid → topk (no groups) | Sigmoid + bias correction, straight topk (n_group=1) | Yes | Yes |
| softmax → bias correction → topk | Softmax, bias via `gate.moe_statics`, topk, gather from original probs, clamp-based renorm | No | Yes |
| softmax → group_limited_greedy | Softmax, group selection (max per group), topk, scale only (no renorm) | No | Yes |
| softmax → topk via gate.wg | Softmax, gate weight at `gate.wg.weight` (not `gate.weight`), always renormalize | No | Yes |
| fused topk → softmax | Routing + expert computation fused in a single kernel | No | Planned |
### Per-model support
| Model Type | Architecture | Routing | ScatterMoE | SonicMoE |
|---|---|---|:---:|:---:|
| `qwen2_moe` | Qwen2-MoE | softmax → topk | **Yes** | **Yes** |
| `qwen3_moe` | Qwen3-MoE | softmax → topk | **Yes** | **Yes** |
| `qwen3_5_moe` | Qwen3.5-MoE | softmax → topk | **Yes** | **Yes** |
| `qwen3_5_moe_text` | Qwen3.5-MoE (VLM text) | softmax → topk | **Yes** | **Yes** |
| `qwen3_next` | Qwen3-Next | softmax → topk | **Yes** | **Yes** |
| `qwen3_vl_moe` | Qwen3-VL-MoE | softmax → topk | **Yes** | **Yes** |
| `qwen3_omni_moe` | Qwen3-Omni (Thinker + Talker) | softmax → topk | **Yes** | **Yes** |
| `olmoe` | OLMoE | softmax → topk | **Yes** | **Yes** |
| `mixtral` | Mixtral | softmax → topk | **Yes** | **Yes** |
| `minimax` | MiniMax | softmax → topk | **Yes** | **Yes** |
| `mistral4` | Mistral 4 | softmax → group → topk | No | **Yes** |
| `glm_moe_dsa` | GLM-MoE DSA (GLM 5) | sigmoid → topk (groups) | **Yes** | **Yes** |
| `deepseek_v3` | DeepSeek-V3 | sigmoid → topk (groups) | **Yes** | **Yes** |
| `glm4_moe` | GLM4-MoE | sigmoid → topk (groups) | **Yes** | **Yes** |
| `glm4_moe_lite` | GLM4-MoE Lite (GLM 4.7 Flash) | sigmoid → topk (groups) | **Yes**\* | **Yes** |
| `glm4v_moe` | GLM4v-MoE | sigmoid → topk (groups) | **Yes** | **Yes** |
| `minimax_m2` | MiniMax M2 | sigmoid → topk (no groups) | **Yes** | **Yes** |
| `ernie4_5_moe` | ERNIE 4.5 MoE | softmax → bias → topk | No | **Yes** |
| `deepseek_v2` | DeepSeek-V2 | softmax → group_limited_greedy | No | **Yes** |
| `hunyuan_v1_moe` | HunYuan V1 MoE | softmax → topk (gate.wg) | No | **Yes** |
| `gpt_oss` | GPT-OSS | fused topk → softmax | No | Planned |
\* `glm4_moe_lite` with ScatterMoE may have issues — see Limitations.
### Feature comparison
| Feature | ScatterMoE | SonicMoE |
|---|:---:|:---:|
| Kernel backend | Triton | CUTLASS |
| GPU requirement | Any CUDA | Hopper (H100/H200) or Blackwell (B200+) |
| LoRA approach | Fused in Triton kernel | Runtime materialization + custom autograd |
| LoRA overhead | Lower (fused computation) | Higher (per-forward materialization) |
| Gate/router LoRA | Yes | Yes |
| Expert LoRA | Yes (fused) | Yes (materialized) |
| Shared expert LoRA | Yes (standard PEFT) | Yes (standard PEFT) |
| Selective expert dequantization | Yes (~97% memory savings) | No |
| Weight format | Transposed `[E, hidden, 2*inter]` | Interleaved gate/up `[2*I, H, E]` |
| torch.compile routing | No | Yes (optional) |
## Shared Expert Handling
Both kernels handle shared experts identically. Shared expert attribute names are detected in order of priority:
1. `shared_expert` (Qwen2-MoE)
2. `shared_experts` (GLM-MoE, DeepSeek-V3)
3. `shared_mlp` (HunYuan V1 MoE)
If `shared_expert_gate` exists, sigmoid gating is applied to the shared expert contribution before adding it to the routed output. PEFT wraps shared expert linear layers with standard LoRA — no special handling is needed.
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment.
SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures.
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).
- **Non-SwiGLU activations**: Neither kernel supports MoE architectures with non-SwiGLU expert activations (e.g., GPT-OSS uses a custom GLU variant).
- **GPT-OSS**: Deferred — requires transposed weight layout `[E, H, 2*I]`, expert biases, and custom GLU activation. A dedicated forward path is needed.
- **FSDP + fused gate LoRA (SonicMoE)**: The fused topk→softmax path materializes a local tensor when LoRA delta is present to avoid DTensor + Tensor mixing under FSDP.
## Note on MegaBlocks

View File

@@ -45,6 +45,28 @@ class KernelsArgs(BaseModel):
return data
@model_validator(mode="before")
@classmethod
def warn_sonicmoe_lora_overhead(cls, data):
if data.get("use_sonicmoe") is True and data.get("adapter") in (
"lora",
"qlora",
):
lora_target = data.get("lora_target_modules") or []
lora_linear = data.get("lora_target_linear_modules") or []
targets = (
lora_target if isinstance(lora_target, list) else [lora_target]
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
expert_keywords = ("gate_up_proj", "down_proj", "experts")
if any(kw in t for t in targets for kw in expert_keywords):
LOG.info(
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
)
return data
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel(cls, data):

View File

@@ -49,6 +49,11 @@ class ParallelLinear(torch.autograd.Function):
grouped_in: bool = False,
grouped_out: bool = False,
):
# Cast weights to match input dtype (e.g. 8-bit LoRA)
if expert_weights.dtype != x.dtype:
expert_weights = expert_weights.to(x.dtype)
if expert_biases is not None and expert_biases.dtype != x.dtype:
expert_biases = expert_biases.to(x.dtype)
with torch.device(x.device):
output = kernels.ops.scatter2scatter(
X=x,

View File

@@ -65,6 +65,11 @@ class ScatterMoELoRA(torch.autograd.Function):
use_fused_dX: bool = False,
use_fused_gather: bool = False,
):
# Cast weights to match input dtype (e.g. 8-bit LoRA)
if expert_weights.dtype != x.dtype:
expert_weights = expert_weights.to(x.dtype)
if expert_biases is not None and expert_biases.dtype != x.dtype:
expert_biases = expert_biases.to(x.dtype)
with torch.device(x.device):
# Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T
output = scatter2scatter_lora(

View File

@@ -0,0 +1,220 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
SonicMoE LoRA support via runtime weight materialization.
SonicMoE uses opaque CUTLASS kernels that cannot be modified to fuse LoRA.
Instead, we materialize the effective weight W_eff = W + scaling * (B @ A)
before each CUTLASS call, and use a custom autograd.Function to route
gradients back to the LoRA A and B parameters.
PEFT unwrapping utilities are also provided to handle the ParamWrapper
chain that PEFT creates when targeting expert parameters.
"""
from typing import Optional
import torch
# =============================================================================
# PEFT unwrapping utilities
# =============================================================================
def has_lora(module) -> bool:
"""Check if a module is wrapped by PEFT with LoRA."""
return hasattr(module, "base_layer") and hasattr(module, "lora_A")
def get_lora_params_from_wrapper(module) -> tuple:
"""Extract LoRA parameters from a PEFT ParamWrapper.
Returns:
(lora_A, lora_B, scaling) if LoRA is active, else (None, None, None)
"""
if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"):
return None, None, None
active_adapters = getattr(module, "active_adapters", ["default"])
if not active_adapters:
return None, None, None
adapter_name = active_adapters[0]
lora_A_dict = getattr(module, "lora_A", {})
lora_B_dict = getattr(module, "lora_B", {})
scaling_dict = getattr(module, "scaling", {})
if (
adapter_name not in lora_A_dict
or adapter_name not in lora_B_dict
or adapter_name not in scaling_dict
):
return None, None, None
lora_A = lora_A_dict[adapter_name].weight
lora_B = lora_B_dict[adapter_name].weight
scaling = scaling_dict[adapter_name]
return lora_A, lora_B, scaling
def unwrap_gate_lora(gate_module):
"""Unwrap PEFT ParamWrapper on the router gate.
When PEFT targets ``gate.weight``, ``self.gate`` becomes::
ParamWrapper(weight)
-> base_layer: Router (the real module)
Returns:
(base_gate, gate_weight, gate_lora_delta_or_None)
``base_gate`` is the original router module (with ``.top_k``, etc.).
``gate_weight`` is the base router weight tensor.
``gate_lora_delta_or_None`` is the LoRA delta if active, else None.
Kept separate to avoid mixing DTensor + Tensor under FSDP.
"""
if has_lora(gate_module):
base_gate = gate_module.base_layer
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module)
if lora_A is not None:
delta = scaling * (lora_B @ lora_A)
return base_gate, base_gate.weight, delta
return base_gate, base_gate.weight, None
return gate_module, gate_module.weight, None
def unwrap_experts_lora(experts_module):
"""Walk a PEFT ParamWrapper chain on ``self.experts``.
When PEFT targets ``experts.gate_up_proj`` and ``experts.down_proj``
via ``target_parameters``, ``self.experts`` becomes::
ParamWrapper(down_proj)
-> base_layer: ParamWrapper(gate_up_proj)
-> base_layer: Experts (the real module)
Returns:
(base_experts, lora_dict)
``lora_dict`` maps parameter names to ``(lora_A, lora_B, scaling)``
tuples, or is empty if no LoRA is active.
"""
wrappers = {}
module = experts_module
while hasattr(module, "base_layer") and hasattr(module, "lora_A"):
param_name = getattr(module, "parameter_name", None)
if param_name is not None:
wrappers[param_name] = module
module = module.base_layer
base_experts = module
lora_dict = {}
for param_name, wrapper in wrappers.items():
lora_A, lora_B, scaling = get_lora_params_from_wrapper(wrapper)
if lora_A is not None:
lora_dict[param_name] = (lora_A, lora_B, scaling)
return base_experts, lora_dict
# =============================================================================
# LoRA weight materialization autograd function
# =============================================================================
class MoELoRAMaterialize(torch.autograd.Function):
"""Materialize effective weight W_eff = W + scaling * (B @ A) per expert.
Inserts into the autograd graph between PEFT's LoRA parameters and
SonicMoE's CUTLASS kernels. The CUTLASS backward computes dW_eff,
which this function decomposes into dA and dB via the chain rule.
Weight layouts (PEFT rank-major):
base_weight: [E, dim1, dim2] (frozen expert parameter)
lora_A: [r*E, dim2] (rows [e*r:(e+1)*r] = A_e)
lora_B: [dim1, r*E] (cols [:, e*r:(e+1)*r] = B_e)
Per-expert: delta_e = B_e @ A_e = [dim1, r] @ [r, dim2] = [dim1, dim2]
"""
@staticmethod
def forward(
ctx,
base_weight: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
scaling: float,
) -> torch.Tensor:
E, dim1, dim2 = base_weight.shape
r = lora_A.shape[0] // E
assert lora_A.shape[0] == r * E, (
f"lora_A rows ({lora_A.shape[0]}) must be divisible by num_experts ({E})"
)
# Reshape PEFT rank-major to per-expert batched format
A_3d = lora_A.reshape(E, r, dim2)
B_3d = lora_B.reshape(dim1, r, E).permute(2, 0, 1).contiguous() # [E, dim1, r]
# Batched matmul: [E, dim1, r] @ [E, r, dim2] = [E, dim1, dim2]
delta = torch.bmm(B_3d, A_3d)
W_eff = base_weight + scaling * delta
ctx.save_for_backward(lora_A, lora_B)
ctx.scaling = scaling
ctx.E = E
ctx.r = r
return W_eff
@staticmethod
def backward(ctx, grad_W_eff: torch.Tensor):
lora_A, lora_B = ctx.saved_tensors
scaling = ctx.scaling
E = ctx.E
r = ctx.r
_, dim1, dim2 = grad_W_eff.shape
# Reshape to per-expert (same as forward)
A_3d = lora_A.reshape(E, r, dim2)
B_3d = lora_B.reshape(dim1, r, E).permute(2, 0, 1).contiguous() # [E, dim1, r]
# dA_e = scaling * B_e^T @ dW_e
# [E, r, dim1] @ [E, dim1, dim2] = [E, r, dim2]
d_A_3d = scaling * torch.bmm(B_3d.transpose(1, 2), grad_W_eff)
# dB_e = scaling * dW_e @ A_e^T
# [E, dim1, dim2] @ [E, dim2, r] = [E, dim1, r]
d_B_3d = scaling * torch.bmm(grad_W_eff, A_3d.transpose(1, 2))
# Reshape back to PEFT rank-major layout
d_lora_A = d_A_3d.reshape(E * r, dim2)
d_lora_B = d_B_3d.permute(1, 2, 0).contiguous().reshape(dim1, E * r)
return None, d_lora_A, d_lora_B, None
def materialize_expert_lora(
base_weight: torch.Tensor,
lora_params: Optional[tuple],
) -> torch.Tensor:
"""Materialize effective expert weight with optional LoRA delta.
Args:
base_weight: [E, dim1, dim2] frozen expert parameter
lora_params: (lora_A, lora_B, scaling) or None
Returns:
W_eff if lora_params is not None, else base_weight unchanged.
"""
if lora_params is None:
return base_weight
lora_A, lora_B, scaling = lora_params
return MoELoRAMaterialize.apply(base_weight, lora_A, lora_B, scaling)

View File

@@ -28,20 +28,72 @@ import torch.nn.functional as F
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
from axolotl.utils.logging import get_logger
from .lora import (
has_lora,
materialize_expert_lora,
unwrap_experts_lora,
unwrap_gate_lora,
)
LOG = get_logger(__name__)
def patch_sonicmoe(model_type: str, torch_compile: bool = False):
"""Main entry point: patch SparseMoeBlock for SonicMoE support.
def _get_expert_weights(experts_module):
"""Extract expert weights, applying LoRA materialization if PEFT is active.
Args:
model_type: The HuggingFace model type (e.g. "qwen3_moe").
torch_compile: If True, wrap routing functions with torch.compile
for kernel fusion (fuses softmax+topk+renorm into fewer launches).
Returns:
(gate_up_weight, down_weight) in SonicMoE layout [dim, dim, E].
"""
if has_lora(experts_module):
base_experts, lora_dict = unwrap_experts_lora(experts_module)
gate_up = materialize_expert_lora(
base_experts.gate_up_proj, lora_dict.get("gate_up_proj")
)
down = materialize_expert_lora(
base_experts.down_proj, lora_dict.get("down_proj")
)
else:
gate_up = experts_module.gate_up_proj
down = experts_module.down_proj
# Permute to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
return gate_up.permute(1, 2, 0), down.permute(1, 2, 0)
def _fix_qwen3_5_moe_text_weight_renaming(model_type: str, base_model_type: str):
"""Strip qwen3_5_moe_text WeightRenaming in VLM mode to preserve custom loaders."""
if model_type != "qwen3_5_moe_text" or base_model_type == "qwen3_5_moe_text":
return
try:
from transformers.conversion_mapping import (
get_checkpoint_conversion_mapping,
register_checkpoint_conversion_mapping,
)
from transformers.core_model_loading import WeightRenaming
except ImportError:
return
text_mapping = get_checkpoint_conversion_mapping(model_type)
if text_mapping and isinstance(text_mapping[0], WeightRenaming):
text_mapping.pop(0)
register_checkpoint_conversion_mapping(model_type, text_mapping, overwrite=True)
LOG.info("Stripped qwen3_5_moe_text WeightRenaming for VLM mode")
def patch_sonicmoe(
model_type: str,
torch_compile: bool = False,
base_model_type: str | None = None,
):
"""Patch SparseMoeBlock for SonicMoE support."""
from .routing import get_model_moe_config
from .weight_converter import register_sonicmoe_weight_converter
_fix_qwen3_5_moe_text_weight_renaming(model_type, base_model_type or model_type)
routing_fn, activation, router_attr = get_model_moe_config(model_type)
if torch_compile and routing_fn is not None:
@@ -113,11 +165,10 @@ def _make_general_forward(moe_cls, routing_fn, activation):
hidden_states_flat, self
)
# Permute weights to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
down_weight = self.experts.down_proj.permute(1, 2, 0)
# Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout
gate_up_weight, down_weight = _get_expert_weights(self.experts)
gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype)
down_weight = down_weight.to(hidden_states_flat.dtype)
E = gate_up_weight.shape[-1]
output, _ = moe_general_routing_inputs(
@@ -161,22 +212,30 @@ def _make_fused_forward(moe_cls, activation, router_attr):
# Shared expert (computed early, matching original model ordering)
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
router = getattr(self, router_attr)
# Unwrap router for attribute access + optional LoRA delta
raw_router = getattr(self, router_attr)
base_router, router_weight, router_lora_delta = unwrap_gate_lora(raw_router)
if router_lora_delta is not None:
# Materialize local tensor to avoid DTensor + Tensor add under FSDP
if hasattr(router_weight, "to_local"):
router_weight = router_weight.to_local()
effective_router_weight = router_weight + router_lora_delta
else:
effective_router_weight = router_weight
# Permute weights to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
down_weight = self.experts.down_proj.permute(1, 2, 0)
# Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout
gate_up_weight, down_weight = _get_expert_weights(self.experts)
gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype)
down_weight = down_weight.to(hidden_states_flat.dtype)
output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer(
hidden_states_flat,
router.weight,
effective_router_weight,
gate_up_weight,
None, # b1 (no gate/up bias)
down_weight,
None, # b2 (no down bias)
router.top_k,
base_router.top_k,
torch.cuda.current_stream().cuda_stream,
activation,
False, # is_inference_mode

View File

@@ -16,6 +16,8 @@ When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
import torch
import torch.nn.functional as F
from .lora import unwrap_gate_lora
def get_model_moe_config(model_type: str):
"""Returns (routing_fn, activation, router_attr) for a given model type.
@@ -40,6 +42,7 @@ def get_model_moe_config(model_type: str):
"qwen2_moe",
"qwen3_moe",
"qwen3_5_moe",
"qwen3_5_moe_text",
"qwen3_next",
"qwen3_vl_moe",
"qwen3_omni_moe",
@@ -88,12 +91,18 @@ def softmax_topk_routing(
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, _ = hidden_states.shape
K = gate.top_k
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
T, H = hidden_states.shape
K = base_gate.top_k
# Compute router logits and softmax over all experts
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
# Compute router logits and softmax over all experts.
# Two F.linear calls avoid mixing DTensor (gate_weight) + Tensor (delta) under FSDP.
# Cast to float32 to match LoRA delta dtype (PEFT computes in fp32).
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Select top-k experts per token
@@ -101,7 +110,7 @@ def softmax_topk_routing(
# Renormalize if configured (default True for models without the attribute,
# e.g. Mixtral/MiniMax which always normalize)
if getattr(gate, "norm_topk_prob", True):
if getattr(base_gate, "norm_topk_prob", True):
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
# no-op: matches transformers which casts to softmax output dtype (float32).
@@ -128,13 +137,17 @@ def softmax_group_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
gate = moe_block.gate
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
T, _ = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
scores_for_choice = router_probs
@@ -206,25 +219,29 @@ def sigmoid_topk_routing(
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
T, _ = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
# Compute router logits and sigmoid probabilities
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = router_logits.sigmoid() # [T, E]
# Bias-corrected scores for expert selection (not used for final weights).
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block.
e_score_correction_bias = getattr(gate, "e_score_correction_bias", None)
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
if e_score_correction_bias is None:
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
if e_score_correction_bias is None:
raise AttributeError(
f"sigmoid_topk_routing requires e_score_correction_bias on "
f"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it"
f"gate ({type(base_gate)}) or moe_block ({type(moe_block)}), but neither has it"
)
scores_for_choice = router_probs + e_score_correction_bias
@@ -296,16 +313,20 @@ def softmax_bias_topk_routing(
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, _ = hidden_states.shape
K = gate.top_k
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
T, H = hidden_states.shape
K = base_gate.top_k
# Compute router logits and softmax (force float32 for numerical stability)
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Bias-corrected scores for expert selection (via moe_statics module)
scores_for_choice = gate.moe_statics(router_probs) # [T, E]
scores_for_choice = base_gate.moe_statics(router_probs) # [T, E]
# Select top-k experts using biased scores
_, selected_experts = torch.topk(scores_for_choice, K, dim=-1) # [T, K]
@@ -314,7 +335,7 @@ def softmax_bias_topk_routing(
top_values = torch.gather(router_probs, dim=-1, index=selected_experts) # [T, K]
# Renormalize with clamp(min=norm_min) instead of sum+epsilon
norm_min = getattr(gate, "norm_min", 1e-20)
norm_min = getattr(base_gate, "norm_min", 1e-20)
top_values = top_values / torch.clamp(
top_values.sum(dim=-1, keepdim=True), min=norm_min
)
@@ -358,15 +379,19 @@ def softmax_group_limited_topk_routing(
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, _ = hidden_states.shape
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
T, H = hidden_states.shape
K = moe_block.top_k
num_group = getattr(moe_block, "num_group", 1)
num_experts = gate.weight.shape[0]
num_experts = gate_weight.shape[0]
topk_method = getattr(moe_block, "topk_method", "greedy")
# Compute logits in float32 and softmax
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
if topk_method == "greedy" or num_group == 1:
@@ -445,12 +470,17 @@ def softmax_topk_wg_routing(
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, _ = hidden_states.shape
T, H = hidden_states.shape
K = moe_block.top_k
# Gate computes logits via gate.wg (nn.Linear, float32)
wg = gate.wg
router_logits = F.linear(hidden_states.float(), wg.weight.float()) # [T, E]
# Unwrap at gate.wg level since PEFT targets the wg Linear, not the gate container
base_wg, wg_weight, wg_lora_delta = unwrap_gate_lora(gate.wg)
router_logits = F.linear(hidden_states.float(), wg_weight.float()) # [T, E]
if wg_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), wg_lora_delta.float()
)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Select top-k experts

View File

@@ -129,15 +129,41 @@ class InterleavedToConcatenated(ConversionOps):
return ConcatenatedToInterleaved(self.dim)
def _make_same_key_interleave_converter():
"""Create a WeightConverter that interleaves an already-fused gate_up_proj."""
from transformers.core_model_loading import WeightConverter
return WeightConverter(
source_patterns="mlp.experts.gate_up_proj",
target_patterns="mlp.experts.gate_up_proj",
operations=[ConcatenatedToInterleaved(dim=1)],
)
def _has_same_key_interleave(mapping) -> bool:
"""Check whether the mapping already has a same-key gate_up_proj interleave converter."""
for conv in mapping:
if (
hasattr(conv, "source_patterns")
and conv.source_patterns == ["mlp.experts.gate_up_proj"]
and conv.target_patterns == ["mlp.experts.gate_up_proj"]
and hasattr(conv, "operations")
and any(isinstance(op, ConcatenatedToInterleaved) for op in conv.operations)
):
return True
return False
def register_sonicmoe_weight_converter(model_type: str):
"""Override the conversion mapping to add interleave step for gate_up_proj.
"""Register weight converters to interleave gate_up_proj for SonicMoE.
Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj
converter chain. For example, qwen3_moe's chain becomes:
MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1)
Handles two checkpoint formats:
1. Separate per-expert weights (e.g. qwen3_moe): appends interleave to the
existing merge chain (MergeModulelist -> Concatenate -> Interleave).
2. Already-fused gate_up_proj (e.g. qwen3_5_moe_text): adds a same-key
converter (gate_up_proj -> gate_up_proj with Interleave).
The reverse is auto-generated for saving:
InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0)
The loader matches whichever source pattern exists in the checkpoint.
"""
from transformers.conversion_mapping import (
get_checkpoint_conversion_mapping,
@@ -145,37 +171,32 @@ def register_sonicmoe_weight_converter(model_type: str):
)
existing = get_checkpoint_conversion_mapping(model_type)
if existing is None:
LOG.warning(
f"No conversion mapping found for model type '{model_type}'. "
"SonicMoE weight interleaving will not be applied during checkpoint loading."
)
# No mapping at all — create one with just the same-key converter
mapping = [_make_same_key_interleave_converter()]
register_checkpoint_conversion_mapping(model_type, mapping)
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")
return
# Find the gate_up_proj converter and append ConcatenatedToInterleaved
patched = False
# Append interleave to any existing many-to-one merge chain
for converter in existing:
if hasattr(converter, "operations") and any(
"gate_up_proj" in pat for pat in converter.target_patterns
):
# Guard against double registration (e.g. plugin reloaded)
if any(
has_separate_sources = any(
"gate_proj" in pat or "up_proj" in pat
for pat in converter.source_patterns
)
if has_separate_sources and not any(
isinstance(op, ConcatenatedToInterleaved) for op in converter.operations
):
LOG.info(
f"SonicMoE weight converter already registered for '{model_type}'"
)
return
converter.operations.append(ConcatenatedToInterleaved(dim=1))
patched = True
converter.operations.append(ConcatenatedToInterleaved(dim=1))
break
if not patched:
LOG.warning(
f"Could not find gate_up_proj converter for model type '{model_type}'. "
"SonicMoE weight interleaving will not be applied during checkpoint loading."
)
return
# Also add a same-key converter for already-fused checkpoints
if not _has_same_key_interleave(existing):
existing.append(_make_same_key_interleave_converter())
register_checkpoint_conversion_mapping(model_type, existing, overwrite=True)
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")

View File

@@ -84,12 +84,13 @@ class KernelsPlugin(BasePlugin):
_check_sonicmoe_gpu_compat()
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
patch_sonicmoe(
moe_model_type,
torch_compile=bool(getattr(cfg, "torch_compile", False)),
base_model_type=cfg.model_config_type,
)
def _register_kernels(self):