From 8f3fb517b3abc0d42523f7c7e7c752bf84a07676 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 16 Mar 2026 23:47:40 -0400 Subject: [PATCH] consolidate behavioud of routing in scattermoe kernels (#3475) * consolidate behavioud of routing in scattermoe kernels * collect telemetry on best chosen autotuned kernel * properly collect data * Fix property name and get smem too * handle issues raised by coderabbit * add tests for parity before refactoring --- .../integrations/kernels/autotune_callback.py | 120 +++++ .../kernels/autotune_collector.py | 114 +++++ .../kernels/libs/scattermoe_lora/layers.py | 199 ++++++-- src/axolotl/integrations/kernels/plugin.py | 10 + .../test_scattermoe_lora_kernels.py | 347 +++++++++++++ tests/integrations/test_routing_parity.py | 474 ++++++++++++++++++ .../test_scattermoe_autotune_telemetry.py | 367 ++++++++++++++ tests/integrations/test_scattermoe_lora.py | 392 ++++++++++++++- 8 files changed, 1988 insertions(+), 35 deletions(-) create mode 100644 src/axolotl/integrations/kernels/autotune_callback.py create mode 100644 src/axolotl/integrations/kernels/autotune_collector.py create mode 100644 tests/integrations/test_routing_parity.py create mode 100644 tests/integrations/test_scattermoe_autotune_telemetry.py diff --git a/src/axolotl/integrations/kernels/autotune_callback.py b/src/axolotl/integrations/kernels/autotune_callback.py new file mode 100644 index 000000000..aa4cbbab1 --- /dev/null +++ b/src/axolotl/integrations/kernels/autotune_callback.py @@ -0,0 +1,120 @@ +"""Trainer callback for reporting Triton autotune results from scattermoe-lora kernels.""" + +import logging + +import torch +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +LOG = logging.getLogger(__name__) + +# Give up looking for autotune data after this many training steps. +_MAX_POLL_STEP = 5 + + +def _get_gpu_info() -> dict: + """Return basic GPU identification for the current device.""" + if not torch.cuda.is_available(): + return {} + try: + idx = torch.cuda.current_device() + props = torch.cuda.get_device_properties(idx) + return { + "gpu_name": props.name, + "gpu_compute_capability": f"{props.major}.{props.minor}", + "gpu_memory_bytes": props.total_memory, + } + except Exception: # pylint: disable=broad-exception-caught + return {} + + +def _get_smem_capacity() -> dict: + """Return shared memory capacity from the runtime lora_ops module.""" + try: + from axolotl.integrations.kernels.autotune_collector import ( + _find_lora_ops_module, + ) + + lora_ops = _find_lora_ops_module() + if lora_ops is None: + return {} + fn = getattr(lora_ops, "_get_smem_capacity", None) + if fn is None: + return {} + return {"smem_capacity_bytes": fn()} + except Exception: # pylint: disable=broad-exception-caught + return {} + + +class AutotuneReportCallback(TrainerCallback): + """Reports Triton kernel autotune selections via telemetry. + + Fires **once** after the first training step completes (step 1), at + which point the forward and backward passes have both run and the + autotuned kernels have populated their caches. If for some reason + the caches are still empty (e.g. the kernel was never invoked), the + callback retries on subsequent steps up to ``_MAX_POLL_STEP`` and + then stops polling. + + After reporting (or giving up) every subsequent ``on_step_end`` + call short-circuits on the ``_reported`` flag — zero hot-path cost. + """ + + def __init__(self): + self._reported = False + + # pylint: disable=unused-argument + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if self._reported: + return + + # Lazy import — Triton / scattermoe kernels may not be installed. + from axolotl.integrations.kernels.autotune_collector import ( + collect_autotune_configs, + ) + + configs = collect_autotune_configs() + + if not configs: + if state.global_step >= _MAX_POLL_STEP: + LOG.debug( + "No autotune data found after %d steps; giving up.", + state.global_step, + ) + self._reported = True + return + + self._reported = True + + from axolotl.telemetry.manager import TelemetryManager + + telemetry_manager = TelemetryManager.get_instance() + if not telemetry_manager.enabled: + return + + properties = { + "kernel_count": len(configs), + "kernels": configs, + } + properties.update(_get_gpu_info()) + properties.update(_get_smem_capacity()) + + telemetry_manager.send_event( + event_type="scattermoe-autotune", + properties=properties, + ) + + LOG.info( + "Reported %d scattermoe kernel autotune config(s) to telemetry.", + len(configs), + ) diff --git a/src/axolotl/integrations/kernels/autotune_collector.py b/src/axolotl/integrations/kernels/autotune_collector.py new file mode 100644 index 000000000..ef4111dcf --- /dev/null +++ b/src/axolotl/integrations/kernels/autotune_collector.py @@ -0,0 +1,114 @@ +"""Collect Triton autotune results from scattermoe-lora kernels. + +This module reads the ``.cache`` attribute from Triton ``@triton.autotune`` +decorated kernel objects and returns structured dicts describing the selected +configurations. It has **no** telemetry dependency — callers decide what to +do with the data. +""" + +import logging +import sys +from types import ModuleType +from typing import Any + +LOG = logging.getLogger(__name__) + +# (human-readable name, attribute on the lora_ops module) +_KERNEL_REGISTRY: list[tuple[str, str]] = [ + ("scatter2scatter_lora_fwd", "_scatter2scatter_lora"), + ("scatter2scatter_lora_dX", "_scatter2scatter_lora_dX"), + ("group_bwd_lora", "_group_bwd_lora"), + ("group_bwd_lora_fused", "_group_bwd_lora_fused"), +] + +# The autotune key declared on every kernel: key=["M", "N", "K"] +_KEY_NAMES: list[str] = ["M", "N", "K"] + + +def _parse_key_tuple(key_tuple: tuple) -> dict[str, Any]: + """Turn the autotune cache key tuple into a labelled dict. + + Triton builds the cache key from the values of the declared ``key`` + args (``M``, ``N``, ``K``) followed by dtype signature elements. + We label the first three and store the rest under ``_extra``. + """ + result: dict[str, Any] = {} + for i, name in enumerate(_KEY_NAMES): + if i < len(key_tuple): + result[name] = key_tuple[i] + if len(key_tuple) > len(_KEY_NAMES): + result["_extra"] = [str(v) for v in key_tuple[len(_KEY_NAMES) :]] + return result + + +def _find_lora_ops_module() -> ModuleType | None: + """Locate the *runtime* ``lora_ops`` module in ``sys.modules``. + + The HF ``kernels`` package loads ``scattermoe_lora`` via + ``import_from_path`` which registers it in ``sys.modules`` under a + hash-suffixed name (e.g. ``scattermoe_lora_a1b2c3d4``). A normal + import (``from axolotl.integrations.kernels...``) would create a + *separate* module instance whose kernel objects have empty + ``.cache`` dicts because autotuning ran on the runtime copy. + + We search ``sys.modules`` for any module whose name contains + ``lora_ops`` and that has the ``_scatter2scatter_lora`` kernel + attribute — that is the runtime copy with populated caches. + """ + for name, module in sys.modules.items(): + if ( + module is not None + and "lora_ops" in name + and hasattr(module, "_scatter2scatter_lora") + ): + return module + return None + + +def collect_autotune_configs() -> list[dict[str, Any]]: + """Read autotune caches from the four scattermoe-lora kernels. + + Returns a (possibly empty) list of dicts, each containing: + + * ``kernel`` – human-readable kernel name + * ``key`` – dict with the ``M``/``N``/``K`` problem dimensions + * ``config`` – dict with the selected tile sizes, ``num_warps``, + and ``num_stages`` + + Returns ``[]`` if the kernel module cannot be found or if no + autotune cache entries exist yet. + """ + lora_ops = _find_lora_ops_module() + if lora_ops is None: + LOG.debug( + "lora_ops module not found in sys.modules; skipping autotune collection" + ) + return [] + + results: list[dict[str, Any]] = [] + + for friendly_name, attr_name in _KERNEL_REGISTRY: + kernel_fn = getattr(lora_ops, attr_name, None) + if kernel_fn is None: + continue + + cache = getattr(kernel_fn, "cache", None) + if not cache: + continue + + for key_tuple, config in cache.items(): + config_dict = dict(config.kwargs) + config_dict["num_warps"] = config.num_warps + config_dict["num_stages"] = config.num_stages + if getattr(config, "num_ctas", None) is not None: + config_dict["num_ctas"] = config.num_ctas + + results.append( + { + "kernel": friendly_name, + "key": _parse_key_tuple(key_tuple), + "config": config_dict, + } + ) + + return results diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index a42577483..5125e8801 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -220,6 +220,158 @@ def _unwrap_experts_lora(experts_module): return base_experts, gup_lora, down_lora +# ============================================================================= +# Routing helpers +# ============================================================================= + + +def _softmax_topk_route( + moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta +): + """Softmax→topk routing (Qwen, OLMoE, Mixtral, MiniMax). + + Returns: + (routing_weights [T, K], selected_experts [T, K], top_k, num_experts) + """ + router_logits = F.linear(hidden_states, gate_weight) + if gate_lora_delta is not None: + router_logits = router_logits + F.linear(hidden_states, gate_lora_delta) + routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) + + top_k = base_gate.top_k + num_experts = base_gate.num_experts + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + if getattr(base_gate, "norm_topk_prob", True): + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + + return routing_weights, selected_experts, top_k, num_experts + + +def _sigmoid_topk_route( + moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta +): + """Sigmoid→topk routing (GLM, DeepSeek V3, MiniMax M2). + + Supports: + - ``e_score_correction_bias`` on gate or moe_block + - Group-based expert selection when ``n_group > 1`` + - ``routed_scaling_factor`` applied to final weights + - Final weights gathered from original sigmoid probs (not bias-corrected) + + Returns: + (routing_weights [T, K], selected_experts [T, K], top_k, num_experts) + """ + router_logits = F.linear(hidden_states.float(), gate_weight.float()) + if gate_lora_delta is not None: + router_logits = router_logits + F.linear( + hidden_states.float(), gate_lora_delta.float() + ) + router_probs = router_logits.sigmoid() # [T, E] + + top_k = getattr(moe_block, "top_k", getattr(base_gate, "top_k", None)) + num_experts = getattr(moe_block, "n_routed_experts", gate_weight.shape[0]) + + # Bias-corrected scores for expert selection (not used for final weights). + # glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 on the block. + e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None) + if e_score_correction_bias is None: + e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None) + if e_score_correction_bias is not None: + scores_for_choice = router_probs + e_score_correction_bias + else: + scores_for_choice = router_probs + + # Group-based selection: pick top groups, mask the rest + n_group = getattr(moe_block, "n_group", 1) + if n_group > 1: + group_scores = ( + scores_for_choice.view(-1, n_group, num_experts // n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [T, n_group] + topk_group = getattr(moe_block, "topk_group", n_group) + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, n_group, num_experts // n_group) + .reshape(-1, num_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + + # Final topk from (possibly masked) scores + topk_indices = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1] + + # Gather weights from original sigmoid scores (not bias-corrected) + topk_weights = router_probs.gather(1, topk_indices) + + # Optional renormalization + scaling + if getattr(moe_block, "norm_topk_prob", True): + topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20) + routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0) + topk_weights = topk_weights * routed_scaling_factor + + return topk_weights, topk_indices, top_k, num_experts + + +def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta): + """Dispatch to the correct routing strategy based on block attributes. + + Detects sigmoid routing by the presence of ``e_score_correction_bias`` + on either the gate or the moe_block. + """ + has_sigmoid = ( + getattr(base_gate, "e_score_correction_bias", None) is not None + or getattr(moe_block, "e_score_correction_bias", None) is not None + ) + if has_sigmoid: + return _sigmoid_topk_route( + moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta + ) + return _softmax_topk_route( + moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta + ) + + +# ============================================================================= +# Shared expert helpers +# ============================================================================= + + +def _compute_shared_expert(moe_block, hidden_states_flat): + """Compute shared expert output if the block has one. + + Handles singular (qwen2_moe: ``shared_expert``), plural + (glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP + (hunyuan_v1_moe: ``shared_mlp``) attribute names. + + peft wraps individual linear layers inside the shared expert with + standard LoRA — calling forward() handles this transparently. + """ + shared_expert = ( + getattr(moe_block, "shared_expert", None) + or getattr(moe_block, "shared_experts", None) + or getattr(moe_block, "shared_mlp", None) + ) + if shared_expert is None: + return None + + shared_expert_output = shared_expert(hidden_states_flat) + + # Optional sigmoid gate (Qwen2MoE pattern). + # shared_expert_gate may also be peft-wrapped (standard LoRA + # on nn.Linear), its forward() applies LoRA automatically. + shared_expert_gate = getattr(moe_block, "shared_expert_gate", None) + if shared_expert_gate is not None: + shared_expert_output = ( + F.sigmoid(shared_expert_gate(hidden_states_flat)) * shared_expert_output + ) + + return shared_expert_output + + # ============================================================================= # Layer classes # ============================================================================= @@ -281,16 +433,18 @@ class ScatterMoEGatedMLP(nn.Module): class HFScatterMoEGatedMLP(nn.Module): """ - ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE). + ScatterMoE-accelerated forward pass for HF MoEs. Used as a kernel layer via the HF ``kernels`` library. The ``forward`` - method replaces the original ``OlmoeSparseMoeBlock.forward``. + method replaces the original SparseMoeBlock.forward. - Supports both full-parameter training and LoRA fine-tuning: + Supports: - * **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel) - * **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts - adapter weights, and uses ``parallel_linear_lora`` (fused kernel) + * **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax + * **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2 + * **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE) + * **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``, + extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel) """ @staticmethod @@ -302,7 +456,7 @@ class HFScatterMoEGatedMLP(nn.Module): self: The MoeSparseMoeBlock module containing: - self.gate: Router (or peft ParamWrapper wrapping it) - self.experts: Experts module (or peft ParamWrapper chain) - - self.shared_expert: Optional shared expert (e.g. Qwen2MoE) + - self.shared_expert(s): Optional shared expert - self.shared_expert_gate: Optional shared expert gate layer_input: Input tensor [batch_size, seq_len, hidden_size] @@ -313,38 +467,17 @@ class HFScatterMoEGatedMLP(nn.Module): hidden_states_flat = layer_input.view(-1, hidden_dim) # ==================================================================== - # Shared Expert (if present, e.g. Qwen2MoE) + # Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3) # ==================================================================== - # peft wraps individual linear layers inside shared_expert with - # standard LoRA — calling forward() handles this transparently. - if hasattr(self, "shared_expert") and self.shared_expert is not None: - shared_expert_output = self.shared_expert(hidden_states_flat) - # shared_expert_gate may also be peft-wrapped (standard LoRA - # on nn.Linear), its forward() applies LoRA automatically. - shared_expert_gate_output = F.sigmoid( - self.shared_expert_gate(hidden_states_flat) - ) - shared_expert_output = shared_expert_output * shared_expert_gate_output - else: - shared_expert_output = None + shared_expert_output = _compute_shared_expert(self, hidden_states_flat) # ==================================================================== # Router Computation (with optional gate LoRA) # ==================================================================== base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate) - router_logits = F.linear(hidden_states_flat, gate_weight) - if gate_lora_delta is not None: - router_logits = router_logits + F.linear( - hidden_states_flat, gate_lora_delta - ) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - - top_k = base_gate.top_k - num_experts = base_gate.num_experts - routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - - if base_gate.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights, selected_experts, top_k, num_experts = _route( + self, base_gate, hidden_states_flat, gate_weight, gate_lora_delta + ) routing_weights = routing_weights.to(hidden_states_flat.dtype) sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count( diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index f085e481c..351db5ef2 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -110,6 +110,16 @@ class KernelsPlugin(BasePlugin): } ) + def add_callbacks_pre_trainer(self, cfg, model): + callbacks = [] + if cfg.use_scattermoe: + from axolotl.integrations.kernels.autotune_callback import ( + AutotuneReportCallback, + ) + + callbacks.append(AutotuneReportCallback()) + return callbacks + def _kernelize_model(self, model_type: str): from kernels import replace_kernel_forward_from_hub diff --git a/tests/e2e/integrations/test_scattermoe_lora_kernels.py b/tests/e2e/integrations/test_scattermoe_lora_kernels.py index d11272c8f..6f7f65b80 100644 --- a/tests/e2e/integrations/test_scattermoe_lora_kernels.py +++ b/tests/e2e/integrations/test_scattermoe_lora_kernels.py @@ -12,6 +12,7 @@ Tests verify correctness of: 3. Frozen weights: expert weight gradients are correctly skipped 4. Various configurations: top-k, grouped_in/out, with/without bias 5. Numerical stability: bf16/fp16 outputs within tolerance of fp32 reference +6. HFScatterMoEGatedMLP with sigmoid routing (GLM/DeepSeek/MiniMax M2) Test strategy: - Reference implementation uses pure PyTorch ops (no Triton) @@ -19,6 +20,8 @@ Test strategy: - Tolerances account for tf32 accumulation in Triton kernels """ +from types import SimpleNamespace + import pytest import torch @@ -1476,3 +1479,347 @@ class TestCombinedOptimizations: torch.testing.assert_close(X1.grad, X2.grad, atol=5e-2, rtol=5e-2) torch.testing.assert_close(A1.grad, A2.grad, atol=5e-2, rtol=5e-2) torch.testing.assert_close(B1.grad, B2.grad, atol=5e-2, rtol=5e-2) + + +# ============================================================================= +# Test: HFScatterMoEGatedMLP with Sigmoid Routing +# ============================================================================= + + +def _reference_moe_forward( + hidden_states, + gate_weight, + gate_up_proj, + down_proj, + act_fn, + routing_weights, + selected_experts, + num_experts, +): + """Pure PyTorch reference for a full MoE forward pass. + + Args: + hidden_states: [T, H] + gate_weight: [E, H] + gate_up_proj: [E, 2*FF, H] + down_proj: [E, H, FF] + act_fn: activation function (e.g. torch.nn.SiLU()) + routing_weights: [T, K] routing weights + selected_experts: [T, K] expert indices + num_experts: int + + Returns: + output: [T, H] + """ + T, H = hidden_states.shape + K = selected_experts.shape[1] + output = torch.zeros(T, H, device=hidden_states.device, dtype=hidden_states.dtype) + + for t in range(T): + for j in range(K): + e = selected_experts[t, j].item() + w = routing_weights[t, j].item() + + # gate_up projection + gup = hidden_states[t] @ gate_up_proj[e].T # [2*I] + I_dim = gup.shape[0] // 2 + gates = gup[:I_dim] + up = gup[I_dim:] + + # activation + h = act_fn(gates) * up + + # down projection + out = h @ down_proj[e].T # [H] + + output[t] += w * out + + return output + + +def _make_mock_sigmoid_moe_block( + T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True +): + """Create a mock MoE block with sigmoid routing for GPU testing.""" + gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02 + down_proj = torch.randn(E, H, FF, device="cuda") * 0.02 + act_fn = torch.nn.SiLU() + + experts = SimpleNamespace( + gate_up_proj=gate_up_proj, + down_proj=down_proj, + act_fn=act_fn, + num_experts=E, + ) + + if bias_on_gate: + gate = SimpleNamespace( + weight=torch.randn(E, H, device="cuda") * 0.1, + e_score_correction_bias=torch.zeros(E, device="cuda"), + ) + moe_block = SimpleNamespace( + gate=gate, + experts=experts, + top_k=K, + n_routed_experts=E, + n_group=n_group, + topk_group=topk_group, + norm_topk_prob=True, + routed_scaling_factor=1.0, + ) + else: + # minimax_m2 style + gate = SimpleNamespace( + weight=torch.randn(E, H, device="cuda") * 0.1, + top_k=K, + ) + moe_block = SimpleNamespace( + gate=gate, + experts=experts, + top_k=K, + e_score_correction_bias=torch.zeros(E, device="cuda"), + ) + + return moe_block, T, H, FF, E, K + + +class TestHFScatterMoESigmoidRouting: + """Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU.""" + + def test_forward_matches_reference_bias_on_gate(self): + """Forward pass with sigmoid routing (bias on gate) matches reference.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + HFScatterMoEGatedMLP, + _sigmoid_topk_route, + ) + + moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block( + T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True + ) + + hidden = torch.randn(1, T, H, device="cuda") + + # Get routing for reference + gate = moe_block.gate + hidden_flat = hidden.view(-1, H) + routing_weights, selected_experts, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden_flat, gate.weight, None + ) + + # Reference output + ref_output = _reference_moe_forward( + hidden_flat, + gate.weight, + moe_block.experts.gate_up_proj, + moe_block.experts.down_proj, + moe_block.experts.act_fn, + routing_weights, + selected_experts, + E, + ) + + # Kernel output + kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden) + kernel_output_flat = kernel_output.view(-1, H) + + torch.testing.assert_close( + kernel_output_flat.float(), + ref_output.float(), + atol=5e-2, + rtol=5e-2, + ) + + def test_forward_matches_reference_bias_on_block(self): + """Forward pass with sigmoid routing (minimax_m2 style, bias on block).""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + HFScatterMoEGatedMLP, + _sigmoid_topk_route, + ) + + moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block( + T=16, H=64, FF=32, E=8, K=2, n_group=1, bias_on_gate=False + ) + + hidden = torch.randn(1, T, H, device="cuda") + hidden_flat = hidden.view(-1, H) + + gate = moe_block.gate + routing_weights, selected_experts, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden_flat, gate.weight, None + ) + + ref_output = _reference_moe_forward( + hidden_flat, + gate.weight, + moe_block.experts.gate_up_proj, + moe_block.experts.down_proj, + moe_block.experts.act_fn, + routing_weights, + selected_experts, + E, + ) + + kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden) + kernel_output_flat = kernel_output.view(-1, H) + + torch.testing.assert_close( + kernel_output_flat.float(), + ref_output.float(), + atol=5e-2, + rtol=5e-2, + ) + + def test_softmax_routing_still_works(self): + """Verify softmax routing (Qwen/OLMoE) is not broken.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + HFScatterMoEGatedMLP, + _softmax_topk_route, + ) + + T, H, FF, E, K = 16, 64, 32, 4, 2 + gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02 + down_proj = torch.randn(E, H, FF, device="cuda") * 0.02 + act_fn = torch.nn.SiLU() + + experts = SimpleNamespace( + gate_up_proj=gate_up_proj, + down_proj=down_proj, + act_fn=act_fn, + num_experts=E, + ) + gate = SimpleNamespace( + weight=torch.randn(E, H, device="cuda") * 0.1, + top_k=K, + num_experts=E, + norm_topk_prob=True, + ) + moe_block = SimpleNamespace(gate=gate, experts=experts) + + hidden = torch.randn(1, T, H, device="cuda") + hidden_flat = hidden.view(-1, H) + + routing_weights, selected_experts, _, _ = _softmax_topk_route( + moe_block, gate, hidden_flat, gate.weight, None + ) + + ref_output = _reference_moe_forward( + hidden_flat, + gate.weight, + gate_up_proj, + down_proj, + act_fn, + routing_weights, + selected_experts, + E, + ) + + kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden) + kernel_output_flat = kernel_output.view(-1, H) + + torch.testing.assert_close( + kernel_output_flat.float(), + ref_output.float(), + atol=5e-2, + rtol=5e-2, + ) + + +class TestHFScatterMoESigmoidWithSharedExperts: + """Test HFScatterMoEGatedMLP with sigmoid routing + shared experts.""" + + def test_shared_experts_plural(self): + """DeepSeek V3 style: shared_experts attribute (plural).""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + HFScatterMoEGatedMLP, + ) + + T, H, FF, E, K = 8, 64, 32, 8, 2 + gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02 + down_proj = torch.randn(E, H, FF, device="cuda") * 0.02 + act_fn = torch.nn.SiLU() + + experts = SimpleNamespace( + gate_up_proj=gate_up_proj, + down_proj=down_proj, + act_fn=act_fn, + num_experts=E, + ) + + # Shared expert as a simple linear for testing + shared_W = torch.randn(H, H, device="cuda") * 0.01 + shared_experts_fn = lambda x: x @ shared_W.T # noqa: E731 + + gate = SimpleNamespace( + weight=torch.randn(E, H, device="cuda") * 0.1, + e_score_correction_bias=torch.zeros(E, device="cuda"), + ) + moe_block = SimpleNamespace( + gate=gate, + experts=experts, + shared_experts=shared_experts_fn, + top_k=K, + n_routed_experts=E, + n_group=1, + norm_topk_prob=True, + routed_scaling_factor=1.0, + ) + + hidden = torch.randn(1, T, H, device="cuda") + + # Should not raise; output should include shared expert contribution + output = HFScatterMoEGatedMLP.forward(moe_block, hidden) + assert output.shape == (1, T, H) + + # Run without shared expert to verify it changes the output + moe_block_no_shared = SimpleNamespace( + gate=gate, + experts=experts, + top_k=K, + n_routed_experts=E, + n_group=1, + norm_topk_prob=True, + routed_scaling_factor=1.0, + ) + output_no_shared = HFScatterMoEGatedMLP.forward(moe_block_no_shared, hidden) + assert not torch.equal(output, output_no_shared) + + def test_shared_expert_with_gate(self): + """Qwen2MoE style: shared_expert + shared_expert_gate.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + HFScatterMoEGatedMLP, + ) + + T, H, FF, E, K = 8, 64, 32, 4, 2 + gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02 + down_proj = torch.randn(E, H, FF, device="cuda") * 0.02 + act_fn = torch.nn.SiLU() + + experts = SimpleNamespace( + gate_up_proj=gate_up_proj, + down_proj=down_proj, + act_fn=act_fn, + num_experts=E, + ) + + shared_W = torch.randn(H, H, device="cuda") * 0.01 + shared_expert_fn = lambda x: x @ shared_W.T # noqa: E731 + # Gate that returns 0 -> sigmoid(0) = 0.5 + gate_W = torch.zeros(H, H, device="cuda") + shared_expert_gate_fn = lambda x: x @ gate_W.T # noqa: E731 + + gate = SimpleNamespace( + weight=torch.randn(E, H, device="cuda") * 0.1, + top_k=K, + num_experts=E, + norm_topk_prob=True, + ) + moe_block = SimpleNamespace( + gate=gate, + experts=experts, + shared_expert=shared_expert_fn, + shared_expert_gate=shared_expert_gate_fn, + ) + + hidden = torch.randn(1, T, H, device="cuda") + output = HFScatterMoEGatedMLP.forward(moe_block, hidden) + assert output.shape == (1, T, H) diff --git a/tests/integrations/test_routing_parity.py b/tests/integrations/test_routing_parity.py new file mode 100644 index 000000000..cc668671c --- /dev/null +++ b/tests/integrations/test_routing_parity.py @@ -0,0 +1,474 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +Parity tests between scattermoe-lora and sonicmoe routing implementations. + +These tests verify that both implementations produce numerically identical +results for the same inputs, ensuring safe centralization of the routing code. + +ScatterMoE returns 2D tensors [T, K]; SonicMoE returns flattened 1D [T*K]. +The core algorithm should be identical — only the output format differs. +""" + +from types import SimpleNamespace + +import pytest +import torch + + +def _require_triton(): + pytest.importorskip("triton") + + +# ============================================================================ +# Fixtures / helpers +# ============================================================================ + + +def _make_softmax_block(T=8, H=16, E=4, K=2): + """Qwen/OLMoE-style block usable by both implementations.""" + gate = SimpleNamespace( + weight=torch.randn(E, H), + top_k=K, + num_experts=E, + norm_topk_prob=True, + ) + moe_block = SimpleNamespace(gate=gate) + hidden = torch.randn(T, H) + return moe_block, gate, hidden, T, H, E, K + + +def _make_sigmoid_block( + T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True +): + """GLM/DeepSeek-style block usable by both implementations.""" + if bias_on_gate: + gate = SimpleNamespace( + weight=torch.randn(E, H), + e_score_correction_bias=torch.zeros(E), + ) + moe_block = SimpleNamespace( + gate=gate, + top_k=K, + n_routed_experts=E, + n_group=n_group, + topk_group=topk_group, + norm_topk_prob=True, + routed_scaling_factor=1.0, + ) + else: + # minimax_m2 style: bias on block + gate = SimpleNamespace( + weight=torch.randn(E, H), + top_k=K, + ) + moe_block = SimpleNamespace( + gate=gate, + top_k=K, + e_score_correction_bias=torch.zeros(E), + ) + return moe_block, gate, hidden_states(T, H), T, H, E, K + + +def hidden_states(T, H): + return torch.randn(T, H) + + +# ============================================================================ +# 1. Softmax routing parity +# ============================================================================ + + +class TestSoftmaxRoutingParity: + """Verify scattermoe and sonicmoe softmax routing produce identical results.""" + + @pytest.fixture(autouse=True) + def _require(self): + _require_triton() + + def test_weights_match(self): + """2D weights from scattermoe == reshaped 1D weights from sonicmoe.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _softmax_topk_route, + ) + from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing + + moe_block, gate, hidden, T, H, E, K = _make_softmax_block() + + # ScatterMoE path (no LoRA delta) + sm_weights, sm_experts, sm_topk, sm_E = _softmax_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + + # SonicMoE path + sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = softmax_topk_routing( + hidden, moe_block + ) + + # ScatterMoE returns [T, K], SonicMoE returns [T*K] flattened + sonic_weights_2d = sonic_scores.reshape(T, K) + sonic_experts_2d = sonic_exp_idx.reshape(T, K) + + assert sm_topk == K + assert sm_E == E + + # Both should select the same experts and produce the same weights + assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)) + assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6) + + def test_logits_not_returned_by_scattermoe(self): + """ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape.""" + from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing + + moe_block, gate, hidden, T, H, E, K = _make_softmax_block() + _, _, _, logits = softmax_topk_routing(hidden, moe_block) + assert logits.shape == (T, E) + + def test_no_renorm(self): + """With norm_topk_prob=False, both should skip renormalization.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _softmax_topk_route, + ) + from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing + + moe_block, gate, hidden, T, H, E, K = _make_softmax_block() + gate.norm_topk_prob = False + + sm_weights, sm_experts, _, _ = _softmax_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block) + + sonic_weights_2d = sonic_scores.reshape(T, K) + sonic_experts_2d = sonic_exp_idx.reshape(T, K) + + assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)) + assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6) + + def test_various_expert_counts(self): + """Parity across different E and K values.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _softmax_topk_route, + ) + from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing + + for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]: + moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K) + + sm_weights, sm_experts, _, _ = _softmax_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block) + + sonic_weights_2d = sonic_scores.reshape(T, K) + sonic_experts_2d = sonic_exp_idx.reshape(T, K) + + assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)), ( + f"Expert mismatch for E={E}, K={K}" + ) + assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6), ( + f"Weight mismatch for E={E}, K={K}" + ) + + +# ============================================================================ +# 2. Sigmoid routing parity +# ============================================================================ + + +class TestSigmoidRoutingParity: + """Verify scattermoe and sonicmoe sigmoid routing produce identical results.""" + + @pytest.fixture(autouse=True) + def _require(self): + _require_triton() + + def test_weights_match_with_groups(self): + """Both implementations should produce identical weights with group selection.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing + + moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( + E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True + ) + + sm_weights, sm_experts, sm_topk, sm_E = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + + sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = sigmoid_topk_routing( + hidden, moe_block + ) + + sonic_weights_2d = sonic_scores.reshape(T, K) + sonic_experts_2d = sonic_exp_idx.reshape(T, K) + + assert sm_topk == K + assert sm_E == E + + # Sort experts within each token to handle different topk orderings + sm_sorted, sm_order = sm_experts.sort(dim=-1) + sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1) + + assert torch.equal(sm_sorted, sonic_sorted) + + # Gather weights in sorted order for comparison + sm_weights_sorted = sm_weights.gather(1, sm_order) + sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order) + assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6) + + def test_weights_match_no_groups(self): + """Both implementations match without group selection (n_group=1).""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing + + moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( + E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True + ) + + sm_weights, sm_experts, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block) + + sonic_weights_2d = sonic_scores.reshape(T, K) + sonic_experts_2d = sonic_exp_idx.reshape(T, K) + + # Sort for comparison (topk with sorted=False may differ in order) + sm_sorted, sm_order = sm_experts.sort(dim=-1) + sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1) + + assert torch.equal(sm_sorted, sonic_sorted) + sm_weights_sorted = sm_weights.gather(1, sm_order) + sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order) + assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6) + + def test_bias_on_block_parity(self): + """minimax_m2 style: bias on block, not gate.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing + + moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( + E=16, K=4, n_group=1, bias_on_gate=False + ) + + sm_weights, sm_experts, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block) + + sonic_weights_2d = sonic_scores.reshape(T, K) + sonic_experts_2d = sonic_exp_idx.reshape(T, K) + + sm_sorted, sm_order = sm_experts.sort(dim=-1) + sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1) + + assert torch.equal(sm_sorted, sonic_sorted) + sm_weights_sorted = sm_weights.gather(1, sm_order) + sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order) + assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6) + + def test_scaling_factor_parity(self): + """routed_scaling_factor applied identically by both.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing + + moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( + n_group=1, bias_on_gate=True + ) + moe_block.routed_scaling_factor = 2.5 + + sm_weights, sm_experts, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block) + + sonic_weights_2d = sonic_scores.reshape(T, K) + sonic_experts_2d = sonic_exp_idx.reshape(T, K) + + sm_sorted, sm_order = sm_experts.sort(dim=-1) + sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1) + + assert torch.equal(sm_sorted, sonic_sorted) + sm_weights_sorted = sm_weights.gather(1, sm_order) + sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order) + assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6) + + def test_no_renorm_parity(self): + """norm_topk_prob=False produces same results in both.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing + + moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( + n_group=1, bias_on_gate=True + ) + moe_block.norm_topk_prob = False + + sm_weights, sm_experts, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block) + + sonic_weights_2d = sonic_scores.reshape(T, K) + sonic_experts_2d = sonic_exp_idx.reshape(T, K) + + sm_sorted, sm_order = sm_experts.sort(dim=-1) + sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1) + + assert torch.equal(sm_sorted, sonic_sorted) + sm_weights_sorted = sm_weights.gather(1, sm_order) + sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order) + assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6) + + +# ============================================================================ +# 3. Shared expert parity +# ============================================================================ + + +class TestSharedExpertParity: + """Verify both _compute_shared_expert implementations behave identically.""" + + @pytest.fixture(autouse=True) + def _require(self): + _require_triton() + + def _get_both_fns(self): + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _compute_shared_expert as scatter_compute, + ) + from axolotl.integrations.kernels.sonicmoe.patch import ( + _compute_shared_expert as sonic_compute, + ) + + return scatter_compute, sonic_compute + + def test_shared_expert_singular(self): + scatter_fn, sonic_fn = self._get_both_fns() + out = torch.randn(4, 8) + block = SimpleNamespace(shared_expert=lambda x: out) + hidden = torch.randn(4, 8) + + assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden)) + + def test_shared_experts_plural(self): + scatter_fn, sonic_fn = self._get_both_fns() + out = torch.randn(4, 8) + block = SimpleNamespace(shared_experts=lambda x: out) + hidden = torch.randn(4, 8) + + assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden)) + + def test_shared_mlp(self): + scatter_fn, sonic_fn = self._get_both_fns() + out = torch.randn(4, 8) + block = SimpleNamespace(shared_mlp=lambda x: out) + hidden = torch.randn(4, 8) + + assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden)) + + def test_no_shared_expert(self): + scatter_fn, sonic_fn = self._get_both_fns() + block = SimpleNamespace() + hidden = torch.randn(4, 8) + + assert scatter_fn(block, hidden) is None + assert sonic_fn(block, hidden) is None + + def test_shared_expert_gate_only_in_scattermoe(self): + """ScatterMoE's _compute_shared_expert handles shared_expert_gate; + SonicMoE's patch.py handles it externally in the forward function. + + This documents the known divergence: the scattermoe version applies + sigmoid gating inline, while sonicmoe applies it in the forward. + """ + scatter_fn, sonic_fn = self._get_both_fns() + + H = 8 + expert_out = torch.ones(4, H) + gate_fn = lambda x: torch.zeros(4, H) # noqa: E731 # sigmoid(0) = 0.5 + + block = SimpleNamespace( + shared_expert=lambda x: expert_out, + shared_expert_gate=gate_fn, + ) + hidden = torch.randn(4, H) + + scatter_result = scatter_fn(block, hidden) + sonic_result = sonic_fn(block, hidden) + + # ScatterMoE applies the gate: expert_out * sigmoid(0) = 0.5 + expected_gated = expert_out * 0.5 + assert torch.allclose(scatter_result, expected_gated, atol=1e-6) + + # SonicMoE does NOT apply the gate here (it does it in the forward) + assert torch.equal(sonic_result, expert_out) + + +# ============================================================================ +# 4. Route dispatcher parity +# ============================================================================ + + +class TestRouteDispatcherParity: + """Verify _route in scattermoe dispatches correctly and matches individual fns.""" + + @pytest.fixture(autouse=True) + def _require(self): + _require_triton() + + def test_route_dispatches_softmax(self): + """_route should use softmax when no e_score_correction_bias.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _route, + _softmax_topk_route, + ) + + moe_block, gate, hidden, T, H, E, K = _make_softmax_block() + + route_w, route_e, route_k, route_E = _route( + moe_block, gate, hidden, gate.weight, None + ) + direct_w, direct_e, direct_k, direct_E = _softmax_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + + assert torch.equal(route_w, direct_w) + assert torch.equal(route_e, direct_e) + assert route_k == direct_k + assert route_E == direct_E + + def test_route_dispatches_sigmoid(self): + """_route should use sigmoid when e_score_correction_bias is present.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _route, + _sigmoid_topk_route, + ) + + moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block( + n_group=1, bias_on_gate=True + ) + + route_w, route_e, route_k, route_E = _route( + moe_block, gate, hidden, gate.weight, None + ) + direct_w, direct_e, direct_k, direct_E = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + + assert torch.equal(route_w, direct_w) + assert torch.equal(route_e, direct_e) + assert route_k == direct_k + assert route_E == direct_E diff --git a/tests/integrations/test_scattermoe_autotune_telemetry.py b/tests/integrations/test_scattermoe_autotune_telemetry.py new file mode 100644 index 000000000..50ac56720 --- /dev/null +++ b/tests/integrations/test_scattermoe_autotune_telemetry.py @@ -0,0 +1,367 @@ +"""Tests for scattermoe autotune telemetry integration. + +These tests use mocking to verify the collection and reporting logic +without requiring Triton or CUDA. +""" + +import sys +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# Simulate the hash-suffixed module name that LocalLayerRepository creates. +_FAKE_MODULE_NAME = "scattermoe_lora_abc123.kernels.lora_ops" + + +def _make_mock_config(kwargs, num_warps=4, num_stages=3): + """Create a mock triton.Config-like object.""" + return SimpleNamespace(kwargs=kwargs, num_warps=num_warps, num_stages=num_stages) + + +def _make_mock_kernel(cache=None): + """Create a mock autotuned kernel object with a ``.cache`` dict.""" + kernel = SimpleNamespace() + kernel.cache = cache if cache is not None else {} + return kernel + + +def _make_mock_lora_ops( + fwd_cache=None, dx_cache=None, bwd_cache=None, fused_cache=None +): + """Build a mock ``lora_ops`` module with the four kernel attributes.""" + mod = SimpleNamespace( + _scatter2scatter_lora=_make_mock_kernel(fwd_cache), + _scatter2scatter_lora_dX=_make_mock_kernel(dx_cache), + _group_bwd_lora=_make_mock_kernel(bwd_cache), + _group_bwd_lora_fused=_make_mock_kernel(fused_cache), + ) + return mod + + +# ========================================================================= +# TestAutotuneCollector +# ========================================================================= + + +class TestAutotuneCollector: + """Test ``collect_autotune_configs`` with mocked kernel objects.""" + + def test_empty_cache_returns_empty_list(self): + """When no kernel has been autotuned yet, return ``[]``.""" + mock_lora_ops = _make_mock_lora_ops() + + with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}): + from axolotl.integrations.kernels.autotune_collector import ( + collect_autotune_configs, + ) + + result = collect_autotune_configs() + assert result == [] + + def test_populated_cache_returns_configs(self): + """When a cache entry exists, it appears in the output.""" + cfg = _make_mock_config( + {"BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4 + ) + mock_lora_ops = _make_mock_lora_ops(fwd_cache={(2048, 4096, 1024): cfg}) + + with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}): + from axolotl.integrations.kernels.autotune_collector import ( + collect_autotune_configs, + ) + + result = collect_autotune_configs() + + assert len(result) == 1 + entry = result[0] + assert entry["kernel"] == "scatter2scatter_lora_fwd" + assert entry["key"] == {"M": 2048, "N": 4096, "K": 1024} + assert entry["config"]["BLOCK_N"] == 128 + assert entry["config"]["BLOCK_K"] == 64 + assert entry["config"]["num_warps"] == 8 + assert entry["config"]["num_stages"] == 4 + + def test_multiple_kernels_and_keys(self): + """Multiple cache entries across kernels are all returned.""" + cfg_fwd = _make_mock_config({"BLOCK_N": 128, "BLOCK_K": 32}) + cfg_dx = _make_mock_config({"BLOCK_K": 64, "BLOCK_N": 128}, num_warps=8) + + mock_lora_ops = _make_mock_lora_ops( + fwd_cache={(16, 256, 128): cfg_fwd}, + dx_cache={(16, 256, 128): cfg_dx}, + ) + + with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}): + from axolotl.integrations.kernels.autotune_collector import ( + collect_autotune_configs, + ) + + result = collect_autotune_configs() + + assert len(result) == 2 + names = {r["kernel"] for r in result} + assert "scatter2scatter_lora_fwd" in names + assert "scatter2scatter_lora_dX" in names + + def test_extra_key_elements_stored(self): + """Dtype or other extra elements in the cache key are captured.""" + cfg = _make_mock_config({"BLOCK_N": 64, "BLOCK_K": 32}) + cache_key = (512, 1024, 256, "float16", "float16") + + mock_lora_ops = _make_mock_lora_ops(fwd_cache={cache_key: cfg}) + + with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}): + from axolotl.integrations.kernels.autotune_collector import ( + collect_autotune_configs, + ) + + result = collect_autotune_configs() + + assert len(result) == 1 + key = result[0]["key"] + assert key["M"] == 512 + assert key["N"] == 1024 + assert key["K"] == 256 + assert key["_extra"] == ["float16", "float16"] + + def test_no_module_in_sys_modules_returns_empty(self): + """If no lora_ops module is loaded, return ``[]``.""" + from axolotl.integrations.kernels.autotune_collector import ( + collect_autotune_configs, + ) + + # Don't inject anything — the real lora_ops isn't loaded either + # (no triton on this machine), so _find_lora_ops_module returns None. + result = collect_autotune_configs() + assert result == [] + + def test_finds_module_under_hash_suffixed_name(self): + """Collector finds lora_ops regardless of the hash suffix.""" + cfg = _make_mock_config({"BLOCK_N": 256, "BLOCK_K": 128}) + mock_lora_ops = _make_mock_lora_ops(fwd_cache={(8, 512, 64): cfg}) + + # Use a different hash to prove it's not hardcoded. + alt_name = "scattermoe_lora_deadbeef.kernels.lora_ops" + with patch.dict(sys.modules, {alt_name: mock_lora_ops}): + from axolotl.integrations.kernels.autotune_collector import ( + collect_autotune_configs, + ) + + result = collect_autotune_configs() + + assert len(result) == 1 + assert result[0]["config"]["BLOCK_N"] == 256 + + +# ========================================================================= +# TestAutotuneReportCallback +# ========================================================================= + + +class TestAutotuneReportCallback: + """Test the callback fires once and sends the correct event.""" + + def test_reports_once_on_first_step(self): + """Callback should call ``send_event`` exactly once.""" + from axolotl.integrations.kernels.autotune_callback import ( + AutotuneReportCallback, + ) + + cb = AutotuneReportCallback() + mock_state = MagicMock() + mock_state.global_step = 1 + + fake_configs = [{"kernel": "test_fwd", "key": {}, "config": {}}] + + with ( + patch( + "axolotl.integrations.kernels.autotune_collector.collect_autotune_configs", + return_value=fake_configs, + ), + patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls, + ): + mock_tm = MagicMock() + mock_tm.enabled = True + mock_tm_cls.get_instance.return_value = mock_tm + + cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock()) + assert mock_tm.send_event.call_count == 1 + + call_kwargs = mock_tm.send_event.call_args[1] + assert call_kwargs["event_type"] == "scattermoe-autotune" + assert call_kwargs["properties"]["kernel_count"] == 1 + + # Second call should NOT send again. + cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock()) + assert mock_tm.send_event.call_count == 1 + + def test_retries_until_step_5_then_gives_up(self): + """If no configs found by step 5, stop retrying.""" + from axolotl.integrations.kernels.autotune_callback import ( + AutotuneReportCallback, + ) + + cb = AutotuneReportCallback() + + with patch( + "axolotl.integrations.kernels.autotune_collector.collect_autotune_configs", + return_value=[], + ): + for step in range(1, 7): + mock_state = MagicMock() + mock_state.global_step = step + cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock()) + + assert cb._reported is True + + def test_reports_on_retry_when_data_arrives(self): + """If step 1 has no data but step 2 does, report at step 2.""" + from axolotl.integrations.kernels.autotune_callback import ( + AutotuneReportCallback, + ) + + cb = AutotuneReportCallback() + fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}] + + call_count = 0 + + def _collector(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [] + return fake_configs + + with ( + patch( + "axolotl.integrations.kernels.autotune_collector.collect_autotune_configs", + side_effect=_collector, + ), + patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls, + ): + mock_tm = MagicMock() + mock_tm.enabled = True + mock_tm_cls.get_instance.return_value = mock_tm + + # Step 1 — empty, no report + s1 = MagicMock() + s1.global_step = 1 + cb.on_step_end(args=MagicMock(), state=s1, control=MagicMock()) + assert mock_tm.send_event.call_count == 0 + + # Step 2 — data arrives, report + s2 = MagicMock() + s2.global_step = 2 + cb.on_step_end(args=MagicMock(), state=s2, control=MagicMock()) + assert mock_tm.send_event.call_count == 1 + + def test_includes_gpu_info(self): + """Event properties should include GPU identification.""" + from axolotl.integrations.kernels.autotune_callback import ( + AutotuneReportCallback, + ) + + cb = AutotuneReportCallback() + mock_state = MagicMock() + mock_state.global_step = 1 + + fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}] + fake_gpu = { + "gpu_name": "NVIDIA H100", + "gpu_compute_capability": "9.0", + "gpu_memory_bytes": 85899345920, + } + + fake_smem = {"smem_capacity_bytes": 233472} + + with ( + patch( + "axolotl.integrations.kernels.autotune_collector.collect_autotune_configs", + return_value=fake_configs, + ), + patch( + "axolotl.integrations.kernels.autotune_callback._get_gpu_info", + return_value=fake_gpu, + ), + patch( + "axolotl.integrations.kernels.autotune_callback._get_smem_capacity", + return_value=fake_smem, + ), + patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls, + ): + mock_tm = MagicMock() + mock_tm.enabled = True + mock_tm_cls.get_instance.return_value = mock_tm + + cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock()) + props = mock_tm.send_event.call_args[1]["properties"] + assert props["gpu_name"] == "NVIDIA H100" + assert props["gpu_compute_capability"] == "9.0" + assert props["gpu_memory_bytes"] == 85899345920 + assert props["smem_capacity_bytes"] == 233472 + + def test_skips_send_when_telemetry_disabled(self): + """If telemetry is disabled, no event is sent.""" + from axolotl.integrations.kernels.autotune_callback import ( + AutotuneReportCallback, + ) + + cb = AutotuneReportCallback() + mock_state = MagicMock() + mock_state.global_step = 1 + + with ( + patch( + "axolotl.integrations.kernels.autotune_collector.collect_autotune_configs", + return_value=[{"kernel": "fwd", "key": {}, "config": {}}], + ), + patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls, + ): + mock_tm = MagicMock() + mock_tm.enabled = False + mock_tm_cls.get_instance.return_value = mock_tm + + cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock()) + assert mock_tm.send_event.call_count == 0 + # Should still mark as reported so we don't retry. + assert cb._reported is True + + +# ========================================================================= +# TestKernelsPluginCallbackRegistration +# ========================================================================= + + +class TestKernelsPluginCallbackRegistration: + """Test that ``KernelsPlugin`` registers the callback correctly.""" + + def test_scattermoe_registers_callback(self): + """When ``use_scattermoe=True``, plugin returns the callback.""" + from axolotl.integrations.kernels.autotune_callback import ( + AutotuneReportCallback, + ) + from axolotl.integrations.kernels.plugin import KernelsPlugin + + plugin = KernelsPlugin() + cfg = MagicMock() + cfg.use_scattermoe = True + model = MagicMock() + + callbacks = plugin.add_callbacks_pre_trainer(cfg, model) + assert len(callbacks) == 1 + assert isinstance(callbacks[0], AutotuneReportCallback) + + def test_no_scattermoe_no_callback(self): + """When ``use_scattermoe=False``, plugin returns empty list.""" + from axolotl.integrations.kernels.plugin import KernelsPlugin + + plugin = KernelsPlugin() + cfg = MagicMock() + cfg.use_scattermoe = False + model = MagicMock() + + callbacks = plugin.add_callbacks_pre_trainer(cfg, model) + assert callbacks == [] diff --git a/tests/integrations/test_scattermoe_lora.py b/tests/integrations/test_scattermoe_lora.py index d498c8010..bd50d06fe 100644 --- a/tests/integrations/test_scattermoe_lora.py +++ b/tests/integrations/test_scattermoe_lora.py @@ -3,17 +3,19 @@ # Licensed under the Apache License, Version 2.0 """ -Unit tests for scattermoe-lora code-review fixes. +Unit tests for scattermoe-lora. Tests cover: - KernelsArgs validator: disable_mlp_kernel -- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward - ParallelExperts: scaling=0.0 not treated as falsy - single2scatter: non-aligned K/N dimensions - group_compileable: coeff=None accepted - HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract +- Routing strategy detection and sigmoid routing +- Generic shared expert handling """ +from types import SimpleNamespace from unittest.mock import patch import pytest @@ -321,3 +323,389 @@ class TestLayerReturnValues: assert "Router logits" not in docstring, ( "Docstring should not mention 'Router logits' in Returns section" ) + + +# ============================================================================ +# 7. Routing strategy detection and sigmoid routing +# ============================================================================ + + +def _make_softmax_gate(E=4, H=16, K=2): + """Create a mock softmax-style gate (Qwen/OLMoE).""" + return SimpleNamespace( + weight=torch.randn(E, H), + top_k=K, + num_experts=E, + norm_topk_prob=True, + ) + + +def _make_sigmoid_gate_with_bias(E=16, H=16): + """Create a mock sigmoid-style gate with e_score_correction_bias on gate.""" + return SimpleNamespace( + weight=torch.randn(E, H), + e_score_correction_bias=torch.zeros(E), + ) + + +def _make_sigmoid_moe_block( + T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True +): + """Create a mock GLM/DeepSeek-style MoE block for sigmoid routing tests.""" + if bias_on_gate: + gate = SimpleNamespace( + weight=torch.randn(E, H), + e_score_correction_bias=torch.zeros(E), + ) + moe_block = SimpleNamespace( + gate=gate, + top_k=K, + n_routed_experts=E, + n_group=n_group, + topk_group=topk_group, + norm_topk_prob=True, + routed_scaling_factor=1.0, + ) + else: + # minimax_m2 style: bias on block, not gate + gate = SimpleNamespace( + weight=torch.randn(E, H), + top_k=K, + ) + moe_block = SimpleNamespace( + gate=gate, + top_k=K, + e_score_correction_bias=torch.zeros(E), + ) + return moe_block, T, H, E, K + + +def _skip_without_triton(): + pytest.importorskip("triton") + + +class TestSigmoidRoutingInScatterMoE: + """Test _sigmoid_topk_route from layers.py.""" + + @pytest.fixture(autouse=True) + def _require_triton(self): + _skip_without_triton() + + def test_output_shapes(self): + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + + moe_block, T, H, E, K = _make_sigmoid_moe_block() + gate = moe_block.gate + hidden = torch.randn(T, H) + + weights, experts, top_k, num_experts = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + + assert weights.shape == (T, K) + assert experts.shape == (T, K) + assert top_k == K + assert num_experts == E + + def test_weights_nonnegative(self): + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + + moe_block, T, H, E, K = _make_sigmoid_moe_block() + gate = moe_block.gate + hidden = torch.randn(T, H) + + weights, _, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + assert (weights >= 0).all() + + def test_group_selection_restricts_experts(self): + """With n_group=4, topk_group=1, experts should be from selected groups.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + + moe_block, T, H, E, K = _make_sigmoid_moe_block( + E=16, K=2, n_group=4, topk_group=1 + ) + gate = moe_block.gate + hidden = torch.randn(T, H) + + _, expert_idx, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + + # Each token's experts should fall within a single group (size E//n_group=4) + for t in range(T): + experts_t = expert_idx[t] + groups = experts_t // (E // moe_block.n_group) + assert (groups == groups[0]).all() + + def test_scaling_factor_applied(self): + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + + moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1) + gate = moe_block.gate + hidden = torch.randn(T, H) + + weights_1x, _, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + + moe_block.routed_scaling_factor = 2.0 + weights_2x, _, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + + assert torch.allclose(weights_2x, weights_1x * 2.0, atol=1e-5) + + def test_bias_on_gate(self): + """e_score_correction_bias on gate is found.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + + moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True) + gate = moe_block.gate + hidden = torch.randn(T, H) + + weights, experts, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + assert weights.shape == (T, K) + + def test_bias_on_block(self): + """e_score_correction_bias on moe_block (not gate) is found.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + + moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False) + gate = moe_block.gate + hidden = torch.randn(T, H) + + weights, experts, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + assert weights.shape == (T, K) + + def test_gate_lora_delta_applied(self): + """Gate LoRA delta should affect routing logits.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + + moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1) + gate = moe_block.gate + hidden = torch.randn(T, H) + + weights_no_lora, _, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + + # Large delta should change the results + delta = torch.randn(E, H) * 10.0 + weights_with_lora, _, _, _ = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, delta + ) + + assert not torch.equal(weights_no_lora, weights_with_lora) + + def test_no_bias_does_not_crash(self): + """Calling _sigmoid_topk_route with no e_score_correction_bias should not crash.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + + T, H, E, K = 8, 16, 8, 2 + gate = SimpleNamespace(weight=torch.randn(E, H)) + moe_block = SimpleNamespace( + gate=gate, + top_k=K, + n_routed_experts=E, + n_group=1, + norm_topk_prob=True, + routed_scaling_factor=1.0, + ) + hidden = torch.randn(T, H) + + weights, experts, top_k, num_experts = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + assert weights.shape == (T, K) + assert experts.shape == (T, K) + # Without bias, scores_for_choice == sigmoid(logits) — all positive + assert (weights >= 0).all() + + def test_missing_topk_group_defaults_to_n_group(self): + """When topk_group is absent but n_group > 1, should default to n_group (no-op masking).""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _sigmoid_topk_route, + ) + + T, H, E, K, n_group = 8, 16, 16, 2, 4 + gate = SimpleNamespace( + weight=torch.randn(E, H), + e_score_correction_bias=torch.zeros(E), + ) + # Intentionally omit topk_group + moe_block = SimpleNamespace( + gate=gate, + top_k=K, + n_routed_experts=E, + n_group=n_group, + norm_topk_prob=True, + routed_scaling_factor=1.0, + ) + hidden = torch.randn(T, H) + + # Should not raise AttributeError; defaults topk_group to n_group + weights, experts, top_k_out, num_experts = _sigmoid_topk_route( + moe_block, gate, hidden, gate.weight, None + ) + assert weights.shape == (T, K) + assert experts.shape == (T, K) + + +class TestRoutingStrategyDetection: + """Test that _route dispatches to the correct strategy.""" + + @pytest.fixture(autouse=True) + def _require_triton(self): + _skip_without_triton() + + def test_softmax_for_qwen_style(self): + """Block without e_score_correction_bias should use softmax.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route + + gate = _make_softmax_gate(E=4, H=16, K=2) + moe_block = SimpleNamespace(gate=gate) + hidden = torch.randn(8, 16) + + weights, experts, top_k, num_experts = _route( + moe_block, gate, hidden, gate.weight, None + ) + + assert weights.shape == (8, 2) + assert experts.shape == (8, 2) + assert top_k == 2 + assert num_experts == 4 + per_token_sums = weights.sum(dim=-1) + assert torch.allclose(per_token_sums, torch.ones(8), atol=1e-5) + + def test_sigmoid_for_glm_style(self): + """Block with e_score_correction_bias on gate should use sigmoid.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route + + moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True, n_group=1) + gate = moe_block.gate + hidden = torch.randn(T, H) + + weights, experts, top_k, num_experts = _route( + moe_block, gate, hidden, gate.weight, None + ) + + assert weights.shape == (T, K) + assert experts.shape == (T, K) + assert (weights >= 0).all() + + def test_sigmoid_for_minimax_m2_style(self): + """Block with e_score_correction_bias on block (not gate) should use sigmoid.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route + + moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False) + gate = moe_block.gate + hidden = torch.randn(T, H) + + weights, experts, top_k, num_experts = _route( + moe_block, gate, hidden, gate.weight, None + ) + + assert weights.shape == (T, K) + assert (weights >= 0).all() + + +# ============================================================================ +# 8. Generic shared expert handling +# ============================================================================ + + +class TestGenericSharedExpert: + """Test _compute_shared_expert from layers.py.""" + + @pytest.fixture(autouse=True) + def _require_triton(self): + _skip_without_triton() + + def test_shared_expert_singular(self): + """shared_expert attribute (Qwen2MoE style).""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _compute_shared_expert, + ) + + called = torch.randn(4, 8) + moe_block = SimpleNamespace( + shared_expert=lambda x: called, + ) + result = _compute_shared_expert(moe_block, torch.randn(4, 8)) + assert torch.equal(result, called) + + def test_shared_experts_plural(self): + """shared_experts attribute (DeepSeek V3 style).""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _compute_shared_expert, + ) + + called = torch.randn(4, 8) + moe_block = SimpleNamespace( + shared_experts=lambda x: called, + ) + result = _compute_shared_expert(moe_block, torch.randn(4, 8)) + assert torch.equal(result, called) + + def test_shared_mlp(self): + """shared_mlp attribute (Hunyuan style).""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _compute_shared_expert, + ) + + called = torch.randn(4, 8) + moe_block = SimpleNamespace( + shared_mlp=lambda x: called, + ) + result = _compute_shared_expert(moe_block, torch.randn(4, 8)) + assert torch.equal(result, called) + + def test_shared_expert_with_gate(self): + """shared_expert + shared_expert_gate applies sigmoid gating.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _compute_shared_expert, + ) + + H = 8 + expert_out = torch.ones(4, H) + gate_fn = lambda x: torch.zeros(4, H) # noqa: E731 + + moe_block = SimpleNamespace( + shared_expert=lambda x: expert_out, + shared_expert_gate=gate_fn, + ) + result = _compute_shared_expert(moe_block, torch.randn(4, H)) + expected = expert_out * 0.5 # sigmoid(0) = 0.5 + assert torch.allclose(result, expected, atol=1e-6) + + def test_no_shared_expert(self): + """No shared expert attributes returns None.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _compute_shared_expert, + ) + + moe_block = SimpleNamespace() + result = _compute_shared_expert(moe_block, torch.randn(4, 8)) + assert result is None