feat: add sonicmoe fused lora support (#3519)

* feat: add sonicmoe fused lora support

* fix: forgot to add file

* feat: add test

* feat: add lora support for other routes

* fix: add int8 lora support

* fix: add qwen35_moe interleave support

* fix: qwen3_5_moe loss

* chore: lint

* address some pr comments

* fix test imports

* add support matrix for moe kernels [skip ci]

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2026-04-02 19:53:48 +07:00
committed by GitHub
parent 16e32232fb
commit 842fa039dd
16 changed files with 1249 additions and 126 deletions

View File

@@ -52,26 +52,91 @@ The `KernelsPlugin` runs before model loading and:
### ScatterMoE
1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation via the HF `kernels` library.
### SonicMoE
1. Resolves the model's MoE block class(es) from `constants.py`.
2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format.
3. Supports both softmax->topk and sigmoid->topk routing strategies.
2. Patches the forward method with SonicMoE's optimized CUTLASS kernels and registers a weight converter for the interleaved gate/up projection format.
3. Supports pluggable routing strategies (see routing table below).
Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.
#### Supported Models
## Model Support Matrix
See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.).
All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel currently supports non-SwiGLU MoE architectures.
### Routing strategies
| Routing Strategy | Description | ScatterMoE | SonicMoE |
|---|---|:---:|:---:|
| softmax → topk | Softmax over experts, select top-K, optional renormalization | Yes | Yes |
| softmax → group selection → topk | Softmax, select top groups (sum of top-2 per group), topk from selected groups, renorm + scaling | No | Yes |
| sigmoid → topk (with groups) | Sigmoid + bias correction, group-based masking, topk from masked scores, weights from original sigmoid | Yes | Yes |
| sigmoid → topk (no groups) | Sigmoid + bias correction, straight topk (n_group=1) | Yes | Yes |
| softmax → bias correction → topk | Softmax, bias via `gate.moe_statics`, topk, gather from original probs, clamp-based renorm | No | Yes |
| softmax → group_limited_greedy | Softmax, group selection (max per group), topk, scale only (no renorm) | No | Yes |
| softmax → topk via gate.wg | Softmax, gate weight at `gate.wg.weight` (not `gate.weight`), always renormalize | No | Yes |
| fused topk → softmax | Routing + expert computation fused in a single kernel | No | Planned |
### Per-model support
| Model Type | Architecture | Routing | ScatterMoE | SonicMoE |
|---|---|---|:---:|:---:|
| `qwen2_moe` | Qwen2-MoE | softmax → topk | **Yes** | **Yes** |
| `qwen3_moe` | Qwen3-MoE | softmax → topk | **Yes** | **Yes** |
| `qwen3_5_moe` | Qwen3.5-MoE | softmax → topk | **Yes** | **Yes** |
| `qwen3_5_moe_text` | Qwen3.5-MoE (VLM text) | softmax → topk | **Yes** | **Yes** |
| `qwen3_next` | Qwen3-Next | softmax → topk | **Yes** | **Yes** |
| `qwen3_vl_moe` | Qwen3-VL-MoE | softmax → topk | **Yes** | **Yes** |
| `qwen3_omni_moe` | Qwen3-Omni (Thinker + Talker) | softmax → topk | **Yes** | **Yes** |
| `olmoe` | OLMoE | softmax → topk | **Yes** | **Yes** |
| `mixtral` | Mixtral | softmax → topk | **Yes** | **Yes** |
| `minimax` | MiniMax | softmax → topk | **Yes** | **Yes** |
| `mistral4` | Mistral 4 | softmax → group → topk | No | **Yes** |
| `glm_moe_dsa` | GLM-MoE DSA (GLM 5) | sigmoid → topk (groups) | **Yes** | **Yes** |
| `deepseek_v3` | DeepSeek-V3 | sigmoid → topk (groups) | **Yes** | **Yes** |
| `glm4_moe` | GLM4-MoE | sigmoid → topk (groups) | **Yes** | **Yes** |
| `glm4_moe_lite` | GLM4-MoE Lite (GLM 4.7 Flash) | sigmoid → topk (groups) | **Yes**\* | **Yes** |
| `glm4v_moe` | GLM4v-MoE | sigmoid → topk (groups) | **Yes** | **Yes** |
| `minimax_m2` | MiniMax M2 | sigmoid → topk (no groups) | **Yes** | **Yes** |
| `ernie4_5_moe` | ERNIE 4.5 MoE | softmax → bias → topk | No | **Yes** |
| `deepseek_v2` | DeepSeek-V2 | softmax → group_limited_greedy | No | **Yes** |
| `hunyuan_v1_moe` | HunYuan V1 MoE | softmax → topk (gate.wg) | No | **Yes** |
| `gpt_oss` | GPT-OSS | fused topk → softmax | No | Planned |
\* `glm4_moe_lite` with ScatterMoE may have issues — see Limitations.
### Feature comparison
| Feature | ScatterMoE | SonicMoE |
|---|:---:|:---:|
| Kernel backend | Triton | CUTLASS |
| GPU requirement | Any CUDA | Hopper (H100/H200) or Blackwell (B200+) |
| LoRA approach | Fused in Triton kernel | Runtime materialization + custom autograd |
| LoRA overhead | Lower (fused computation) | Higher (per-forward materialization) |
| Gate/router LoRA | Yes | Yes |
| Expert LoRA | Yes (fused) | Yes (materialized) |
| Shared expert LoRA | Yes (standard PEFT) | Yes (standard PEFT) |
| Selective expert dequantization | Yes (~97% memory savings) | No |
| Weight format | Transposed `[E, hidden, 2*inter]` | Interleaved gate/up `[2*I, H, E]` |
| torch.compile routing | No | Yes (optional) |
## Shared Expert Handling
Both kernels handle shared experts identically. Shared expert attribute names are detected in order of priority:
1. `shared_expert` (Qwen2-MoE)
2. `shared_experts` (GLM-MoE, DeepSeek-V3)
3. `shared_mlp` (HunYuan V1 MoE)
If `shared_expert_gate` exists, sigmoid gating is applied to the shared expert contribution before adding it to the routed output. PEFT wraps shared expert linear layers with standard LoRA — no special handling is needed.
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment.
SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures.
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).
- **Non-SwiGLU activations**: Neither kernel supports MoE architectures with non-SwiGLU expert activations (e.g., GPT-OSS uses a custom GLU variant).
- **GPT-OSS**: Deferred — requires transposed weight layout `[E, H, 2*I]`, expert biases, and custom GLU activation. A dedicated forward path is needed.
- **FSDP + fused gate LoRA (SonicMoE)**: The fused topk→softmax path materializes a local tensor when LoRA delta is present to avoid DTensor + Tensor mixing under FSDP.
## Note on MegaBlocks

View File

@@ -45,6 +45,28 @@ class KernelsArgs(BaseModel):
return data
@model_validator(mode="before")
@classmethod
def warn_sonicmoe_lora_overhead(cls, data):
if data.get("use_sonicmoe") is True and data.get("adapter") in (
"lora",
"qlora",
):
lora_target = data.get("lora_target_modules") or []
lora_linear = data.get("lora_target_linear_modules") or []
targets = (
lora_target if isinstance(lora_target, list) else [lora_target]
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
expert_keywords = ("gate_up_proj", "down_proj", "experts")
if any(kw in t for t in targets for kw in expert_keywords):
LOG.info(
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
)
return data
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel(cls, data):

View File

@@ -49,6 +49,11 @@ class ParallelLinear(torch.autograd.Function):
grouped_in: bool = False,
grouped_out: bool = False,
):
# Cast weights to match input dtype (e.g. 8-bit LoRA)
if expert_weights.dtype != x.dtype:
expert_weights = expert_weights.to(x.dtype)
if expert_biases is not None and expert_biases.dtype != x.dtype:
expert_biases = expert_biases.to(x.dtype)
with torch.device(x.device):
output = kernels.ops.scatter2scatter(
X=x,

View File

@@ -65,6 +65,11 @@ class ScatterMoELoRA(torch.autograd.Function):
use_fused_dX: bool = False,
use_fused_gather: bool = False,
):
# Cast weights to match input dtype (e.g. 8-bit LoRA)
if expert_weights.dtype != x.dtype:
expert_weights = expert_weights.to(x.dtype)
if expert_biases is not None and expert_biases.dtype != x.dtype:
expert_biases = expert_biases.to(x.dtype)
with torch.device(x.device):
# Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T
output = scatter2scatter_lora(

View File

@@ -0,0 +1,220 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
SonicMoE LoRA support via runtime weight materialization.
SonicMoE uses opaque CUTLASS kernels that cannot be modified to fuse LoRA.
Instead, we materialize the effective weight W_eff = W + scaling * (B @ A)
before each CUTLASS call, and use a custom autograd.Function to route
gradients back to the LoRA A and B parameters.
PEFT unwrapping utilities are also provided to handle the ParamWrapper
chain that PEFT creates when targeting expert parameters.
"""
from typing import Optional
import torch
# =============================================================================
# PEFT unwrapping utilities
# =============================================================================
def has_lora(module) -> bool:
"""Check if a module is wrapped by PEFT with LoRA."""
return hasattr(module, "base_layer") and hasattr(module, "lora_A")
def get_lora_params_from_wrapper(module) -> tuple:
"""Extract LoRA parameters from a PEFT ParamWrapper.
Returns:
(lora_A, lora_B, scaling) if LoRA is active, else (None, None, None)
"""
if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"):
return None, None, None
active_adapters = getattr(module, "active_adapters", ["default"])
if not active_adapters:
return None, None, None
adapter_name = active_adapters[0]
lora_A_dict = getattr(module, "lora_A", {})
lora_B_dict = getattr(module, "lora_B", {})
scaling_dict = getattr(module, "scaling", {})
if (
adapter_name not in lora_A_dict
or adapter_name not in lora_B_dict
or adapter_name not in scaling_dict
):
return None, None, None
lora_A = lora_A_dict[adapter_name].weight
lora_B = lora_B_dict[adapter_name].weight
scaling = scaling_dict[adapter_name]
return lora_A, lora_B, scaling
def unwrap_gate_lora(gate_module):
"""Unwrap PEFT ParamWrapper on the router gate.
When PEFT targets ``gate.weight``, ``self.gate`` becomes::
ParamWrapper(weight)
-> base_layer: Router (the real module)
Returns:
(base_gate, gate_weight, gate_lora_delta_or_None)
``base_gate`` is the original router module (with ``.top_k``, etc.).
``gate_weight`` is the base router weight tensor.
``gate_lora_delta_or_None`` is the LoRA delta if active, else None.
Kept separate to avoid mixing DTensor + Tensor under FSDP.
"""
if has_lora(gate_module):
base_gate = gate_module.base_layer
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module)
if lora_A is not None:
delta = scaling * (lora_B @ lora_A)
return base_gate, base_gate.weight, delta
return base_gate, base_gate.weight, None
return gate_module, gate_module.weight, None
def unwrap_experts_lora(experts_module):
"""Walk a PEFT ParamWrapper chain on ``self.experts``.
When PEFT targets ``experts.gate_up_proj`` and ``experts.down_proj``
via ``target_parameters``, ``self.experts`` becomes::
ParamWrapper(down_proj)
-> base_layer: ParamWrapper(gate_up_proj)
-> base_layer: Experts (the real module)
Returns:
(base_experts, lora_dict)
``lora_dict`` maps parameter names to ``(lora_A, lora_B, scaling)``
tuples, or is empty if no LoRA is active.
"""
wrappers = {}
module = experts_module
while hasattr(module, "base_layer") and hasattr(module, "lora_A"):
param_name = getattr(module, "parameter_name", None)
if param_name is not None:
wrappers[param_name] = module
module = module.base_layer
base_experts = module
lora_dict = {}
for param_name, wrapper in wrappers.items():
lora_A, lora_B, scaling = get_lora_params_from_wrapper(wrapper)
if lora_A is not None:
lora_dict[param_name] = (lora_A, lora_B, scaling)
return base_experts, lora_dict
# =============================================================================
# LoRA weight materialization autograd function
# =============================================================================
class MoELoRAMaterialize(torch.autograd.Function):
"""Materialize effective weight W_eff = W + scaling * (B @ A) per expert.
Inserts into the autograd graph between PEFT's LoRA parameters and
SonicMoE's CUTLASS kernels. The CUTLASS backward computes dW_eff,
which this function decomposes into dA and dB via the chain rule.
Weight layouts (PEFT rank-major):
base_weight: [E, dim1, dim2] (frozen expert parameter)
lora_A: [r*E, dim2] (rows [e*r:(e+1)*r] = A_e)
lora_B: [dim1, r*E] (cols [:, e*r:(e+1)*r] = B_e)
Per-expert: delta_e = B_e @ A_e = [dim1, r] @ [r, dim2] = [dim1, dim2]
"""
@staticmethod
def forward(
ctx,
base_weight: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
scaling: float,
) -> torch.Tensor:
E, dim1, dim2 = base_weight.shape
r = lora_A.shape[0] // E
assert lora_A.shape[0] == r * E, (
f"lora_A rows ({lora_A.shape[0]}) must be divisible by num_experts ({E})"
)
# Reshape PEFT rank-major to per-expert batched format
A_3d = lora_A.reshape(E, r, dim2)
B_3d = lora_B.reshape(dim1, r, E).permute(2, 0, 1).contiguous() # [E, dim1, r]
# Batched matmul: [E, dim1, r] @ [E, r, dim2] = [E, dim1, dim2]
delta = torch.bmm(B_3d, A_3d)
W_eff = base_weight + scaling * delta
ctx.save_for_backward(lora_A, lora_B)
ctx.scaling = scaling
ctx.E = E
ctx.r = r
return W_eff
@staticmethod
def backward(ctx, grad_W_eff: torch.Tensor):
lora_A, lora_B = ctx.saved_tensors
scaling = ctx.scaling
E = ctx.E
r = ctx.r
_, dim1, dim2 = grad_W_eff.shape
# Reshape to per-expert (same as forward)
A_3d = lora_A.reshape(E, r, dim2)
B_3d = lora_B.reshape(dim1, r, E).permute(2, 0, 1).contiguous() # [E, dim1, r]
# dA_e = scaling * B_e^T @ dW_e
# [E, r, dim1] @ [E, dim1, dim2] = [E, r, dim2]
d_A_3d = scaling * torch.bmm(B_3d.transpose(1, 2), grad_W_eff)
# dB_e = scaling * dW_e @ A_e^T
# [E, dim1, dim2] @ [E, dim2, r] = [E, dim1, r]
d_B_3d = scaling * torch.bmm(grad_W_eff, A_3d.transpose(1, 2))
# Reshape back to PEFT rank-major layout
d_lora_A = d_A_3d.reshape(E * r, dim2)
d_lora_B = d_B_3d.permute(1, 2, 0).contiguous().reshape(dim1, E * r)
return None, d_lora_A, d_lora_B, None
def materialize_expert_lora(
base_weight: torch.Tensor,
lora_params: Optional[tuple],
) -> torch.Tensor:
"""Materialize effective expert weight with optional LoRA delta.
Args:
base_weight: [E, dim1, dim2] frozen expert parameter
lora_params: (lora_A, lora_B, scaling) or None
Returns:
W_eff if lora_params is not None, else base_weight unchanged.
"""
if lora_params is None:
return base_weight
lora_A, lora_B, scaling = lora_params
return MoELoRAMaterialize.apply(base_weight, lora_A, lora_B, scaling)

View File

@@ -28,20 +28,72 @@ import torch.nn.functional as F
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
from axolotl.utils.logging import get_logger
from .lora import (
has_lora,
materialize_expert_lora,
unwrap_experts_lora,
unwrap_gate_lora,
)
LOG = get_logger(__name__)
def patch_sonicmoe(model_type: str, torch_compile: bool = False):
"""Main entry point: patch SparseMoeBlock for SonicMoE support.
def _get_expert_weights(experts_module):
"""Extract expert weights, applying LoRA materialization if PEFT is active.
Args:
model_type: The HuggingFace model type (e.g. "qwen3_moe").
torch_compile: If True, wrap routing functions with torch.compile
for kernel fusion (fuses softmax+topk+renorm into fewer launches).
Returns:
(gate_up_weight, down_weight) in SonicMoE layout [dim, dim, E].
"""
if has_lora(experts_module):
base_experts, lora_dict = unwrap_experts_lora(experts_module)
gate_up = materialize_expert_lora(
base_experts.gate_up_proj, lora_dict.get("gate_up_proj")
)
down = materialize_expert_lora(
base_experts.down_proj, lora_dict.get("down_proj")
)
else:
gate_up = experts_module.gate_up_proj
down = experts_module.down_proj
# Permute to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
return gate_up.permute(1, 2, 0), down.permute(1, 2, 0)
def _fix_qwen3_5_moe_text_weight_renaming(model_type: str, base_model_type: str):
"""Strip qwen3_5_moe_text WeightRenaming in VLM mode to preserve custom loaders."""
if model_type != "qwen3_5_moe_text" or base_model_type == "qwen3_5_moe_text":
return
try:
from transformers.conversion_mapping import (
get_checkpoint_conversion_mapping,
register_checkpoint_conversion_mapping,
)
from transformers.core_model_loading import WeightRenaming
except ImportError:
return
text_mapping = get_checkpoint_conversion_mapping(model_type)
if text_mapping and isinstance(text_mapping[0], WeightRenaming):
text_mapping.pop(0)
register_checkpoint_conversion_mapping(model_type, text_mapping, overwrite=True)
LOG.info("Stripped qwen3_5_moe_text WeightRenaming for VLM mode")
def patch_sonicmoe(
model_type: str,
torch_compile: bool = False,
base_model_type: str | None = None,
):
"""Patch SparseMoeBlock for SonicMoE support."""
from .routing import get_model_moe_config
from .weight_converter import register_sonicmoe_weight_converter
_fix_qwen3_5_moe_text_weight_renaming(model_type, base_model_type or model_type)
routing_fn, activation, router_attr = get_model_moe_config(model_type)
if torch_compile and routing_fn is not None:
@@ -113,11 +165,10 @@ def _make_general_forward(moe_cls, routing_fn, activation):
hidden_states_flat, self
)
# Permute weights to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
down_weight = self.experts.down_proj.permute(1, 2, 0)
# Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout
gate_up_weight, down_weight = _get_expert_weights(self.experts)
gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype)
down_weight = down_weight.to(hidden_states_flat.dtype)
E = gate_up_weight.shape[-1]
output, _ = moe_general_routing_inputs(
@@ -161,22 +212,30 @@ def _make_fused_forward(moe_cls, activation, router_attr):
# Shared expert (computed early, matching original model ordering)
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
router = getattr(self, router_attr)
# Unwrap router for attribute access + optional LoRA delta
raw_router = getattr(self, router_attr)
base_router, router_weight, router_lora_delta = unwrap_gate_lora(raw_router)
if router_lora_delta is not None:
# Materialize local tensor to avoid DTensor + Tensor add under FSDP
if hasattr(router_weight, "to_local"):
router_weight = router_weight.to_local()
effective_router_weight = router_weight + router_lora_delta
else:
effective_router_weight = router_weight
# Permute weights to SonicMoE layout:
# gate_up: [E, 2*I, H] -> [2*I, H, E]
# down: [E, H, I] -> [H, I, E]
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
down_weight = self.experts.down_proj.permute(1, 2, 0)
# Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout
gate_up_weight, down_weight = _get_expert_weights(self.experts)
gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype)
down_weight = down_weight.to(hidden_states_flat.dtype)
output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer(
hidden_states_flat,
router.weight,
effective_router_weight,
gate_up_weight,
None, # b1 (no gate/up bias)
down_weight,
None, # b2 (no down bias)
router.top_k,
base_router.top_k,
torch.cuda.current_stream().cuda_stream,
activation,
False, # is_inference_mode

View File

@@ -16,6 +16,8 @@ When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
import torch
import torch.nn.functional as F
from .lora import unwrap_gate_lora
def get_model_moe_config(model_type: str):
"""Returns (routing_fn, activation, router_attr) for a given model type.
@@ -40,6 +42,7 @@ def get_model_moe_config(model_type: str):
"qwen2_moe",
"qwen3_moe",
"qwen3_5_moe",
"qwen3_5_moe_text",
"qwen3_next",
"qwen3_vl_moe",
"qwen3_omni_moe",
@@ -88,12 +91,18 @@ def softmax_topk_routing(
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, _ = hidden_states.shape
K = gate.top_k
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
T, H = hidden_states.shape
K = base_gate.top_k
# Compute router logits and softmax over all experts
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
# Compute router logits and softmax over all experts.
# Two F.linear calls avoid mixing DTensor (gate_weight) + Tensor (delta) under FSDP.
# Cast to float32 to match LoRA delta dtype (PEFT computes in fp32).
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Select top-k experts per token
@@ -101,7 +110,7 @@ def softmax_topk_routing(
# Renormalize if configured (default True for models without the attribute,
# e.g. Mixtral/MiniMax which always normalize)
if getattr(gate, "norm_topk_prob", True):
if getattr(base_gate, "norm_topk_prob", True):
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
# no-op: matches transformers which casts to softmax output dtype (float32).
@@ -128,13 +137,17 @@ def softmax_group_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
gate = moe_block.gate
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
T, _ = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
scores_for_choice = router_probs
@@ -206,25 +219,29 @@ def sigmoid_topk_routing(
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
T, _ = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
# Compute router logits and sigmoid probabilities
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = router_logits.sigmoid() # [T, E]
# Bias-corrected scores for expert selection (not used for final weights).
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block.
e_score_correction_bias = getattr(gate, "e_score_correction_bias", None)
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
if e_score_correction_bias is None:
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
if e_score_correction_bias is None:
raise AttributeError(
f"sigmoid_topk_routing requires e_score_correction_bias on "
f"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it"
f"gate ({type(base_gate)}) or moe_block ({type(moe_block)}), but neither has it"
)
scores_for_choice = router_probs + e_score_correction_bias
@@ -296,16 +313,20 @@ def softmax_bias_topk_routing(
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, _ = hidden_states.shape
K = gate.top_k
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
T, H = hidden_states.shape
K = base_gate.top_k
# Compute router logits and softmax (force float32 for numerical stability)
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Bias-corrected scores for expert selection (via moe_statics module)
scores_for_choice = gate.moe_statics(router_probs) # [T, E]
scores_for_choice = base_gate.moe_statics(router_probs) # [T, E]
# Select top-k experts using biased scores
_, selected_experts = torch.topk(scores_for_choice, K, dim=-1) # [T, K]
@@ -314,7 +335,7 @@ def softmax_bias_topk_routing(
top_values = torch.gather(router_probs, dim=-1, index=selected_experts) # [T, K]
# Renormalize with clamp(min=norm_min) instead of sum+epsilon
norm_min = getattr(gate, "norm_min", 1e-20)
norm_min = getattr(base_gate, "norm_min", 1e-20)
top_values = top_values / torch.clamp(
top_values.sum(dim=-1, keepdim=True), min=norm_min
)
@@ -358,15 +379,19 @@ def softmax_group_limited_topk_routing(
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, _ = hidden_states.shape
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
T, H = hidden_states.shape
K = moe_block.top_k
num_group = getattr(moe_block, "num_group", 1)
num_experts = gate.weight.shape[0]
num_experts = gate_weight.shape[0]
topk_method = getattr(moe_block, "topk_method", "greedy")
# Compute logits in float32 and softmax
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
if topk_method == "greedy" or num_group == 1:
@@ -445,12 +470,17 @@ def softmax_topk_wg_routing(
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, _ = hidden_states.shape
T, H = hidden_states.shape
K = moe_block.top_k
# Gate computes logits via gate.wg (nn.Linear, float32)
wg = gate.wg
router_logits = F.linear(hidden_states.float(), wg.weight.float()) # [T, E]
# Unwrap at gate.wg level since PEFT targets the wg Linear, not the gate container
base_wg, wg_weight, wg_lora_delta = unwrap_gate_lora(gate.wg)
router_logits = F.linear(hidden_states.float(), wg_weight.float()) # [T, E]
if wg_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), wg_lora_delta.float()
)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Select top-k experts

View File

@@ -129,15 +129,41 @@ class InterleavedToConcatenated(ConversionOps):
return ConcatenatedToInterleaved(self.dim)
def _make_same_key_interleave_converter():
"""Create a WeightConverter that interleaves an already-fused gate_up_proj."""
from transformers.core_model_loading import WeightConverter
return WeightConverter(
source_patterns="mlp.experts.gate_up_proj",
target_patterns="mlp.experts.gate_up_proj",
operations=[ConcatenatedToInterleaved(dim=1)],
)
def _has_same_key_interleave(mapping) -> bool:
"""Check whether the mapping already has a same-key gate_up_proj interleave converter."""
for conv in mapping:
if (
hasattr(conv, "source_patterns")
and conv.source_patterns == ["mlp.experts.gate_up_proj"]
and conv.target_patterns == ["mlp.experts.gate_up_proj"]
and hasattr(conv, "operations")
and any(isinstance(op, ConcatenatedToInterleaved) for op in conv.operations)
):
return True
return False
def register_sonicmoe_weight_converter(model_type: str):
"""Override the conversion mapping to add interleave step for gate_up_proj.
"""Register weight converters to interleave gate_up_proj for SonicMoE.
Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj
converter chain. For example, qwen3_moe's chain becomes:
MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1)
Handles two checkpoint formats:
1. Separate per-expert weights (e.g. qwen3_moe): appends interleave to the
existing merge chain (MergeModulelist -> Concatenate -> Interleave).
2. Already-fused gate_up_proj (e.g. qwen3_5_moe_text): adds a same-key
converter (gate_up_proj -> gate_up_proj with Interleave).
The reverse is auto-generated for saving:
InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0)
The loader matches whichever source pattern exists in the checkpoint.
"""
from transformers.conversion_mapping import (
get_checkpoint_conversion_mapping,
@@ -145,37 +171,32 @@ def register_sonicmoe_weight_converter(model_type: str):
)
existing = get_checkpoint_conversion_mapping(model_type)
if existing is None:
LOG.warning(
f"No conversion mapping found for model type '{model_type}'. "
"SonicMoE weight interleaving will not be applied during checkpoint loading."
)
# No mapping at all — create one with just the same-key converter
mapping = [_make_same_key_interleave_converter()]
register_checkpoint_conversion_mapping(model_type, mapping)
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")
return
# Find the gate_up_proj converter and append ConcatenatedToInterleaved
patched = False
# Append interleave to any existing many-to-one merge chain
for converter in existing:
if hasattr(converter, "operations") and any(
"gate_up_proj" in pat for pat in converter.target_patterns
):
# Guard against double registration (e.g. plugin reloaded)
if any(
has_separate_sources = any(
"gate_proj" in pat or "up_proj" in pat
for pat in converter.source_patterns
)
if has_separate_sources and not any(
isinstance(op, ConcatenatedToInterleaved) for op in converter.operations
):
LOG.info(
f"SonicMoE weight converter already registered for '{model_type}'"
)
return
converter.operations.append(ConcatenatedToInterleaved(dim=1))
patched = True
converter.operations.append(ConcatenatedToInterleaved(dim=1))
break
if not patched:
LOG.warning(
f"Could not find gate_up_proj converter for model type '{model_type}'. "
"SonicMoE weight interleaving will not be applied during checkpoint loading."
)
return
# Also add a same-key converter for already-fused checkpoints
if not _has_same_key_interleave(existing):
existing.append(_make_same_key_interleave_converter())
register_checkpoint_conversion_mapping(model_type, existing, overwrite=True)
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")

View File

@@ -84,12 +84,13 @@ class KernelsPlugin(BasePlugin):
_check_sonicmoe_gpu_compat()
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
patch_sonicmoe(
moe_model_type,
torch_compile=bool(getattr(cfg, "torch_compile", False)),
base_model_type=cfg.model_config_type,
)
def _register_kernels(self):

View File

@@ -51,7 +51,7 @@ def _create_tiny_qwen3_config():
def _interleave_gate_up_weights(model):
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
interleave_gate_up,
)
@@ -80,7 +80,7 @@ class TestSonicMoEForwardCorrectness:
def test_forward_output_matches(self):
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
@@ -117,8 +117,8 @@ class TestSonicMoEGradientCorrectness:
"""Verify all parameter gradients match between original and patched."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
deinterleave_gate_up,
)
@@ -191,7 +191,7 @@ class TestSonicMoEGradientCorrectness:
"""Verify that router (gate) weights get non-zero gradients."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
@@ -223,7 +223,7 @@ class TestSonicMoETrainingConvergence:
"""Run 30 training steps, verify loss decreases and no NaN/Inf."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
@@ -254,7 +254,7 @@ class TestSonicMoETrainingConvergence:
"""Verify expert weights change during training (not frozen)."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")

View File

@@ -0,0 +1,318 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
End-to-end tests for SonicMoE + LoRA integration.
Verifies that PEFT-wrapped MoE models work correctly with SonicMoE's
runtime LoRA materialization: gradients flow to adapters, base weights
stay frozen, and loss converges.
Requires:
- H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90)
- sonicmoe package installed
- peft package installed
- transformers with Qwen3MoE support
Usage:
pytest tests/e2e/integrations/test_sonicmoe_lora.py -v -s
"""
import importlib.util
import math
import pytest
import torch
_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None
_peft_available = importlib.util.find_spec("peft") is not None
_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
pytestmark = [
pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"),
pytest.mark.skipif(
not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)"
),
pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"),
pytest.mark.skipif(not _peft_available, reason="PEFT not installed"),
]
def _create_tiny_qwen3_config():
"""Create a minimal Qwen3MoE config for fast testing."""
from transformers import AutoConfig
config = AutoConfig.for_model("qwen3_moe")
config.hidden_size = 512
config.intermediate_size = 1024
config.moe_intermediate_size = 64
config.num_attention_heads = 16
config.num_key_value_heads = 2
config.head_dim = 32
config.num_hidden_layers = 2
config.num_experts = 8
config.num_experts_per_tok = 2
config.vocab_size = 1000
config.max_position_embeddings = 128
config.norm_topk_prob = True
config.torch_dtype = torch.bfloat16
return config
def _interleave_gate_up_weights(model):
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
interleave_gate_up,
)
with torch.no_grad():
for name, param in model.named_parameters():
if "gate_up_proj" in name:
param.copy_(interleave_gate_up(param))
def _unpatch_sonicmoe():
"""Restore original forward on the MoE block class if it was patched."""
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
for moe_cls in resolve_moe_block_classes("qwen3_moe"):
if hasattr(moe_cls, "_original_forward"):
moe_cls.forward = moe_cls._original_forward
del moe_cls._original_forward
def _apply_lora(model, target_modules):
"""Apply PEFT LoRA to the model."""
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=target_modules,
lora_dropout=0.0,
bias="none",
)
return get_peft_model(model, lora_config)
class TestSonicMoELoRATraining:
"""Verify SonicMoE + LoRA training works end-to-end."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_loss_decreases(self):
"""Run 30 training steps with LoRA on experts, verify loss decreases."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
model = _apply_lora(model, ["gate_up_proj", "down_proj"])
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad], lr=1e-3
)
losses = []
for step in range(30):
out = model(input_ids, labels=input_ids)
loss = out.loss
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
losses.append(loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
assert losses[-1] < losses[0], (
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
)
def test_base_weights_frozen(self):
"""Verify base (non-LoRA) weights don't change during training."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
model = _apply_lora(model, ["gate_up_proj", "down_proj"])
# Snapshot frozen weights
frozen_before = {}
for name, param in model.named_parameters():
if not param.requires_grad:
frozen_before[name] = param.data.clone()
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad], lr=1e-3
)
for _ in range(5):
out = model(input_ids, labels=input_ids)
out.loss.backward()
optimizer.step()
optimizer.zero_grad()
for name, param in model.named_parameters():
if name in frozen_before:
assert torch.equal(param.data, frozen_before[name]), (
f"Frozen weight changed: {name}"
)
def test_lora_adapters_receive_gradients(self):
"""Verify LoRA A and B matrices get non-zero gradients."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
model = _apply_lora(model, ["gate_up_proj", "down_proj"])
out = model(input_ids, labels=input_ids)
out.loss.backward()
lora_grads_found = 0
for name, param in model.named_parameters():
if "lora_" in name and param.requires_grad:
assert param.grad is not None, f"No gradient for LoRA param: {name}"
assert param.grad.abs().max() > 0, (
f"Zero gradient for LoRA param: {name}"
)
lora_grads_found += 1
assert lora_grads_found > 0, "No LoRA parameters found with gradients"
def test_lora_adapters_update(self):
"""Verify LoRA adapter weights change during training."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
model = _apply_lora(model, ["gate_up_proj", "down_proj"])
# Snapshot LoRA weights
lora_before = {}
for name, param in model.named_parameters():
if "lora_" in name and param.requires_grad:
lora_before[name] = param.data.clone()
assert lora_before, "No LoRA parameters found"
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad], lr=1e-3
)
for _ in range(5):
out = model(input_ids, labels=input_ids)
out.loss.backward()
optimizer.step()
optimizer.zero_grad()
changed = sum(
1
for name, param in model.named_parameters()
if name in lora_before and not torch.equal(param.data, lora_before[name])
)
assert changed > 0, "No LoRA weights changed after 5 training steps"
class TestSonicMoEGateOnlyLoRA:
"""Verify LoRA targeting only the gate (router) works with SonicMoE."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_gate_only_lora_loss_decreases(self):
"""LoRA only on gate — expert path should have zero materialization overhead."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
# Only target the gate (router), not expert projections
model = _apply_lora(model, ["gate"])
optimizer = torch.optim.AdamW(
[p for p in model.parameters() if p.requires_grad], lr=1e-3
)
losses = []
for step in range(20):
out = model(input_ids, labels=input_ids)
loss = out.loss
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
losses.append(loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
assert losses[-1] < losses[0], (
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
)
class TestSonicMoENoLoRARegression:
"""Verify SonicMoE without LoRA still works after LoRA code was added."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_no_lora_loss_decreases(self):
"""Full fine-tuning (no PEFT) with SonicMoE — regression test."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
losses = []
for step in range(20):
out = model(input_ids, labels=input_ids)
loss = out.loss
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
losses.append(loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
assert losses[-1] < losses[0], (
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
)

View File

@@ -93,7 +93,9 @@ class TestSoftmaxRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
@@ -120,7 +122,9 @@ class TestSoftmaxRoutingParity:
def test_logits_not_returned_by_scattermoe(self):
"""ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
_, _, _, logits = softmax_topk_routing(hidden, moe_block)
@@ -131,7 +135,9 @@ class TestSoftmaxRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
gate.norm_topk_prob = False
@@ -152,7 +158,9 @@ class TestSoftmaxRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_routing,
)
for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]:
moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K)
@@ -190,7 +198,9 @@ class TestSigmoidRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
@@ -226,7 +236,9 @@ class TestSigmoidRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True
@@ -254,7 +266,9 @@ class TestSigmoidRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=1, bias_on_gate=False
@@ -281,7 +295,9 @@ class TestSigmoidRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
@@ -309,7 +325,9 @@ class TestSigmoidRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
@@ -349,7 +367,7 @@ class TestSharedExpertParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert as scatter_compute,
)
from axolotl.integrations.kernels.sonicmoe.patch import (
from axolotl.integrations.kernels.libs.sonicmoe.patch import (
_compute_shared_expert as sonic_compute,
)

View File

@@ -6,11 +6,11 @@ import pytest
import torch
from axolotl.integrations.kernels.args import KernelsArgs
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
softmax_topk_routing,
)
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
ConcatenatedToInterleaved,
InterleavedToConcatenated,
register_sonicmoe_weight_converter,
@@ -212,9 +212,40 @@ class TestWeightConverterRegistration:
)
break
def test_register_unsupported_model_type_warns(self):
# A model type with no conversion mapping should warn but not raise
register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
def test_register_adds_same_key_converter(self):
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
register_sonicmoe_weight_converter("qwen3_moe")
modified = get_checkpoint_conversion_mapping("qwen3_moe")
# Should have a same-key converter for already-fused checkpoints
same_key = [
c
for c in modified
if hasattr(c, "source_patterns")
and c.source_patterns == ["mlp.experts.gate_up_proj"]
and c.target_patterns == ["mlp.experts.gate_up_proj"]
]
assert len(same_key) == 1
assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved)
def test_register_creates_mapping_when_none(self):
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
# qwen3_5_moe has no conversion mapping in transformers
register_sonicmoe_weight_converter("qwen3_5_moe")
mapping = get_checkpoint_conversion_mapping("qwen3_5_moe")
assert mapping is not None
same_key = [
c
for c in mapping
if hasattr(c, "source_patterns")
and c.source_patterns == ["mlp.experts.gate_up_proj"]
and c.target_patterns == ["mlp.experts.gate_up_proj"]
]
assert len(same_key) == 1
assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved)
def _make_qwen_moe_block(T=8, H=16, E=4, K=2):
@@ -462,7 +493,7 @@ class TestSoftmaxBiasTopkRouting:
"""Tests for Ernie 4.5 MoE routing (softmax_bias_topk_routing)."""
def test_output_shapes(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_bias_topk_routing,
)
@@ -479,7 +510,7 @@ class TestSoftmaxBiasTopkRouting:
assert logits.shape == (T, E)
def test_scores_are_float32(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_bias_topk_routing,
)
@@ -490,7 +521,7 @@ class TestSoftmaxBiasTopkRouting:
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_bias_topk_routing,
)
@@ -502,7 +533,7 @@ class TestSoftmaxBiasTopkRouting:
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_bias_topk_routing,
)
@@ -514,7 +545,7 @@ class TestSoftmaxBiasTopkRouting:
assert (expert_idx < E).all()
def test_renormalized_scores_sum_to_one(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_bias_topk_routing,
)
@@ -527,7 +558,7 @@ class TestSoftmaxBiasTopkRouting:
def test_bias_affects_expert_selection(self):
"""Large positive bias on expert 0 should make it always selected."""
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_bias_topk_routing,
)
@@ -570,7 +601,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
"""Tests for DeepSeek V2 routing (softmax_group_limited_topk_routing)."""
def test_output_shapes_group_limited(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -589,7 +620,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
assert logits.shape == (T, E)
def test_output_shapes_greedy(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -604,7 +635,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
assert logits.shape == (T, E)
def test_scores_are_float32(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -615,7 +646,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -627,7 +658,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -639,7 +670,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
assert (expert_idx < E).all()
def test_scaling_factor_applied(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -655,7 +686,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
def test_group_selection_restricts_experts(self):
"""With num_group=4 and topk_group=1, experts should come from selected groups."""
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -674,7 +705,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
assert (groups == groups[0]).all()
def test_unsupported_topk_method_raises(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -706,7 +737,7 @@ class TestSoftmaxTopkWgRouting:
"""Tests for HunYuan V1 MoE routing (softmax_topk_wg_routing)."""
def test_output_shapes(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_wg_routing,
)
@@ -723,7 +754,7 @@ class TestSoftmaxTopkWgRouting:
assert logits.shape == (T, E)
def test_scores_are_float32(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_wg_routing,
)
@@ -734,7 +765,7 @@ class TestSoftmaxTopkWgRouting:
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_wg_routing,
)
@@ -746,7 +777,7 @@ class TestSoftmaxTopkWgRouting:
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_wg_routing,
)
@@ -759,7 +790,7 @@ class TestSoftmaxTopkWgRouting:
def test_renormalized_scores_sum_to_one(self):
"""HunYuan V1 always renormalizes (no norm_topk_prob flag)."""
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_wg_routing,
)
@@ -772,7 +803,7 @@ class TestSoftmaxTopkWgRouting:
def test_uses_gate_wg_weight(self):
"""Verify that modifying gate.wg.weight changes the routing output."""
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_wg_routing,
)

View File

@@ -7,7 +7,7 @@ code path where routing happens in float32.
import torch
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
softmax_topk_routing,
)

View File

@@ -0,0 +1,328 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""Unit tests for SonicMoE LoRA support."""
from unittest.mock import MagicMock
import pytest
import torch
from axolotl.integrations.kernels.libs.sonicmoe.lora import (
MoELoRAMaterialize,
get_lora_params_from_wrapper,
has_lora,
materialize_expert_lora,
unwrap_experts_lora,
unwrap_gate_lora,
)
# =============================================================================
# Helpers: mock PEFT modules
# =============================================================================
def _make_mock_lora_module(weight_A, weight_B, scaling_val, param_name=None):
"""Create a mock PEFT-wrapped module with LoRA attributes."""
mock = MagicMock()
lora_A_linear = MagicMock()
lora_A_linear.weight = weight_A
lora_B_linear = MagicMock()
lora_B_linear.weight = weight_B
mock.lora_A = {"default": lora_A_linear}
mock.lora_B = {"default": lora_B_linear}
mock.scaling = {"default": scaling_val}
mock.active_adapters = ["default"]
if param_name is not None:
mock.parameter_name = param_name
return mock
def _make_peft_gate(hidden_size, num_experts, rank, scaling=0.5):
"""Create a mock PEFT-wrapped gate module."""
base_gate = MagicMock()
base_gate.weight = torch.randn(num_experts, hidden_size)
base_gate.top_k = 2
base_gate.norm_topk_prob = True
lora_A = torch.randn(rank, hidden_size)
lora_B = torch.randn(num_experts, rank)
wrapper = _make_mock_lora_module(lora_A, lora_B, scaling)
wrapper.base_layer = base_gate
return wrapper, base_gate
def _make_peft_experts(
num_experts, gate_up_dim, down_dim, hidden_size, rank, scaling=0.5
):
"""Create a mock PEFT-wrapped experts chain.
Simulates: ParamWrapper(down_proj) -> ParamWrapper(gate_up_proj) -> Experts
"""
base_experts = MagicMock()
base_experts.gate_up_proj = torch.randn(num_experts, gate_up_dim, hidden_size)
base_experts.down_proj = torch.randn(num_experts, hidden_size, down_dim)
# Remove base_layer and lora_A from base_experts so the chain walk stops
del base_experts.base_layer
del base_experts.lora_A
# gate_up_proj wrapper
gup_A = torch.randn(rank * num_experts, hidden_size)
gup_B = torch.randn(gate_up_dim, rank * num_experts)
gup_wrapper = _make_mock_lora_module(gup_A, gup_B, scaling, "gate_up_proj")
gup_wrapper.base_layer = base_experts
# down_proj wrapper (outermost)
down_A = torch.randn(rank * num_experts, down_dim)
down_B = torch.randn(hidden_size, rank * num_experts)
down_wrapper = _make_mock_lora_module(down_A, down_B, scaling, "down_proj")
down_wrapper.base_layer = gup_wrapper
return down_wrapper, base_experts, (gup_A, gup_B), (down_A, down_B)
# =============================================================================
# Tests: has_lora
# =============================================================================
class TestHasLora:
def test_plain_module(self):
m = MagicMock(spec=["weight"])
del m.base_layer
del m.lora_A
assert not has_lora(m)
def test_wrapped_module(self):
m = MagicMock()
m.base_layer = MagicMock()
m.lora_A = {"default": MagicMock()}
assert has_lora(m)
# =============================================================================
# Tests: get_lora_params_from_wrapper
# =============================================================================
class TestGetLoraParams:
def test_no_lora_attrs(self):
m = MagicMock(spec=["weight"])
del m.lora_A
del m.lora_B
assert get_lora_params_from_wrapper(m) == (None, None, None)
def test_extracts_params(self):
A = torch.randn(4, 8)
B = torch.randn(16, 4)
wrapper = _make_mock_lora_module(A, B, 0.5)
lora_A, lora_B, scaling = get_lora_params_from_wrapper(wrapper)
assert torch.equal(lora_A, A)
assert torch.equal(lora_B, B)
assert scaling == 0.5
def test_no_active_adapters(self):
wrapper = _make_mock_lora_module(torch.randn(4, 8), torch.randn(16, 4), 0.5)
wrapper.active_adapters = []
assert get_lora_params_from_wrapper(wrapper) == (None, None, None)
# =============================================================================
# Tests: unwrap_gate_lora
# =============================================================================
class TestUnwrapGateLora:
def test_plain_gate(self):
gate = MagicMock(spec=["weight", "top_k"])
del gate.base_layer
del gate.lora_A
gate.weight = torch.randn(8, 64)
base, weight, delta = unwrap_gate_lora(gate)
assert base is gate
assert torch.equal(weight, gate.weight)
assert delta is None
def test_wrapped_gate(self):
wrapper, base_gate = _make_peft_gate(
hidden_size=64, num_experts=8, rank=4, scaling=0.5
)
base, weight, delta = unwrap_gate_lora(wrapper)
assert base is base_gate
assert torch.equal(weight, base_gate.weight)
assert delta is not None
assert delta.shape == base_gate.weight.shape
# Verify delta = scaling * B @ A
lora_A = wrapper.lora_A["default"].weight
lora_B = wrapper.lora_B["default"].weight
expected = 0.5 * (lora_B @ lora_A)
assert torch.allclose(delta, expected)
# =============================================================================
# Tests: unwrap_experts_lora
# =============================================================================
class TestUnwrapExpertsLora:
def test_plain_experts(self):
experts = MagicMock(spec=["gate_up_proj", "down_proj"])
del experts.base_layer
del experts.lora_A
base, lora_dict = unwrap_experts_lora(experts)
assert base is experts
assert lora_dict == {}
def test_wrapped_experts(self):
E, I2, I, H, r = 4, 256, 128, 64, 8 # noqa: E741
wrapper, base_experts, (gup_A, gup_B), (down_A, down_B) = _make_peft_experts(
E, I2, I, H, r, scaling=0.25
)
base, lora_dict = unwrap_experts_lora(wrapper)
assert base is base_experts
assert "gate_up_proj" in lora_dict
assert "down_proj" in lora_dict
gup_lA, gup_lB, gup_s = lora_dict["gate_up_proj"]
assert torch.equal(gup_lA, gup_A)
assert torch.equal(gup_lB, gup_B)
assert gup_s == 0.25
down_lA, down_lB, down_s = lora_dict["down_proj"]
assert torch.equal(down_lA, down_A)
assert torch.equal(down_lB, down_B)
assert down_s == 0.25
def test_partial_lora(self):
"""Only gate_up_proj has LoRA, down_proj does not."""
base_experts = MagicMock(spec=["gate_up_proj", "down_proj"])
del base_experts.base_layer
del base_experts.lora_A
gup_A = torch.randn(16, 64)
gup_B = torch.randn(256, 16)
gup_wrapper = _make_mock_lora_module(gup_A, gup_B, 0.5, "gate_up_proj")
gup_wrapper.base_layer = base_experts
base, lora_dict = unwrap_experts_lora(gup_wrapper)
assert base is base_experts
assert "gate_up_proj" in lora_dict
assert "down_proj" not in lora_dict
# =============================================================================
# Tests: MoELoRAMaterialize
# =============================================================================
class TestMoELoRAMaterialize:
@pytest.fixture()
def setup(self):
E, dim1, dim2, r = 4, 32, 16, 4
scaling = 0.5
W = torch.randn(E, dim1, dim2, dtype=torch.float64, requires_grad=False)
A = torch.randn(r * E, dim2, dtype=torch.float64, requires_grad=True)
B = torch.randn(dim1, r * E, dtype=torch.float64, requires_grad=True)
return W, A, B, scaling, E, r
def test_forward_shape(self, setup):
W, A, B, scaling, E, r = setup
W_eff = MoELoRAMaterialize.apply(W, A, B, scaling)
assert W_eff.shape == W.shape
def test_forward_correctness(self, setup):
W, A, B, scaling, E, r = setup
W_eff = MoELoRAMaterialize.apply(W, A, B, scaling)
# Manual per-expert computation.
# lora_A is expert-major: [r*E, dim2] -> rows [e*r:(e+1)*r] = expert e
# lora_B is rank-major: [dim1, r*E] -> reshape [dim1, r, E], slice [:, :, e]
_, dim1, dim2 = W.shape
expected = W.clone()
B_3d = B.reshape(dim1, r, E)
for e in range(E):
A_e = A[e * r : (e + 1) * r, :] # [r, dim2]
B_e = B_3d[:, :, e] # [dim1, r]
expected[e] += scaling * (B_e @ A_e)
assert torch.allclose(W_eff, expected, atol=1e-10)
def test_backward_gradcheck(self, setup):
W, A, B, scaling, E, r = setup
# gradcheck requires float64
assert torch.autograd.gradcheck(
lambda a, b: MoELoRAMaterialize.apply(W, a, b, scaling),
(A, B),
eps=1e-6,
atol=1e-4,
)
def test_no_grad_for_base_weight(self, setup):
W, A, B, scaling, E, r = setup
W.requires_grad_(True)
W_eff = MoELoRAMaterialize.apply(W, A, B, scaling)
loss = W_eff.sum()
loss.backward()
assert W.grad is None
assert A.grad is not None
assert B.grad is not None
def test_scaling_zero(self, setup):
W, A, B, _, E, r = setup
W_eff = MoELoRAMaterialize.apply(W, A, B, 0.0)
assert torch.allclose(W_eff, W)
def test_gate_up_proj_shapes(self):
"""Test with realistic gate_up_proj shapes [E, 2*I, H]."""
E, I2, H, r = 8, 512, 256, 16
W = torch.randn(E, I2, H, dtype=torch.float64)
A = torch.randn(r * E, H, dtype=torch.float64, requires_grad=True)
B = torch.randn(I2, r * E, dtype=torch.float64, requires_grad=True)
W_eff = MoELoRAMaterialize.apply(W, A, B, 1.0)
assert W_eff.shape == (E, I2, H)
loss = W_eff.sum()
loss.backward()
assert A.grad.shape == A.shape
assert B.grad.shape == B.shape
def test_down_proj_shapes(self):
"""Test with realistic down_proj shapes [E, H, I]."""
E, H, I, r = 8, 256, 512, 16 # noqa: E741
W = torch.randn(E, H, I, dtype=torch.float64)
A = torch.randn(r * E, I, dtype=torch.float64, requires_grad=True)
B = torch.randn(H, r * E, dtype=torch.float64, requires_grad=True)
W_eff = MoELoRAMaterialize.apply(W, A, B, 1.0)
assert W_eff.shape == (E, H, I)
loss = W_eff.sum()
loss.backward()
assert A.grad.shape == A.shape
assert B.grad.shape == B.shape
# =============================================================================
# Tests: materialize_expert_lora
# =============================================================================
class TestMaterializeExpertLora:
def test_none_passthrough(self):
W = torch.randn(4, 32, 16)
result = materialize_expert_lora(W, None)
assert result is W
def test_with_lora(self):
E, dim1, dim2, r = 4, 32, 16, 4
W = torch.randn(E, dim1, dim2)
A = torch.randn(r * E, dim2, requires_grad=True)
B = torch.randn(dim1, r * E, requires_grad=True)
result = materialize_expert_lora(W, (A, B, 0.5))
assert result.shape == W.shape
assert not torch.equal(result, W)