feat: add sonicmoe fused lora support (#3519)
* feat: add sonicmoe fused lora support * fix: forgot to add file * feat: add test * feat: add lora support for other routes * fix: add int8 lora support * fix: add qwen35_moe interleave support * fix: qwen3_5_moe loss * chore: lint * address some pr comments * fix test imports * add support matrix for moe kernels [skip ci] --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -52,26 +52,91 @@ The `KernelsPlugin` runs before model loading and:
|
||||
|
||||
### ScatterMoE
|
||||
1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels).
|
||||
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
|
||||
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation via the HF `kernels` library.
|
||||
|
||||
### SonicMoE
|
||||
1. Resolves the model's MoE block class(es) from `constants.py`.
|
||||
2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format.
|
||||
3. Supports both softmax->topk and sigmoid->topk routing strategies.
|
||||
2. Patches the forward method with SonicMoE's optimized CUTLASS kernels and registers a weight converter for the interleaved gate/up projection format.
|
||||
3. Supports pluggable routing strategies (see routing table below).
|
||||
|
||||
Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution.
|
||||
|
||||
#### Supported Models
|
||||
## Model Support Matrix
|
||||
|
||||
See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.).
|
||||
All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel currently supports non-SwiGLU MoE architectures.
|
||||
|
||||
### Routing strategies
|
||||
|
||||
| Routing Strategy | Description | ScatterMoE | SonicMoE |
|
||||
|---|---|:---:|:---:|
|
||||
| softmax → topk | Softmax over experts, select top-K, optional renormalization | Yes | Yes |
|
||||
| softmax → group selection → topk | Softmax, select top groups (sum of top-2 per group), topk from selected groups, renorm + scaling | No | Yes |
|
||||
| sigmoid → topk (with groups) | Sigmoid + bias correction, group-based masking, topk from masked scores, weights from original sigmoid | Yes | Yes |
|
||||
| sigmoid → topk (no groups) | Sigmoid + bias correction, straight topk (n_group=1) | Yes | Yes |
|
||||
| softmax → bias correction → topk | Softmax, bias via `gate.moe_statics`, topk, gather from original probs, clamp-based renorm | No | Yes |
|
||||
| softmax → group_limited_greedy | Softmax, group selection (max per group), topk, scale only (no renorm) | No | Yes |
|
||||
| softmax → topk via gate.wg | Softmax, gate weight at `gate.wg.weight` (not `gate.weight`), always renormalize | No | Yes |
|
||||
| fused topk → softmax | Routing + expert computation fused in a single kernel | No | Planned |
|
||||
|
||||
### Per-model support
|
||||
|
||||
| Model Type | Architecture | Routing | ScatterMoE | SonicMoE |
|
||||
|---|---|---|:---:|:---:|
|
||||
| `qwen2_moe` | Qwen2-MoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_moe` | Qwen3-MoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_5_moe` | Qwen3.5-MoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_5_moe_text` | Qwen3.5-MoE (VLM text) | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_next` | Qwen3-Next | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_vl_moe` | Qwen3-VL-MoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `qwen3_omni_moe` | Qwen3-Omni (Thinker + Talker) | softmax → topk | **Yes** | **Yes** |
|
||||
| `olmoe` | OLMoE | softmax → topk | **Yes** | **Yes** |
|
||||
| `mixtral` | Mixtral | softmax → topk | **Yes** | **Yes** |
|
||||
| `minimax` | MiniMax | softmax → topk | **Yes** | **Yes** |
|
||||
| `mistral4` | Mistral 4 | softmax → group → topk | No | **Yes** |
|
||||
| `glm_moe_dsa` | GLM-MoE DSA (GLM 5) | sigmoid → topk (groups) | **Yes** | **Yes** |
|
||||
| `deepseek_v3` | DeepSeek-V3 | sigmoid → topk (groups) | **Yes** | **Yes** |
|
||||
| `glm4_moe` | GLM4-MoE | sigmoid → topk (groups) | **Yes** | **Yes** |
|
||||
| `glm4_moe_lite` | GLM4-MoE Lite (GLM 4.7 Flash) | sigmoid → topk (groups) | **Yes**\* | **Yes** |
|
||||
| `glm4v_moe` | GLM4v-MoE | sigmoid → topk (groups) | **Yes** | **Yes** |
|
||||
| `minimax_m2` | MiniMax M2 | sigmoid → topk (no groups) | **Yes** | **Yes** |
|
||||
| `ernie4_5_moe` | ERNIE 4.5 MoE | softmax → bias → topk | No | **Yes** |
|
||||
| `deepseek_v2` | DeepSeek-V2 | softmax → group_limited_greedy | No | **Yes** |
|
||||
| `hunyuan_v1_moe` | HunYuan V1 MoE | softmax → topk (gate.wg) | No | **Yes** |
|
||||
| `gpt_oss` | GPT-OSS | fused topk → softmax | No | Planned |
|
||||
|
||||
\* `glm4_moe_lite` with ScatterMoE may have issues — see Limitations.
|
||||
|
||||
### Feature comparison
|
||||
|
||||
| Feature | ScatterMoE | SonicMoE |
|
||||
|---|:---:|:---:|
|
||||
| Kernel backend | Triton | CUTLASS |
|
||||
| GPU requirement | Any CUDA | Hopper (H100/H200) or Blackwell (B200+) |
|
||||
| LoRA approach | Fused in Triton kernel | Runtime materialization + custom autograd |
|
||||
| LoRA overhead | Lower (fused computation) | Higher (per-forward materialization) |
|
||||
| Gate/router LoRA | Yes | Yes |
|
||||
| Expert LoRA | Yes (fused) | Yes (materialized) |
|
||||
| Shared expert LoRA | Yes (standard PEFT) | Yes (standard PEFT) |
|
||||
| Selective expert dequantization | Yes (~97% memory savings) | No |
|
||||
| Weight format | Transposed `[E, hidden, 2*inter]` | Interleaved gate/up `[2*I, H, E]` |
|
||||
| torch.compile routing | No | Yes (optional) |
|
||||
|
||||
## Shared Expert Handling
|
||||
|
||||
Both kernels handle shared experts identically. Shared expert attribute names are detected in order of priority:
|
||||
|
||||
1. `shared_expert` (Qwen2-MoE)
|
||||
2. `shared_experts` (GLM-MoE, DeepSeek-V3)
|
||||
3. `shared_mlp` (HunYuan V1 MoE)
|
||||
|
||||
If `shared_expert_gate` exists, sigmoid gating is applied to the shared expert contribution before adding it to the routed output. PEFT wraps shared expert linear layers with standard LoRA — no special handling is needed.
|
||||
|
||||
## Limitations
|
||||
|
||||
ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment.
|
||||
|
||||
SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures.
|
||||
|
||||
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
|
||||
- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`).
|
||||
- **Non-SwiGLU activations**: Neither kernel supports MoE architectures with non-SwiGLU expert activations (e.g., GPT-OSS uses a custom GLU variant).
|
||||
- **GPT-OSS**: Deferred — requires transposed weight layout `[E, H, 2*I]`, expert biases, and custom GLU activation. A dedicated forward path is needed.
|
||||
- **FSDP + fused gate LoRA (SonicMoE)**: The fused topk→softmax path materializes a local tensor when LoRA delta is present to avoid DTensor + Tensor mixing under FSDP.
|
||||
|
||||
## Note on MegaBlocks
|
||||
|
||||
|
||||
@@ -45,6 +45,28 @@ class KernelsArgs(BaseModel):
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_sonicmoe_lora_overhead(cls, data):
|
||||
if data.get("use_sonicmoe") is True and data.get("adapter") in (
|
||||
"lora",
|
||||
"qlora",
|
||||
):
|
||||
lora_target = data.get("lora_target_modules") or []
|
||||
lora_linear = data.get("lora_target_linear_modules") or []
|
||||
targets = (
|
||||
lora_target if isinstance(lora_target, list) else [lora_target]
|
||||
) + (lora_linear if isinstance(lora_linear, list) else [lora_linear])
|
||||
expert_keywords = ("gate_up_proj", "down_proj", "experts")
|
||||
if any(kw in t for t in targets for kw in expert_keywords):
|
||||
LOG.info(
|
||||
"SonicMoE + LoRA on expert modules uses runtime weight materialization "
|
||||
"(W_eff = W + scaling*B@A per forward). This has slightly higher overhead "
|
||||
"than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def disable_mlp_kernel(cls, data):
|
||||
|
||||
@@ -49,6 +49,11 @@ class ParallelLinear(torch.autograd.Function):
|
||||
grouped_in: bool = False,
|
||||
grouped_out: bool = False,
|
||||
):
|
||||
# Cast weights to match input dtype (e.g. 8-bit LoRA)
|
||||
if expert_weights.dtype != x.dtype:
|
||||
expert_weights = expert_weights.to(x.dtype)
|
||||
if expert_biases is not None and expert_biases.dtype != x.dtype:
|
||||
expert_biases = expert_biases.to(x.dtype)
|
||||
with torch.device(x.device):
|
||||
output = kernels.ops.scatter2scatter(
|
||||
X=x,
|
||||
|
||||
@@ -65,6 +65,11 @@ class ScatterMoELoRA(torch.autograd.Function):
|
||||
use_fused_dX: bool = False,
|
||||
use_fused_gather: bool = False,
|
||||
):
|
||||
# Cast weights to match input dtype (e.g. 8-bit LoRA)
|
||||
if expert_weights.dtype != x.dtype:
|
||||
expert_weights = expert_weights.to(x.dtype)
|
||||
if expert_biases is not None and expert_biases.dtype != x.dtype:
|
||||
expert_biases = expert_biases.to(x.dtype)
|
||||
with torch.device(x.device):
|
||||
# Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T
|
||||
output = scatter2scatter_lora(
|
||||
|
||||
220
src/axolotl/integrations/kernels/libs/sonicmoe/lora.py
Normal file
220
src/axolotl/integrations/kernels/libs/sonicmoe/lora.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
SonicMoE LoRA support via runtime weight materialization.
|
||||
|
||||
SonicMoE uses opaque CUTLASS kernels that cannot be modified to fuse LoRA.
|
||||
Instead, we materialize the effective weight W_eff = W + scaling * (B @ A)
|
||||
before each CUTLASS call, and use a custom autograd.Function to route
|
||||
gradients back to the LoRA A and B parameters.
|
||||
|
||||
PEFT unwrapping utilities are also provided to handle the ParamWrapper
|
||||
chain that PEFT creates when targeting expert parameters.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
# =============================================================================
|
||||
# PEFT unwrapping utilities
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def has_lora(module) -> bool:
|
||||
"""Check if a module is wrapped by PEFT with LoRA."""
|
||||
return hasattr(module, "base_layer") and hasattr(module, "lora_A")
|
||||
|
||||
|
||||
def get_lora_params_from_wrapper(module) -> tuple:
|
||||
"""Extract LoRA parameters from a PEFT ParamWrapper.
|
||||
|
||||
Returns:
|
||||
(lora_A, lora_B, scaling) if LoRA is active, else (None, None, None)
|
||||
"""
|
||||
if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"):
|
||||
return None, None, None
|
||||
|
||||
active_adapters = getattr(module, "active_adapters", ["default"])
|
||||
if not active_adapters:
|
||||
return None, None, None
|
||||
|
||||
adapter_name = active_adapters[0]
|
||||
|
||||
lora_A_dict = getattr(module, "lora_A", {})
|
||||
lora_B_dict = getattr(module, "lora_B", {})
|
||||
scaling_dict = getattr(module, "scaling", {})
|
||||
|
||||
if (
|
||||
adapter_name not in lora_A_dict
|
||||
or adapter_name not in lora_B_dict
|
||||
or adapter_name not in scaling_dict
|
||||
):
|
||||
return None, None, None
|
||||
|
||||
lora_A = lora_A_dict[adapter_name].weight
|
||||
lora_B = lora_B_dict[adapter_name].weight
|
||||
scaling = scaling_dict[adapter_name]
|
||||
|
||||
return lora_A, lora_B, scaling
|
||||
|
||||
|
||||
def unwrap_gate_lora(gate_module):
|
||||
"""Unwrap PEFT ParamWrapper on the router gate.
|
||||
|
||||
When PEFT targets ``gate.weight``, ``self.gate`` becomes::
|
||||
|
||||
ParamWrapper(weight)
|
||||
-> base_layer: Router (the real module)
|
||||
|
||||
Returns:
|
||||
(base_gate, gate_weight, gate_lora_delta_or_None)
|
||||
|
||||
``base_gate`` is the original router module (with ``.top_k``, etc.).
|
||||
``gate_weight`` is the base router weight tensor.
|
||||
``gate_lora_delta_or_None`` is the LoRA delta if active, else None.
|
||||
Kept separate to avoid mixing DTensor + Tensor under FSDP.
|
||||
"""
|
||||
if has_lora(gate_module):
|
||||
base_gate = gate_module.base_layer
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module)
|
||||
if lora_A is not None:
|
||||
delta = scaling * (lora_B @ lora_A)
|
||||
return base_gate, base_gate.weight, delta
|
||||
return base_gate, base_gate.weight, None
|
||||
|
||||
return gate_module, gate_module.weight, None
|
||||
|
||||
|
||||
def unwrap_experts_lora(experts_module):
|
||||
"""Walk a PEFT ParamWrapper chain on ``self.experts``.
|
||||
|
||||
When PEFT targets ``experts.gate_up_proj`` and ``experts.down_proj``
|
||||
via ``target_parameters``, ``self.experts`` becomes::
|
||||
|
||||
ParamWrapper(down_proj)
|
||||
-> base_layer: ParamWrapper(gate_up_proj)
|
||||
-> base_layer: Experts (the real module)
|
||||
|
||||
Returns:
|
||||
(base_experts, lora_dict)
|
||||
|
||||
``lora_dict`` maps parameter names to ``(lora_A, lora_B, scaling)``
|
||||
tuples, or is empty if no LoRA is active.
|
||||
"""
|
||||
wrappers = {}
|
||||
module = experts_module
|
||||
while hasattr(module, "base_layer") and hasattr(module, "lora_A"):
|
||||
param_name = getattr(module, "parameter_name", None)
|
||||
if param_name is not None:
|
||||
wrappers[param_name] = module
|
||||
module = module.base_layer
|
||||
|
||||
base_experts = module
|
||||
lora_dict = {}
|
||||
|
||||
for param_name, wrapper in wrappers.items():
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(wrapper)
|
||||
if lora_A is not None:
|
||||
lora_dict[param_name] = (lora_A, lora_B, scaling)
|
||||
|
||||
return base_experts, lora_dict
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LoRA weight materialization autograd function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MoELoRAMaterialize(torch.autograd.Function):
|
||||
"""Materialize effective weight W_eff = W + scaling * (B @ A) per expert.
|
||||
|
||||
Inserts into the autograd graph between PEFT's LoRA parameters and
|
||||
SonicMoE's CUTLASS kernels. The CUTLASS backward computes dW_eff,
|
||||
which this function decomposes into dA and dB via the chain rule.
|
||||
|
||||
Weight layouts (PEFT rank-major):
|
||||
base_weight: [E, dim1, dim2] (frozen expert parameter)
|
||||
lora_A: [r*E, dim2] (rows [e*r:(e+1)*r] = A_e)
|
||||
lora_B: [dim1, r*E] (cols [:, e*r:(e+1)*r] = B_e)
|
||||
|
||||
Per-expert: delta_e = B_e @ A_e = [dim1, r] @ [r, dim2] = [dim1, dim2]
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
base_weight: torch.Tensor,
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
scaling: float,
|
||||
) -> torch.Tensor:
|
||||
E, dim1, dim2 = base_weight.shape
|
||||
r = lora_A.shape[0] // E
|
||||
assert lora_A.shape[0] == r * E, (
|
||||
f"lora_A rows ({lora_A.shape[0]}) must be divisible by num_experts ({E})"
|
||||
)
|
||||
|
||||
# Reshape PEFT rank-major to per-expert batched format
|
||||
A_3d = lora_A.reshape(E, r, dim2)
|
||||
B_3d = lora_B.reshape(dim1, r, E).permute(2, 0, 1).contiguous() # [E, dim1, r]
|
||||
|
||||
# Batched matmul: [E, dim1, r] @ [E, r, dim2] = [E, dim1, dim2]
|
||||
delta = torch.bmm(B_3d, A_3d)
|
||||
|
||||
W_eff = base_weight + scaling * delta
|
||||
|
||||
ctx.save_for_backward(lora_A, lora_B)
|
||||
ctx.scaling = scaling
|
||||
ctx.E = E
|
||||
ctx.r = r
|
||||
|
||||
return W_eff
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_W_eff: torch.Tensor):
|
||||
lora_A, lora_B = ctx.saved_tensors
|
||||
scaling = ctx.scaling
|
||||
E = ctx.E
|
||||
r = ctx.r
|
||||
|
||||
_, dim1, dim2 = grad_W_eff.shape
|
||||
|
||||
# Reshape to per-expert (same as forward)
|
||||
A_3d = lora_A.reshape(E, r, dim2)
|
||||
B_3d = lora_B.reshape(dim1, r, E).permute(2, 0, 1).contiguous() # [E, dim1, r]
|
||||
|
||||
# dA_e = scaling * B_e^T @ dW_e
|
||||
# [E, r, dim1] @ [E, dim1, dim2] = [E, r, dim2]
|
||||
d_A_3d = scaling * torch.bmm(B_3d.transpose(1, 2), grad_W_eff)
|
||||
|
||||
# dB_e = scaling * dW_e @ A_e^T
|
||||
# [E, dim1, dim2] @ [E, dim2, r] = [E, dim1, r]
|
||||
d_B_3d = scaling * torch.bmm(grad_W_eff, A_3d.transpose(1, 2))
|
||||
|
||||
# Reshape back to PEFT rank-major layout
|
||||
d_lora_A = d_A_3d.reshape(E * r, dim2)
|
||||
d_lora_B = d_B_3d.permute(1, 2, 0).contiguous().reshape(dim1, E * r)
|
||||
|
||||
return None, d_lora_A, d_lora_B, None
|
||||
|
||||
|
||||
def materialize_expert_lora(
|
||||
base_weight: torch.Tensor,
|
||||
lora_params: Optional[tuple],
|
||||
) -> torch.Tensor:
|
||||
"""Materialize effective expert weight with optional LoRA delta.
|
||||
|
||||
Args:
|
||||
base_weight: [E, dim1, dim2] frozen expert parameter
|
||||
lora_params: (lora_A, lora_B, scaling) or None
|
||||
|
||||
Returns:
|
||||
W_eff if lora_params is not None, else base_weight unchanged.
|
||||
"""
|
||||
if lora_params is None:
|
||||
return base_weight
|
||||
lora_A, lora_B, scaling = lora_params
|
||||
return MoELoRAMaterialize.apply(base_weight, lora_A, lora_B, scaling)
|
||||
@@ -28,20 +28,72 @@ import torch.nn.functional as F
|
||||
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .lora import (
|
||||
has_lora,
|
||||
materialize_expert_lora,
|
||||
unwrap_experts_lora,
|
||||
unwrap_gate_lora,
|
||||
)
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def patch_sonicmoe(model_type: str, torch_compile: bool = False):
|
||||
"""Main entry point: patch SparseMoeBlock for SonicMoE support.
|
||||
def _get_expert_weights(experts_module):
|
||||
"""Extract expert weights, applying LoRA materialization if PEFT is active.
|
||||
|
||||
Args:
|
||||
model_type: The HuggingFace model type (e.g. "qwen3_moe").
|
||||
torch_compile: If True, wrap routing functions with torch.compile
|
||||
for kernel fusion (fuses softmax+topk+renorm into fewer launches).
|
||||
Returns:
|
||||
(gate_up_weight, down_weight) in SonicMoE layout [dim, dim, E].
|
||||
"""
|
||||
if has_lora(experts_module):
|
||||
base_experts, lora_dict = unwrap_experts_lora(experts_module)
|
||||
gate_up = materialize_expert_lora(
|
||||
base_experts.gate_up_proj, lora_dict.get("gate_up_proj")
|
||||
)
|
||||
down = materialize_expert_lora(
|
||||
base_experts.down_proj, lora_dict.get("down_proj")
|
||||
)
|
||||
else:
|
||||
gate_up = experts_module.gate_up_proj
|
||||
down = experts_module.down_proj
|
||||
|
||||
# Permute to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
return gate_up.permute(1, 2, 0), down.permute(1, 2, 0)
|
||||
|
||||
|
||||
def _fix_qwen3_5_moe_text_weight_renaming(model_type: str, base_model_type: str):
|
||||
"""Strip qwen3_5_moe_text WeightRenaming in VLM mode to preserve custom loaders."""
|
||||
if model_type != "qwen3_5_moe_text" or base_model_type == "qwen3_5_moe_text":
|
||||
return
|
||||
|
||||
try:
|
||||
from transformers.conversion_mapping import (
|
||||
get_checkpoint_conversion_mapping,
|
||||
register_checkpoint_conversion_mapping,
|
||||
)
|
||||
from transformers.core_model_loading import WeightRenaming
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
text_mapping = get_checkpoint_conversion_mapping(model_type)
|
||||
if text_mapping and isinstance(text_mapping[0], WeightRenaming):
|
||||
text_mapping.pop(0)
|
||||
register_checkpoint_conversion_mapping(model_type, text_mapping, overwrite=True)
|
||||
LOG.info("Stripped qwen3_5_moe_text WeightRenaming for VLM mode")
|
||||
|
||||
|
||||
def patch_sonicmoe(
|
||||
model_type: str,
|
||||
torch_compile: bool = False,
|
||||
base_model_type: str | None = None,
|
||||
):
|
||||
"""Patch SparseMoeBlock for SonicMoE support."""
|
||||
from .routing import get_model_moe_config
|
||||
from .weight_converter import register_sonicmoe_weight_converter
|
||||
|
||||
_fix_qwen3_5_moe_text_weight_renaming(model_type, base_model_type or model_type)
|
||||
|
||||
routing_fn, activation, router_attr = get_model_moe_config(model_type)
|
||||
|
||||
if torch_compile and routing_fn is not None:
|
||||
@@ -113,11 +165,10 @@ def _make_general_forward(moe_cls, routing_fn, activation):
|
||||
hidden_states_flat, self
|
||||
)
|
||||
|
||||
# Permute weights to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
|
||||
down_weight = self.experts.down_proj.permute(1, 2, 0)
|
||||
# Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout
|
||||
gate_up_weight, down_weight = _get_expert_weights(self.experts)
|
||||
gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype)
|
||||
down_weight = down_weight.to(hidden_states_flat.dtype)
|
||||
E = gate_up_weight.shape[-1]
|
||||
|
||||
output, _ = moe_general_routing_inputs(
|
||||
@@ -161,22 +212,30 @@ def _make_fused_forward(moe_cls, activation, router_attr):
|
||||
# Shared expert (computed early, matching original model ordering)
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
router = getattr(self, router_attr)
|
||||
# Unwrap router for attribute access + optional LoRA delta
|
||||
raw_router = getattr(self, router_attr)
|
||||
base_router, router_weight, router_lora_delta = unwrap_gate_lora(raw_router)
|
||||
if router_lora_delta is not None:
|
||||
# Materialize local tensor to avoid DTensor + Tensor add under FSDP
|
||||
if hasattr(router_weight, "to_local"):
|
||||
router_weight = router_weight.to_local()
|
||||
effective_router_weight = router_weight + router_lora_delta
|
||||
else:
|
||||
effective_router_weight = router_weight
|
||||
|
||||
# Permute weights to SonicMoE layout:
|
||||
# gate_up: [E, 2*I, H] -> [2*I, H, E]
|
||||
# down: [E, H, I] -> [H, I, E]
|
||||
gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0)
|
||||
down_weight = self.experts.down_proj.permute(1, 2, 0)
|
||||
# Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout
|
||||
gate_up_weight, down_weight = _get_expert_weights(self.experts)
|
||||
gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype)
|
||||
down_weight = down_weight.to(hidden_states_flat.dtype)
|
||||
|
||||
output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer(
|
||||
hidden_states_flat,
|
||||
router.weight,
|
||||
effective_router_weight,
|
||||
gate_up_weight,
|
||||
None, # b1 (no gate/up bias)
|
||||
down_weight,
|
||||
None, # b2 (no down bias)
|
||||
router.top_k,
|
||||
base_router.top_k,
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
activation,
|
||||
False, # is_inference_mode
|
||||
@@ -16,6 +16,8 @@ When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .lora import unwrap_gate_lora
|
||||
|
||||
|
||||
def get_model_moe_config(model_type: str):
|
||||
"""Returns (routing_fn, activation, router_attr) for a given model type.
|
||||
@@ -40,6 +42,7 @@ def get_model_moe_config(model_type: str):
|
||||
"qwen2_moe",
|
||||
"qwen3_moe",
|
||||
"qwen3_5_moe",
|
||||
"qwen3_5_moe_text",
|
||||
"qwen3_next",
|
||||
"qwen3_vl_moe",
|
||||
"qwen3_omni_moe",
|
||||
@@ -88,12 +91,18 @@ def softmax_topk_routing(
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, _ = hidden_states.shape
|
||||
K = gate.top_k
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, H = hidden_states.shape
|
||||
K = base_gate.top_k
|
||||
|
||||
# Compute router logits and softmax over all experts
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
# Compute router logits and softmax over all experts.
|
||||
# Two F.linear calls avoid mixing DTensor (gate_weight) + Tensor (delta) under FSDP.
|
||||
# Cast to float32 to match LoRA delta dtype (PEFT computes in fp32).
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Select top-k experts per token
|
||||
@@ -101,7 +110,7 @@ def softmax_topk_routing(
|
||||
|
||||
# Renormalize if configured (default True for models without the attribute,
|
||||
# e.g. Mixtral/MiniMax which always normalize)
|
||||
if getattr(gate, "norm_topk_prob", True):
|
||||
if getattr(base_gate, "norm_topk_prob", True):
|
||||
top_values = top_values / top_values.sum(dim=-1, keepdim=True)
|
||||
|
||||
# no-op: matches transformers which casts to softmax output dtype (float32).
|
||||
@@ -128,13 +137,17 @@ def softmax_group_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
|
||||
gate = moe_block.gate
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, _ = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||
E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
|
||||
router_logits = F.linear(hidden_states, gate.weight) # [T, E]
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
scores_for_choice = router_probs
|
||||
@@ -206,25 +219,29 @@ def sigmoid_topk_routing(
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, _ = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||
E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
|
||||
# Compute router logits and sigmoid probabilities
|
||||
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = router_logits.sigmoid() # [T, E]
|
||||
|
||||
# Bias-corrected scores for expert selection (not used for final weights).
|
||||
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block.
|
||||
e_score_correction_bias = getattr(gate, "e_score_correction_bias", None)
|
||||
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
|
||||
if e_score_correction_bias is None:
|
||||
raise AttributeError(
|
||||
f"sigmoid_topk_routing requires e_score_correction_bias on "
|
||||
f"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it"
|
||||
f"gate ({type(base_gate)}) or moe_block ({type(moe_block)}), but neither has it"
|
||||
)
|
||||
scores_for_choice = router_probs + e_score_correction_bias
|
||||
|
||||
@@ -296,16 +313,20 @@ def softmax_bias_topk_routing(
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, _ = hidden_states.shape
|
||||
K = gate.top_k
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, H = hidden_states.shape
|
||||
K = base_gate.top_k
|
||||
|
||||
# Compute router logits and softmax (force float32 for numerical stability)
|
||||
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Bias-corrected scores for expert selection (via moe_statics module)
|
||||
scores_for_choice = gate.moe_statics(router_probs) # [T, E]
|
||||
scores_for_choice = base_gate.moe_statics(router_probs) # [T, E]
|
||||
|
||||
# Select top-k experts using biased scores
|
||||
_, selected_experts = torch.topk(scores_for_choice, K, dim=-1) # [T, K]
|
||||
@@ -314,7 +335,7 @@ def softmax_bias_topk_routing(
|
||||
top_values = torch.gather(router_probs, dim=-1, index=selected_experts) # [T, K]
|
||||
|
||||
# Renormalize with clamp(min=norm_min) instead of sum+epsilon
|
||||
norm_min = getattr(gate, "norm_min", 1e-20)
|
||||
norm_min = getattr(base_gate, "norm_min", 1e-20)
|
||||
top_values = top_values / torch.clamp(
|
||||
top_values.sum(dim=-1, keepdim=True), min=norm_min
|
||||
)
|
||||
@@ -358,15 +379,19 @@ def softmax_group_limited_topk_routing(
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, _ = hidden_states.shape
|
||||
base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate)
|
||||
T, H = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
num_group = getattr(moe_block, "num_group", 1)
|
||||
num_experts = gate.weight.shape[0]
|
||||
num_experts = gate_weight.shape[0]
|
||||
topk_method = getattr(moe_block, "topk_method", "greedy")
|
||||
|
||||
# Compute logits in float32 and softmax
|
||||
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
||||
router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E]
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), gate_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
if topk_method == "greedy" or num_group == 1:
|
||||
@@ -445,12 +470,17 @@ def softmax_topk_wg_routing(
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, _ = hidden_states.shape
|
||||
T, H = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
|
||||
# Gate computes logits via gate.wg (nn.Linear, float32)
|
||||
wg = gate.wg
|
||||
router_logits = F.linear(hidden_states.float(), wg.weight.float()) # [T, E]
|
||||
# Unwrap at gate.wg level since PEFT targets the wg Linear, not the gate container
|
||||
base_wg, wg_weight, wg_lora_delta = unwrap_gate_lora(gate.wg)
|
||||
router_logits = F.linear(hidden_states.float(), wg_weight.float()) # [T, E]
|
||||
if wg_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states.float(), wg_lora_delta.float()
|
||||
)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Select top-k experts
|
||||
@@ -129,15 +129,41 @@ class InterleavedToConcatenated(ConversionOps):
|
||||
return ConcatenatedToInterleaved(self.dim)
|
||||
|
||||
|
||||
def _make_same_key_interleave_converter():
|
||||
"""Create a WeightConverter that interleaves an already-fused gate_up_proj."""
|
||||
from transformers.core_model_loading import WeightConverter
|
||||
|
||||
return WeightConverter(
|
||||
source_patterns="mlp.experts.gate_up_proj",
|
||||
target_patterns="mlp.experts.gate_up_proj",
|
||||
operations=[ConcatenatedToInterleaved(dim=1)],
|
||||
)
|
||||
|
||||
|
||||
def _has_same_key_interleave(mapping) -> bool:
|
||||
"""Check whether the mapping already has a same-key gate_up_proj interleave converter."""
|
||||
for conv in mapping:
|
||||
if (
|
||||
hasattr(conv, "source_patterns")
|
||||
and conv.source_patterns == ["mlp.experts.gate_up_proj"]
|
||||
and conv.target_patterns == ["mlp.experts.gate_up_proj"]
|
||||
and hasattr(conv, "operations")
|
||||
and any(isinstance(op, ConcatenatedToInterleaved) for op in conv.operations)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def register_sonicmoe_weight_converter(model_type: str):
|
||||
"""Override the conversion mapping to add interleave step for gate_up_proj.
|
||||
"""Register weight converters to interleave gate_up_proj for SonicMoE.
|
||||
|
||||
Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj
|
||||
converter chain. For example, qwen3_moe's chain becomes:
|
||||
MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1)
|
||||
Handles two checkpoint formats:
|
||||
1. Separate per-expert weights (e.g. qwen3_moe): appends interleave to the
|
||||
existing merge chain (MergeModulelist -> Concatenate -> Interleave).
|
||||
2. Already-fused gate_up_proj (e.g. qwen3_5_moe_text): adds a same-key
|
||||
converter (gate_up_proj -> gate_up_proj with Interleave).
|
||||
|
||||
The reverse is auto-generated for saving:
|
||||
InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0)
|
||||
The loader matches whichever source pattern exists in the checkpoint.
|
||||
"""
|
||||
from transformers.conversion_mapping import (
|
||||
get_checkpoint_conversion_mapping,
|
||||
@@ -145,37 +171,32 @@ def register_sonicmoe_weight_converter(model_type: str):
|
||||
)
|
||||
|
||||
existing = get_checkpoint_conversion_mapping(model_type)
|
||||
|
||||
if existing is None:
|
||||
LOG.warning(
|
||||
f"No conversion mapping found for model type '{model_type}'. "
|
||||
"SonicMoE weight interleaving will not be applied during checkpoint loading."
|
||||
)
|
||||
# No mapping at all — create one with just the same-key converter
|
||||
mapping = [_make_same_key_interleave_converter()]
|
||||
register_checkpoint_conversion_mapping(model_type, mapping)
|
||||
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")
|
||||
return
|
||||
|
||||
# Find the gate_up_proj converter and append ConcatenatedToInterleaved
|
||||
patched = False
|
||||
# Append interleave to any existing many-to-one merge chain
|
||||
for converter in existing:
|
||||
if hasattr(converter, "operations") and any(
|
||||
"gate_up_proj" in pat for pat in converter.target_patterns
|
||||
):
|
||||
# Guard against double registration (e.g. plugin reloaded)
|
||||
if any(
|
||||
has_separate_sources = any(
|
||||
"gate_proj" in pat or "up_proj" in pat
|
||||
for pat in converter.source_patterns
|
||||
)
|
||||
if has_separate_sources and not any(
|
||||
isinstance(op, ConcatenatedToInterleaved) for op in converter.operations
|
||||
):
|
||||
LOG.info(
|
||||
f"SonicMoE weight converter already registered for '{model_type}'"
|
||||
)
|
||||
return
|
||||
converter.operations.append(ConcatenatedToInterleaved(dim=1))
|
||||
patched = True
|
||||
converter.operations.append(ConcatenatedToInterleaved(dim=1))
|
||||
break
|
||||
|
||||
if not patched:
|
||||
LOG.warning(
|
||||
f"Could not find gate_up_proj converter for model type '{model_type}'. "
|
||||
"SonicMoE weight interleaving will not be applied during checkpoint loading."
|
||||
)
|
||||
return
|
||||
# Also add a same-key converter for already-fused checkpoints
|
||||
if not _has_same_key_interleave(existing):
|
||||
existing.append(_make_same_key_interleave_converter())
|
||||
|
||||
register_checkpoint_conversion_mapping(model_type, existing, overwrite=True)
|
||||
LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'")
|
||||
@@ -84,12 +84,13 @@ class KernelsPlugin(BasePlugin):
|
||||
|
||||
_check_sonicmoe_gpu_compat()
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe
|
||||
|
||||
LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}")
|
||||
patch_sonicmoe(
|
||||
moe_model_type,
|
||||
torch_compile=bool(getattr(cfg, "torch_compile", False)),
|
||||
base_model_type=cfg.model_config_type,
|
||||
)
|
||||
|
||||
def _register_kernels(self):
|
||||
|
||||
@@ -51,7 +51,7 @@ def _create_tiny_qwen3_config():
|
||||
|
||||
def _interleave_gate_up_weights(model):
|
||||
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
|
||||
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
|
||||
interleave_gate_up,
|
||||
)
|
||||
|
||||
@@ -80,7 +80,7 @@ class TestSonicMoEForwardCorrectness:
|
||||
def test_forward_output_matches(self):
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
|
||||
@@ -117,8 +117,8 @@ class TestSonicMoEGradientCorrectness:
|
||||
"""Verify all parameter gradients match between original and patched."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
|
||||
deinterleave_gate_up,
|
||||
)
|
||||
|
||||
@@ -191,7 +191,7 @@ class TestSonicMoEGradientCorrectness:
|
||||
"""Verify that router (gate) weights get non-zero gradients."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
|
||||
@@ -223,7 +223,7 @@ class TestSonicMoETrainingConvergence:
|
||||
"""Run 30 training steps, verify loss decreases and no NaN/Inf."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
@@ -254,7 +254,7 @@ class TestSonicMoETrainingConvergence:
|
||||
"""Verify expert weights change during training (not frozen)."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
318
tests/e2e/integrations/test_sonicmoe_lora.py
Normal file
318
tests/e2e/integrations/test_sonicmoe_lora.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
End-to-end tests for SonicMoE + LoRA integration.
|
||||
|
||||
Verifies that PEFT-wrapped MoE models work correctly with SonicMoE's
|
||||
runtime LoRA materialization: gradients flow to adapters, base weights
|
||||
stay frozen, and loss converges.
|
||||
|
||||
Requires:
|
||||
- H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90)
|
||||
- sonicmoe package installed
|
||||
- peft package installed
|
||||
- transformers with Qwen3MoE support
|
||||
|
||||
Usage:
|
||||
pytest tests/e2e/integrations/test_sonicmoe_lora.py -v -s
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None
|
||||
_peft_available = importlib.util.find_spec("peft") is not None
|
||||
_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"),
|
||||
pytest.mark.skipif(
|
||||
not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)"
|
||||
),
|
||||
pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"),
|
||||
pytest.mark.skipif(not _peft_available, reason="PEFT not installed"),
|
||||
]
|
||||
|
||||
|
||||
def _create_tiny_qwen3_config():
|
||||
"""Create a minimal Qwen3MoE config for fast testing."""
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.for_model("qwen3_moe")
|
||||
config.hidden_size = 512
|
||||
config.intermediate_size = 1024
|
||||
config.moe_intermediate_size = 64
|
||||
config.num_attention_heads = 16
|
||||
config.num_key_value_heads = 2
|
||||
config.head_dim = 32
|
||||
config.num_hidden_layers = 2
|
||||
config.num_experts = 8
|
||||
config.num_experts_per_tok = 2
|
||||
config.vocab_size = 1000
|
||||
config.max_position_embeddings = 128
|
||||
config.norm_topk_prob = True
|
||||
config.torch_dtype = torch.bfloat16
|
||||
return config
|
||||
|
||||
|
||||
def _interleave_gate_up_weights(model):
|
||||
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
|
||||
interleave_gate_up,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
for name, param in model.named_parameters():
|
||||
if "gate_up_proj" in name:
|
||||
param.copy_(interleave_gate_up(param))
|
||||
|
||||
|
||||
def _unpatch_sonicmoe():
|
||||
"""Restore original forward on the MoE block class if it was patched."""
|
||||
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
|
||||
|
||||
for moe_cls in resolve_moe_block_classes("qwen3_moe"):
|
||||
if hasattr(moe_cls, "_original_forward"):
|
||||
moe_cls.forward = moe_cls._original_forward
|
||||
del moe_cls._original_forward
|
||||
|
||||
|
||||
def _apply_lora(model, target_modules):
|
||||
"""Apply PEFT LoRA to the model."""
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=target_modules,
|
||||
lora_dropout=0.0,
|
||||
bias="none",
|
||||
)
|
||||
return get_peft_model(model, lora_config)
|
||||
|
||||
|
||||
class TestSonicMoELoRATraining:
|
||||
"""Verify SonicMoE + LoRA training works end-to-end."""
|
||||
|
||||
def teardown_method(self):
|
||||
_unpatch_sonicmoe()
|
||||
|
||||
def test_loss_decreases(self):
|
||||
"""Run 30 training steps with LoRA on experts, verify loss decreases."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
model = _apply_lora(model, ["gate_up_proj", "down_proj"])
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
[p for p in model.parameters() if p.requires_grad], lr=1e-3
|
||||
)
|
||||
losses = []
|
||||
|
||||
for step in range(30):
|
||||
out = model(input_ids, labels=input_ids)
|
||||
loss = out.loss
|
||||
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
|
||||
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
|
||||
losses.append(loss.item())
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
assert losses[-1] < losses[0], (
|
||||
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
|
||||
)
|
||||
|
||||
def test_base_weights_frozen(self):
|
||||
"""Verify base (non-LoRA) weights don't change during training."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
model = _apply_lora(model, ["gate_up_proj", "down_proj"])
|
||||
|
||||
# Snapshot frozen weights
|
||||
frozen_before = {}
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
frozen_before[name] = param.data.clone()
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
[p for p in model.parameters() if p.requires_grad], lr=1e-3
|
||||
)
|
||||
for _ in range(5):
|
||||
out = model(input_ids, labels=input_ids)
|
||||
out.loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if name in frozen_before:
|
||||
assert torch.equal(param.data, frozen_before[name]), (
|
||||
f"Frozen weight changed: {name}"
|
||||
)
|
||||
|
||||
def test_lora_adapters_receive_gradients(self):
|
||||
"""Verify LoRA A and B matrices get non-zero gradients."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
model = _apply_lora(model, ["gate_up_proj", "down_proj"])
|
||||
|
||||
out = model(input_ids, labels=input_ids)
|
||||
out.loss.backward()
|
||||
|
||||
lora_grads_found = 0
|
||||
for name, param in model.named_parameters():
|
||||
if "lora_" in name and param.requires_grad:
|
||||
assert param.grad is not None, f"No gradient for LoRA param: {name}"
|
||||
assert param.grad.abs().max() > 0, (
|
||||
f"Zero gradient for LoRA param: {name}"
|
||||
)
|
||||
lora_grads_found += 1
|
||||
|
||||
assert lora_grads_found > 0, "No LoRA parameters found with gradients"
|
||||
|
||||
def test_lora_adapters_update(self):
|
||||
"""Verify LoRA adapter weights change during training."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
model = _apply_lora(model, ["gate_up_proj", "down_proj"])
|
||||
|
||||
# Snapshot LoRA weights
|
||||
lora_before = {}
|
||||
for name, param in model.named_parameters():
|
||||
if "lora_" in name and param.requires_grad:
|
||||
lora_before[name] = param.data.clone()
|
||||
|
||||
assert lora_before, "No LoRA parameters found"
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
[p for p in model.parameters() if p.requires_grad], lr=1e-3
|
||||
)
|
||||
for _ in range(5):
|
||||
out = model(input_ids, labels=input_ids)
|
||||
out.loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
changed = sum(
|
||||
1
|
||||
for name, param in model.named_parameters()
|
||||
if name in lora_before and not torch.equal(param.data, lora_before[name])
|
||||
)
|
||||
assert changed > 0, "No LoRA weights changed after 5 training steps"
|
||||
|
||||
|
||||
class TestSonicMoEGateOnlyLoRA:
|
||||
"""Verify LoRA targeting only the gate (router) works with SonicMoE."""
|
||||
|
||||
def teardown_method(self):
|
||||
_unpatch_sonicmoe()
|
||||
|
||||
def test_gate_only_lora_loss_decreases(self):
|
||||
"""LoRA only on gate — expert path should have zero materialization overhead."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
# Only target the gate (router), not expert projections
|
||||
model = _apply_lora(model, ["gate"])
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
[p for p in model.parameters() if p.requires_grad], lr=1e-3
|
||||
)
|
||||
losses = []
|
||||
|
||||
for step in range(20):
|
||||
out = model(input_ids, labels=input_ids)
|
||||
loss = out.loss
|
||||
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
|
||||
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
|
||||
losses.append(loss.item())
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
assert losses[-1] < losses[0], (
|
||||
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
|
||||
)
|
||||
|
||||
|
||||
class TestSonicMoENoLoRARegression:
|
||||
"""Verify SonicMoE without LoRA still works after LoRA code was added."""
|
||||
|
||||
def teardown_method(self):
|
||||
_unpatch_sonicmoe()
|
||||
|
||||
def test_no_lora_loss_decreases(self):
|
||||
"""Full fine-tuning (no PEFT) with SonicMoE — regression test."""
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
|
||||
|
||||
config = _create_tiny_qwen3_config()
|
||||
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
|
||||
|
||||
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
|
||||
patch_sonicmoe("qwen3_moe")
|
||||
_interleave_gate_up_weights(model)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
||||
losses = []
|
||||
|
||||
for step in range(20):
|
||||
out = model(input_ids, labels=input_ids)
|
||||
loss = out.loss
|
||||
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
|
||||
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
|
||||
losses.append(loss.item())
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
assert losses[-1] < losses[0], (
|
||||
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
|
||||
)
|
||||
@@ -93,7 +93,9 @@ class TestSoftmaxRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
|
||||
@@ -120,7 +122,9 @@ class TestSoftmaxRoutingParity:
|
||||
|
||||
def test_logits_not_returned_by_scattermoe(self):
|
||||
"""ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
_, _, _, logits = softmax_topk_routing(hidden, moe_block)
|
||||
@@ -131,7 +135,9 @@ class TestSoftmaxRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
gate.norm_topk_prob = False
|
||||
@@ -152,7 +158,9 @@ class TestSoftmaxRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_routing,
|
||||
)
|
||||
|
||||
for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]:
|
||||
moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K)
|
||||
@@ -190,7 +198,9 @@ class TestSigmoidRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
||||
@@ -226,7 +236,9 @@ class TestSigmoidRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True
|
||||
@@ -254,7 +266,9 @@ class TestSigmoidRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=1, bias_on_gate=False
|
||||
@@ -281,7 +295,9 @@ class TestSigmoidRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
@@ -309,7 +325,9 @@ class TestSigmoidRoutingParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
@@ -349,7 +367,7 @@ class TestSharedExpertParity:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert as scatter_compute,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.patch import (
|
||||
_compute_shared_expert as sonic_compute,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.args import KernelsArgs
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
softmax_topk_routing,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
|
||||
ConcatenatedToInterleaved,
|
||||
InterleavedToConcatenated,
|
||||
register_sonicmoe_weight_converter,
|
||||
@@ -212,9 +212,40 @@ class TestWeightConverterRegistration:
|
||||
)
|
||||
break
|
||||
|
||||
def test_register_unsupported_model_type_warns(self):
|
||||
# A model type with no conversion mapping should warn but not raise
|
||||
register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
|
||||
def test_register_adds_same_key_converter(self):
|
||||
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
|
||||
|
||||
register_sonicmoe_weight_converter("qwen3_moe")
|
||||
|
||||
modified = get_checkpoint_conversion_mapping("qwen3_moe")
|
||||
# Should have a same-key converter for already-fused checkpoints
|
||||
same_key = [
|
||||
c
|
||||
for c in modified
|
||||
if hasattr(c, "source_patterns")
|
||||
and c.source_patterns == ["mlp.experts.gate_up_proj"]
|
||||
and c.target_patterns == ["mlp.experts.gate_up_proj"]
|
||||
]
|
||||
assert len(same_key) == 1
|
||||
assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved)
|
||||
|
||||
def test_register_creates_mapping_when_none(self):
|
||||
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
|
||||
|
||||
# qwen3_5_moe has no conversion mapping in transformers
|
||||
register_sonicmoe_weight_converter("qwen3_5_moe")
|
||||
|
||||
mapping = get_checkpoint_conversion_mapping("qwen3_5_moe")
|
||||
assert mapping is not None
|
||||
same_key = [
|
||||
c
|
||||
for c in mapping
|
||||
if hasattr(c, "source_patterns")
|
||||
and c.source_patterns == ["mlp.experts.gate_up_proj"]
|
||||
and c.target_patterns == ["mlp.experts.gate_up_proj"]
|
||||
]
|
||||
assert len(same_key) == 1
|
||||
assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved)
|
||||
|
||||
|
||||
def _make_qwen_moe_block(T=8, H=16, E=4, K=2):
|
||||
@@ -462,7 +493,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
"""Tests for Ernie 4.5 MoE routing (softmax_bias_topk_routing)."""
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -479,7 +510,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -490,7 +521,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -502,7 +533,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -514,7 +545,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_renormalized_scores_sum_to_one(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -527,7 +558,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
|
||||
def test_bias_affects_expert_selection(self):
|
||||
"""Large positive bias on expert 0 should make it always selected."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -570,7 +601,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
"""Tests for DeepSeek V2 routing (softmax_group_limited_topk_routing)."""
|
||||
|
||||
def test_output_shapes_group_limited(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -589,7 +620,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_output_shapes_greedy(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -604,7 +635,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -615,7 +646,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -627,7 +658,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -639,7 +670,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_scaling_factor_applied(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -655,7 +686,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
|
||||
def test_group_selection_restricts_experts(self):
|
||||
"""With num_group=4 and topk_group=1, experts should come from selected groups."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -674,7 +705,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert (groups == groups[0]).all()
|
||||
|
||||
def test_unsupported_topk_method_raises(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -706,7 +737,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
"""Tests for HunYuan V1 MoE routing (softmax_topk_wg_routing)."""
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -723,7 +754,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -734,7 +765,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -746,7 +777,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -759,7 +790,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
|
||||
def test_renormalized_scores_sum_to_one(self):
|
||||
"""HunYuan V1 always renormalizes (no norm_topk_prob flag)."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -772,7 +803,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
|
||||
def test_uses_gate_wg_weight(self):
|
||||
"""Verify that modifying gate.wg.weight changes the routing output."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ code path where routing happens in float32.
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
softmax_topk_routing,
|
||||
)
|
||||
|
||||
328
tests/integrations/test_sonicmoe_lora.py
Normal file
328
tests/integrations/test_sonicmoe_lora.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""Unit tests for SonicMoE LoRA support."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.lora import (
|
||||
MoELoRAMaterialize,
|
||||
get_lora_params_from_wrapper,
|
||||
has_lora,
|
||||
materialize_expert_lora,
|
||||
unwrap_experts_lora,
|
||||
unwrap_gate_lora,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Helpers: mock PEFT modules
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _make_mock_lora_module(weight_A, weight_B, scaling_val, param_name=None):
|
||||
"""Create a mock PEFT-wrapped module with LoRA attributes."""
|
||||
mock = MagicMock()
|
||||
|
||||
lora_A_linear = MagicMock()
|
||||
lora_A_linear.weight = weight_A
|
||||
|
||||
lora_B_linear = MagicMock()
|
||||
lora_B_linear.weight = weight_B
|
||||
|
||||
mock.lora_A = {"default": lora_A_linear}
|
||||
mock.lora_B = {"default": lora_B_linear}
|
||||
mock.scaling = {"default": scaling_val}
|
||||
mock.active_adapters = ["default"]
|
||||
|
||||
if param_name is not None:
|
||||
mock.parameter_name = param_name
|
||||
|
||||
return mock
|
||||
|
||||
|
||||
def _make_peft_gate(hidden_size, num_experts, rank, scaling=0.5):
|
||||
"""Create a mock PEFT-wrapped gate module."""
|
||||
base_gate = MagicMock()
|
||||
base_gate.weight = torch.randn(num_experts, hidden_size)
|
||||
base_gate.top_k = 2
|
||||
base_gate.norm_topk_prob = True
|
||||
|
||||
lora_A = torch.randn(rank, hidden_size)
|
||||
lora_B = torch.randn(num_experts, rank)
|
||||
|
||||
wrapper = _make_mock_lora_module(lora_A, lora_B, scaling)
|
||||
wrapper.base_layer = base_gate
|
||||
return wrapper, base_gate
|
||||
|
||||
|
||||
def _make_peft_experts(
|
||||
num_experts, gate_up_dim, down_dim, hidden_size, rank, scaling=0.5
|
||||
):
|
||||
"""Create a mock PEFT-wrapped experts chain.
|
||||
|
||||
Simulates: ParamWrapper(down_proj) -> ParamWrapper(gate_up_proj) -> Experts
|
||||
"""
|
||||
base_experts = MagicMock()
|
||||
base_experts.gate_up_proj = torch.randn(num_experts, gate_up_dim, hidden_size)
|
||||
base_experts.down_proj = torch.randn(num_experts, hidden_size, down_dim)
|
||||
# Remove base_layer and lora_A from base_experts so the chain walk stops
|
||||
del base_experts.base_layer
|
||||
del base_experts.lora_A
|
||||
|
||||
# gate_up_proj wrapper
|
||||
gup_A = torch.randn(rank * num_experts, hidden_size)
|
||||
gup_B = torch.randn(gate_up_dim, rank * num_experts)
|
||||
gup_wrapper = _make_mock_lora_module(gup_A, gup_B, scaling, "gate_up_proj")
|
||||
gup_wrapper.base_layer = base_experts
|
||||
|
||||
# down_proj wrapper (outermost)
|
||||
down_A = torch.randn(rank * num_experts, down_dim)
|
||||
down_B = torch.randn(hidden_size, rank * num_experts)
|
||||
down_wrapper = _make_mock_lora_module(down_A, down_B, scaling, "down_proj")
|
||||
down_wrapper.base_layer = gup_wrapper
|
||||
|
||||
return down_wrapper, base_experts, (gup_A, gup_B), (down_A, down_B)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: has_lora
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestHasLora:
|
||||
def test_plain_module(self):
|
||||
m = MagicMock(spec=["weight"])
|
||||
del m.base_layer
|
||||
del m.lora_A
|
||||
assert not has_lora(m)
|
||||
|
||||
def test_wrapped_module(self):
|
||||
m = MagicMock()
|
||||
m.base_layer = MagicMock()
|
||||
m.lora_A = {"default": MagicMock()}
|
||||
assert has_lora(m)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: get_lora_params_from_wrapper
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetLoraParams:
|
||||
def test_no_lora_attrs(self):
|
||||
m = MagicMock(spec=["weight"])
|
||||
del m.lora_A
|
||||
del m.lora_B
|
||||
assert get_lora_params_from_wrapper(m) == (None, None, None)
|
||||
|
||||
def test_extracts_params(self):
|
||||
A = torch.randn(4, 8)
|
||||
B = torch.randn(16, 4)
|
||||
wrapper = _make_mock_lora_module(A, B, 0.5)
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(wrapper)
|
||||
assert torch.equal(lora_A, A)
|
||||
assert torch.equal(lora_B, B)
|
||||
assert scaling == 0.5
|
||||
|
||||
def test_no_active_adapters(self):
|
||||
wrapper = _make_mock_lora_module(torch.randn(4, 8), torch.randn(16, 4), 0.5)
|
||||
wrapper.active_adapters = []
|
||||
assert get_lora_params_from_wrapper(wrapper) == (None, None, None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: unwrap_gate_lora
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestUnwrapGateLora:
|
||||
def test_plain_gate(self):
|
||||
gate = MagicMock(spec=["weight", "top_k"])
|
||||
del gate.base_layer
|
||||
del gate.lora_A
|
||||
gate.weight = torch.randn(8, 64)
|
||||
base, weight, delta = unwrap_gate_lora(gate)
|
||||
assert base is gate
|
||||
assert torch.equal(weight, gate.weight)
|
||||
assert delta is None
|
||||
|
||||
def test_wrapped_gate(self):
|
||||
wrapper, base_gate = _make_peft_gate(
|
||||
hidden_size=64, num_experts=8, rank=4, scaling=0.5
|
||||
)
|
||||
base, weight, delta = unwrap_gate_lora(wrapper)
|
||||
assert base is base_gate
|
||||
assert torch.equal(weight, base_gate.weight)
|
||||
assert delta is not None
|
||||
assert delta.shape == base_gate.weight.shape
|
||||
|
||||
# Verify delta = scaling * B @ A
|
||||
lora_A = wrapper.lora_A["default"].weight
|
||||
lora_B = wrapper.lora_B["default"].weight
|
||||
expected = 0.5 * (lora_B @ lora_A)
|
||||
assert torch.allclose(delta, expected)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: unwrap_experts_lora
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestUnwrapExpertsLora:
|
||||
def test_plain_experts(self):
|
||||
experts = MagicMock(spec=["gate_up_proj", "down_proj"])
|
||||
del experts.base_layer
|
||||
del experts.lora_A
|
||||
base, lora_dict = unwrap_experts_lora(experts)
|
||||
assert base is experts
|
||||
assert lora_dict == {}
|
||||
|
||||
def test_wrapped_experts(self):
|
||||
E, I2, I, H, r = 4, 256, 128, 64, 8 # noqa: E741
|
||||
wrapper, base_experts, (gup_A, gup_B), (down_A, down_B) = _make_peft_experts(
|
||||
E, I2, I, H, r, scaling=0.25
|
||||
)
|
||||
base, lora_dict = unwrap_experts_lora(wrapper)
|
||||
assert base is base_experts
|
||||
assert "gate_up_proj" in lora_dict
|
||||
assert "down_proj" in lora_dict
|
||||
|
||||
gup_lA, gup_lB, gup_s = lora_dict["gate_up_proj"]
|
||||
assert torch.equal(gup_lA, gup_A)
|
||||
assert torch.equal(gup_lB, gup_B)
|
||||
assert gup_s == 0.25
|
||||
|
||||
down_lA, down_lB, down_s = lora_dict["down_proj"]
|
||||
assert torch.equal(down_lA, down_A)
|
||||
assert torch.equal(down_lB, down_B)
|
||||
assert down_s == 0.25
|
||||
|
||||
def test_partial_lora(self):
|
||||
"""Only gate_up_proj has LoRA, down_proj does not."""
|
||||
base_experts = MagicMock(spec=["gate_up_proj", "down_proj"])
|
||||
del base_experts.base_layer
|
||||
del base_experts.lora_A
|
||||
|
||||
gup_A = torch.randn(16, 64)
|
||||
gup_B = torch.randn(256, 16)
|
||||
gup_wrapper = _make_mock_lora_module(gup_A, gup_B, 0.5, "gate_up_proj")
|
||||
gup_wrapper.base_layer = base_experts
|
||||
|
||||
base, lora_dict = unwrap_experts_lora(gup_wrapper)
|
||||
assert base is base_experts
|
||||
assert "gate_up_proj" in lora_dict
|
||||
assert "down_proj" not in lora_dict
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: MoELoRAMaterialize
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMoELoRAMaterialize:
|
||||
@pytest.fixture()
|
||||
def setup(self):
|
||||
E, dim1, dim2, r = 4, 32, 16, 4
|
||||
scaling = 0.5
|
||||
W = torch.randn(E, dim1, dim2, dtype=torch.float64, requires_grad=False)
|
||||
A = torch.randn(r * E, dim2, dtype=torch.float64, requires_grad=True)
|
||||
B = torch.randn(dim1, r * E, dtype=torch.float64, requires_grad=True)
|
||||
return W, A, B, scaling, E, r
|
||||
|
||||
def test_forward_shape(self, setup):
|
||||
W, A, B, scaling, E, r = setup
|
||||
W_eff = MoELoRAMaterialize.apply(W, A, B, scaling)
|
||||
assert W_eff.shape == W.shape
|
||||
|
||||
def test_forward_correctness(self, setup):
|
||||
W, A, B, scaling, E, r = setup
|
||||
W_eff = MoELoRAMaterialize.apply(W, A, B, scaling)
|
||||
|
||||
# Manual per-expert computation.
|
||||
# lora_A is expert-major: [r*E, dim2] -> rows [e*r:(e+1)*r] = expert e
|
||||
# lora_B is rank-major: [dim1, r*E] -> reshape [dim1, r, E], slice [:, :, e]
|
||||
_, dim1, dim2 = W.shape
|
||||
expected = W.clone()
|
||||
B_3d = B.reshape(dim1, r, E)
|
||||
for e in range(E):
|
||||
A_e = A[e * r : (e + 1) * r, :] # [r, dim2]
|
||||
B_e = B_3d[:, :, e] # [dim1, r]
|
||||
expected[e] += scaling * (B_e @ A_e)
|
||||
|
||||
assert torch.allclose(W_eff, expected, atol=1e-10)
|
||||
|
||||
def test_backward_gradcheck(self, setup):
|
||||
W, A, B, scaling, E, r = setup
|
||||
# gradcheck requires float64
|
||||
assert torch.autograd.gradcheck(
|
||||
lambda a, b: MoELoRAMaterialize.apply(W, a, b, scaling),
|
||||
(A, B),
|
||||
eps=1e-6,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
def test_no_grad_for_base_weight(self, setup):
|
||||
W, A, B, scaling, E, r = setup
|
||||
W.requires_grad_(True)
|
||||
W_eff = MoELoRAMaterialize.apply(W, A, B, scaling)
|
||||
loss = W_eff.sum()
|
||||
loss.backward()
|
||||
assert W.grad is None
|
||||
assert A.grad is not None
|
||||
assert B.grad is not None
|
||||
|
||||
def test_scaling_zero(self, setup):
|
||||
W, A, B, _, E, r = setup
|
||||
W_eff = MoELoRAMaterialize.apply(W, A, B, 0.0)
|
||||
assert torch.allclose(W_eff, W)
|
||||
|
||||
def test_gate_up_proj_shapes(self):
|
||||
"""Test with realistic gate_up_proj shapes [E, 2*I, H]."""
|
||||
E, I2, H, r = 8, 512, 256, 16
|
||||
W = torch.randn(E, I2, H, dtype=torch.float64)
|
||||
A = torch.randn(r * E, H, dtype=torch.float64, requires_grad=True)
|
||||
B = torch.randn(I2, r * E, dtype=torch.float64, requires_grad=True)
|
||||
W_eff = MoELoRAMaterialize.apply(W, A, B, 1.0)
|
||||
assert W_eff.shape == (E, I2, H)
|
||||
loss = W_eff.sum()
|
||||
loss.backward()
|
||||
assert A.grad.shape == A.shape
|
||||
assert B.grad.shape == B.shape
|
||||
|
||||
def test_down_proj_shapes(self):
|
||||
"""Test with realistic down_proj shapes [E, H, I]."""
|
||||
E, H, I, r = 8, 256, 512, 16 # noqa: E741
|
||||
W = torch.randn(E, H, I, dtype=torch.float64)
|
||||
A = torch.randn(r * E, I, dtype=torch.float64, requires_grad=True)
|
||||
B = torch.randn(H, r * E, dtype=torch.float64, requires_grad=True)
|
||||
W_eff = MoELoRAMaterialize.apply(W, A, B, 1.0)
|
||||
assert W_eff.shape == (E, H, I)
|
||||
loss = W_eff.sum()
|
||||
loss.backward()
|
||||
assert A.grad.shape == A.shape
|
||||
assert B.grad.shape == B.shape
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests: materialize_expert_lora
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMaterializeExpertLora:
|
||||
def test_none_passthrough(self):
|
||||
W = torch.randn(4, 32, 16)
|
||||
result = materialize_expert_lora(W, None)
|
||||
assert result is W
|
||||
|
||||
def test_with_lora(self):
|
||||
E, dim1, dim2, r = 4, 32, 16, 4
|
||||
W = torch.randn(E, dim1, dim2)
|
||||
A = torch.randn(r * E, dim2, requires_grad=True)
|
||||
B = torch.randn(dim1, r * E, requires_grad=True)
|
||||
result = materialize_expert_lora(W, (A, B, 0.5))
|
||||
assert result.shape == W.shape
|
||||
assert not torch.equal(result, W)
|
||||
Reference in New Issue
Block a user