feat: add sonicmoe fused lora support (#3519)
* feat: add sonicmoe fused lora support * fix: forgot to add file * feat: add test * feat: add lora support for other routes * fix: add int8 lora support * fix: add qwen35_moe interleave support * fix: qwen3_5_moe loss * chore: lint * address some pr comments * fix test imports * add support matrix for moe kernels [skip ci] --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -52,26 +52,91 @@ The `KernelsPlugin` runs before model loading and:
|
||||
|
||||
### ScatterMoE
|
||||
1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).
|
||||
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
|
||||
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation via the HF `kernels` library.
|
||||
|
||||
### SonicMoE
|
||||
1. Resolves the model's MoE block class(es) from `constants.py`.
|
||||
2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format.
|
||||
3. Supports both softmax->topk and sigmoid->topk routing strategies.
|
||||
2. Patches the forward method with SonicMoE's optimized CUTLASS kernels and registers a weight converter for the interleaved gate/up projection format.
|
||||
3. Supports pluggable routing strategies (see routing table below).
|
||||
|
||||
Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.
|
||||
|
||||
#### Supported Models
|
||||
## Model Support Matrix
|
||||
|
||||
See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.).
|
||||
All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel currently supports non-SwiGLU MoE architectures.
|
||||
|
||||
### Routing strategies
|
||||
|
||||
| Routing Strategy | Description | ScatterMoE | SonicMoE |
|
||||
|---|---|:---:|:---:|
|
||||
| softmax → topk | Softmax over experts, select top-K, optional renormalization | Yes | Yes |
|
||||
| softmax → group selection → topk | Softmax, select top groups (sum of top-2 per group), topk from selected groups, renorm + scaling | No | Yes |
|
||||
| sigmoid → topk (with groups) | Sigmoid + bias correction, group-based masking, topk from masked scores, weights from original sigmoid | Yes | Yes |
|
||||
| sigmoid → topk (no groups) | Sigmoid + bias correction, straight topk (n_group=1) | Yes | Yes |
|
||||
| softmax → bias correction → topk | Softmax, bias via `gate.moe_statics`, topk, gather from original probs, clamp-based renorm | No | Yes |
|
||||
| softmax → group_limited_greedy | Softmax, group selection (max per group), topk, scale only (no renorm) | No | Yes |
|
||||
| softmax → topk via gate.wg | Softmax, gate weight at `gate.wg.weight` (not `gate.weight`), always renormalize | No | Yes |
|
||||
| fused topk → softmax | Routing + expert computation fused in a single kernel | No | Planned |
|
||||
|
||||
### Per-model support
|
||||
|
||||
| Model Type | Architecture | Routing | ScatterMoE | SonicMoE |
|
||||
|---|---|---|:---:|:---:|
|
||||
| `qwen2_moe` | Qwen2-MoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_moe` | Qwen3-MoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_5_moe` | Qwen3.5-MoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_5_moe_text` | Qwen3.5-MoE (VLM text) | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_next` | Qwen3-Next | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_vl_moe` | Qwen3-VL-MoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_omni_moe` | Qwen3-Omni (Thinker + Talker) | softmax → topk | **Yes** | **Yes** |
|
||||
| `olmoe` | OLMoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `mixtral` | Mixtral | softmax → topk | **Yes** | **Yes** |
|
||||
| `minimax` | MiniMax | softmax → topk | **Yes** | **Yes** |
|
||||
| `mistral4` | Mistral 4 | softmax → group → topk | No | **Yes** |
|
||||
| `glm_moe_dsa` | GLM-MoE DSA (GLM 5) | sigmoid → topk (groups) | **Yes** | **Yes** |
|
||||
| `deepseek_v3` | DeepSeek-V3 | sigmoid → topk (groups) | **Yes** | **Yes** |
|
||||
| `glm4_moe` | GLM4-MoE | sigmoid → topk (groups) | **Yes** | **Yes** |
|
||||
| `glm4_moe_lite` | GLM4-MoE Lite (GLM 4.7 Flash) | sigmoid → topk (groups) | **Yes**\* | **Yes** |
|
||||
| `glm4v_moe` | GLM4v-MoE | sigmoid → topk (groups) | **Yes** | **Yes** |
|
||||
| `minimax_m2` | MiniMax M2 | sigmoid → topk (no groups) | **Yes** | **Yes** |
|
||||
| `ernie4_5_moe` | ERNIE 4.5 MoE | softmax → bias → topk | No | **Yes** |
|
||||
| `deepseek_v2` | DeepSeek-V2 | softmax → group_limited_greedy | No | **Yes** |
|
||||
| `hunyuan_v1_moe` | HunYuan V1 MoE | softmax → topk (gate.wg) | No | **Yes** |
|
||||
| `gpt_oss` | GPT-OSS | fused topk → softmax | No | Planned |
|
||||
|
||||
\* `glm4_moe_lite` with ScatterMoE may have issues — see Limitations.
|
||||
|
||||
### Feature comparison
|
||||
|
||||
| Feature | ScatterMoE | SonicMoE |
|
||||
|---|:---:|:---:|
|
||||
| Kernel backend | Triton | CUTLASS |
|
||||
| GPU requirement | Any CUDA | Hopper (H100/H200) or Blackwell (B200+) |
|
||||
| LoRA approach | Fused in Triton kernel | Runtime materialization + custom autograd |
|
||||
| LoRA overhead | Lower (fused computation) | Higher (per-forward materialization) |
|
||||
| Gate/router LoRA | Yes | Yes |
|
||||
| Expert LoRA | Yes (fused) | Yes (materialized) |
|
||||
| Shared expert LoRA | Yes (standard PEFT) | Yes (standard PEFT) |
|
||||
| Selective expert dequantization | Yes (~97% memory savings) | No |
|
||||
| Weight format | Transposed `[E, hidden, 2*inter]` | Interleaved gate/up `[2*I, H, E]` |
|
||||
| torch.compile routing | No | Yes (optional) |
|
||||
|
||||
## Shared Expert Handling
|
||||
|
||||
Both kernels handle shared experts identically. Shared expert attribute names are detected in order of priority:
|
||||
|
||||
1. `shared_expert` (Qwen2-MoE)
|
||||
2. `shared_experts` (GLM-MoE, DeepSeek-V3)
|
||||
3. `shared_mlp` (HunYuan V1 MoE)
|
||||
|
||||
If `shared_expert_gate` exists, sigmoid gating is applied to the shared expert contribution before adding it to the routed output. PEFT wraps shared expert linear layers with standard LoRA — no special handling is needed.
|
||||
|
||||
## Limitations
|
||||
|
||||
ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment.
|
||||
|
||||
SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures.
|
||||
|
||||
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
|
||||
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).
|
||||
- **Non-SwiGLU activations**: Neither kernel supports MoE architectures with non-SwiGLU expert activations (e.g., GPT-OSS uses a custom GLU variant).
|
||||
- **GPT-OSS**: Deferred — requires transposed weight layout `[E, H, 2*I]`, expert biases, and custom GLU activation. A dedicated forward path is needed.
|
||||
- **FSDP + fused gate LoRA (SonicMoE)**: The fused topk→softmax path materializes a local tensor when LoRA delta is present to avoid DTensor + Tensor mixing under FSDP.
|
||||
|
||||
## Note on MegaBlocks
|
||||
|
||||
|
||||
@@ -45,6 +45,28 @@ class KernelsArgs(BaseModel):
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_sonicmoe_lora_overhead(cls, data):
|
||||
if data.get("use_sonicmoe") is True and data.get("adapter") in (
|
||||
"lora",
|
||||
"qlora",
|
||||
):
|
||||
lora_target = data.get("lora_target_modules") or []
|
||||
lora_linear = data.get("lora_target_linear_modules") or []
|
||||
targets = (
|
||||
lora_target if isinstance(lora_target, list) else [lora_target]
|
||||
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
|
||||
expert_keywords = ("gate_up_proj", "down_proj", "experts")
|
||||
if any(kw in t for t in targets for kw in expert_keywords):
|
||||
LOG.info(
|
||||
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
|
||||
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
|
||||
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def disable_mlp_kernel(cls, data):
|
||||
|
||||
@@ -49,6 +49,11 @@ class ParallelLinear(torch.autograd.Function):
|
||||
grouped_in: bool = False,
|
||||
grouped_out: bool = False,
|
||||
):
|
||||
# Cast weights to match input dtype (e.g. 8-bit LoRA)
|
||||
if expert_weights.dtype != x.dtype:
|
||||
expert_weights = expert_weights.to(x.dtype)
|
||||
if expert_biases is not None and expert_biases.dtype != x.dtype:
|
||||
expert_biases = expert_biases.to(x.dtype)
|
||||
with torch.device(x.device):
|
||||
output = kernels.ops.scatter2scatter(
|
||||
X=x,
|
||||
|
||||
@@ -65,6 +65,11 @@ class ScatterMoELoRA(torch.autograd.Function):
|
||||
use_fused_dX: bool = False,
|
||||
use_fused_gather: bool = False,
|
||||
):
|
||||
# Cast weights to match input dtype (e.g. 8-bit LoRA)
|
||||
if expert_weights.dtype != x.dtype:
|
||||
expert_weights = expert_weights.to(x.dtype)
|
||||
if expert_biases is not None and expert_biases.dtype != x.dtype:
|
||||
expert_biases = expert_biases.to(x.dtype)
|
||||
with torch.device(x.device):
|
||||
# Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T
|
||||
output = scatter2scatter_lora(
|
||||
|
||||
220
src/axolotl/integrations/kernels/libs/sonicmoe/lora.py
Normal file
220
src/axolotl/integrations/kernels/libs/sonicmoe/lora.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
SonicMoE LoRA support via runtime weight materialization.
|
||||
|
||||
SonicMoE uses opaque CUTLASS kernels that cannot be modified to fuse LoRA.
|
||||
Instead, we materialize the effective weight W_eff = W + scaling * (B @ A)
|
||||
before each CUTLASS call, and use a custom autograd.Function to route
|
||||
gradients back to the LoRA A and B parameters.
|
||||
|
||||
PEFT unwrapping utilities are also provided to handle the ParamWrapper
|
||||
chain that PEFT creates when targeting expert parameters.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
# =============================================================================
|
||||
# PEFT unwrapping utilities
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def has_lora(module) -> bool:
|
||||
"""Check if a module is wrapped by PEFT with LoRA."""
|
||||
return hasattr(module, "base_layer") and hasattr(module, "lora_A")
|
||||
|
||||
|
||||
def get_lora_params_from_wrapper(module) -> tuple:
|
||||
"""Extract LoRA parameters from a PEFT ParamWrapper.
|
||||
|
||||
Returns:
|
||||
(lora_A, lora_B, scaling) if LoRA is active, else (None, None, None)
|
||||
"""
|
||||
if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"):
|
||||
return None, None, None
|
||||
|
||||
active_adapters = getattr(module, "active_adapters", ["default"])
|
||||
if not active_adapters:
|
||||
return None, None, None
|
||||
|
||||
adapter_name = active_adapters[0]
|
||||
|
||||
lora_A_dict = getattr(module, "lora_A", {})
|
||||
lora_B_dict = getattr(module, "lora_B", {})
|
||||
scaling_dict = getattr(module, "scaling", {})
|
||||
|
||||
if (
|
||||
adapter_name not in lora_A_dict
|
||||
or adapter_name not in lora_B_dict
|
||||
or adapter_name not in scaling_dict
|
||||
):
|
||||
return None, None, None
|
||||
|
||||
lora_A = lora_A_dict[adapter_name].weight
|
||||
lora_B = lora_B_dict[adapter_name].weight
|
||||
scaling = scaling_dict[adapter_name]
|
||||
|
||||
return lora_A, lora_B, scaling
|
||||
|
||||
|
||||
def unwrap_gate_lora(gate_module):
|
||||
"""Unwrap PEFT ParamWrapper on the router gate.
|
||||
|
||||
When PEFT targets ``gate.weight``, ``self.gate`` becomes::
|
||||
|
||||
ParamWrapper(weight)
|
||||
-> base_layer: Router (the real module)
|
||||
|
||||
Returns:
|
||||
(base_gate, gate_weight, gate_lora_delta_or_None)
|
||||
|
||||
``base_gate`` is the original router module (with ``.top_k``, etc.).
|
||||
``gate_weight`` is the base router weight tensor.
|
||||
``gate_lora_delta_or_None`` is the LoRA delta if active, else None.
|
||||
Kept separate to avoid mixing DTensor + Tensor under FSDP.
|
||||
"""
|
||||
if has_lora(gate_module):
|
||||
base_gate = gate_module.base_layer
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module)
|
||||
if lora_A is not None:
|
||||
delta = scaling * (lora_B @ lora_A)
|
||||
return base_gate, base_gate.weight, delta
|
||||
return base_gate, base_gate.weight, None
|
||||
|
||||
return gate_module, gate_module.weight, None
|
||||
|
||||
|
||||
def unwrap_experts_lora(experts_module):
|
||||
"""Walk a PEFT ParamWrapper chain on ``self.experts``.
|
||||
|
||||
When PEFT targets ``experts.gate_up_proj`` and ``experts.down_proj``
|
||||
via ``target_parameters``, ``self.experts`` becomes::
|
||||
|
||||
ParamWrapper(down_proj)
|
||||
-> base_layer: ParamWrapper(gate_up_proj)
|
||||
-> base_layer: Experts (the real module)
|
||||
|
||||
Returns:
|
||||
(base_experts, lora_dict)
|
||||
|
||||
``lora_dict`` maps parameter names to ``(lora_A, lora_B, scaling)``
|
||||
tuples, or is empty if no LoRA is active.
|
||||
"""
|
||||
wrappers = {}
|
||||
module = experts_module
|
||||
while hasattr(module, "base_layer") and hasattr(module, "lora_A"):
|
||||
param_name = getattr(module, "parameter_name", None)
|
||||
if param_name is not None:
|
||||
wrappers[param_name] = module
|
||||
module = module.base_layer
|
||||
|
||||
base_experts = module
|
||||
lora_dict = {}
|
||||
|
||||
for param_name, wrapper in wrappers.items():
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(wrapper)
|
||||
if lora_A is not None:
|
||||
lora_dict[param_name] = (lora_A, lora_B, scaling)
|
||||
|
||||
return base_experts, lora_dict
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LoRA weight materialization autograd function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MoELoRAMaterialize(torch.autograd.Function):
|
||||
"""Materialize effective weight W_eff = W + scaling * (B @ A) per expert.
|
||||
|
||||
Inserts into the autograd graph between PEFT's LoRA parameters and
|
||||
SonicMoE's CUTLASS kernels. The CUTLASS backward computes dW_eff,
|
||||
which this function decomposes into dA and dB via the chain rule.
|
||||
|
||||
Weight layouts (PEFT rank-major):
|
||||
base_weight: [E, dim1, dim2] (frozen expert parameter)
|
||||
lora_A: [r*E, dim2] (rows [e*r:(e+1)*r] = A_e)
|
||||
lora_B: [dim1, r*E] (cols [:, e*r:(e+1)*r] = B_e)
|
||||
|
||||
Per-expert: delta_e = B_e @ A_e = [dim1, r] @ [r, dim2] = [dim1, dim2]
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
base_weight: torch.Tensor,
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
scaling: float,
|
||||
) -> torch.Tensor:
|
||||
E, dim1, dim2 = base_weight.shape
|
||||
r = lora_A.shape[0] // E
|
||||
assert lora_A.shape[0] == r * E, (
|
||||
f"lora_A rows ({lora_A.shape[0]}) must be divisible by num_experts ({E})"
|
||||
)
|
||||
|
||||
# Reshape PEFT rank-major to per-expert batched format
|
||||
A_3d = lora_A.reshape(E, r, dim2)
|
||||
B_3d = lora_B.reshape(dim1, r, E).permute(2, 0, 1).contiguous() # [E, dim1, r]
|
||||
|
||||
# Batched matmul: [E, dim1, r] @ [E, r, dim2] = [E, dim1, dim2]
|
||||
delta = torch.bmm(B_3d, A_3d)
|
||||
|
||||
W_eff = base_weight + scaling * delta
|
||||
|
||||
ctx.save_for_backward(lora_A, lora_B)
|
||||
ctx.scaling = scaling
|
||||
ctx.E = E
|
||||
ctx.r = r
|
||||
|
||||
return W_eff
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_W_eff: torch.Tensor):
|
||||
lora_A, lora_B = ctx.saved_tensors
|
||||
scaling = ctx.scaling
|
||||
E = ctx.E
|
||||
r = ctx.r
|
||||
|
||||
_, dim1, dim2 = grad_W_eff.shape
|
||||
|
||||
# Reshape to per-expert (same as forward)
|
||||
A_3d = lora_A.reshape(E, r, dim2)
|
||||
B_3d = lora_B.reshape(dim1, r, E).permute(2, 0, 1).contiguous() # [E, dim1, r]
|
||||
|
||||
# dA_e = scaling * B_e^T @ dW_e
|
||||
# [E, r, dim1] @ [E, dim1, dim2] = [E, r, dim2]
|
||||
d_A_3d = scaling * torch.bmm(B_3d.transpose(1, 2), grad_W_eff)
|
||||
|
||||
# dB_e = scaling * dW_e @ A_e^T
|
||||
# [E, dim1, dim2] @ [E, dim2, r] = [E, dim1, r]
|
||||
d_B_3d = scaling * torch.bmm(grad_W_eff, A_3d.transpose(1, 2))
|
||||
|
||||
# Reshape back to PEFT rank-major layout
|
||||
d_lora_A = d_A_3d.reshape(E * r, dim2)
|
||||
d_lora_B = d_B_3d.permute(1, 2, 0).contiguous().reshape(dim1, E * r)
|
||||
|
||||
return None, d_lora_A, d_lora_B, None
|
||||
|
||||
|
||||
def materialize_expert_lora(
|
||||
base_weight: torch.Tensor,
|
||||
lora_params: Optional[tuple],
|
||||
) -> torch.Tensor:
|
||||
"""Materialize effective expert weight with optional LoRA delta.
|
||||
|
||||
Args:
|
||||
base_weight: [E, dim1, dim2] frozen expert parameter
|
||||
lora_params: (lora_A, lora_B, scaling) or None
|
||||
|
||||
Returns:
|
||||
W_eff if lora_params is not None, else base_weight unchanged.
|
||||
"""
|
||||
if lora_params is None:
|
||||
return base_weight
|
||||
lora_A, lora_B, scaling = lora_params
|
||||
return MoELoRAMaterialize.apply(base_weight, lora_A, lora_B, scaling)
|
||||
@@ -28,20 +28,72 @@ import torch.nn.functional as F
|
||||
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .lora import (
|
||||
has_lora,
|
||||
materialize_expert_lora,
|
||||
unwrap_experts_lora,
|
||||
unwrap_gate_lora,
|
||||
)
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def patch_sonicmoe(model_type: str, torch_compile: bool = False):
|
||||
"""Main entry point: patch SparseMoeBlock for SonicMoE support.
|
||||
def _get_expert_weights(experts_module):
|
||||
"""Extract expert weights, applying LoRA materialization if PEFT is active.
|
||||
|
||||
Args:
|
||||
model_type: The HuggingFace model type (e.g. "qwen3_moe").
|
||||
torch_compile: If True, wrap routing functions with torch.compile
|
||||
for kernel fusion (fuses softmax+topk+renorm into fewer launches).
|
||||
Returns:
|
||||
(gate_up_weight, down_weight) in SonicMoE layout [dim, dim, E].
|
||||
"""
|
||||
if has_lora(experts_module):
|
||||
base_experts, lora_dict = unwrap_experts_lora(experts_module)
|
||||
gate_up = materialize_expert_lora(
|
||||
base_experts.gate_up_proj, lora_dict.get("gate_up_proj")
|
||||
)
|
||||
down = materialize_expert_lora(
|
||||
base_experts.down_proj, lora_dict.get("down_proj")
|
||||
)
|
||||
else:
|
||||
gate_up = experts_module.gate_up_proj
|
||||
down = experts_module.down_proj
|
||||
|
||||
# Permute to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
return gate_up.permute(1, 2, 0), down.permute(1, 2, 0)
|
||||
|
||||
|
||||
def _fix_qwen3_5_moe_text_weight_renaming(model_type: str, base_model_type: str):
|
||||
"""Strip qwen3_5_moe_text WeightRenaming in VLM mode to preserve custom loaders."""
|
||||
if model_type != "qwen3_5_moe_text" or base_model_type == "qwen3_5_moe_text":
|
||||
return
|
||||
|
||||
try:
|
||||
from transformers.conversion_mapping import (
|
||||
get_checkpoint_conversion_mapping,
|
||||
register_checkpoint_conversion_mapping,
|
||||
)
|
||||
from transformers.core_model_loading import WeightRenaming
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
text_mapping = get_checkpoint_conversion_mapping(model_type)
|
||||
if text_mapping and isinstance(text_mapping[0], WeightRenaming):
|
||||
text_mapping.pop(0)
|
||||
register_checkpoint_conversion_mapping(model_type, text_mapping, overwrite=True)
|
||||
LOG.info("Stripped qwen3_5_moe_text WeightRenaming for VLM mode")
|
||||
|
||||
|
||||
def patch_sonicmoe(
|
||||
model_type: str,
|
||||
torch_compile: bool = False,
|
||||
base_model_type: str | None = None,
|
||||
):
|
||||
"""Patch SparseMoeBlock for SonicMoE support."""
|
||||
from .routing import get_model_moe_config
|
||||
from .weight_converter import register_sonicmoe_weight_converter
|
||||
|
||||
_fix_qwen3_5_moe_text_weight_renaming(model_type, base_model_type or model_type)
|
||||
|
||||
routing_fn, activation, router_attr = get_model_moe_config(model_type)
|
||||
|
||||
if torch_compile and routing_fn is not None:
|
||||
@@ -113,11 +165,10 @@ def _make_general_forward(moe_cls, routing_fn, activation):
|
||||
hidden_states_flat, self
|
||||
)
|
||||
|
||||
# Permute weights to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
|
||||
down_weight = self.experts.down_proj.permute(1, 2, 0)
|
||||
# Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout
|
||||
gate_up_weight, down_weight = _get_expert_weights(self.experts)
|
||||
gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype)
|
||||
down_weight = down_weight.to(hidden_states_flat.dtype)
|
||||
E = gate_up_weight.shape[-1]
|
||||
|
||||
output, _ = moe_general_routing_inputs(
|
||||
@@ -161,22 +212,30 @@ def _make_fused_forward(moe_cls, activation, router_attr):
|
||||
# Shared expert (computed early, matching original model ordering)
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
router = getattr(self, router_attr)
|
||||
# Unwrap router for attribute access + optional LoRA delta
|
||||
raw_router = getattr(self, router_attr)
|
||||
base_router, router_weight, router_lora_delta = unwrap_gate_lora(raw_router)
|
||||
if router_lora_delta is not None:
|
||||
# Materialize local tensor to avoid DTensor + Tensor add under FSDP
|
||||
if hasattr(router_weight, "to_local"):
|
||||
router_weight = router_weight.to_local()
|
||||
effective_router_weight = router_weight + router_lora_delta
|
||||
else:
|
||||
effective_router_weight = router_weight
|
||||
|
||||
# Permute weights to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
|
||||
down_weight = self.experts.down_proj.permute(1, 2, 0)
|
||||
# Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout
|
||||
gate_up_weight, down_weight = _get_expert_weights(self.experts)
|
||||
gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype)
|
||||
down_weight = down_weight.to(hidden_states_flat.dtype)
|
||||
|
||||
output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer(
|
||||
hidden_states_flat,
|
||||
router.weight,
|
||||
effective_router_weight,
|
||||
gate_up_weight,
|
||||
None, # b1 (no gate/up bias)
|
||||
down_weight,
|
||||
None, # b2 (no down bias)
|
||||
router.top_k,
|
||||
base_router.top_k,
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
activation,
|
||||
False, # is_inference_mode
|
||||
@@ -16,6 +16,8 @@ When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .lora import unwrap_gate_lora
|
||||
|
||||
|
||||
def get_model_moe_config(model_type: str):
|
||||
"""Returns (routing_fn, activation, router_attr) for a given model type.
|
||||
@@ -40,6 +42,7 @@ def get_model_moe_config(model_type: str):
|
||||
"qwen2_moe",
|
||||
"qwen3_moe",
|
||||
"qwen3_5_moe",
|
||||
"qwen3_5_moe_text",
|
||||
"qwen3_next",
|
||||
"qwen3_vl_moe",
|
||||
"qwen3_omni_moe",
|
||||
@@ -88,12 +91,18 @@ def softmax_topk_routing(
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, _ = hidden_states.shape
|
||||
K = gate.top_k
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, H = hidden_states.shape
|
||||
K = base_gate.top_k
|
||||
|
||||
# Compute router logits and softmax over all experts
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
# Compute router logits and softmax over all experts.
|
||||
# Two F.linear calls avoid mixing DTensor (gate_weight) + Tensor (delta) under FSDP.
|
||||
# Cast to float32 to match LoRA delta dtype (PEFT computes in fp32).
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Select top-k experts per token
|
||||
@@ -101,7 +110,7 @@ def softmax_topk_routing(
|
||||
|
||||
# Renormalize if configured (default True for models without the attribute,
|
||||
# e.g. Mixtral/MiniMax which always normalize)
|
||||
if getattr(gate, "norm_topk_prob", True):
|
||||
if getattr(base_gate, "norm_topk_prob", True):
|
||||
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||
|
||||
# no-op: matches transformers which casts to softmax output dtype (float32).
|
||||
@@ -128,13 +137,17 @@ def softmax_group_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
|
||||
gate = moe_block.gate
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, _ = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||
E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
scores_for_choice = router_probs
|
||||
@@ -206,25 +219,29 @@ def sigmoid_topk_routing(
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, _ = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||
E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
|
||||
# Compute router logits and sigmoid probabilities
|
||||
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = router_logits.sigmoid() # [T, E]
|
||||
|
||||
# Bias-corrected scores for expert selection (not used for final weights).
|
||||
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block.
|
||||
e_score_correction_bias = getattr(gate, "e_score_correction_bias", None)
|
||||
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
raise AttributeError(
|
||||
f"sigmoid_topk_routing requires e_score_correction_bias on "
|
||||
f"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it"
|
||||
f"gate ({type(base_gate)}) or moe_block ({type(moe_block)}), but neither has it"
|
||||
)
|
||||
scores_for_choice = router_probs + e_score_correction_bias
|
||||
|
||||
@@ -296,16 +313,20 @@ def softmax_bias_topk_routing(
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, _ = hidden_states.shape
|
||||
K = gate.top_k
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, H = hidden_states.shape
|
||||
K = base_gate.top_k
|
||||
|
||||
# Compute router logits and softmax (force float32 for numerical stability)
|
||||
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Bias-corrected scores for expert selection (via moe_statics module)
|
||||
scores_for_choice = gate.moe_statics(router_probs) # [T, E]
|
||||
scores_for_choice = base_gate.moe_statics(router_probs) # [T, E]
|
||||
|
||||
# Select top-k experts using biased scores
|
||||
_, selected_experts = torch.topk(scores_for_choice, K, dim=-1) # [T, K]
|
||||
@@ -314,7 +335,7 @@ def softmax_bias_topk_routing(
|
||||
top_values = torch.gather(router_probs, dim=-1, index=selected_experts) # [T, K]
|
||||
|
||||
# Renormalize with clamp(min=norm_min) instead of sum+epsilon
|
||||
norm_min = getattr(gate, "norm_min", 1e-20)
|
||||
norm_min = getattr(base_gate, "norm_min", 1e-20)
|
||||
top_values = top_values / torch.clamp(
|
||||
top_values.sum(dim=-1, keepdim=True), min=norm_min
|
||||
)
|
||||
@@ -358,15 +379,19 @@ def softmax_group_limited_topk_routing(
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, _ = hidden_states.shape
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, H = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
num_group = getattr(moe_block, "num_group", 1)
|
||||
num_experts = gate.weight.shape[0]
|
||||
num_experts = gate_weight.shape[0]
|
||||
topk_method = getattr(moe_block, "topk_method", "greedy")
|
||||
|
||||
# Compute logits in float32 and softmax
|
||||
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
if topk_method == "greedy" or num_group == 1:
|
||||
@@ -445,12 +470,17 @@ def softmax_topk_wg_routing(
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, _ = hidden_states.shape
|
||||
T, H = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
|
||||
# Gate computes logits via gate.wg (nn.Linear, float32)
|
||||
wg = gate.wg
|
||||
router_logits = F.linear(hidden_states.float(), wg.weight.float()) # [T, E]
|
||||
# Unwrap at gate.wg level since PEFT targets the wg Linear, not the gate container
|
||||
base_wg, wg_weight, wg_lora_delta = unwrap_gate_lora(gate.wg)
|
||||
router_logits = F.linear(hidden_states.float(), wg_weight.float()) # [T, E]
|
||||
if wg_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), wg_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Select top-k experts
|
||||
@@ -129,15 +129,41 @@ class InterleavedToConcatenated(ConversionOps):
|
||||
return ConcatenatedToInterleaved(self.dim)
|
||||
|
||||
|
||||
def _make_same_key_interleave_converter():
|
||||
"""Create a WeightConverter that interleaves an already-fused gate_up_proj."""
|
||||
from transformers.core_model_loading import WeightConverter
|
||||
|
||||
return WeightConverter(
|
||||
source_patterns="mlp.experts.gate_up_proj",
|
||||
target_patterns="mlp.experts.gate_up_proj",
|
||||
operations=[ConcatenatedToInterleaved(dim=1)],
|
||||
)
|
||||
|
||||
|
||||
def _has_same_key_interleave(mapping) -> bool:
|
||||
"""Check whether the mapping already has a same-key gate_up_proj interleave converter."""
|
||||
for conv in mapping:
|
||||
if (
|
||||
hasattr(conv, "source_patterns")
|
||||
and conv.source_patterns == ["mlp.experts.gate_up_proj"]
|
||||
and conv.target_patterns == ["mlp.experts.gate_up_proj"]
|
||||
and hasattr(conv, "operations")
|
||||
and any(isinstance(op, ConcatenatedToInterleaved) for op in conv.operations)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def register_sonicmoe_weight_converter(model_type: str):
|
||||
"""Override the conversion mapping to add interleave step for gate_up_proj.
|
||||
"""Register weight converters to interleave gate_up_proj for SonicMoE.
|
||||
|
||||
Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj
|
||||
converter chain. For example, qwen3_moe's chain becomes:
|
||||
MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1)
|
||||
Handles two checkpoint formats:
|
||||
1. Separate per-expert weights (e.g. qwen3_moe): appends interleave to the
|
||||
existing merge chain (MergeModulelist -> Concatenate -> Interleave).
|
||||
2. Already-fused gate_up_proj (e.g. qwen3_5_moe_text): adds a same-key
|
||||
converter (gate_up_proj -> gate_up_proj with Interleave).
|
||||
|
||||
The reverse is auto-generated for saving:
|
||||
InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0)
|
||||
The loader matches whichever source pattern exists in the checkpoint.
|
||||
"""
|
||||
from transformers.conversion_mapping import (
|
||||
get_checkpoint_conversion_mapping,
|
||||
@@ -145,37 +171,32 @@ def register_sonicmoe_weight_converter(model_type: str):
|
||||
)
|
||||
|
||||
existing = get_checkpoint_conversion_mapping(model_type)
|
||||
|
||||
if existing is None:
|
||||
LOG.warning(
|
||||
f"No conversion mapping found for model type '{model_type}'. "
|
||||
"SonicMoE weight interleaving will not be applied during checkpoint loading."
|
||||
)
|
||||
# No mapping at all — create one with just the same-key converter
|
||||
mapping = [_make_same_key_interleave_converter()]
|
||||
register_checkpoint_conversion_mapping(model_type, mapping)
|
||||
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")
|
||||
return
|
||||
|
||||
# Find the gate_up_proj converter and append ConcatenatedToInterleaved
|
||||
patched = False
|
||||
# Append interleave to any existing many-to-one merge chain
|
||||
for converter in existing:
|
||||
if hasattr(converter, "operations") and any(
|
||||
"gate_up_proj" in pat for pat in converter.target_patterns
|
||||
):
|
||||
# Guard against double registration (e.g. plugin reloaded)
|
||||
if any(
|
||||
has_separate_sources = any(
|
||||
"gate_proj" in pat or "up_proj" in pat
|
||||
for pat in converter.source_patterns
|
||||
)
|
||||
if has_separate_sources and not any(
|
||||
isinstance(op, ConcatenatedToInterleaved) for op in converter.operations
|
||||
):
|
||||
LOG.info(
|
||||
f"SonicMoE weight converter already registered for '{model_type}'"
|
||||
)
|
||||
return
|
||||
converter.operations.append(ConcatenatedToInterleaved(dim=1))
|
||||
patched = True
|
||||
converter.operations.append(ConcatenatedToInterleaved(dim=1))
|
||||
break
|
||||
|
||||
if not patched:
|
||||
LOG.warning(
|
||||
f"Could not find gate_up_proj converter for model type '{model_type}'. "
|
||||
"SonicMoE weight interleaving will not be applied during checkpoint loading."
|
||||
)
|
||||
return
|
||||
# Also add a same-key converter for already-fused checkpoints
|
||||
if not _has_same_key_interleave(existing):
|
||||
existing.append(_make_same_key_interleave_converter())
|
||||
|
||||
register_checkpoint_conversion_mapping(model_type, existing, overwrite=True)
|
||||
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")
|
||||
@@ -84,12 +84,13 @@ class KernelsPlugin(BasePlugin):
|
||||
|
||||
_check_sonicmoe_gpu_compat()
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe
|
||||
|
||||
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
|
||||
patch_sonicmoe(
|
||||
moe_model_type,
|
||||
torch_compile=bool(getattr(cfg, "torch_compile", False)),
|
||||
base_model_type=cfg.model_config_type,
|
||||
)
|
||||
|
||||
def _register_kernels(self):
|
||||
|
||||
Reference in New Issue
Block a user