feat: add sonicmoe (#3411)

* feat: add sonicmoe

* feat: add torch compile for routing

* feat: add routing smoke test

* feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe

* fix: disable mlp kernel for sonicmoe too

* feat: update to sonicmoe release

* chore: update import following new sonicmoe changes

* feat: update handling for blackwell

* feat: add sonicmoe e2e test

* fix: installation for updated sonicmoe

* fix: git commit

* fix: ignore py req and fix metadata

* fix: increase min hidden size to match sonicmoe kernel min

* fix: attempt properly interleave and handle unpatch mid-test

* chore: refactor teardown better

* chore: refactor to re-use rearrange

* fix: add idempotency guard

* fix: address comments on CI memory and interleave

* fix: tests grad, param doublewrapped
This commit is contained in:
NanoCode012
2026-03-06 01:43:31 +07:00
committed by GitHub
parent 1eaf4d7418
commit 6a8baf8fa7
12 changed files with 1698 additions and 42 deletions

View File

@@ -10,7 +10,7 @@ class ExpertsInterface(GeneralInterface):
}
```
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
In our custom integration, we add support for **ScatterMoE** and **SonicMoE**, which are more efficient and faster than `grouped_mm`.
## Usage
@@ -21,23 +21,55 @@ plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
# Choose one (mutually exclusive):
use_scattermoe: true
# OR
use_sonicmoe: true
```
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
**Important:** Setting `experts_implementation` is incompatible with custom kernel options.
### SonicMoE installation
**Prerequisites:**
- NVIDIA Hopper (H100, H200) or Blackwell (B200, GB200) GPU
- CUDA 12.9+ (13.0+ for B300)
- PyTorch 2.7+ (2.9.1 recommended)
- For B300: Triton 3.6.0
```bash
pip install --ignore-requires-python --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@116e2df0a41874f77fa0ad269ce7df3f0cfcb956" && pip install nvidia-cutlass-dsl==4.4.0 quack-kernels==0.2.5
```
See the [SonicMoE installation guide](https://github.com/Dao-AILab/sonic-moe?tab=readme-ov-file#-installation) for the latest prerequisite details.
**Note:** Blackwell support is in upstream beta. On Blackwell GPUs, Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels.
## How It Works
The `KernelsPlugin` runs before model loading and:
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
### ScatterMoE
1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
### SonicMoE
1. Resolves the model's MoE block class(es) from `constants.py`.
2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format.
3. Supports both softmax->topk and sigmoid->topk routing strategies.
Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.
#### Supported Models
See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment.
SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures.
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.

View File

@@ -6,7 +6,18 @@ LOG = get_logger(__name__)
class KernelsArgs(BaseModel):
use_scattermoe: bool | None = True
use_scattermoe: bool | None = None
use_sonicmoe: bool | None = None
@model_validator(mode="before")
@classmethod
def check_mutually_exclusive(cls, data):
if data.get("use_scattermoe") and data.get("use_sonicmoe"):
raise ValueError(
"Cannot use both ScatterMoE and SonicMoE simultaneously. "
"Please set only one of `use_scattermoe` or `use_sonicmoe` to true."
)
return data
@model_validator(mode="before")
@classmethod
@@ -36,11 +47,11 @@ class KernelsArgs(BaseModel):
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel_scattermoe(cls, data):
if data.get("use_scattermoe") is True:
def disable_mlp_kernel(cls, data):
if data.get("use_scattermoe") is True or data.get("use_sonicmoe") is True:
if data.get("lora_mlp_kernel") is True:
LOG.warning(
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
"Disabling lora_mlp_kernel when using custom MoE kernels due to compatibility issues."
)
data["lora_mlp_kernel"] = False
data["mlp_kernel"] = False

View File

@@ -0,0 +1,68 @@
"""
Supported MoE block mappings for kernel integrations.
Maps model_type to the SparseMoeBlock class name(s) in transformers.
Used by both ScatterMoE and SonicMoE kernel paths.
Values can be a single class name (str) or a list of class names for models
with multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker).
"""
import importlib
SPARSE_MOE_BLOCK = {
# softmax -> topk routing
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
"qwen3_next": "Qwen3NextSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
"qwen3_omni_moe": [
"Qwen3OmniMoeThinkerTextSparseMoeBlock",
"Qwen3OmniMoeTalkerTextSparseMoeBlock",
],
"olmoe": "OlmoeSparseMoeBlock",
"mixtral": "MixtralSparseMoeBlock",
"minimax": "MiniMaxSparseMoeBlock",
# sigmoid -> topk routing (with group-based expert selection)
"glm_moe_dsa": "GlmMoeDsaMoE",
"deepseek_v3": "DeepseekV3MoE",
"glm4_moe": "Glm4MoeMoE",
"glm4_moe_lite": "Glm4MoeLiteMoE",
"glm4v_moe": "Glm4vMoeTextMoE",
# sigmoid -> topk routing (no group selection)
"minimax_m2": "MiniMaxM2SparseMoeBlock",
# Models below need custom routing (not yet implemented):
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
# "hunyuan_v1_moe": "HunYuanMoEV1Moe", # softmax->topk, gate.wg (not gate.weight), scatter routing
# "gpt_oss": "GptOssMLP", # topk->softmax, transposed layout [E,H,2*I], custom GLU, expert biases
}
def resolve_moe_block_classes(model_type: str):
"""Resolve all MoE block classes from transformers for the given model type.
Returns a list of classes (one for most models, multiple for models with
distinct MoE block types like qwen3_omni_moe).
"""
entry = SPARSE_MOE_BLOCK.get(model_type)
if entry is None:
raise ValueError(
f"Unsupported MoE model type '{model_type}'. "
f"Supported types: {list(SPARSE_MOE_BLOCK.keys())}"
)
cls_names = entry if isinstance(entry, list) else [entry]
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = importlib.import_module(module_path)
classes = []
for cls_name in cls_names:
moe_cls = getattr(module, cls_name, None)
if moe_cls is None:
raise ValueError(f"Could not find class '{cls_name}' in '{module_path}'")
classes.append(moe_cls)
return classes

View File

@@ -1,14 +1,59 @@
import importlib
import os
from pathlib import Path
from kernels import (
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _check_sonicmoe_gpu_compat():
"""Validate GPU compute capability for SonicMoE and configure env.
Supported: Hopper (sm_90), Blackwell (sm_100 - sm_103).
B300 (sm_103) additionally requires Triton 3.6.0.
"""
if not torch.cuda.is_available():
return
cc = torch.cuda.get_device_capability()
if cc < (9, 0):
raise RuntimeError(
f"SonicMoE requires Hopper (sm_90) or Blackwell (sm_100+) GPU, "
f"but detected sm_{cc[0]}{cc[1]}."
)
if cc > (10, 3):
raise RuntimeError(
f"SonicMoE does not yet support sm_{cc[0]}{cc[1]}. "
f"Supported: Hopper (sm_90) and Blackwell (sm_100 - sm_103)."
)
# Blackwell (sm_100+): enable QuACK GEMM kernels
if cc >= (10, 0):
os.environ.setdefault("USE_QUACK_GEMM", "1")
LOG.info(
f"Blackwell GPU (sm_{cc[0]}{cc[1]}) detected, enabling USE_QUACK_GEMM=1"
)
# B300 (sm_103): requires Triton 3.6.0
if cc == (10, 3):
triton_spec = importlib.util.find_spec("triton")
if triton_spec is None:
raise RuntimeError(
"B300 (sm_103) requires Triton 3.6.0, but Triton is not installed."
)
import triton
triton_version = tuple(int(x) for x in triton.__version__.split(".")[:2])
if triton_version != (3, 6):
raise RuntimeError(
f"B300 (sm_103) requires Triton 3.6.x, but found {triton.__version__}."
)
class KernelsPlugin(BasePlugin):
@@ -19,8 +64,32 @@ class KernelsPlugin(BasePlugin):
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(cfg.model_config_type)
elif cfg.use_sonicmoe:
if not importlib.util.find_spec("sonicmoe"):
raise RuntimeError(
"SonicMoE is not installed. See installation instructions at "
"https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/kernels/README.md#sonicmoe-installation"
)
_check_sonicmoe_gpu_compat()
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
LOG.info(
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
)
patch_sonicmoe(
cfg.model_config_type,
torch_compile=bool(getattr(cfg, "torch_compile", False)),
)
def _register_kernels(self):
from kernels import (
LocalLayerRepository,
Mode,
register_kernel_mapping,
)
plugin_root = Path(__file__).parent
register_kernel_mapping(
{
@@ -42,25 +111,11 @@ class KernelsPlugin(BasePlugin):
)
def _kernelize_model(self, model_type: str):
if model_type == "olmoe":
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
from kernels import replace_kernel_forward_from_hub
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
for model_moe_cls in resolve_moe_block_classes(model_type):
replace_kernel_forward_from_hub(
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
model_moe_cls, "HFScatterMoEParallelExperts"
)
else:
try:
model_moe_cls = get_model_moe_block(model_type)
replace_kernel_forward_from_hub(
model_moe_cls, "HFScatterMoEParallelExperts"
)
except Exception as err:
raise ValueError(f"Unsupported model type: {model_type}") from err
def get_model_moe_block(model_type: str):
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"])
model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock")
return model_cls

View File

@@ -0,0 +1,3 @@
from .patch import patch_sonicmoe
__all__ = ["patch_sonicmoe"]

View File

@@ -0,0 +1,213 @@
"""
SonicMoE patching for SparseMoeBlock forward pass.
Monkeypatches the SparseMoeBlock class for a given model type to use
SonicMoE's optimized kernels. Two forward paths are supported:
1. **General routing path** (routing_fn is not None):
Uses a custom routing function + ``moe_general_routing_inputs``.
Suitable for models with non-standard routing (softmax->topk, sigmoid->topk).
2. **Fused topk->softmax path** (routing_fn is None):
Uses ``moe_TC_softmax_topk_layer`` which fuses routing + expert computation.
Suitable for models with simple topk->softmax routing.
Weight format conversion (interleave/deinterleave) is handled by the
WeightConverter system, so the forward assumes weights are already in
interleaved format.
Shared experts are handled generically: if the block has a ``shared_expert``
or ``shared_experts`` attribute, its output is computed alongside the routed
experts and added to the final output. An optional ``shared_expert_gate``
applies sigmoid gating to the shared expert contribution.
"""
import torch
import torch.nn.functional as F
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patch_sonicmoe(model_type: str, torch_compile: bool = False):
"""Main entry point: patch SparseMoeBlock for SonicMoE support.
Args:
model_type: The HuggingFace model type (e.g. "qwen3_moe").
torch_compile: If True, wrap routing functions with torch.compile
for kernel fusion (fuses softmax+topk+renorm into fewer launches).
"""
from .routing import get_model_moe_config
from .weight_converter import register_sonicmoe_weight_converter
routing_fn, activation, router_attr = get_model_moe_config(model_type)
if torch_compile and routing_fn is not None:
routing_fn = _try_compile_routing(routing_fn)
for moe_cls in resolve_moe_block_classes(model_type):
_patch_forward(moe_cls, routing_fn, activation, router_attr)
register_sonicmoe_weight_converter(model_type)
def _try_compile_routing(routing_fn):
"""Attempt to torch.compile the routing function, fall back to eager on failure."""
try:
compiled_fn = torch.compile(routing_fn, mode="reduce-overhead", dynamic=False)
LOG.info(f"torch.compile enabled for routing function: {routing_fn.__name__}")
return compiled_fn
except Exception as exc: # pylint: disable=broad-except
LOG.warning(
f"torch.compile failed for routing function {routing_fn.__name__}, "
f"falling back to eager: {exc}"
)
return routing_fn
def _patch_forward(moe_cls, routing_fn, activation, router_attr):
"""Monkeypatch the SparseMoeBlock class with a SonicMoE forward.
The patched forward handles shared experts generically: if
``self.shared_expert`` or ``self.shared_experts`` exists, it is computed
and added to the routed output. If ``self.shared_expert_gate`` also exists,
it applies sigmoid gating to the shared expert contribution (as in qwen2_moe).
Args:
moe_cls: The SparseMoeBlock class to patch.
routing_fn: Routing function (e.g. softmax_topk_routing), or None
for the fused moe_TC_softmax_topk_layer path.
activation: SonicMoE ActivationType enum value.
router_attr: Name of the router module attribute on the MoE block.
"""
if hasattr(moe_cls, "_original_forward"):
LOG.info(f"{moe_cls.__name__}.forward already patched with SonicMoE, skipping")
return
original_forward = moe_cls.forward
if routing_fn is not None:
_make_general_forward(moe_cls, routing_fn, activation)
else:
_make_fused_forward(moe_cls, activation, router_attr)
moe_cls._original_forward = original_forward
LOG.info(f"Patched {moe_cls.__name__}.forward with SonicMoE implementation")
def _make_general_forward(moe_cls, routing_fn, activation):
"""Create forward using routing_fn + moe_general_routing_inputs."""
def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
from sonicmoe import moe_general_routing_inputs
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)
# Shared expert (computed early, matching original model ordering)
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
# Routing
router_scores, token_indices, expert_indices, _router_logits = routing_fn(
hidden_states_flat, self
)
# Permute weights to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
down_weight = self.experts.down_proj.permute(1, 2, 0)
E = gate_up_weight.shape[-1]
output, _ = moe_general_routing_inputs(
hidden_states_flat,
router_scores,
token_indices,
expert_indices,
gate_up_weight,
None, # b1 (no gate/up bias)
down_weight,
None, # b2 (no down bias)
E,
torch.cuda.current_stream().cuda_stream,
activation,
False, # is_inference_mode
)
# Add shared expert contribution if present
if shared_expert_output is not None:
if hasattr(self, "shared_expert_gate"):
shared_expert_output = (
F.sigmoid(self.shared_expert_gate(hidden_states_flat))
* shared_expert_output
)
output = output + shared_expert_output
return output.view(batch_size, sequence_length, hidden_dim)
moe_cls.forward = sonicmoe_forward
def _make_fused_forward(moe_cls, activation, router_attr):
"""Create forward using moe_TC_softmax_topk_layer (topk -> softmax)."""
def sonicmoe_fused_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
from sonicmoe import moe_TC_softmax_topk_layer
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)
# Shared expert (computed early, matching original model ordering)
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
router = getattr(self, router_attr)
# Permute weights to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
down_weight = self.experts.down_proj.permute(1, 2, 0)
output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer(
hidden_states_flat,
router.weight,
gate_up_weight,
None, # b1 (no gate/up bias)
down_weight,
None, # b2 (no down bias)
router.top_k,
torch.cuda.current_stream().cuda_stream,
activation,
False, # is_inference_mode
)
# Add shared expert contribution if present
if shared_expert_output is not None:
if hasattr(self, "shared_expert_gate"):
shared_expert_output = (
F.sigmoid(self.shared_expert_gate(hidden_states_flat))
* shared_expert_output
)
output = output + shared_expert_output
return output.view(batch_size, sequence_length, hidden_dim)
moe_cls.forward = sonicmoe_fused_forward
def _compute_shared_expert(moe_block, hidden_states_flat):
"""Compute shared expert output if the block has one.
Handles singular (qwen2_moe: ``shared_expert``), plural
(glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP
(hunyuan_v1_moe: ``shared_mlp``) attribute names.
"""
shared_expert = (
getattr(moe_block, "shared_expert", None)
or getattr(moe_block, "shared_experts", None)
or getattr(moe_block, "shared_mlp", None)
)
if shared_expert is not None:
return shared_expert(hidden_states_flat)
return None

View File

@@ -0,0 +1,219 @@
"""
Routing functions for SonicMoE integration.
Different MoE architectures use different routing strategies:
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
"""
import torch
import torch.nn.functional as F
def get_model_moe_config(model_type: str):
"""Returns (routing_fn, activation, router_attr) for a given model type.
Args:
model_type: HuggingFace model type string.
Returns:
routing_fn: Callable or None. None signals the fused
moe_TC_softmax_topk_layer path (topk -> softmax models).
activation: SonicMoE ActivationType enum value.
router_attr: Name of the router module attribute on the MoE block
(e.g. "gate" or "router").
The activation type cannot be derived from config.hidden_act because
e.g. qwen3_moe reports "silu" but architecturally uses SwiGLU
(act_fn(gate) * up pattern). So we specify it per model type.
"""
from sonicmoe.enums import ActivationType
if model_type in (
"qwen2_moe",
"qwen3_moe",
"qwen3_5_moe",
"qwen3_next",
"qwen3_vl_moe",
"qwen3_omni_moe",
"olmoe",
"mixtral",
"minimax",
):
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in (
"glm_moe_dsa",
"deepseek_v3",
"glm4_moe",
"glm4_moe_lite",
"glm4v_moe",
"minimax_m2",
):
return sigmoid_topk_routing, ActivationType.SWIGLU, "gate"
# elif model_type in ("ernie4_5_moe",):
# # Softmax→topk with e_score_correction_bias applied between softmax and topk.
# return ..., ActivationType.SWIGLU, "gate"
# elif model_type in ("deepseek_v2",):
# # Softmax→topk with group_limited_greedy. Different attr names: num_group
# # (not n_group), gate is nn.Linear (not a router class).
# return ..., ActivationType.SWIGLU, "gate"
# elif model_type in ("hunyuan_v1_moe",):
# # Softmax→topk but gate structure differs: gate.wg (not gate.weight),
# # top_k on block not gate, creates scatter routing matrix.
# return ..., ActivationType.SWIGLU, "gate"
# Fused topk -> softmax path (routing_fn=None):
# elif model_type in ("gpt_oss",):
# # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer
# # ignores (it only takes router_w, not bias). Also has transposed
# # weight layout [E, H, 2*I] and custom GLU activation.
# return None, ActivationType.SWIGLU, "router"
else:
raise ValueError(f"SonicMoE: unsupported model type '{model_type}'")
def softmax_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Qwen3/Qwen2-style routing: softmax -> topk -> optional renorm.
Args:
hidden_states: [T, H] flattened token representations
moe_block: MoE block module (accesses moe_block.gate.*)
Returns:
router_scores: [T*K] flattened scores (float32)
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, H = hidden_states.shape
K = gate.top_k
# Compute router logits and softmax over all experts
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Select top-k experts per token
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
# Renormalize if configured (default True for models without the attribute,
# e.g. Mixtral/MiniMax which always normalize)
if getattr(gate, "norm_topk_prob", True):
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
# no-op: matches transformers which casts to softmax output dtype (float32).
# top_values = top_values.to(router_probs.dtype)
# Flatten for moe_general_routing_inputs.
# Token indices are naturally sorted ascending from the [T, K] layout:
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
# Expert sorting is handled internally by general_routing_router_metadata.
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = top_values.reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def sigmoid_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sigmoid-based routing: sigmoid -> optional group selection -> topk.
Supports two variants:
- **Group selection** (glm_moe_dsa, deepseek_v3, etc.): n_group > 1,
bias on gate, group-based masking before topk.
- **No group selection** (minimax_m2): n_group == 1 (or absent),
bias on moe_block, straight topk from all experts.
Final routing weights come from the original sigmoid scores (not
bias-corrected), with optional renormalization and scaling.
Args:
hidden_states: [T, H] flattened token representations
moe_block: MoE block module (accesses moe_block.gate.* and
optional moe_block.n_group, .topk_group, .top_k, .norm_topk_prob,
.routed_scaling_factor, .n_routed_experts)
Returns:
router_scores: [T*K] flattened scores (float32)
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, H = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
# Compute router logits and sigmoid probabilities
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
router_probs = router_logits.sigmoid() # [T, E]
# Bias-corrected scores for expert selection (not used for final weights).
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block.
e_score_correction_bias = getattr(gate, "e_score_correction_bias", None)
if e_score_correction_bias is None:
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
if e_score_correction_bias is None:
raise AttributeError(
f"sigmoid_topk_routing requires e_score_correction_bias on "
f"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it"
)
scores_for_choice = router_probs + e_score_correction_bias
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
if n_group > 1:
group_scores = (
scores_for_choice.view(-1, n_group, E // n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [T, n_group]
group_idx = torch.topk(
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
# Final topk from (possibly masked) scores
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
# Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices)
# Optional renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob:
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
topk_weights = topk_weights * routed_scaling_factor
# Flatten for moe_general_routing_inputs.
# Token indices are naturally sorted ascending from the [T, K] layout.
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits

View File

@@ -0,0 +1,181 @@
"""
Custom WeightConverter operations for SonicMoE weight format conversion.
SonicMoE requires gate_up_proj weights in interleaved format:
- Standard (concatenated): [E, 2*I, H] where first I rows are gate, last I rows are up
- SonicMoE (interleaved): [E, 2*I, H] where rows alternate [g0, u0, g1, u1, ...]
These ConversionOps integrate with transformers' WeightConverter system so that
weights are transparently converted during loading and reverted during saving.
"""
from typing import Any
import torch
from einops import rearrange
from transformers.core_model_loading import ConversionOps
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def interleave_gate_up(tensor: torch.Tensor) -> torch.Tensor:
"""[gate..., up...] -> [g0, u0, g1, u1, ...] along the 2*I dimension."""
return rearrange(tensor, "... (two out) h -> ... (out two) h", two=2)
def deinterleave_gate_up(tensor: torch.Tensor) -> torch.Tensor:
"""[g0, u0, g1, u1, ...] -> [gate..., up...] along the 2*I dimension."""
return rearrange(tensor, "... (out two) h -> ... (two out) h", two=2)
class ConcatenatedToInterleaved(ConversionOps):
"""Convert concatenated gate/up projections to interleaved format.
Input: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H]
Output: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...]
This operation is applied along ``dim`` (default 1, the 2*I dimension).
"""
def __init__(self, dim: int = 1):
self.dim = dim
@torch.no_grad()
def convert(
self,
input_dict: dict[str, Any],
source_patterns: list[str],
target_patterns: list[str],
**kwargs,
) -> dict[str, torch.Tensor]:
target_pattern = self._get_target_pattern(
input_dict, source_patterns, target_patterns
)
tensors = next(iter(input_dict.values()))
tensor = tensors[0] if isinstance(tensors, list) else tensors
interleaved = interleave_gate_up(tensor)
return {target_pattern: interleaved}
def _get_target_pattern(
self,
input_dict: dict[str, Any],
source_patterns: list[str],
target_patterns: list[str],
) -> str:
# Follow the same logic as Transpose.get_target_pattern
if len(input_dict) != 1:
raise ValueError("Undefined Operation encountered!")
if len(target_patterns) > 1:
if len(source_patterns) == 1:
return source_patterns[0]
raise ValueError("Undefined Operation encountered!")
return target_patterns[0]
@property
def reverse_op(self) -> ConversionOps:
return InterleavedToConcatenated(self.dim)
class InterleavedToConcatenated(ConversionOps):
"""Convert interleaved gate/up projections back to concatenated format.
Input: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...]
Output: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H]
This is the reverse of ``ConcatenatedToInterleaved``.
"""
def __init__(self, dim: int = 1):
self.dim = dim
@torch.no_grad()
def convert(
self,
input_dict: dict[str, Any],
source_patterns: list[str],
target_patterns: list[str],
**kwargs,
) -> dict[str, torch.Tensor]:
target_pattern = self._get_target_pattern(
input_dict, source_patterns, target_patterns
)
tensors = next(iter(input_dict.values()))
tensor = tensors[0] if isinstance(tensors, list) else tensors
concatenated = deinterleave_gate_up(tensor)
return {target_pattern: concatenated}
def _get_target_pattern(
self,
input_dict: dict[str, Any],
source_patterns: list[str],
target_patterns: list[str],
) -> str:
if len(input_dict) != 1:
raise ValueError("Undefined Operation encountered!")
if len(target_patterns) > 1:
if len(source_patterns) == 1:
return source_patterns[0]
raise ValueError("Undefined Operation encountered!")
return target_patterns[0]
@property
def reverse_op(self) -> ConversionOps:
return ConcatenatedToInterleaved(self.dim)
def register_sonicmoe_weight_converter(model_type: str):
"""Override the conversion mapping to add interleave step for gate_up_proj.
Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj
converter chain. For example, qwen3_moe's chain becomes:
MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1)
The reverse is auto-generated for saving:
InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0)
"""
from transformers.conversion_mapping import (
get_checkpoint_conversion_mapping,
register_checkpoint_conversion_mapping,
)
existing = get_checkpoint_conversion_mapping(model_type)
if existing is None:
LOG.warning(
f"No conversion mapping found for model type '{model_type}'. "
"SonicMoE weight interleaving will not be applied during checkpoint loading."
)
return
# Find the gate_up_proj converter and append ConcatenatedToInterleaved
patched = False
for converter in existing:
if hasattr(converter, "operations") and any(
"gate_up_proj" in pat for pat in converter.target_patterns
):
# Guard against double registration (e.g. plugin reloaded)
if any(
isinstance(op, ConcatenatedToInterleaved) for op in converter.operations
):
LOG.info(
f"SonicMoE weight converter already registered for '{model_type}'"
)
return
converter.operations.append(ConcatenatedToInterleaved(dim=1))
patched = True
break
if not patched:
LOG.warning(
f"Could not find gate_up_proj converter for model type '{model_type}'. "
"SonicMoE weight interleaving will not be applied during checkpoint loading."
)
return
register_checkpoint_conversion_mapping(model_type, existing, overwrite=True)
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")