From 842fa039ddde343401ff4199ca347bbb0c99419c Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 2 Apr 2026 19:53:48 +0700 Subject: [PATCH] feat: add sonicmoe fused lora support (#3519) * feat: add sonicmoe fused lora support * fix: forgot to add file * feat: add test * feat: add lora support for other routes * fix: add int8 lora support * fix: add qwen35_moe interleave support * fix: qwen3_5_moe loss * chore: lint * address some pr comments * fix test imports * add support matrix for moe kernels [skip ci] --------- Co-authored-by: Wing Lian --- src/axolotl/integrations/kernels/README.md | 85 ++++- src/axolotl/integrations/kernels/args.py | 22 ++ .../libs/scattermoe_lora/parallel_experts.py | 5 + .../scattermoe_lora/parallel_linear_lora.py | 5 + .../kernels/{ => libs}/sonicmoe/__init__.py | 0 .../kernels/libs/sonicmoe/lora.py | 220 ++++++++++++ .../kernels/{ => libs}/sonicmoe/patch.py | 97 +++++- .../kernels/{ => libs}/sonicmoe/routing.py | 84 +++-- .../{ => libs}/sonicmoe/weight_converter.py | 73 ++-- src/axolotl/integrations/kernels/plugin.py | 3 +- tests/e2e/integrations/test_sonicmoe.py | 14 +- tests/e2e/integrations/test_sonicmoe_lora.py | 318 +++++++++++++++++ tests/integrations/test_routing_parity.py | 38 +- tests/integrations/test_sonicmoe.py | 81 +++-- tests/integrations/test_sonicmoe_gradients.py | 2 +- tests/integrations/test_sonicmoe_lora.py | 328 ++++++++++++++++++ 16 files changed, 1249 insertions(+), 126 deletions(-) rename src/axolotl/integrations/kernels/{ => libs}/sonicmoe/__init__.py (100%) create mode 100644 src/axolotl/integrations/kernels/libs/sonicmoe/lora.py rename src/axolotl/integrations/kernels/{ => libs}/sonicmoe/patch.py (68%) rename src/axolotl/integrations/kernels/{ => libs}/sonicmoe/routing.py (86%) rename src/axolotl/integrations/kernels/{ => libs}/sonicmoe/weight_converter.py (69%) create mode 100644 tests/e2e/integrations/test_sonicmoe_lora.py create mode 100644 tests/integrations/test_sonicmoe_lora.py diff --git a/src/axolotl/integrations/kernels/README.md b/src/axolotl/integrations/kernels/README.md index 7c40720af..a852cd6cf 100644 --- a/src/axolotl/integrations/kernels/README.md +++ b/src/axolotl/integrations/kernels/README.md @@ -52,26 +52,91 @@ The `KernelsPlugin` runs before model loading and: ### ScatterMoE 1. Registers the ScatterMoE kernel from the local `libs/scattermoe_lora` package (includes fused LoRA support via Triton kernels). -2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation. +2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation via the HF `kernels` library. ### SonicMoE 1. Resolves the model's MoE block class(es) from `constants.py`. -2. Patches the forward method with SonicMoE's optimized kernels and registers a weight converter for the interleaved gate/up projection format. -3. Supports both softmax->topk and sigmoid->topk routing strategies. +2. Patches the forward method with SonicMoE's optimized CUTLASS kernels and registers a weight converter for the interleaved gate/up projection format. +3. Supports pluggable routing strategies (see routing table below). Both paths use the shared `resolve_moe_block_classes` utility in `constants.py` for model-type-to-class resolution. -#### Supported Models +## Model Support Matrix -See `constants.py` for the full list of supported model types (Qwen2-MoE, Qwen3-MoE, OLMoE, Mixtral, DeepSeek-V3, GLM-MoE, MiniMax, etc.). +All models use the **SwiGLU** activation (`act_fn(gate) * up`). Neither kernel currently supports non-SwiGLU MoE architectures. + +### Routing strategies + +| Routing Strategy | Description | ScatterMoE | SonicMoE | +|---|---|:---:|:---:| +| softmax → topk | Softmax over experts, select top-K, optional renormalization | Yes | Yes | +| softmax → group selection → topk | Softmax, select top groups (sum of top-2 per group), topk from selected groups, renorm + scaling | No | Yes | +| sigmoid → topk (with groups) | Sigmoid + bias correction, group-based masking, topk from masked scores, weights from original sigmoid | Yes | Yes | +| sigmoid → topk (no groups) | Sigmoid + bias correction, straight topk (n_group=1) | Yes | Yes | +| softmax → bias correction → topk | Softmax, bias via `gate.moe_statics`, topk, gather from original probs, clamp-based renorm | No | Yes | +| softmax → group_limited_greedy | Softmax, group selection (max per group), topk, scale only (no renorm) | No | Yes | +| softmax → topk via gate.wg | Softmax, gate weight at `gate.wg.weight` (not `gate.weight`), always renormalize | No | Yes | +| fused topk → softmax | Routing + expert computation fused in a single kernel | No | Planned | + +### Per-model support + +| Model Type | Architecture | Routing | ScatterMoE | SonicMoE | +|---|---|---|:---:|:---:| +| `qwen2_moe` | Qwen2-MoE | softmax → topk | **Yes** | **Yes** | +| `qwen3_moe` | Qwen3-MoE | softmax → topk | **Yes** | **Yes** | +| `qwen3_5_moe` | Qwen3.5-MoE | softmax → topk | **Yes** | **Yes** | +| `qwen3_5_moe_text` | Qwen3.5-MoE (VLM text) | softmax → topk | **Yes** | **Yes** | +| `qwen3_next` | Qwen3-Next | softmax → topk | **Yes** | **Yes** | +| `qwen3_vl_moe` | Qwen3-VL-MoE | softmax → topk | **Yes** | **Yes** | +| `qwen3_omni_moe` | Qwen3-Omni (Thinker + Talker) | softmax → topk | **Yes** | **Yes** | +| `olmoe` | OLMoE | softmax → topk | **Yes** | **Yes** | +| `mixtral` | Mixtral | softmax → topk | **Yes** | **Yes** | +| `minimax` | MiniMax | softmax → topk | **Yes** | **Yes** | +| `mistral4` | Mistral 4 | softmax → group → topk | No | **Yes** | +| `glm_moe_dsa` | GLM-MoE DSA (GLM 5) | sigmoid → topk (groups) | **Yes** | **Yes** | +| `deepseek_v3` | DeepSeek-V3 | sigmoid → topk (groups) | **Yes** | **Yes** | +| `glm4_moe` | GLM4-MoE | sigmoid → topk (groups) | **Yes** | **Yes** | +| `glm4_moe_lite` | GLM4-MoE Lite (GLM 4.7 Flash) | sigmoid → topk (groups) | **Yes**\* | **Yes** | +| `glm4v_moe` | GLM4v-MoE | sigmoid → topk (groups) | **Yes** | **Yes** | +| `minimax_m2` | MiniMax M2 | sigmoid → topk (no groups) | **Yes** | **Yes** | +| `ernie4_5_moe` | ERNIE 4.5 MoE | softmax → bias → topk | No | **Yes** | +| `deepseek_v2` | DeepSeek-V2 | softmax → group_limited_greedy | No | **Yes** | +| `hunyuan_v1_moe` | HunYuan V1 MoE | softmax → topk (gate.wg) | No | **Yes** | +| `gpt_oss` | GPT-OSS | fused topk → softmax | No | Planned | + +\* `glm4_moe_lite` with ScatterMoE may have issues — see Limitations. + +### Feature comparison + +| Feature | ScatterMoE | SonicMoE | +|---|:---:|:---:| +| Kernel backend | Triton | CUTLASS | +| GPU requirement | Any CUDA | Hopper (H100/H200) or Blackwell (B200+) | +| LoRA approach | Fused in Triton kernel | Runtime materialization + custom autograd | +| LoRA overhead | Lower (fused computation) | Higher (per-forward materialization) | +| Gate/router LoRA | Yes | Yes | +| Expert LoRA | Yes (fused) | Yes (materialized) | +| Shared expert LoRA | Yes (standard PEFT) | Yes (standard PEFT) | +| Selective expert dequantization | Yes (~97% memory savings) | No | +| Weight format | Transposed `[E, hidden, 2*inter]` | Interleaved gate/up `[2*I, H, E]` | +| torch.compile routing | No | Yes (optional) | + +## Shared Expert Handling + +Both kernels handle shared experts identically. Shared expert attribute names are detected in order of priority: + +1. `shared_expert` (Qwen2-MoE) +2. `shared_experts` (GLM-MoE, DeepSeek-V3) +3. `shared_mlp` (HunYuan V1 MoE) + +If `shared_expert_gate` exists, sigmoid gating is applied to the shared expert contribution before adding it to the routed output. PEFT wraps shared expert linear layers with standard LoRA — no special handling is needed. ## Limitations -ScatterMoE uses a softmax -> topk routing, so results may be different for some model architectures as baseline (GPT-OSS, etc). Incompatible with `GLM_MOE_DSA` (GLM 5) and `GLM4_MOE_LITE` (GLM 4.7 Flash) at the moment. - -SonicMoE supports both softmax->topk and sigmoid->topk routing, covering a wider range of architectures. - -ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm. +- **ScatterMoE + GLM4-MoE Lite**: ScatterMoE does not work reliably for GLM 4.7 Flash (`glm4_moe_lite`). +- **Non-SwiGLU activations**: Neither kernel supports MoE architectures with non-SwiGLU expert activations (e.g., GPT-OSS uses a custom GLU variant). +- **GPT-OSS**: Deferred — requires transposed weight layout `[E, H, 2*I]`, expert biases, and custom GLU activation. A dedicated forward path is needed. +- **FSDP + fused gate LoRA (SonicMoE)**: The fused topk→softmax path materializes a local tensor when LoRA delta is present to avoid DTensor + Tensor mixing under FSDP. ## Note on MegaBlocks diff --git a/src/axolotl/integrations/kernels/args.py b/src/axolotl/integrations/kernels/args.py index d9b261e72..7c9e23b6c 100644 --- a/src/axolotl/integrations/kernels/args.py +++ b/src/axolotl/integrations/kernels/args.py @@ -45,6 +45,28 @@ class KernelsArgs(BaseModel): return data + @model_validator(mode="before") + @classmethod + def warn_sonicmoe_lora_overhead(cls, data): + if data.get("use_sonicmoe") is True and data.get("adapter") in ( + "lora", + "qlora", + ): + lora_target = data.get("lora_target_modules") or [] + lora_linear = data.get("lora_target_linear_modules") or [] + targets = ( + lora_target if isinstance(lora_target, list) else [lora_target] + ) + (lora_linear if isinstance(lora_linear, list) else [lora_linear]) + expert_keywords = ("gate_up_proj", "down_proj", "experts") + if any(kw in t for t in targets for kw in expert_keywords): + LOG.info( + "SonicMoE + LoRA on expert modules uses runtime weight materialization " + "(W_eff = W + scaling*B@A per forward). This has slightly higher overhead " + "than ScatterMoE's fused Triton LoRA kernels but works with any CUTLASS kernel." + ) + + return data + @model_validator(mode="before") @classmethod def disable_mlp_kernel(cls, data): diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py index 7a1eef472..5180587aa 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py @@ -49,6 +49,11 @@ class ParallelLinear(torch.autograd.Function): grouped_in: bool = False, grouped_out: bool = False, ): + # Cast weights to match input dtype (e.g. 8-bit LoRA) + if expert_weights.dtype != x.dtype: + expert_weights = expert_weights.to(x.dtype) + if expert_biases is not None and expert_biases.dtype != x.dtype: + expert_biases = expert_biases.to(x.dtype) with torch.device(x.device): output = kernels.ops.scatter2scatter( X=x, diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py index 5d00e1230..17dfd420c 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py @@ -65,6 +65,11 @@ class ScatterMoELoRA(torch.autograd.Function): use_fused_dX: bool = False, use_fused_gather: bool = False, ): + # Cast weights to match input dtype (e.g. 8-bit LoRA) + if expert_weights.dtype != x.dtype: + expert_weights = expert_weights.to(x.dtype) + if expert_biases is not None and expert_biases.dtype != x.dtype: + expert_biases = expert_biases.to(x.dtype) with torch.device(x.device): # Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T output = scatter2scatter_lora( diff --git a/src/axolotl/integrations/kernels/sonicmoe/__init__.py b/src/axolotl/integrations/kernels/libs/sonicmoe/__init__.py similarity index 100% rename from src/axolotl/integrations/kernels/sonicmoe/__init__.py rename to src/axolotl/integrations/kernels/libs/sonicmoe/__init__.py diff --git a/src/axolotl/integrations/kernels/libs/sonicmoe/lora.py b/src/axolotl/integrations/kernels/libs/sonicmoe/lora.py new file mode 100644 index 000000000..4d7a21925 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/sonicmoe/lora.py @@ -0,0 +1,220 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +SonicMoE LoRA support via runtime weight materialization. + +SonicMoE uses opaque CUTLASS kernels that cannot be modified to fuse LoRA. +Instead, we materialize the effective weight W_eff = W + scaling * (B @ A) +before each CUTLASS call, and use a custom autograd.Function to route +gradients back to the LoRA A and B parameters. + +PEFT unwrapping utilities are also provided to handle the ParamWrapper +chain that PEFT creates when targeting expert parameters. +""" + +from typing import Optional + +import torch + +# ============================================================================= +# PEFT unwrapping utilities +# ============================================================================= + + +def has_lora(module) -> bool: + """Check if a module is wrapped by PEFT with LoRA.""" + return hasattr(module, "base_layer") and hasattr(module, "lora_A") + + +def get_lora_params_from_wrapper(module) -> tuple: + """Extract LoRA parameters from a PEFT ParamWrapper. + + Returns: + (lora_A, lora_B, scaling) if LoRA is active, else (None, None, None) + """ + if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"): + return None, None, None + + active_adapters = getattr(module, "active_adapters", ["default"]) + if not active_adapters: + return None, None, None + + adapter_name = active_adapters[0] + + lora_A_dict = getattr(module, "lora_A", {}) + lora_B_dict = getattr(module, "lora_B", {}) + scaling_dict = getattr(module, "scaling", {}) + + if ( + adapter_name not in lora_A_dict + or adapter_name not in lora_B_dict + or adapter_name not in scaling_dict + ): + return None, None, None + + lora_A = lora_A_dict[adapter_name].weight + lora_B = lora_B_dict[adapter_name].weight + scaling = scaling_dict[adapter_name] + + return lora_A, lora_B, scaling + + +def unwrap_gate_lora(gate_module): + """Unwrap PEFT ParamWrapper on the router gate. + + When PEFT targets ``gate.weight``, ``self.gate`` becomes:: + + ParamWrapper(weight) + -> base_layer: Router (the real module) + + Returns: + (base_gate, gate_weight, gate_lora_delta_or_None) + + ``base_gate`` is the original router module (with ``.top_k``, etc.). + ``gate_weight`` is the base router weight tensor. + ``gate_lora_delta_or_None`` is the LoRA delta if active, else None. + Kept separate to avoid mixing DTensor + Tensor under FSDP. + """ + if has_lora(gate_module): + base_gate = gate_module.base_layer + lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module) + if lora_A is not None: + delta = scaling * (lora_B @ lora_A) + return base_gate, base_gate.weight, delta + return base_gate, base_gate.weight, None + + return gate_module, gate_module.weight, None + + +def unwrap_experts_lora(experts_module): + """Walk a PEFT ParamWrapper chain on ``self.experts``. + + When PEFT targets ``experts.gate_up_proj`` and ``experts.down_proj`` + via ``target_parameters``, ``self.experts`` becomes:: + + ParamWrapper(down_proj) + -> base_layer: ParamWrapper(gate_up_proj) + -> base_layer: Experts (the real module) + + Returns: + (base_experts, lora_dict) + + ``lora_dict`` maps parameter names to ``(lora_A, lora_B, scaling)`` + tuples, or is empty if no LoRA is active. + """ + wrappers = {} + module = experts_module + while hasattr(module, "base_layer") and hasattr(module, "lora_A"): + param_name = getattr(module, "parameter_name", None) + if param_name is not None: + wrappers[param_name] = module + module = module.base_layer + + base_experts = module + lora_dict = {} + + for param_name, wrapper in wrappers.items(): + lora_A, lora_B, scaling = get_lora_params_from_wrapper(wrapper) + if lora_A is not None: + lora_dict[param_name] = (lora_A, lora_B, scaling) + + return base_experts, lora_dict + + +# ============================================================================= +# LoRA weight materialization autograd function +# ============================================================================= + + +class MoELoRAMaterialize(torch.autograd.Function): + """Materialize effective weight W_eff = W + scaling * (B @ A) per expert. + + Inserts into the autograd graph between PEFT's LoRA parameters and + SonicMoE's CUTLASS kernels. The CUTLASS backward computes dW_eff, + which this function decomposes into dA and dB via the chain rule. + + Weight layouts (PEFT rank-major): + base_weight: [E, dim1, dim2] (frozen expert parameter) + lora_A: [r*E, dim2] (rows [e*r:(e+1)*r] = A_e) + lora_B: [dim1, r*E] (cols [:, e*r:(e+1)*r] = B_e) + + Per-expert: delta_e = B_e @ A_e = [dim1, r] @ [r, dim2] = [dim1, dim2] + """ + + @staticmethod + def forward( + ctx, + base_weight: torch.Tensor, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + scaling: float, + ) -> torch.Tensor: + E, dim1, dim2 = base_weight.shape + r = lora_A.shape[0] // E + assert lora_A.shape[0] == r * E, ( + f"lora_A rows ({lora_A.shape[0]}) must be divisible by num_experts ({E})" + ) + + # Reshape PEFT rank-major to per-expert batched format + A_3d = lora_A.reshape(E, r, dim2) + B_3d = lora_B.reshape(dim1, r, E).permute(2, 0, 1).contiguous() # [E, dim1, r] + + # Batched matmul: [E, dim1, r] @ [E, r, dim2] = [E, dim1, dim2] + delta = torch.bmm(B_3d, A_3d) + + W_eff = base_weight + scaling * delta + + ctx.save_for_backward(lora_A, lora_B) + ctx.scaling = scaling + ctx.E = E + ctx.r = r + + return W_eff + + @staticmethod + def backward(ctx, grad_W_eff: torch.Tensor): + lora_A, lora_B = ctx.saved_tensors + scaling = ctx.scaling + E = ctx.E + r = ctx.r + + _, dim1, dim2 = grad_W_eff.shape + + # Reshape to per-expert (same as forward) + A_3d = lora_A.reshape(E, r, dim2) + B_3d = lora_B.reshape(dim1, r, E).permute(2, 0, 1).contiguous() # [E, dim1, r] + + # dA_e = scaling * B_e^T @ dW_e + # [E, r, dim1] @ [E, dim1, dim2] = [E, r, dim2] + d_A_3d = scaling * torch.bmm(B_3d.transpose(1, 2), grad_W_eff) + + # dB_e = scaling * dW_e @ A_e^T + # [E, dim1, dim2] @ [E, dim2, r] = [E, dim1, r] + d_B_3d = scaling * torch.bmm(grad_W_eff, A_3d.transpose(1, 2)) + + # Reshape back to PEFT rank-major layout + d_lora_A = d_A_3d.reshape(E * r, dim2) + d_lora_B = d_B_3d.permute(1, 2, 0).contiguous().reshape(dim1, E * r) + + return None, d_lora_A, d_lora_B, None + + +def materialize_expert_lora( + base_weight: torch.Tensor, + lora_params: Optional[tuple], +) -> torch.Tensor: + """Materialize effective expert weight with optional LoRA delta. + + Args: + base_weight: [E, dim1, dim2] frozen expert parameter + lora_params: (lora_A, lora_B, scaling) or None + + Returns: + W_eff if lora_params is not None, else base_weight unchanged. + """ + if lora_params is None: + return base_weight + lora_A, lora_B, scaling = lora_params + return MoELoRAMaterialize.apply(base_weight, lora_A, lora_B, scaling) diff --git a/src/axolotl/integrations/kernels/sonicmoe/patch.py b/src/axolotl/integrations/kernels/libs/sonicmoe/patch.py similarity index 68% rename from src/axolotl/integrations/kernels/sonicmoe/patch.py rename to src/axolotl/integrations/kernels/libs/sonicmoe/patch.py index a3b96f12a..65095a987 100644 --- a/src/axolotl/integrations/kernels/sonicmoe/patch.py +++ b/src/axolotl/integrations/kernels/libs/sonicmoe/patch.py @@ -28,20 +28,72 @@ import torch.nn.functional as F from axolotl.integrations.kernels.constants import resolve_moe_block_classes from axolotl.utils.logging import get_logger +from .lora import ( + has_lora, + materialize_expert_lora, + unwrap_experts_lora, + unwrap_gate_lora, +) + LOG = get_logger(__name__) -def patch_sonicmoe(model_type: str, torch_compile: bool = False): - """Main entry point: patch SparseMoeBlock for SonicMoE support. +def _get_expert_weights(experts_module): + """Extract expert weights, applying LoRA materialization if PEFT is active. - Args: - model_type: The HuggingFace model type (e.g. "qwen3_moe"). - torch_compile: If True, wrap routing functions with torch.compile - for kernel fusion (fuses softmax+topk+renorm into fewer launches). + Returns: + (gate_up_weight, down_weight) in SonicMoE layout [dim, dim, E]. """ + if has_lora(experts_module): + base_experts, lora_dict = unwrap_experts_lora(experts_module) + gate_up = materialize_expert_lora( + base_experts.gate_up_proj, lora_dict.get("gate_up_proj") + ) + down = materialize_expert_lora( + base_experts.down_proj, lora_dict.get("down_proj") + ) + else: + gate_up = experts_module.gate_up_proj + down = experts_module.down_proj + + # Permute to SonicMoE layout: + # gate_up: [E, 2*I, H] -> [2*I, H, E] + # down: [E, H, I] -> [H, I, E] + return gate_up.permute(1, 2, 0), down.permute(1, 2, 0) + + +def _fix_qwen3_5_moe_text_weight_renaming(model_type: str, base_model_type: str): + """Strip qwen3_5_moe_text WeightRenaming in VLM mode to preserve custom loaders.""" + if model_type != "qwen3_5_moe_text" or base_model_type == "qwen3_5_moe_text": + return + + try: + from transformers.conversion_mapping import ( + get_checkpoint_conversion_mapping, + register_checkpoint_conversion_mapping, + ) + from transformers.core_model_loading import WeightRenaming + except ImportError: + return + + text_mapping = get_checkpoint_conversion_mapping(model_type) + if text_mapping and isinstance(text_mapping[0], WeightRenaming): + text_mapping.pop(0) + register_checkpoint_conversion_mapping(model_type, text_mapping, overwrite=True) + LOG.info("Stripped qwen3_5_moe_text WeightRenaming for VLM mode") + + +def patch_sonicmoe( + model_type: str, + torch_compile: bool = False, + base_model_type: str | None = None, +): + """Patch SparseMoeBlock for SonicMoE support.""" from .routing import get_model_moe_config from .weight_converter import register_sonicmoe_weight_converter + _fix_qwen3_5_moe_text_weight_renaming(model_type, base_model_type or model_type) + routing_fn, activation, router_attr = get_model_moe_config(model_type) if torch_compile and routing_fn is not None: @@ -113,11 +165,10 @@ def _make_general_forward(moe_cls, routing_fn, activation): hidden_states_flat, self ) - # Permute weights to SonicMoE layout: - # gate_up: [E, 2*I, H] -> [2*I, H, E] - # down: [E, H, I] -> [H, I, E] - gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0) - down_weight = self.experts.down_proj.permute(1, 2, 0) + # Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout + gate_up_weight, down_weight = _get_expert_weights(self.experts) + gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype) + down_weight = down_weight.to(hidden_states_flat.dtype) E = gate_up_weight.shape[-1] output, _ = moe_general_routing_inputs( @@ -161,22 +212,30 @@ def _make_fused_forward(moe_cls, activation, router_attr): # Shared expert (computed early, matching original model ordering) shared_expert_output = _compute_shared_expert(self, hidden_states_flat) - router = getattr(self, router_attr) + # Unwrap router for attribute access + optional LoRA delta + raw_router = getattr(self, router_attr) + base_router, router_weight, router_lora_delta = unwrap_gate_lora(raw_router) + if router_lora_delta is not None: + # Materialize local tensor to avoid DTensor + Tensor add under FSDP + if hasattr(router_weight, "to_local"): + router_weight = router_weight.to_local() + effective_router_weight = router_weight + router_lora_delta + else: + effective_router_weight = router_weight - # Permute weights to SonicMoE layout: - # gate_up: [E, 2*I, H] -> [2*I, H, E] - # down: [E, H, I] -> [H, I, E] - gate_up_weight = self.experts.gate_up_proj.permute(1, 2, 0) - down_weight = self.experts.down_proj.permute(1, 2, 0) + # Unwrap PEFT + optional LoRA materialization, then permute to SonicMoE layout + gate_up_weight, down_weight = _get_expert_weights(self.experts) + gate_up_weight = gate_up_weight.to(hidden_states_flat.dtype) + down_weight = down_weight.to(hidden_states_flat.dtype) output, _router_logits, _expert_freq = moe_TC_softmax_topk_layer( hidden_states_flat, - router.weight, + effective_router_weight, gate_up_weight, None, # b1 (no gate/up bias) down_weight, None, # b2 (no down bias) - router.top_k, + base_router.top_k, torch.cuda.current_stream().cuda_stream, activation, False, # is_inference_mode diff --git a/src/axolotl/integrations/kernels/sonicmoe/routing.py b/src/axolotl/integrations/kernels/libs/sonicmoe/routing.py similarity index 86% rename from src/axolotl/integrations/kernels/sonicmoe/routing.py rename to src/axolotl/integrations/kernels/libs/sonicmoe/routing.py index 09bffc742..4bdb37890 100644 --- a/src/axolotl/integrations/kernels/sonicmoe/routing.py +++ b/src/axolotl/integrations/kernels/libs/sonicmoe/routing.py @@ -16,6 +16,8 @@ When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used. import torch import torch.nn.functional as F +from .lora import unwrap_gate_lora + def get_model_moe_config(model_type: str): """Returns (routing_fn, activation, router_attr) for a given model type. @@ -40,6 +42,7 @@ def get_model_moe_config(model_type: str): "qwen2_moe", "qwen3_moe", "qwen3_5_moe", + "qwen3_5_moe_text", "qwen3_next", "qwen3_vl_moe", "qwen3_omni_moe", @@ -88,12 +91,18 @@ def softmax_topk_routing( expert_indices: [T*K] which expert (int32) router_logits: [T, E] original logits for aux loss """ - gate = moe_block.gate - T, _ = hidden_states.shape - K = gate.top_k + base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate) + T, H = hidden_states.shape + K = base_gate.top_k - # Compute router logits and softmax over all experts - router_logits = F.linear(hidden_states, gate.weight) # [T, E] + # Compute router logits and softmax over all experts. + # Two F.linear calls avoid mixing DTensor (gate_weight) + Tensor (delta) under FSDP. + # Cast to float32 to match LoRA delta dtype (PEFT computes in fp32). + router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E] + if gate_lora_delta is not None: + router_logits = router_logits + F.linear( + hidden_states.float(), gate_lora_delta.float() + ) router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] # Select top-k experts per token @@ -101,7 +110,7 @@ def softmax_topk_routing( # Renormalize if configured (default True for models without the attribute, # e.g. Mixtral/MiniMax which always normalize) - if getattr(gate, "norm_topk_prob", True): + if getattr(base_gate, "norm_topk_prob", True): top_values = top_values / top_values.sum(dim=-1, keepdim=True) # no-op: matches transformers which casts to softmax output dtype (float32). @@ -128,13 +137,17 @@ def softmax_group_topk_routing( hidden_states: torch.Tensor, moe_block ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale.""" - gate = moe_block.gate + base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate) T, _ = hidden_states.shape K = moe_block.top_k - E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0]) + E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0]) n_group = getattr(moe_block, "n_group", 1) - router_logits = F.linear(hidden_states, gate.weight) # [T, E] + router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E] + if gate_lora_delta is not None: + router_logits = router_logits + F.linear( + hidden_states.float(), gate_lora_delta.float() + ) router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] scores_for_choice = router_probs @@ -206,25 +219,29 @@ def sigmoid_topk_routing( expert_indices: [T*K] which expert (int32) router_logits: [T, E] original logits for aux loss """ - gate = moe_block.gate + base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate) T, _ = hidden_states.shape K = moe_block.top_k - E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0]) + E = getattr(moe_block, "n_routed_experts", gate_weight.shape[0]) n_group = getattr(moe_block, "n_group", 1) # Compute router logits and sigmoid probabilities - router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E] + router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E] + if gate_lora_delta is not None: + router_logits = router_logits + F.linear( + hidden_states.float(), gate_lora_delta.float() + ) router_probs = router_logits.sigmoid() # [T, E] # Bias-corrected scores for expert selection (not used for final weights). # glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 stores it on the block. - e_score_correction_bias = getattr(gate, "e_score_correction_bias", None) + e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None) if e_score_correction_bias is None: e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None) if e_score_correction_bias is None: raise AttributeError( f"sigmoid_topk_routing requires e_score_correction_bias on " - f"gate ({type(gate)}) or moe_block ({type(moe_block)}), but neither has it" + f"gate ({type(base_gate)}) or moe_block ({type(moe_block)}), but neither has it" ) scores_for_choice = router_probs + e_score_correction_bias @@ -296,16 +313,20 @@ def softmax_bias_topk_routing( expert_indices: [T*K] which expert (int32) router_logits: [T, E] original logits for aux loss """ - gate = moe_block.gate - T, _ = hidden_states.shape - K = gate.top_k + base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate) + T, H = hidden_states.shape + K = base_gate.top_k # Compute router logits and softmax (force float32 for numerical stability) - router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E] + router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E] + if gate_lora_delta is not None: + router_logits = router_logits + F.linear( + hidden_states.float(), gate_lora_delta.float() + ) router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] # Bias-corrected scores for expert selection (via moe_statics module) - scores_for_choice = gate.moe_statics(router_probs) # [T, E] + scores_for_choice = base_gate.moe_statics(router_probs) # [T, E] # Select top-k experts using biased scores _, selected_experts = torch.topk(scores_for_choice, K, dim=-1) # [T, K] @@ -314,7 +335,7 @@ def softmax_bias_topk_routing( top_values = torch.gather(router_probs, dim=-1, index=selected_experts) # [T, K] # Renormalize with clamp(min=norm_min) instead of sum+epsilon - norm_min = getattr(gate, "norm_min", 1e-20) + norm_min = getattr(base_gate, "norm_min", 1e-20) top_values = top_values / torch.clamp( top_values.sum(dim=-1, keepdim=True), min=norm_min ) @@ -358,15 +379,19 @@ def softmax_group_limited_topk_routing( expert_indices: [T*K] which expert (int32) router_logits: [T, E] original logits for aux loss """ - gate = moe_block.gate - T, _ = hidden_states.shape + base_gate, gate_weight, gate_lora_delta = unwrap_gate_lora(moe_block.gate) + T, H = hidden_states.shape K = moe_block.top_k num_group = getattr(moe_block, "num_group", 1) - num_experts = gate.weight.shape[0] + num_experts = gate_weight.shape[0] topk_method = getattr(moe_block, "topk_method", "greedy") # Compute logits in float32 and softmax - router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E] + router_logits = F.linear(hidden_states.float(), gate_weight.float()) # [T, E] + if gate_lora_delta is not None: + router_logits = router_logits + F.linear( + hidden_states.float(), gate_lora_delta.float() + ) router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] if topk_method == "greedy" or num_group == 1: @@ -445,12 +470,17 @@ def softmax_topk_wg_routing( router_logits: [T, E] original logits for aux loss """ gate = moe_block.gate - T, _ = hidden_states.shape + T, H = hidden_states.shape K = moe_block.top_k # Gate computes logits via gate.wg (nn.Linear, float32) - wg = gate.wg - router_logits = F.linear(hidden_states.float(), wg.weight.float()) # [T, E] + # Unwrap at gate.wg level since PEFT targets the wg Linear, not the gate container + base_wg, wg_weight, wg_lora_delta = unwrap_gate_lora(gate.wg) + router_logits = F.linear(hidden_states.float(), wg_weight.float()) # [T, E] + if wg_lora_delta is not None: + router_logits = router_logits + F.linear( + hidden_states.float(), wg_lora_delta.float() + ) router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] # Select top-k experts diff --git a/src/axolotl/integrations/kernels/sonicmoe/weight_converter.py b/src/axolotl/integrations/kernels/libs/sonicmoe/weight_converter.py similarity index 69% rename from src/axolotl/integrations/kernels/sonicmoe/weight_converter.py rename to src/axolotl/integrations/kernels/libs/sonicmoe/weight_converter.py index 172864ac6..20da27ff0 100644 --- a/src/axolotl/integrations/kernels/sonicmoe/weight_converter.py +++ b/src/axolotl/integrations/kernels/libs/sonicmoe/weight_converter.py @@ -129,15 +129,41 @@ class InterleavedToConcatenated(ConversionOps): return ConcatenatedToInterleaved(self.dim) +def _make_same_key_interleave_converter(): + """Create a WeightConverter that interleaves an already-fused gate_up_proj.""" + from transformers.core_model_loading import WeightConverter + + return WeightConverter( + source_patterns="mlp.experts.gate_up_proj", + target_patterns="mlp.experts.gate_up_proj", + operations=[ConcatenatedToInterleaved(dim=1)], + ) + + +def _has_same_key_interleave(mapping) -> bool: + """Check whether the mapping already has a same-key gate_up_proj interleave converter.""" + for conv in mapping: + if ( + hasattr(conv, "source_patterns") + and conv.source_patterns == ["mlp.experts.gate_up_proj"] + and conv.target_patterns == ["mlp.experts.gate_up_proj"] + and hasattr(conv, "operations") + and any(isinstance(op, ConcatenatedToInterleaved) for op in conv.operations) + ): + return True + return False + + def register_sonicmoe_weight_converter(model_type: str): - """Override the conversion mapping to add interleave step for gate_up_proj. + """Register weight converters to interleave gate_up_proj for SonicMoE. - Appends a ConcatenatedToInterleaved operation to the existing gate_up_proj - converter chain. For example, qwen3_moe's chain becomes: - MergeModulelist(dim=0) -> Concatenate(dim=1) -> ConcatenatedToInterleaved(dim=1) + Handles two checkpoint formats: + 1. Separate per-expert weights (e.g. qwen3_moe): appends interleave to the + existing merge chain (MergeModulelist -> Concatenate -> Interleave). + 2. Already-fused gate_up_proj (e.g. qwen3_5_moe_text): adds a same-key + converter (gate_up_proj -> gate_up_proj with Interleave). - The reverse is auto-generated for saving: - InterleavedToConcatenated(dim=1) -> Chunk(dim=1) -> SplitModulelist(dim=0) + The loader matches whichever source pattern exists in the checkpoint. """ from transformers.conversion_mapping import ( get_checkpoint_conversion_mapping, @@ -145,37 +171,32 @@ def register_sonicmoe_weight_converter(model_type: str): ) existing = get_checkpoint_conversion_mapping(model_type) + if existing is None: - LOG.warning( - f"No conversion mapping found for model type '{model_type}'. " - "SonicMoE weight interleaving will not be applied during checkpoint loading." - ) + # No mapping at all — create one with just the same-key converter + mapping = [_make_same_key_interleave_converter()] + register_checkpoint_conversion_mapping(model_type, mapping) + LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'") return - # Find the gate_up_proj converter and append ConcatenatedToInterleaved - patched = False + # Append interleave to any existing many-to-one merge chain for converter in existing: if hasattr(converter, "operations") and any( "gate_up_proj" in pat for pat in converter.target_patterns ): - # Guard against double registration (e.g. plugin reloaded) - if any( + has_separate_sources = any( + "gate_proj" in pat or "up_proj" in pat + for pat in converter.source_patterns + ) + if has_separate_sources and not any( isinstance(op, ConcatenatedToInterleaved) for op in converter.operations ): - LOG.info( - f"SonicMoE weight converter already registered for '{model_type}'" - ) - return - converter.operations.append(ConcatenatedToInterleaved(dim=1)) - patched = True + converter.operations.append(ConcatenatedToInterleaved(dim=1)) break - if not patched: - LOG.warning( - f"Could not find gate_up_proj converter for model type '{model_type}'. " - "SonicMoE weight interleaving will not be applied during checkpoint loading." - ) - return + # Also add a same-key converter for already-fused checkpoints + if not _has_same_key_interleave(existing): + existing.append(_make_same_key_interleave_converter()) register_checkpoint_conversion_mapping(model_type, existing, overwrite=True) LOG.info(f"Registered SonicMoE weight converter for model type '{model_type}'") diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index 939bdb790..4ab22bfce 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -84,12 +84,13 @@ class KernelsPlugin(BasePlugin): _check_sonicmoe_gpu_compat() - from axolotl.integrations.kernels.sonicmoe import patch_sonicmoe + from axolotl.integrations.kernels.libs.sonicmoe import patch_sonicmoe LOG.info(f"Applying SonicMoE patches for model type: {moe_model_type}") patch_sonicmoe( moe_model_type, torch_compile=bool(getattr(cfg, "torch_compile", False)), + base_model_type=cfg.model_config_type, ) def _register_kernels(self): diff --git a/tests/e2e/integrations/test_sonicmoe.py b/tests/e2e/integrations/test_sonicmoe.py index 2152e94c7..ff8620b2f 100644 --- a/tests/e2e/integrations/test_sonicmoe.py +++ b/tests/e2e/integrations/test_sonicmoe.py @@ -51,7 +51,7 @@ def _create_tiny_qwen3_config(): def _interleave_gate_up_weights(model): """Interleave all gate_up_proj parameters in-place for SonicMoE.""" - from axolotl.integrations.kernels.sonicmoe.weight_converter import ( + from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import ( interleave_gate_up, ) @@ -80,7 +80,7 @@ class TestSonicMoEForwardCorrectness: def test_forward_output_matches(self): from transformers import AutoModelForCausalLM - from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe + from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe config = _create_tiny_qwen3_config() input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda") @@ -117,8 +117,8 @@ class TestSonicMoEGradientCorrectness: """Verify all parameter gradients match between original and patched.""" from transformers import AutoModelForCausalLM - from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe - from axolotl.integrations.kernels.sonicmoe.weight_converter import ( + from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe + from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import ( deinterleave_gate_up, ) @@ -191,7 +191,7 @@ class TestSonicMoEGradientCorrectness: """Verify that router (gate) weights get non-zero gradients.""" from transformers import AutoModelForCausalLM - from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe + from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe config = _create_tiny_qwen3_config() input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda") @@ -223,7 +223,7 @@ class TestSonicMoETrainingConvergence: """Run 30 training steps, verify loss decreases and no NaN/Inf.""" from transformers import AutoModelForCausalLM - from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe + from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe config = _create_tiny_qwen3_config() input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") @@ -254,7 +254,7 @@ class TestSonicMoETrainingConvergence: """Verify expert weights change during training (not frozen).""" from transformers import AutoModelForCausalLM - from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe + from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe config = _create_tiny_qwen3_config() input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") diff --git a/tests/e2e/integrations/test_sonicmoe_lora.py b/tests/e2e/integrations/test_sonicmoe_lora.py new file mode 100644 index 000000000..74721ee57 --- /dev/null +++ b/tests/e2e/integrations/test_sonicmoe_lora.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +End-to-end tests for SonicMoE + LoRA integration. + +Verifies that PEFT-wrapped MoE models work correctly with SonicMoE's +runtime LoRA materialization: gradients flow to adapters, base weights +stay frozen, and loss converges. + +Requires: + - H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90) + - sonicmoe package installed + - peft package installed + - transformers with Qwen3MoE support + +Usage: + pytest tests/e2e/integrations/test_sonicmoe_lora.py -v -s +""" + +import importlib.util +import math + +import pytest +import torch + +_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None +_peft_available = importlib.util.find_spec("peft") is not None +_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0) + +pytestmark = [ + pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"), + pytest.mark.skipif( + not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)" + ), + pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"), + pytest.mark.skipif(not _peft_available, reason="PEFT not installed"), +] + + +def _create_tiny_qwen3_config(): + """Create a minimal Qwen3MoE config for fast testing.""" + from transformers import AutoConfig + + config = AutoConfig.for_model("qwen3_moe") + config.hidden_size = 512 + config.intermediate_size = 1024 + config.moe_intermediate_size = 64 + config.num_attention_heads = 16 + config.num_key_value_heads = 2 + config.head_dim = 32 + config.num_hidden_layers = 2 + config.num_experts = 8 + config.num_experts_per_tok = 2 + config.vocab_size = 1000 + config.max_position_embeddings = 128 + config.norm_topk_prob = True + config.torch_dtype = torch.bfloat16 + return config + + +def _interleave_gate_up_weights(model): + """Interleave all gate_up_proj parameters in-place for SonicMoE.""" + from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import ( + interleave_gate_up, + ) + + with torch.no_grad(): + for name, param in model.named_parameters(): + if "gate_up_proj" in name: + param.copy_(interleave_gate_up(param)) + + +def _unpatch_sonicmoe(): + """Restore original forward on the MoE block class if it was patched.""" + from axolotl.integrations.kernels.constants import resolve_moe_block_classes + + for moe_cls in resolve_moe_block_classes("qwen3_moe"): + if hasattr(moe_cls, "_original_forward"): + moe_cls.forward = moe_cls._original_forward + del moe_cls._original_forward + + +def _apply_lora(model, target_modules): + """Apply PEFT LoRA to the model.""" + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=8, + lora_alpha=16, + target_modules=target_modules, + lora_dropout=0.0, + bias="none", + ) + return get_peft_model(model, lora_config) + + +class TestSonicMoELoRATraining: + """Verify SonicMoE + LoRA training works end-to-end.""" + + def teardown_method(self): + _unpatch_sonicmoe() + + def test_loss_decreases(self): + """Run 30 training steps with LoRA on experts, verify loss decreases.""" + from transformers import AutoModelForCausalLM + + from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe + + config = _create_tiny_qwen3_config() + input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") + + model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + patch_sonicmoe("qwen3_moe") + _interleave_gate_up_weights(model) + model = _apply_lora(model, ["gate_up_proj", "down_proj"]) + + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], lr=1e-3 + ) + losses = [] + + for step in range(30): + out = model(input_ids, labels=input_ids) + loss = out.loss + assert not math.isnan(loss.item()), f"NaN loss at step {step}" + assert not math.isinf(loss.item()), f"Inf loss at step {step}" + losses.append(loss.item()) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + assert losses[-1] < losses[0], ( + f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}" + ) + + def test_base_weights_frozen(self): + """Verify base (non-LoRA) weights don't change during training.""" + from transformers import AutoModelForCausalLM + + from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe + + config = _create_tiny_qwen3_config() + input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") + + model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + patch_sonicmoe("qwen3_moe") + _interleave_gate_up_weights(model) + model = _apply_lora(model, ["gate_up_proj", "down_proj"]) + + # Snapshot frozen weights + frozen_before = {} + for name, param in model.named_parameters(): + if not param.requires_grad: + frozen_before[name] = param.data.clone() + + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], lr=1e-3 + ) + for _ in range(5): + out = model(input_ids, labels=input_ids) + out.loss.backward() + optimizer.step() + optimizer.zero_grad() + + for name, param in model.named_parameters(): + if name in frozen_before: + assert torch.equal(param.data, frozen_before[name]), ( + f"Frozen weight changed: {name}" + ) + + def test_lora_adapters_receive_gradients(self): + """Verify LoRA A and B matrices get non-zero gradients.""" + from transformers import AutoModelForCausalLM + + from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe + + config = _create_tiny_qwen3_config() + input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda") + + model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + patch_sonicmoe("qwen3_moe") + _interleave_gate_up_weights(model) + model = _apply_lora(model, ["gate_up_proj", "down_proj"]) + + out = model(input_ids, labels=input_ids) + out.loss.backward() + + lora_grads_found = 0 + for name, param in model.named_parameters(): + if "lora_" in name and param.requires_grad: + assert param.grad is not None, f"No gradient for LoRA param: {name}" + assert param.grad.abs().max() > 0, ( + f"Zero gradient for LoRA param: {name}" + ) + lora_grads_found += 1 + + assert lora_grads_found > 0, "No LoRA parameters found with gradients" + + def test_lora_adapters_update(self): + """Verify LoRA adapter weights change during training.""" + from transformers import AutoModelForCausalLM + + from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe + + config = _create_tiny_qwen3_config() + input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") + + model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + patch_sonicmoe("qwen3_moe") + _interleave_gate_up_weights(model) + model = _apply_lora(model, ["gate_up_proj", "down_proj"]) + + # Snapshot LoRA weights + lora_before = {} + for name, param in model.named_parameters(): + if "lora_" in name and param.requires_grad: + lora_before[name] = param.data.clone() + + assert lora_before, "No LoRA parameters found" + + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], lr=1e-3 + ) + for _ in range(5): + out = model(input_ids, labels=input_ids) + out.loss.backward() + optimizer.step() + optimizer.zero_grad() + + changed = sum( + 1 + for name, param in model.named_parameters() + if name in lora_before and not torch.equal(param.data, lora_before[name]) + ) + assert changed > 0, "No LoRA weights changed after 5 training steps" + + +class TestSonicMoEGateOnlyLoRA: + """Verify LoRA targeting only the gate (router) works with SonicMoE.""" + + def teardown_method(self): + _unpatch_sonicmoe() + + def test_gate_only_lora_loss_decreases(self): + """LoRA only on gate — expert path should have zero materialization overhead.""" + from transformers import AutoModelForCausalLM + + from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe + + config = _create_tiny_qwen3_config() + input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") + + model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + patch_sonicmoe("qwen3_moe") + _interleave_gate_up_weights(model) + # Only target the gate (router), not expert projections + model = _apply_lora(model, ["gate"]) + + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], lr=1e-3 + ) + losses = [] + + for step in range(20): + out = model(input_ids, labels=input_ids) + loss = out.loss + assert not math.isnan(loss.item()), f"NaN loss at step {step}" + assert not math.isinf(loss.item()), f"Inf loss at step {step}" + losses.append(loss.item()) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + assert losses[-1] < losses[0], ( + f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}" + ) + + +class TestSonicMoENoLoRARegression: + """Verify SonicMoE without LoRA still works after LoRA code was added.""" + + def teardown_method(self): + _unpatch_sonicmoe() + + def test_no_lora_loss_decreases(self): + """Full fine-tuning (no PEFT) with SonicMoE — regression test.""" + from transformers import AutoModelForCausalLM + + from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe + + config = _create_tiny_qwen3_config() + input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda") + + model = AutoModelForCausalLM.from_config(config).cuda().bfloat16() + patch_sonicmoe("qwen3_moe") + _interleave_gate_up_weights(model) + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + losses = [] + + for step in range(20): + out = model(input_ids, labels=input_ids) + loss = out.loss + assert not math.isnan(loss.item()), f"NaN loss at step {step}" + assert not math.isinf(loss.item()), f"Inf loss at step {step}" + losses.append(loss.item()) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + assert losses[-1] < losses[0], ( + f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}" + ) diff --git a/tests/integrations/test_routing_parity.py b/tests/integrations/test_routing_parity.py index cc668671c..885206809 100644 --- a/tests/integrations/test_routing_parity.py +++ b/tests/integrations/test_routing_parity.py @@ -93,7 +93,9 @@ class TestSoftmaxRoutingParity: from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( _softmax_topk_route, ) - from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( + softmax_topk_routing, + ) moe_block, gate, hidden, T, H, E, K = _make_softmax_block() @@ -120,7 +122,9 @@ class TestSoftmaxRoutingParity: def test_logits_not_returned_by_scattermoe(self): """ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape.""" - from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( + softmax_topk_routing, + ) moe_block, gate, hidden, T, H, E, K = _make_softmax_block() _, _, _, logits = softmax_topk_routing(hidden, moe_block) @@ -131,7 +135,9 @@ class TestSoftmaxRoutingParity: from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( _softmax_topk_route, ) - from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( + softmax_topk_routing, + ) moe_block, gate, hidden, T, H, E, K = _make_softmax_block() gate.norm_topk_prob = False @@ -152,7 +158,9 @@ class TestSoftmaxRoutingParity: from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( _softmax_topk_route, ) - from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( + softmax_topk_routing, + ) for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]: moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K) @@ -190,7 +198,9 @@ class TestSigmoidRoutingParity: from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( _sigmoid_topk_route, ) - from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( + sigmoid_topk_routing, + ) moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True @@ -226,7 +236,9 @@ class TestSigmoidRoutingParity: from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( _sigmoid_topk_route, ) - from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( + sigmoid_topk_routing, + ) moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True @@ -254,7 +266,9 @@ class TestSigmoidRoutingParity: from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( _sigmoid_topk_route, ) - from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( + sigmoid_topk_routing, + ) moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( E=16, K=4, n_group=1, bias_on_gate=False @@ -281,7 +295,9 @@ class TestSigmoidRoutingParity: from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( _sigmoid_topk_route, ) - from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( + sigmoid_topk_routing, + ) moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( n_group=1, bias_on_gate=True @@ -309,7 +325,9 @@ class TestSigmoidRoutingParity: from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( _sigmoid_topk_route, ) - from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( + sigmoid_topk_routing, + ) moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( n_group=1, bias_on_gate=True @@ -349,7 +367,7 @@ class TestSharedExpertParity: from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( _compute_shared_expert as scatter_compute, ) - from axolotl.integrations.kernels.sonicmoe.patch import ( + from axolotl.integrations.kernels.libs.sonicmoe.patch import ( _compute_shared_expert as sonic_compute, ) diff --git a/tests/integrations/test_sonicmoe.py b/tests/integrations/test_sonicmoe.py index 7d26d9d93..864abca36 100644 --- a/tests/integrations/test_sonicmoe.py +++ b/tests/integrations/test_sonicmoe.py @@ -6,11 +6,11 @@ import pytest import torch from axolotl.integrations.kernels.args import KernelsArgs -from axolotl.integrations.kernels.sonicmoe.routing import ( +from axolotl.integrations.kernels.libs.sonicmoe.routing import ( sigmoid_topk_routing, softmax_topk_routing, ) -from axolotl.integrations.kernels.sonicmoe.weight_converter import ( +from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import ( ConcatenatedToInterleaved, InterleavedToConcatenated, register_sonicmoe_weight_converter, @@ -212,9 +212,40 @@ class TestWeightConverterRegistration: ) break - def test_register_unsupported_model_type_warns(self): - # A model type with no conversion mapping should warn but not raise - register_sonicmoe_weight_converter("nonexistent_model_type_xyz") + def test_register_adds_same_key_converter(self): + from transformers.conversion_mapping import get_checkpoint_conversion_mapping + + register_sonicmoe_weight_converter("qwen3_moe") + + modified = get_checkpoint_conversion_mapping("qwen3_moe") + # Should have a same-key converter for already-fused checkpoints + same_key = [ + c + for c in modified + if hasattr(c, "source_patterns") + and c.source_patterns == ["mlp.experts.gate_up_proj"] + and c.target_patterns == ["mlp.experts.gate_up_proj"] + ] + assert len(same_key) == 1 + assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved) + + def test_register_creates_mapping_when_none(self): + from transformers.conversion_mapping import get_checkpoint_conversion_mapping + + # qwen3_5_moe has no conversion mapping in transformers + register_sonicmoe_weight_converter("qwen3_5_moe") + + mapping = get_checkpoint_conversion_mapping("qwen3_5_moe") + assert mapping is not None + same_key = [ + c + for c in mapping + if hasattr(c, "source_patterns") + and c.source_patterns == ["mlp.experts.gate_up_proj"] + and c.target_patterns == ["mlp.experts.gate_up_proj"] + ] + assert len(same_key) == 1 + assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved) def _make_qwen_moe_block(T=8, H=16, E=4, K=2): @@ -462,7 +493,7 @@ class TestSoftmaxBiasTopkRouting: """Tests for Ernie 4.5 MoE routing (softmax_bias_topk_routing).""" def test_output_shapes(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_bias_topk_routing, ) @@ -479,7 +510,7 @@ class TestSoftmaxBiasTopkRouting: assert logits.shape == (T, E) def test_scores_are_float32(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_bias_topk_routing, ) @@ -490,7 +521,7 @@ class TestSoftmaxBiasTopkRouting: assert scores.dtype == torch.float32 def test_token_indices_sorted_ascending(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_bias_topk_routing, ) @@ -502,7 +533,7 @@ class TestSoftmaxBiasTopkRouting: assert (diffs >= 0).all() def test_expert_indices_in_range(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_bias_topk_routing, ) @@ -514,7 +545,7 @@ class TestSoftmaxBiasTopkRouting: assert (expert_idx < E).all() def test_renormalized_scores_sum_to_one(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_bias_topk_routing, ) @@ -527,7 +558,7 @@ class TestSoftmaxBiasTopkRouting: def test_bias_affects_expert_selection(self): """Large positive bias on expert 0 should make it always selected.""" - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_bias_topk_routing, ) @@ -570,7 +601,7 @@ class TestSoftmaxGroupLimitedTopkRouting: """Tests for DeepSeek V2 routing (softmax_group_limited_topk_routing).""" def test_output_shapes_group_limited(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_group_limited_topk_routing, ) @@ -589,7 +620,7 @@ class TestSoftmaxGroupLimitedTopkRouting: assert logits.shape == (T, E) def test_output_shapes_greedy(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_group_limited_topk_routing, ) @@ -604,7 +635,7 @@ class TestSoftmaxGroupLimitedTopkRouting: assert logits.shape == (T, E) def test_scores_are_float32(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_group_limited_topk_routing, ) @@ -615,7 +646,7 @@ class TestSoftmaxGroupLimitedTopkRouting: assert scores.dtype == torch.float32 def test_token_indices_sorted_ascending(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_group_limited_topk_routing, ) @@ -627,7 +658,7 @@ class TestSoftmaxGroupLimitedTopkRouting: assert (diffs >= 0).all() def test_expert_indices_in_range(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_group_limited_topk_routing, ) @@ -639,7 +670,7 @@ class TestSoftmaxGroupLimitedTopkRouting: assert (expert_idx < E).all() def test_scaling_factor_applied(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_group_limited_topk_routing, ) @@ -655,7 +686,7 @@ class TestSoftmaxGroupLimitedTopkRouting: def test_group_selection_restricts_experts(self): """With num_group=4 and topk_group=1, experts should come from selected groups.""" - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_group_limited_topk_routing, ) @@ -674,7 +705,7 @@ class TestSoftmaxGroupLimitedTopkRouting: assert (groups == groups[0]).all() def test_unsupported_topk_method_raises(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_group_limited_topk_routing, ) @@ -706,7 +737,7 @@ class TestSoftmaxTopkWgRouting: """Tests for HunYuan V1 MoE routing (softmax_topk_wg_routing).""" def test_output_shapes(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_topk_wg_routing, ) @@ -723,7 +754,7 @@ class TestSoftmaxTopkWgRouting: assert logits.shape == (T, E) def test_scores_are_float32(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_topk_wg_routing, ) @@ -734,7 +765,7 @@ class TestSoftmaxTopkWgRouting: assert scores.dtype == torch.float32 def test_token_indices_sorted_ascending(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_topk_wg_routing, ) @@ -746,7 +777,7 @@ class TestSoftmaxTopkWgRouting: assert (diffs >= 0).all() def test_expert_indices_in_range(self): - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_topk_wg_routing, ) @@ -759,7 +790,7 @@ class TestSoftmaxTopkWgRouting: def test_renormalized_scores_sum_to_one(self): """HunYuan V1 always renormalizes (no norm_topk_prob flag).""" - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_topk_wg_routing, ) @@ -772,7 +803,7 @@ class TestSoftmaxTopkWgRouting: def test_uses_gate_wg_weight(self): """Verify that modifying gate.wg.weight changes the routing output.""" - from axolotl.integrations.kernels.sonicmoe.routing import ( + from axolotl.integrations.kernels.libs.sonicmoe.routing import ( softmax_topk_wg_routing, ) diff --git a/tests/integrations/test_sonicmoe_gradients.py b/tests/integrations/test_sonicmoe_gradients.py index e76bdd480..cb5ef7663 100644 --- a/tests/integrations/test_sonicmoe_gradients.py +++ b/tests/integrations/test_sonicmoe_gradients.py @@ -7,7 +7,7 @@ code path where routing happens in float32. import torch -from axolotl.integrations.kernels.sonicmoe.routing import ( +from axolotl.integrations.kernels.libs.sonicmoe.routing import ( sigmoid_topk_routing, softmax_topk_routing, ) diff --git a/tests/integrations/test_sonicmoe_lora.py b/tests/integrations/test_sonicmoe_lora.py new file mode 100644 index 000000000..4b25843fe --- /dev/null +++ b/tests/integrations/test_sonicmoe_lora.py @@ -0,0 +1,328 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +"""Unit tests for SonicMoE LoRA support.""" + +from unittest.mock import MagicMock + +import pytest +import torch + +from axolotl.integrations.kernels.libs.sonicmoe.lora import ( + MoELoRAMaterialize, + get_lora_params_from_wrapper, + has_lora, + materialize_expert_lora, + unwrap_experts_lora, + unwrap_gate_lora, +) + +# ============================================================================= +# Helpers: mock PEFT modules +# ============================================================================= + + +def _make_mock_lora_module(weight_A, weight_B, scaling_val, param_name=None): + """Create a mock PEFT-wrapped module with LoRA attributes.""" + mock = MagicMock() + + lora_A_linear = MagicMock() + lora_A_linear.weight = weight_A + + lora_B_linear = MagicMock() + lora_B_linear.weight = weight_B + + mock.lora_A = {"default": lora_A_linear} + mock.lora_B = {"default": lora_B_linear} + mock.scaling = {"default": scaling_val} + mock.active_adapters = ["default"] + + if param_name is not None: + mock.parameter_name = param_name + + return mock + + +def _make_peft_gate(hidden_size, num_experts, rank, scaling=0.5): + """Create a mock PEFT-wrapped gate module.""" + base_gate = MagicMock() + base_gate.weight = torch.randn(num_experts, hidden_size) + base_gate.top_k = 2 + base_gate.norm_topk_prob = True + + lora_A = torch.randn(rank, hidden_size) + lora_B = torch.randn(num_experts, rank) + + wrapper = _make_mock_lora_module(lora_A, lora_B, scaling) + wrapper.base_layer = base_gate + return wrapper, base_gate + + +def _make_peft_experts( + num_experts, gate_up_dim, down_dim, hidden_size, rank, scaling=0.5 +): + """Create a mock PEFT-wrapped experts chain. + + Simulates: ParamWrapper(down_proj) -> ParamWrapper(gate_up_proj) -> Experts + """ + base_experts = MagicMock() + base_experts.gate_up_proj = torch.randn(num_experts, gate_up_dim, hidden_size) + base_experts.down_proj = torch.randn(num_experts, hidden_size, down_dim) + # Remove base_layer and lora_A from base_experts so the chain walk stops + del base_experts.base_layer + del base_experts.lora_A + + # gate_up_proj wrapper + gup_A = torch.randn(rank * num_experts, hidden_size) + gup_B = torch.randn(gate_up_dim, rank * num_experts) + gup_wrapper = _make_mock_lora_module(gup_A, gup_B, scaling, "gate_up_proj") + gup_wrapper.base_layer = base_experts + + # down_proj wrapper (outermost) + down_A = torch.randn(rank * num_experts, down_dim) + down_B = torch.randn(hidden_size, rank * num_experts) + down_wrapper = _make_mock_lora_module(down_A, down_B, scaling, "down_proj") + down_wrapper.base_layer = gup_wrapper + + return down_wrapper, base_experts, (gup_A, gup_B), (down_A, down_B) + + +# ============================================================================= +# Tests: has_lora +# ============================================================================= + + +class TestHasLora: + def test_plain_module(self): + m = MagicMock(spec=["weight"]) + del m.base_layer + del m.lora_A + assert not has_lora(m) + + def test_wrapped_module(self): + m = MagicMock() + m.base_layer = MagicMock() + m.lora_A = {"default": MagicMock()} + assert has_lora(m) + + +# ============================================================================= +# Tests: get_lora_params_from_wrapper +# ============================================================================= + + +class TestGetLoraParams: + def test_no_lora_attrs(self): + m = MagicMock(spec=["weight"]) + del m.lora_A + del m.lora_B + assert get_lora_params_from_wrapper(m) == (None, None, None) + + def test_extracts_params(self): + A = torch.randn(4, 8) + B = torch.randn(16, 4) + wrapper = _make_mock_lora_module(A, B, 0.5) + lora_A, lora_B, scaling = get_lora_params_from_wrapper(wrapper) + assert torch.equal(lora_A, A) + assert torch.equal(lora_B, B) + assert scaling == 0.5 + + def test_no_active_adapters(self): + wrapper = _make_mock_lora_module(torch.randn(4, 8), torch.randn(16, 4), 0.5) + wrapper.active_adapters = [] + assert get_lora_params_from_wrapper(wrapper) == (None, None, None) + + +# ============================================================================= +# Tests: unwrap_gate_lora +# ============================================================================= + + +class TestUnwrapGateLora: + def test_plain_gate(self): + gate = MagicMock(spec=["weight", "top_k"]) + del gate.base_layer + del gate.lora_A + gate.weight = torch.randn(8, 64) + base, weight, delta = unwrap_gate_lora(gate) + assert base is gate + assert torch.equal(weight, gate.weight) + assert delta is None + + def test_wrapped_gate(self): + wrapper, base_gate = _make_peft_gate( + hidden_size=64, num_experts=8, rank=4, scaling=0.5 + ) + base, weight, delta = unwrap_gate_lora(wrapper) + assert base is base_gate + assert torch.equal(weight, base_gate.weight) + assert delta is not None + assert delta.shape == base_gate.weight.shape + + # Verify delta = scaling * B @ A + lora_A = wrapper.lora_A["default"].weight + lora_B = wrapper.lora_B["default"].weight + expected = 0.5 * (lora_B @ lora_A) + assert torch.allclose(delta, expected) + + +# ============================================================================= +# Tests: unwrap_experts_lora +# ============================================================================= + + +class TestUnwrapExpertsLora: + def test_plain_experts(self): + experts = MagicMock(spec=["gate_up_proj", "down_proj"]) + del experts.base_layer + del experts.lora_A + base, lora_dict = unwrap_experts_lora(experts) + assert base is experts + assert lora_dict == {} + + def test_wrapped_experts(self): + E, I2, I, H, r = 4, 256, 128, 64, 8 # noqa: E741 + wrapper, base_experts, (gup_A, gup_B), (down_A, down_B) = _make_peft_experts( + E, I2, I, H, r, scaling=0.25 + ) + base, lora_dict = unwrap_experts_lora(wrapper) + assert base is base_experts + assert "gate_up_proj" in lora_dict + assert "down_proj" in lora_dict + + gup_lA, gup_lB, gup_s = lora_dict["gate_up_proj"] + assert torch.equal(gup_lA, gup_A) + assert torch.equal(gup_lB, gup_B) + assert gup_s == 0.25 + + down_lA, down_lB, down_s = lora_dict["down_proj"] + assert torch.equal(down_lA, down_A) + assert torch.equal(down_lB, down_B) + assert down_s == 0.25 + + def test_partial_lora(self): + """Only gate_up_proj has LoRA, down_proj does not.""" + base_experts = MagicMock(spec=["gate_up_proj", "down_proj"]) + del base_experts.base_layer + del base_experts.lora_A + + gup_A = torch.randn(16, 64) + gup_B = torch.randn(256, 16) + gup_wrapper = _make_mock_lora_module(gup_A, gup_B, 0.5, "gate_up_proj") + gup_wrapper.base_layer = base_experts + + base, lora_dict = unwrap_experts_lora(gup_wrapper) + assert base is base_experts + assert "gate_up_proj" in lora_dict + assert "down_proj" not in lora_dict + + +# ============================================================================= +# Tests: MoELoRAMaterialize +# ============================================================================= + + +class TestMoELoRAMaterialize: + @pytest.fixture() + def setup(self): + E, dim1, dim2, r = 4, 32, 16, 4 + scaling = 0.5 + W = torch.randn(E, dim1, dim2, dtype=torch.float64, requires_grad=False) + A = torch.randn(r * E, dim2, dtype=torch.float64, requires_grad=True) + B = torch.randn(dim1, r * E, dtype=torch.float64, requires_grad=True) + return W, A, B, scaling, E, r + + def test_forward_shape(self, setup): + W, A, B, scaling, E, r = setup + W_eff = MoELoRAMaterialize.apply(W, A, B, scaling) + assert W_eff.shape == W.shape + + def test_forward_correctness(self, setup): + W, A, B, scaling, E, r = setup + W_eff = MoELoRAMaterialize.apply(W, A, B, scaling) + + # Manual per-expert computation. + # lora_A is expert-major: [r*E, dim2] -> rows [e*r:(e+1)*r] = expert e + # lora_B is rank-major: [dim1, r*E] -> reshape [dim1, r, E], slice [:, :, e] + _, dim1, dim2 = W.shape + expected = W.clone() + B_3d = B.reshape(dim1, r, E) + for e in range(E): + A_e = A[e * r : (e + 1) * r, :] # [r, dim2] + B_e = B_3d[:, :, e] # [dim1, r] + expected[e] += scaling * (B_e @ A_e) + + assert torch.allclose(W_eff, expected, atol=1e-10) + + def test_backward_gradcheck(self, setup): + W, A, B, scaling, E, r = setup + # gradcheck requires float64 + assert torch.autograd.gradcheck( + lambda a, b: MoELoRAMaterialize.apply(W, a, b, scaling), + (A, B), + eps=1e-6, + atol=1e-4, + ) + + def test_no_grad_for_base_weight(self, setup): + W, A, B, scaling, E, r = setup + W.requires_grad_(True) + W_eff = MoELoRAMaterialize.apply(W, A, B, scaling) + loss = W_eff.sum() + loss.backward() + assert W.grad is None + assert A.grad is not None + assert B.grad is not None + + def test_scaling_zero(self, setup): + W, A, B, _, E, r = setup + W_eff = MoELoRAMaterialize.apply(W, A, B, 0.0) + assert torch.allclose(W_eff, W) + + def test_gate_up_proj_shapes(self): + """Test with realistic gate_up_proj shapes [E, 2*I, H].""" + E, I2, H, r = 8, 512, 256, 16 + W = torch.randn(E, I2, H, dtype=torch.float64) + A = torch.randn(r * E, H, dtype=torch.float64, requires_grad=True) + B = torch.randn(I2, r * E, dtype=torch.float64, requires_grad=True) + W_eff = MoELoRAMaterialize.apply(W, A, B, 1.0) + assert W_eff.shape == (E, I2, H) + loss = W_eff.sum() + loss.backward() + assert A.grad.shape == A.shape + assert B.grad.shape == B.shape + + def test_down_proj_shapes(self): + """Test with realistic down_proj shapes [E, H, I].""" + E, H, I, r = 8, 256, 512, 16 # noqa: E741 + W = torch.randn(E, H, I, dtype=torch.float64) + A = torch.randn(r * E, I, dtype=torch.float64, requires_grad=True) + B = torch.randn(H, r * E, dtype=torch.float64, requires_grad=True) + W_eff = MoELoRAMaterialize.apply(W, A, B, 1.0) + assert W_eff.shape == (E, H, I) + loss = W_eff.sum() + loss.backward() + assert A.grad.shape == A.shape + assert B.grad.shape == B.shape + + +# ============================================================================= +# Tests: materialize_expert_lora +# ============================================================================= + + +class TestMaterializeExpertLora: + def test_none_passthrough(self): + W = torch.randn(4, 32, 16) + result = materialize_expert_lora(W, None) + assert result is W + + def test_with_lora(self): + E, dim1, dim2, r = 4, 32, 16, 4 + W = torch.randn(E, dim1, dim2) + A = torch.randn(r * E, dim2, requires_grad=True) + B = torch.randn(dim1, r * E, requires_grad=True) + result = materialize_expert_lora(W, (A, B, 0.5)) + assert result.shape == W.shape + assert not torch.equal(result, W)