feat: add sonicmoe (#3411)

* feat: add sonicmoe

* feat: add torch compile for routing

* feat: add routing smoke test

* feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe

* fix: disable mlp kernel for sonicmoe too

* feat: update to sonicmoe release

* chore: update import following new sonicmoe changes

* feat: update handling for blackwell

* feat: add sonicmoe e2e test

* fix: installation for updated sonicmoe

* fix: git commit

* fix: ignore py req and fix metadata

* fix: increase min hidden size to match sonicmoe kernel min

* fix: attempt properly interleave and handle unpatch mid-test

* chore: refactor teardown better

* chore: refactor to re-use rearrange

* fix: add idempotency guard

* fix: address comments on CI memory and interleave

* fix: tests grad, param doublewrapped
This commit is contained in:
NanoCode012
2026-03-06 01:43:31 +07:00
committed by GitHub
parent 1eaf4d7418
commit 6a8baf8fa7
12 changed files with 1698 additions and 42 deletions

View File

@@ -10,7 +10,7 @@ class ExpertsInterface(GeneralInterface):
}
```
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
In our custom integration, we add support for **ScatterMoE** and **SonicMoE**, which are more efficient and faster than `grouped_mm`.
## Usage
@@ -21,23 +21,55 @@ plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
# Choose one (mutually exclusive):
use_scattermoe: true
# OR
use_sonicmoe: true
```
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
**Important:** Setting `experts_implementation` is incompatible with custom kernel options.
### SonicMoE installation
**Prerequisites:**
- NVIDIA Hopper (H100, H200) or Blackwell (B200, GB200) GPU
- CUDA 12.9+ (13.0+ for B300)
- PyTorch 2.7+ (2.9.1 recommended)
- For B300: Triton 3.6.0
```bash
pip install --ignore-requires-python --no-deps "sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@116e2df0a41874f77fa0ad269ce7df3f0cfcb956" && pip install nvidia-cutlass-dsl==4.4.0 quack-kernels==0.2.5
```
See the [SonicMoE installation guide](https://github.com/Dao-AILab/sonic-moe?tab=readme-ov-file#-installation) for the latest prerequisite details.
**Note:** Blackwell support is in upstream beta. On Blackwell GPUs, Axolotl automatically sets `USE_QUACK_GEMM=1` to enable the Blackwell kernels.
## How It Works
The `KernelsPlugin` runs before model loading and:
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
### ScatterMoE
1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
### SonicMoE
1. Resolves the model's MoE block class(es) from `constants.py`.
2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format.
3. Supports both softmax->topk and sigmoid->topk routing strategies.
Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.
#### Supported Models
See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment.
SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures.
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.

View File

@@ -6,7 +6,18 @@ LOG = get_logger(__name__)
class KernelsArgs(BaseModel):
use_scattermoe: bool | None = True
use_scattermoe: bool | None = None
use_sonicmoe: bool | None = None
@model_validator(mode="before")
@classmethod
def check_mutually_exclusive(cls, data):
if data.get("use_scattermoe") and data.get("use_sonicmoe"):
raise ValueError(
"Cannot use both ScatterMoE and SonicMoE simultaneously. "
"Please set only one of `use_scattermoe` or `use_sonicmoe` to true."
)
return data
@model_validator(mode="before")
@classmethod
@@ -36,11 +47,11 @@ class KernelsArgs(BaseModel):
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel_scattermoe(cls, data):
if data.get("use_scattermoe") is True:
def disable_mlp_kernel(cls, data):
if data.get("use_scattermoe") is True or data.get("use_sonicmoe") is True:
if data.get("lora_mlp_kernel") is True:
LOG.warning(
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
"Disabling lora_mlp_kernel when using custom MoE kernels due to compatibility issues."
)
data["lora_mlp_kernel"] = False
data["mlp_kernel"] = False

View File

@@ -0,0 +1,68 @@
"""
Supported MoE block mappings for kernel integrations.
Maps model_type to the SparseMoeBlock class name(s) in transformers.
Used by both ScatterMoE and SonicMoE kernel paths.
Values can be a single class name (str) or a list of class names for models
with multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker).
"""
import importlib
SPARSE_MOE_BLOCK = {
# softmax -> topk routing
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
"qwen3_next": "Qwen3NextSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
# qwen3_omni_moe: Thinker (standard) + Talker (shared experts + shared_expert_gate)
"qwen3_omni_moe": [
"Qwen3OmniMoeThinkerTextSparseMoeBlock",
"Qwen3OmniMoeTalkerTextSparseMoeBlock",
],
"olmoe": "OlmoeSparseMoeBlock",
"mixtral": "MixtralSparseMoeBlock",
"minimax": "MiniMaxSparseMoeBlock",
# sigmoid -> topk routing (with group-based expert selection)
"glm_moe_dsa": "GlmMoeDsaMoE",
"deepseek_v3": "DeepseekV3MoE",
"glm4_moe": "Glm4MoeMoE",
"glm4_moe_lite": "Glm4MoeLiteMoE",
"glm4v_moe": "Glm4vMoeTextMoE",
# sigmoid -> topk routing (no group selection)
"minimax_m2": "MiniMaxM2SparseMoeBlock",
# Models below need custom routing (not yet implemented):
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
# "hunyuan_v1_moe": "HunYuanMoEV1Moe", # softmax->topk, gate.wg (not gate.weight), scatter routing
# "gpt_oss": "GptOssMLP", # topk->softmax, transposed layout [E,H,2*I], custom GLU, expert biases
}
def resolve_moe_block_classes(model_type: str):
"""Resolve all MoE block classes from transformers for the given model type.
Returns a list of classes (one for most models, multiple for models with
distinct MoE block types like qwen3_omni_moe).
"""
entry = SPARSE_MOE_BLOCK.get(model_type)
if entry is None:
raise ValueError(
f"Unsupported MoE model type '{model_type}'. "
f"Supported types: {list(SPARSE_MOE_BLOCK.keys())}"
)
cls_names = entry if isinstance(entry, list) else [entry]
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = importlib.import_module(module_path)
classes = []
for cls_name in cls_names:
moe_cls = getattr(module, cls_name, None)
if moe_cls is None:
raise ValueError(f"Could not find class '{cls_name}' in '{module_path}'")
classes.append(moe_cls)
return classes

View File

@@ -1,14 +1,59 @@
import importlib
import os
from pathlib import Path
from kernels import (
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _check_sonicmoe_gpu_compat():
"""Validate GPU compute capability for SonicMoE and configure env.
Supported: Hopper (sm_90), Blackwell (sm_100 - sm_103).
B300 (sm_103) additionally requires Triton 3.6.0.
"""
if not torch.cuda.is_available():
return
cc = torch.cuda.get_device_capability()
if cc < (9, 0):
raise RuntimeError(
f"SonicMoE requires Hopper (sm_90) or Blackwell (sm_100+) GPU, "
f"but detected sm_{cc[0]}{cc[1]}."
)
if cc > (10, 3):
raise RuntimeError(
f"SonicMoE does not yet support sm_{cc[0]}{cc[1]}. "
f"Supported: Hopper (sm_90) and Blackwell (sm_100 - sm_103)."
)
# Blackwell (sm_100+): enable QuACK GEMM kernels
if cc >= (10, 0):
os.environ.setdefault("USE_QUACK_GEMM", "1")
LOG.info(
f"Blackwell GPU (sm_{cc[0]}{cc[1]}) detected, enabling USE_QUACK_GEMM=1"
)
# B300 (sm_103): requires Triton 3.6.0
if cc == (10, 3):
triton_spec = importlib.util.find_spec("triton")
if triton_spec is None:
raise RuntimeError(
"B300 (sm_103) requires Triton 3.6.0, but Triton is not installed."
)
import triton
triton_version = tuple(int(x) for x in triton.__version__.split(".")[:2])
if triton_version != (3, 6):
raise RuntimeError(
f"B300 (sm_103) requires Triton 3.6.x, but found {triton.__version__}."
)
class KernelsPlugin(BasePlugin):
@@ -19,8 +64,32 @@ class KernelsPlugin(BasePlugin):
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(cfg.model_config_type)
elif cfg.use_sonicmoe:
if not importlib.util.find_spec("sonicmoe"):
raise RuntimeError(
"SonicMoE is not installed. See installation instructions at "
"https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/kernels/README.md#sonicmoe-installation"
)
_check_sonicmoe_gpu_compat()
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
LOG.info(
f"Applying SonicMoE patches for model type: {cfg.model_config_type}"
)
patch_sonicmoe(
cfg.model_config_type,
torch_compile=bool(getattr(cfg, "torch_compile", False)),
)
def _register_kernels(self):
from kernels import (
LocalLayerRepository,
Mode,
register_kernel_mapping,
)
plugin_root = Path(__file__).parent
register_kernel_mapping(
{
@@ -42,25 +111,11 @@ class KernelsPlugin(BasePlugin):
)
def _kernelize_model(self, model_type: str):
if model_type == "olmoe":
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
from kernels import replace_kernel_forward_from_hub
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
for model_moe_cls in resolve_moe_block_classes(model_type):
replace_kernel_forward_from_hub(
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
model_moe_cls, "HFScatterMoEParallelExperts"
)
else:
try:
model_moe_cls = get_model_moe_block(model_type)
replace_kernel_forward_from_hub(
model_moe_cls, "HFScatterMoEParallelExperts"
)
except Exception as err:
raise ValueError(f"Unsupported model type: {model_type}") from err
def get_model_moe_block(model_type: str):
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"])
model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock")
return model_cls

View File

@@ -0,0 +1,3 @@
from .patch import patch_sonicmoe
__all__ = ["patch_sonicmoe"]

View File

@@ -0,0 +1,213 @@
"""
SonicMoE patching for SparseMoeBlock forward pass.
Monkeypatches the SparseMoeBlock class for a given model type to use
SonicMoE's optimized kernels. Two forward paths are supported:
1. **General routing path** (routing_fn is not None):
Uses a custom routing function + ``moe_general_routing_inputs``.
Suitable for models with non-standard routing (softmax->topk, sigmoid->topk).
2. **Fused topk->softmax path** (routing_fn is None):
Uses ``moe_TC_softmax_topk_layer`` which fuses routing + expert computation.
Suitable for models with simple topk->softmax routing.
Weight format conversion (interleave/deinterleave) is handled by the
WeightConverter system, so the forward assumes weights are already in
interleaved format.
Shared experts are handled generically: if the block has a ``shared_expert``
or ``shared_experts`` attribute, its output is computed alongside the routed
experts and added to the final output. An optional ``shared_expert_gate``
applies sigmoid gating to the shared expert contribution.
"""
import torch
import torch.nn.functional as F
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patch_sonicmoe(model_type: str, torch_compile: bool = False):
"""Main entry point: patch SparseMoeBlock for SonicMoE support.
Args:
model_type: The HuggingFace model type (e.g. "qwen3_moe").
torch_compile: If True, wrap routing functions with torch.compile
for kernel fusion (fuses softmax+topk+renorm into fewer launches).
"""
from .routing import get_model_moe_config
from .weight_converter import register_sonicmoe_weight_converter
routing_fn, activation, router_attr = get_model_moe_config(model_type)
if torch_compile and routing_fn is not None:
routing_fn = _try_compile_routing(routing_fn)
for moe_cls in resolve_moe_block_classes(model_type):
_patch_forward(moe_cls, routing_fn, activation, router_attr)
register_sonicmoe_weight_converter(model_type)
def _try_compile_routing(routing_fn):
"""Attempt to torch.compile the routing function, fall back to eager on failure."""
try:
compiled_fn = torch.compile(routing_fn, mode="reduce-overhead", dynamic=False)
LOG.info(f"torch.compile enabled for routing function: {routing_fn.__name__}")
return compiled_fn
except Exception as exc: # pylint: disable=broad-except
LOG.warning(
f"torch.compile failed for routing function {routing_fn.__name__}, "
f"falling back to eager: {exc}"
)
return routing_fn
def _patch_forward(moe_cls, routing_fn, activation, router_attr):
"""Monkeypatch the SparseMoeBlock class with a SonicMoE forward.
The patched forward handles shared experts generically: if
``self.shared_expert`` or ``self.shared_experts`` exists, it is computed
and added to the routed output. If ``self.shared_expert_gate`` also exists,
it applies sigmoid gating to the shared expert contribution (as in qwen2_moe).
Args:
moe_cls: The SparseMoeBlock class to patch.
routing_fn: Routing function (e.g. softmax_topk_routing), or None
for the fused moe_TC_softmax_topk_layer path.
activation: SonicMoE ActivationType enum value.
router_attr: Name of the router module attribute on the MoE block.
"""
if hasattr(moe_cls, "_original_forward"):
LOG.info(f"{moe_cls.__name__}.forward already patched with SonicMoE, skipping")
return
original_forward = moe_cls.forward
if routing_fn is not None:
_make_general_forward(moe_cls, routing_fn, activation)
else:
_make_fused_forward(moe_cls, activation, router_attr)
moe_cls._original_forward = original_forward
LOG.info(f"Patched {moe_cls.__name__}.forward with SonicMoE implementation")
def _make_general_forward(moe_cls, routing_fn, activation):
"""Create forward using routing_fn + moe_general_routing_inputs."""
def sonicmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
from sonicmoe import moe_general_routing_inputs
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)
# Shared expert (computed early, matching original model ordering)
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
# Routing
router_scores, token_indices, expert_indices, _router_logits = routing_fn(
hidden_states_flat, self
)
# Permute weights to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
down_weight = self.experts.down_proj.permute(1, 2, 0)
E = gate_up_weight.shape[-1]
output, _ = moe_general_routing_inputs(
hidden_states_flat,
router_scores,
token_indices,
expert_indices,
gate_up_weight,
None, # b1 (no gate/up bias)
down_weight,
None, # b2 (no down bias)
E,
torch.cuda.current_stream().cuda_stream,
activation,
False, # is_inference_mode
)
# Add shared expert contribution if present
if shared_expert_output is not None:
if hasattr(self, "shared_expert_gate"):
shared_expert_output = (
F.sigmoid(self.shared_expert_gate(hidden_states_flat))
* shared_expert_output
)
output = output + shared_expert_output
return output.view(batch_size, sequence_length, hidden_dim)
moe_cls.forward = sonicmoe_forward
def _make_fused_forward(moe_cls, activation, router_attr):
"""Create forward using moe_TC_softmax_topk_layer (topk -> softmax)."""
def sonicmoe_fused_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
from sonicmoe import moe_TC_softmax_topk_layer
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)
# Shared expert (computed early, matching original model ordering)
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
router = getattr(self, router_attr)
# Permute weights to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
down_weight = self.experts.down_proj.permute(1, 2, 0)
output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer(
hidden_states_flat,
router.weight,
gate_up_weight,
None, # b1 (no gate/up bias)
down_weight,
None, # b2 (no down bias)
router.top_k,
torch.cuda.current_stream().cuda_stream,
activation,
False, # is_inference_mode
)
# Add shared expert contribution if present
if shared_expert_output is not None:
if hasattr(self, "shared_expert_gate"):
shared_expert_output = (
F.sigmoid(self.shared_expert_gate(hidden_states_flat))
* shared_expert_output
)
output = output + shared_expert_output
return output.view(batch_size, sequence_length, hidden_dim)
moe_cls.forward = sonicmoe_fused_forward
def _compute_shared_expert(moe_block, hidden_states_flat):
"""Compute shared expert output if the block has one.
Handles singular (qwen2_moe: ``shared_expert``), plural
(glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP
(hunyuan_v1_moe: ``shared_mlp``) attribute names.
"""
shared_expert = (
getattr(moe_block, "shared_expert", None)
or getattr(moe_block, "shared_experts", None)
or getattr(moe_block, "shared_mlp", None)
)
if shared_expert is not None:
return shared_expert(hidden_states_flat)
return None

View File

@@ -0,0 +1,219 @@
"""
Routing functions for SonicMoE integration.
Different MoE architectures use different routing strategies:
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
"""
import torch
import torch.nn.functional as F
def get_model_moe_config(model_type: str):
"""Returns (routing_fn, activation, router_attr) for a given model type.
Args:
model_type: HuggingFace model type string.
Returns:
routing_fn: Callable or None. None signals the fused
moe_TC_softmax_topk_layer path (topk -> softmax models).
activation: SonicMoE ActivationType enum value.
router_attr: Name of the router module attribute on the MoE block
(e.g. "gate" or "router").
The activation type cannot be derived from config.hidden_act because
e.g. qwen3_moe reports "silu" but architecturally uses SwiGLU
(act_fn(gate) * up pattern). So we specify it per model type.
"""
from sonicmoe.enums import ActivationType
if model_type in (
"qwen2_moe",
"qwen3_moe",
"qwen3_5_moe",
"qwen3_next",
"qwen3_vl_moe",
"qwen3_omni_moe",
"olmoe",
"mixtral",
"minimax",
):
return softmax_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in (
"glm_moe_dsa",
"deepseek_v3",
"glm4_moe",
"glm4_moe_lite",
"glm4v_moe",
"minimax_m2",
):
return sigmoid_topk_routing, ActivationType.SWIGLU, "gate"
# elif model_type in ("ernie4_5_moe",):
# # Softmax→topk with e_score_correction_bias applied between softmax and topk.
# return ..., ActivationType.SWIGLU, "gate"
# elif model_type in ("deepseek_v2",):
# # Softmax→topk with group_limited_greedy. Different attr names: num_group
# # (not n_group), gate is nn.Linear (not a router class).
# return ..., ActivationType.SWIGLU, "gate"
# elif model_type in ("hunyuan_v1_moe",):
# # Softmax→topk but gate structure differs: gate.wg (not gate.weight),
# # top_k on block not gate, creates scatter routing matrix.
# return ..., ActivationType.SWIGLU, "gate"
# Fused topk -> softmax path (routing_fn=None):
# elif model_type in ("gpt_oss",):
# # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer
# # ignores (it only takes router_w, not bias). Also has transposed
# # weight layout [E, H, 2*I] and custom GLU activation.
# return None, ActivationType.SWIGLU, "router"
else:
raise ValueError(f"SonicMoE: unsupported model type '{model_type}'")
def softmax_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Qwen3/Qwen2-style routing: softmax -> topk -> optional renorm.
Args:
hidden_states: [T, H] flattened token representations
moe_block: MoE block module (accesses moe_block.gate.*)
Returns:
router_scores: [T*K] flattened scores (float32)
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, H = hidden_states.shape
K = gate.top_k
# Compute router logits and softmax over all experts
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Select top-k experts per token
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
# Renormalize if configured (default True for models without the attribute,
# e.g. Mixtral/MiniMax which always normalize)
if getattr(gate, "norm_topk_prob", True):
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
# no-op: matches transformers which casts to softmax output dtype (float32).
# top_values = top_values.to(router_probs.dtype)
# Flatten for moe_general_routing_inputs.
# Token indices are naturally sorted ascending from the [T, K] layout:
# [0, 0, ..., 1, 1, ..., T-1, T-1, ...] — this is required by SonicMoE.
# Expert sorting is handled internally by general_routing_router_metadata.
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = top_values.reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def sigmoid_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sigmoid-based routing: sigmoid -> optional group selection -> topk.
Supports two variants:
- **Group selection** (glm_moe_dsa, deepseek_v3, etc.): n_group > 1,
bias on gate, group-based masking before topk.
- **No group selection** (minimax_m2): n_group == 1 (or absent),
bias on moe_block, straight topk from all experts.
Final routing weights come from the original sigmoid scores (not
bias-corrected), with optional renormalization and scaling.
Args:
hidden_states: [T, H] flattened token representations
moe_block: MoE block module (accesses moe_block.gate.* and
optional moe_block.n_group, .topk_group, .top_k, .norm_topk_prob,
.routed_scaling_factor, .n_routed_experts)
Returns:
router_scores: [T*K] flattened scores (float32)
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, H = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
# Compute router logits and sigmoid probabilities
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
router_probs = router_logits.sigmoid() # [T, E]
# Bias-corrected scores for expert selection (not used for final weights).
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block.
e_score_correction_bias = getattr(gate, "e_score_correction_bias", None)
if e_score_correction_bias is None:
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
if e_score_correction_bias is None:
raise AttributeError(
f"sigmoid_topk_routing requires e_score_correction_bias on "
f"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it"
)
scores_for_choice = router_probs + e_score_correction_bias
# Group-based selection: pick top groups, mask the rest (skip when n_group == 1)
if n_group > 1:
group_scores = (
scores_for_choice.view(-1, n_group, E // n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [T, n_group]
group_idx = torch.topk(
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
# Final topk from (possibly masked) scores
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
# Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices)
# Optional renormalization + scaling
norm_topk_prob = getattr(moe_block, "norm_topk_prob", True)
if norm_topk_prob:
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
topk_weights = topk_weights * routed_scaling_factor
# Flatten for moe_general_routing_inputs.
# Token indices are naturally sorted ascending from the [T, K] layout.
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits

View File

@@ -0,0 +1,181 @@
"""
Custom WeightConverter operations for SonicMoE weight format conversion.
SonicMoE requires gate_up_proj weights in interleaved format:
- Standard (concatenated): [E, 2*I, H] where first I rows are gate, last I rows are up
- SonicMoE (interleaved): [E, 2*I, H] where rows alternate [g0, u0, g1, u1, ...]
These ConversionOps integrate with transformers' WeightConverter system so that
weights are transparently converted during loading and reverted during saving.
"""
from typing import Any
import torch
from einops import rearrange
from transformers.core_model_loading import ConversionOps
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def interleave_gate_up(tensor: torch.Tensor) -> torch.Tensor:
"""[gate..., up...] -> [g0, u0, g1, u1, ...] along the 2*I dimension."""
return rearrange(tensor, "... (two out) h -> ... (out two) h", two=2)
def deinterleave_gate_up(tensor: torch.Tensor) -> torch.Tensor:
"""[g0, u0, g1, u1, ...] -> [gate..., up...] along the 2*I dimension."""
return rearrange(tensor, "... (out two) h -> ... (two out) h", two=2)
class ConcatenatedToInterleaved(ConversionOps):
"""Convert concatenated gate/up projections to interleaved format.
Input: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H]
Output: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...]
This operation is applied along ``dim`` (default 1, the 2*I dimension).
"""
def __init__(self, dim: int = 1):
self.dim = dim
@torch.no_grad()
def convert(
self,
input_dict: dict[str, Any],
source_patterns: list[str],
target_patterns: list[str],
**kwargs,
) -> dict[str, torch.Tensor]:
target_pattern = self._get_target_pattern(
input_dict, source_patterns, target_patterns
)
tensors = next(iter(input_dict.values()))
tensor = tensors[0] if isinstance(tensors, list) else tensors
interleaved = interleave_gate_up(tensor)
return {target_pattern: interleaved}
def _get_target_pattern(
self,
input_dict: dict[str, Any],
source_patterns: list[str],
target_patterns: list[str],
) -> str:
# Follow the same logic as Transpose.get_target_pattern
if len(input_dict) != 1:
raise ValueError("Undefined Operation encountered!")
if len(target_patterns) > 1:
if len(source_patterns) == 1:
return source_patterns[0]
raise ValueError("Undefined Operation encountered!")
return target_patterns[0]
@property
def reverse_op(self) -> ConversionOps:
return InterleavedToConcatenated(self.dim)
class InterleavedToConcatenated(ConversionOps):
"""Convert interleaved gate/up projections back to concatenated format.
Input: [E, 2*I, H] with rows alternating [g0, u0, g1, u1, ...]
Output: [E, 2*I, H] with gate=[E, :I, H] and up=[E, I:, H]
This is the reverse of ``ConcatenatedToInterleaved``.
"""
def __init__(self, dim: int = 1):
self.dim = dim
@torch.no_grad()
def convert(
self,
input_dict: dict[str, Any],
source_patterns: list[str],
target_patterns: list[str],
**kwargs,
) -> dict[str, torch.Tensor]:
target_pattern = self._get_target_pattern(
input_dict, source_patterns, target_patterns
)
tensors = next(iter(input_dict.values()))
tensor = tensors[0] if isinstance(tensors, list) else tensors
concatenated = deinterleave_gate_up(tensor)
return {target_pattern: concatenated}
def _get_target_pattern(
self,
input_dict: dict[str, Any],
source_patterns: list[str],
target_patterns: list[str],
) -> str:
if len(input_dict) != 1:
raise ValueError("Undefined Operation encountered!")
if len(target_patterns) > 1:
if len(source_patterns) == 1:
return source_patterns[0]
raise ValueError("Undefined Operation encountered!")
return target_patterns[0]
@property
def reverse_op(self) -> ConversionOps:
return ConcatenatedToInterleaved(self.dim)
def register_sonicmoe_weight_converter(model_type: str):
"""Override the conversion mapping to add interleave step for gate_up_proj.
Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj
converter chain. For example, qwen3_moe's chain becomes:
MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1)
The reverse is auto-generated for saving:
InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0)
"""
from transformers.conversion_mapping import (
get_checkpoint_conversion_mapping,
register_checkpoint_conversion_mapping,
)
existing = get_checkpoint_conversion_mapping(model_type)
if existing is None:
LOG.warning(
f"No conversion mapping found for model type '{model_type}'. "
"SonicMoE weight interleaving will not be applied during checkpoint loading."
)
return
# Find the gate_up_proj converter and append ConcatenatedToInterleaved
patched = False
for converter in existing:
if hasattr(converter, "operations") and any(
"gate_up_proj" in pat for pat in converter.target_patterns
):
# Guard against double registration (e.g. plugin reloaded)
if any(
isinstance(op, ConcatenatedToInterleaved) for op in converter.operations
):
LOG.info(
f"SonicMoE weight converter already registered for '{model_type}'"
)
return
converter.operations.append(ConcatenatedToInterleaved(dim=1))
patched = True
break
if not patched:
LOG.warning(
f"Could not find gate_up_proj converter for model type '{model_type}'. "
"SonicMoE weight interleaving will not be applied during checkpoint loading."
)
return
register_checkpoint_conversion_mapping(model_type, existing, overwrite=True)
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")

View File

@@ -0,0 +1,288 @@
"""
End-to-end gradient and convergence tests for SonicMoE integration.
Requires:
- H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90)
- sonicmoe package installed
- transformers with Qwen3MoE support
Usage:
pytest tests/e2e/integrations/test_sonicmoe.py -v -s
"""
import importlib.util
import math
import pytest
import torch
_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None
_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
pytestmark = [
pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"),
pytest.mark.skipif(
not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)"
),
pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"),
]
def _create_tiny_qwen3_config():
"""Create a minimal Qwen3MoE config for fast testing."""
from transformers import AutoConfig
config = AutoConfig.for_model("qwen3_moe")
config.hidden_size = 512
config.intermediate_size = 1024
config.moe_intermediate_size = 64
config.num_attention_heads = 16
config.num_key_value_heads = 2
config.head_dim = 32
config.num_hidden_layers = 2
config.num_experts = 8
config.num_experts_per_tok = 2
config.vocab_size = 1000
config.max_position_embeddings = 128
config.norm_topk_prob = True
config.torch_dtype = torch.bfloat16
return config
def _interleave_gate_up_weights(model):
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
interleave_gate_up,
)
with torch.no_grad():
for name, param in model.named_parameters():
if "gate_up_proj" in name:
param.copy_(interleave_gate_up(param))
def _unpatch_sonicmoe():
"""Restore original forward on the MoE block class if it was patched."""
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
for moe_cls in resolve_moe_block_classes("qwen3_moe"):
if hasattr(moe_cls, "_original_forward"):
moe_cls.forward = moe_cls._original_forward
del moe_cls._original_forward
class TestSonicMoEForwardCorrectness:
"""Verify SonicMoE-patched model produces same output as original."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_forward_output_matches(self):
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
# Original model
model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
with torch.no_grad():
out_orig = model_orig(input_ids)
# Patched model (same weights, interleaved for SonicMoE)
model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
model_patched.load_state_dict(model_orig.state_dict())
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model_patched)
with torch.no_grad():
out_patched = model_patched(input_ids)
max_diff = (out_orig.logits - out_patched.logits).abs().max().item()
assert torch.allclose(
out_orig.logits, out_patched.logits, atol=1e-1, rtol=1e-1
), f"Output mismatch: max diff={max_diff:.6f}"
class TestSonicMoEGradientCorrectness:
"""Compare gradients between original HuggingFace and SonicMoE-patched forward."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_gradients_match(self):
"""Verify all parameter gradients match between original and patched."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
deinterleave_gate_up,
)
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
# ---------- Original model ----------
model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
out_orig = model_orig(input_ids, labels=input_ids)
out_orig.loss.backward()
grads_orig = {
n: p.grad.float().clone()
for n, p in model_orig.named_parameters()
if p.grad is not None
}
loss_orig = out_orig.loss.item()
# ---------- SonicMoE-patched model (same weights, interleaved) ----------
model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
model_patched.load_state_dict(model_orig.state_dict())
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model_patched)
out_patched = model_patched(input_ids, labels=input_ids)
out_patched.loss.backward()
grads_patched = {}
for n, p in model_patched.named_parameters():
if p.grad is None:
continue
g = p.grad.float().clone()
# gate_up_proj grads are in interleaved layout, de-interleave to match orig
if "gate_up_proj" in n:
g = deinterleave_gate_up(g)
grads_patched[n] = g
loss_patched = out_patched.loss.item()
# ---------- Compare ----------
assert abs(loss_orig - loss_patched) < 0.5, (
f"Loss mismatch: orig={loss_orig:.4f}, patched={loss_patched:.4f}"
)
# All parameters with gradients in original should have them in patched
missing = set(grads_orig.keys()) - set(grads_patched.keys())
assert not missing, f"Missing gradients in patched model: {missing}"
# Compare gradient values
# bf16 with different GEMM impls (cuBLAS vs CUTLASS) can diverge,
# so use generous tolerance: flag only if both rel >10% AND abs >1e-2
mismatches = []
for name in grads_orig:
if name not in grads_patched:
continue
g_orig = grads_orig[name]
g_patched = grads_patched[name]
max_diff = (g_orig - g_patched).abs().max().item()
rel_diff = max_diff / (g_orig.abs().max().item() + 1e-8)
if rel_diff > 0.1 and max_diff > 1e-2:
mismatches.append(
f" {name}: max_abs_diff={max_diff:.6f}, rel_diff={rel_diff:.4f}"
)
assert not mismatches, (
"Gradient mismatches (rel_diff > 10% and abs_diff > 1e-2):\n"
+ "\n".join(mismatches)
)
def test_router_weights_receive_gradients(self):
"""Verify that router (gate) weights get non-zero gradients."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
out = model(input_ids, labels=input_ids)
out.loss.backward()
gate_grads_found = False
for name, param in model.named_parameters():
if "gate" in name and "weight" in name:
gate_grads_found = True
assert param.grad is not None, f"No gradient for router: {name}"
assert param.grad.abs().max() > 0, f"Zero gradient for router: {name}"
assert gate_grads_found, "No gate.weight parameters found in model"
class TestSonicMoETrainingConvergence:
"""Verify loss decreases during training with SonicMoE."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_loss_decreases(self):
"""Run 30 training steps, verify loss decreases and no NaN/Inf."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
losses = []
for step in range(30):
out = model(input_ids, labels=input_ids)
loss = out.loss
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
losses.append(loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
assert losses[-1] < losses[0], (
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
)
def test_expert_weights_update(self):
"""Verify expert weights change during training (not frozen)."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
# Snapshot expert weights before training
expert_weights_before = {}
for name, param in model.named_parameters():
if "experts" in name:
expert_weights_before[name] = param.data.clone()
assert expert_weights_before, "No expert parameters found"
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
for _ in range(5):
out = model(input_ids, labels=input_ids)
out.loss.backward()
optimizer.step()
optimizer.zero_grad()
# Check that expert weights changed
changed = 0
for name, param in model.named_parameters():
if name in expert_weights_before:
if not torch.equal(param.data, expert_weights_before[name]):
changed += 1
assert changed > 0, "No expert weights changed after 5 training steps"

View File

@@ -6,7 +6,7 @@
Unit tests for scattermoe-lora code-review fixes.
Tests cover:
- KernelsArgs validator: disable_mlp_kernel_scattermoe
- KernelsArgs validator: disable_mlp_kernel
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
- ParallelExperts: scaling=0.0 not treated as falsy
- single2scatter: non-aligned K/N dimensions
@@ -20,12 +20,12 @@ import pytest
import torch
# ============================================================================
# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator
# 1. KernelsArgs: disable_mlp_kernel validator
# ============================================================================
class TestKernelsArgsValidator:
"""Test that disable_mlp_kernel_scattermoe sets both flags correctly.
"""Test that disable_mlp_kernel sets both flags correctly.
These tests call the validator classmethod directly on raw dicts,
since lora_mlp_kernel / mlp_kernel are not declared model fields.
@@ -40,7 +40,7 @@ class TestKernelsArgsValidator:
"use_scattermoe": True,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is False
assert result["mlp_kernel"] is False
@@ -52,7 +52,7 @@ class TestKernelsArgsValidator:
"use_kernels": True,
"use_scattermoe": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
result = KernelsArgs.disable_mlp_kernel(data)
assert result["mlp_kernel"] is False
# lora_mlp_kernel was not in data, should not be added
assert "lora_mlp_kernel" not in result
@@ -66,7 +66,7 @@ class TestKernelsArgsValidator:
"use_scattermoe": True,
"lora_mlp_kernel": False,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is False
def test_no_change_when_scattermoe_disabled(self):
@@ -78,7 +78,7 @@ class TestKernelsArgsValidator:
"use_scattermoe": False,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is True

View File

@@ -0,0 +1,428 @@
"""Unit tests for the SonicMoE integration."""
from types import SimpleNamespace
import pytest
import torch
from axolotl.integrations.kernels.args import KernelsArgs
from axolotl.integrations.kernels.sonicmoe.routing import (
sigmoid_topk_routing,
softmax_topk_routing,
)
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
ConcatenatedToInterleaved,
InterleavedToConcatenated,
register_sonicmoe_weight_converter,
)
class TestKernelsArgs:
def test_mutual_exclusivity_raises(self):
with pytest.raises(ValueError, match="Cannot use both"):
KernelsArgs.model_validate({"use_scattermoe": True, "use_sonicmoe": True})
def test_sonicmoe_only(self):
result = KernelsArgs.model_validate({"use_sonicmoe": True})
assert result.use_sonicmoe is True
assert result.use_scattermoe is None
def test_scattermoe_only(self):
result = KernelsArgs.model_validate({"use_scattermoe": True})
assert result.use_scattermoe is True
assert result.use_sonicmoe is None
def test_neither_set(self):
result = KernelsArgs.model_validate({})
assert result.use_scattermoe is None
assert result.use_sonicmoe is None
def test_disables_mlp_kernel_when_sonicmoe(self):
data = {"use_sonicmoe": True, "lora_mlp_kernel": True}
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is False
assert result["mlp_kernel"] is False
class TestConcatenatedToInterleaved:
@pytest.fixture
def sample_tensor(self):
"""Create a test tensor [E=2, 2*I=4, H=3] with distinct gate/up values."""
E, I, H = 2, 2, 3 # noqa: E741
gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H)
up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H)
return torch.cat([gate, up], dim=1)
def test_interleave_rows_alternate(self, sample_tensor):
op = ConcatenatedToInterleaved(dim=1)
result = op.convert(
{"test": sample_tensor},
source_patterns=["test"],
target_patterns=["test"],
)
interleaved = result["test"]
# For expert 0: even rows should be gate, odd rows should be up
E, two_I, H = sample_tensor.shape
I = two_I // 2 # noqa: E741
gate_orig = sample_tensor[:, :I, :]
up_orig = sample_tensor[:, I:, :]
assert torch.equal(interleaved[:, 0::2, :], gate_orig)
assert torch.equal(interleaved[:, 1::2, :], up_orig)
def test_interleave_handles_list_input(self, sample_tensor):
op = ConcatenatedToInterleaved(dim=1)
result = op.convert(
{"test": [sample_tensor]},
source_patterns=["test"],
target_patterns=["test"],
)
assert result["test"].shape == sample_tensor.shape
def test_reverse_op_type(self):
op = ConcatenatedToInterleaved(dim=1)
assert isinstance(op.reverse_op, InterleavedToConcatenated)
assert op.reverse_op.dim == 1
class TestInterleavedToConcatenated:
@pytest.fixture
def interleaved_tensor(self):
"""Create an interleaved tensor [E=2, 2*I=4, H=3]."""
E, I, H = 2, 2, 3 # noqa: E741
gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H)
up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H)
interleaved = torch.empty(E, 2 * I, H)
interleaved[:, 0::2, :] = gate
interleaved[:, 1::2, :] = up
return interleaved
def test_deinterleave_gate_up_separated(self, interleaved_tensor):
op = InterleavedToConcatenated(dim=1)
result = op.convert(
{"test": interleaved_tensor},
source_patterns=["test"],
target_patterns=["test"],
)
concatenated = result["test"]
E, two_I, H = concatenated.shape
I = two_I // 2 # noqa: E741
# First half should be gate (even rows from interleaved)
assert torch.equal(concatenated[:, :I, :], interleaved_tensor[:, 0::2, :])
# Second half should be up (odd rows from interleaved)
assert torch.equal(concatenated[:, I:, :], interleaved_tensor[:, 1::2, :])
def test_reverse_op_type(self):
op = InterleavedToConcatenated(dim=1)
assert isinstance(op.reverse_op, ConcatenatedToInterleaved)
assert op.reverse_op.dim == 1
class TestRoundTrip:
@pytest.fixture
def concat_tensor(self):
E, I, H = 4, 8, 16 # noqa: E741
gate = torch.randn(E, I, H)
up = torch.randn(E, I, H)
return torch.cat([gate, up], dim=1)
def test_interleave_then_deinterleave_is_identity(self, concat_tensor):
fwd = ConcatenatedToInterleaved(dim=1)
rev = InterleavedToConcatenated(dim=1)
interleaved = fwd.convert(
{"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"]
)["k"]
recovered = rev.convert(
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
)["k"]
assert torch.equal(concat_tensor, recovered)
def test_reverse_op_chain_is_identity(self, concat_tensor):
"""Verify that op.reverse_op produces an exact inverse."""
op = ConcatenatedToInterleaved(dim=1)
rev = op.reverse_op
interleaved = op.convert(
{"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"]
)["k"]
recovered = rev.convert(
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
)["k"]
assert torch.equal(concat_tensor, recovered)
def test_various_shapes(self):
"""Test with different expert counts and dimensions."""
fwd = ConcatenatedToInterleaved(dim=1)
rev = InterleavedToConcatenated(dim=1)
for E, I, H in [(1, 4, 8), (8, 16, 32), (16, 128, 256)]: # noqa: E741
concat = torch.randn(E, 2 * I, H)
interleaved = fwd.convert(
{"k": concat}, source_patterns=["k"], target_patterns=["k"]
)["k"]
recovered = rev.convert(
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
)["k"]
assert torch.equal(concat, recovered), (
f"Failed for shape ({E}, {2 * I}, {H})"
)
class TestWeightConverterRegistration:
def test_register_appends_interleave_op(self):
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
register_sonicmoe_weight_converter("qwen3_moe")
modified = get_checkpoint_conversion_mapping("qwen3_moe")
# Find the gate_up_proj converter
gate_up_converter = None
for conv in modified:
if hasattr(conv, "operations") and any(
"gate_up_proj" in pat for pat in conv.target_patterns
):
gate_up_converter = conv
break
assert gate_up_converter is not None
assert isinstance(gate_up_converter.operations[-1], ConcatenatedToInterleaved)
def test_double_registration_is_idempotent(self):
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
register_sonicmoe_weight_converter("qwen3_moe")
register_sonicmoe_weight_converter("qwen3_moe")
modified = get_checkpoint_conversion_mapping("qwen3_moe")
for conv in modified:
if hasattr(conv, "operations") and any(
"gate_up_proj" in pat for pat in conv.target_patterns
):
interleave_count = sum(
isinstance(op, ConcatenatedToInterleaved) for op in conv.operations
)
assert interleave_count == 1, (
f"Expected 1 ConcatenatedToInterleaved op, got {interleave_count}"
)
break
def test_register_unsupported_model_type_warns(self):
# A model type with no conversion mapping should warn but not raise
register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
def _make_qwen_moe_block(T=8, H=16, E=4, K=2):
"""Create a mock qwen-style MoE block for routing tests."""
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
return SimpleNamespace(gate=gate), T, H, E, K
def _make_glm_moe_block(T=8, H=16, E=16, K=4, n_group=2, topk_group=1):
"""Create a mock GLM5-style MoE block for routing tests."""
gate = SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=n_group,
topk_group=topk_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
return moe_block, T, H, E, K
def _make_minimax_m2_moe_block(T=8, H=16, E=16, K=4):
"""Create a mock minimax_m2-style MoE block for routing tests.
minimax_m2 uses sigmoid->topk WITHOUT group selection:
- e_score_correction_bias is on the moe_block (not on gate)
- No n_group / topk_group attributes
- Always normalizes (norm_topk_prob defaults to True)
- No routed_scaling_factor (defaults to 1.0)
"""
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
e_score_correction_bias=torch.zeros(E),
)
return moe_block, T, H, E, K
class TestSoftmaxTopkRouting:
def test_output_shapes(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = softmax_topk_routing(hidden, moe_block)
assert scores.shape == (T * K,)
assert token_idx.shape == (T * K,)
assert expert_idx.shape == (T * K,)
assert logits.shape == (T, E)
def test_scores_are_float32(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
_, token_idx, _, _ = softmax_topk_routing(hidden, moe_block)
# Token indices must be sorted ascending (SonicMoE requirement)
diffs = token_idx[1:] - token_idx[:-1]
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
_, _, expert_idx, _ = softmax_topk_routing(hidden, moe_block)
assert (expert_idx >= 0).all()
assert (expert_idx < E).all()
def test_renormalized_scores_sum_to_one(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
per_token_sums = scores.reshape(T, K).sum(dim=-1)
assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5)
class TestSigmoidTopkRouting:
def test_output_shapes(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block)
assert scores.shape == (T * K,)
assert token_idx.shape == (T * K,)
assert expert_idx.shape == (T * K,)
assert logits.shape == (T, E)
def test_scores_are_float32(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
_, token_idx, _, _ = sigmoid_topk_routing(hidden, moe_block)
diffs = token_idx[1:] - token_idx[:-1]
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
assert (expert_idx >= 0).all()
assert (expert_idx < E).all()
def test_scores_are_nonnegative(self):
"""Sigmoid outputs are in [0, 1], so scores should be non-negative."""
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
assert (scores >= 0).all()
def test_scaling_factor_applied(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
# Get scores with scaling_factor=1.0
scores_1x, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
# Get scores with scaling_factor=2.0
moe_block.routed_scaling_factor = 2.0
scores_2x, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
assert torch.allclose(scores_2x, scores_1x * 2.0, atol=1e-5)
def test_group_selection_restricts_experts(self):
"""With n_group=4 and topk_group=1, only 1/4 of experts should be selectable."""
moe_block, T, H, E, K = _make_glm_moe_block(E=16, K=2, n_group=4, topk_group=1)
hidden = torch.randn(T, H)
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
# Each token's experts should all fall within a single group (size E//n_group=4)
expert_idx_2d = expert_idx.reshape(T, K)
for t in range(T):
experts = expert_idx_2d[t]
groups = experts // (E // moe_block.n_group)
# All selected experts should be from the same group
assert (groups == groups[0]).all()
class TestMiniMaxM2SigmoidRouting:
"""Tests for minimax_m2 routing: sigmoid->topk without group selection."""
def test_output_shapes(self):
"""Validates getattr defaults work: n_group=1, E from gate.weight.shape[0]."""
moe_block, T, H, E, K = _make_minimax_m2_moe_block()
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block)
assert scores.shape == (T * K,)
assert token_idx.shape == (T * K,)
assert expert_idx.shape == (T * K,)
assert logits.shape == (T, E)
def test_bias_on_block_not_gate(self):
"""Verify that e_score_correction_bias on the block (not gate) is used."""
T, H, E, K = 8, 16, 8, 2
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
)
# Large positive bias on expert 0 should make it selected more often
bias = torch.zeros(E)
bias[0] = 100.0
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
e_score_correction_bias=bias,
)
hidden = torch.randn(T, H)
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
# Expert 0 should appear for every token due to the large bias
expert_idx_2d = expert_idx.reshape(T, K)
for t in range(T):
assert 0 in expert_idx_2d[t]

View File

@@ -0,0 +1,158 @@
"""
Gradient correctness tests for SonicMoE routing functions (CPU-only).
Uses torch.autograd.gradcheck with float32 inputs to match the production
code path where routing happens in float32.
"""
import torch
from axolotl.integrations.kernels.sonicmoe.routing import (
sigmoid_topk_routing,
softmax_topk_routing,
)
_GC_EPS = 1e-3
_GC_ATOL = 1e-3
_GC_RTOL = 1e-3
def _make_softmax_moe_block(weight):
gate = torch.nn.Module()
gate.weight = weight
gate.top_k = 2
gate.norm_topk_prob = True
moe_block = torch.nn.Module()
moe_block.gate = gate
return moe_block
def _make_sigmoid_moe_block(weight, bias):
gate = torch.nn.Module()
gate.weight = weight
gate.e_score_correction_bias = bias
moe_block = torch.nn.Module()
moe_block.gate = gate
moe_block.top_k = 2
moe_block.n_routed_experts = weight.shape[0]
moe_block.n_group = 1
moe_block.norm_topk_prob = True
moe_block.routed_scaling_factor = 1.0
return moe_block
class TestSoftmaxTopkRoutingGradcheck:
"""Numerical gradient verification for softmax_topk_routing."""
def test_gradcheck_wrt_gate_weight(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_hidden_states(self):
T, H, E = 4, 8, 4
weight = torch.randn(E, H, dtype=torch.float32)
moe_block = _make_softmax_moe_block(weight)
def fn(hidden):
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_router_logits(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
_, _, _, router_logits = softmax_topk_routing(hidden, moe_block)
return router_logits
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_no_norm_variant(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
moe_block.gate.norm_topk_prob = False
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
class TestSigmoidTopkRoutingGradcheck:
"""Numerical gradient verification for sigmoid_topk_routing."""
def test_gradcheck_wrt_gate_weight(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
bias = torch.zeros(E, dtype=torch.float32)
def fn(weight):
moe_block = _make_sigmoid_moe_block(weight, bias)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_hidden_states(self):
T, H, E = 4, 8, 4
weight = torch.randn(E, H, dtype=torch.float32)
bias = torch.zeros(E, dtype=torch.float32)
moe_block = _make_sigmoid_moe_block(weight, bias)
def fn(hidden):
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_bias(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
weight = torch.randn(E, H, dtype=torch.float32)
def fn(bias):
moe_block = _make_sigmoid_moe_block(weight, bias)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
bias = torch.zeros(E, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(fn, (bias,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL)