consolidate behavioud of routing in scattermoe kernels (#3475)

* consolidate behavioud of routing in scattermoe kernels

* collect telemetry on best chosen autotuned kernel

* properly collect data

* Fix property name and get smem too

* handle issues raised by coderabbit

* add tests for parity before refactoring
This commit is contained in:
Wing Lian
2026-03-16 23:47:40 -04:00
committed by GitHub
parent 830e9f7eaf
commit 8f3fb517b3
8 changed files with 1988 additions and 35 deletions

View File

@@ -0,0 +1,120 @@
"""Trainer callback for reporting Triton autotune results from scattermoe-lora kernels."""
import logging
import torch
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
LOG = logging.getLogger(__name__)
# Give up looking for autotune data after this many training steps.
_MAX_POLL_STEP = 5
def _get_gpu_info() -> dict:
"""Return basic GPU identification for the current device."""
if not torch.cuda.is_available():
return {}
try:
idx = torch.cuda.current_device()
props = torch.cuda.get_device_properties(idx)
return {
"gpu_name": props.name,
"gpu_compute_capability": f"{props.major}.{props.minor}",
"gpu_memory_bytes": props.total_memory,
}
except Exception: # pylint: disable=broad-exception-caught
return {}
def _get_smem_capacity() -> dict:
"""Return shared memory capacity from the runtime lora_ops module."""
try:
from axolotl.integrations.kernels.autotune_collector import (
_find_lora_ops_module,
)
lora_ops = _find_lora_ops_module()
if lora_ops is None:
return {}
fn = getattr(lora_ops, "_get_smem_capacity", None)
if fn is None:
return {}
return {"smem_capacity_bytes": fn()}
except Exception: # pylint: disable=broad-exception-caught
return {}
class AutotuneReportCallback(TrainerCallback):
"""Reports Triton kernel autotune selections via telemetry.
Fires **once** after the first training step completes (step 1), at
which point the forward and backward passes have both run and the
autotuned kernels have populated their caches. If for some reason
the caches are still empty (e.g. the kernel was never invoked), the
callback retries on subsequent steps up to ``_MAX_POLL_STEP`` and
then stops polling.
After reporting (or giving up) every subsequent ``on_step_end``
call short-circuits on the ``_reported`` flag — zero hot-path cost.
"""
def __init__(self):
self._reported = False
# pylint: disable=unused-argument
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if self._reported:
return
# Lazy import — Triton / scattermoe kernels may not be installed.
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
configs = collect_autotune_configs()
if not configs:
if state.global_step >= _MAX_POLL_STEP:
LOG.debug(
"No autotune data found after %d steps; giving up.",
state.global_step,
)
self._reported = True
return
self._reported = True
from axolotl.telemetry.manager import TelemetryManager
telemetry_manager = TelemetryManager.get_instance()
if not telemetry_manager.enabled:
return
properties = {
"kernel_count": len(configs),
"kernels": configs,
}
properties.update(_get_gpu_info())
properties.update(_get_smem_capacity())
telemetry_manager.send_event(
event_type="scattermoe-autotune",
properties=properties,
)
LOG.info(
"Reported %d scattermoe kernel autotune config(s) to telemetry.",
len(configs),
)

View File

@@ -0,0 +1,114 @@
"""Collect Triton autotune results from scattermoe-lora kernels.
This module reads the ``.cache`` attribute from Triton ``@triton.autotune``
decorated kernel objects and returns structured dicts describing the selected
configurations. It has **no** telemetry dependency — callers decide what to
do with the data.
"""
import logging
import sys
from types import ModuleType
from typing import Any
LOG = logging.getLogger(__name__)
# (human-readable name, attribute on the lora_ops module)
_KERNEL_REGISTRY: list[tuple[str, str]] = [
("scatter2scatter_lora_fwd", "_scatter2scatter_lora"),
("scatter2scatter_lora_dX", "_scatter2scatter_lora_dX"),
("group_bwd_lora", "_group_bwd_lora"),
("group_bwd_lora_fused", "_group_bwd_lora_fused"),
]
# The autotune key declared on every kernel: key=["M", "N", "K"]
_KEY_NAMES: list[str] = ["M", "N", "K"]
def _parse_key_tuple(key_tuple: tuple) -> dict[str, Any]:
"""Turn the autotune cache key tuple into a labelled dict.
Triton builds the cache key from the values of the declared ``key``
args (``M``, ``N``, ``K``) followed by dtype signature elements.
We label the first three and store the rest under ``_extra``.
"""
result: dict[str, Any] = {}
for i, name in enumerate(_KEY_NAMES):
if i < len(key_tuple):
result[name] = key_tuple[i]
if len(key_tuple) > len(_KEY_NAMES):
result["_extra"] = [str(v) for v in key_tuple[len(_KEY_NAMES) :]]
return result
def _find_lora_ops_module() -> ModuleType | None:
"""Locate the *runtime* ``lora_ops`` module in ``sys.modules``.
The HF ``kernels`` package loads ``scattermoe_lora`` via
``import_from_path`` which registers it in ``sys.modules`` under a
hash-suffixed name (e.g. ``scattermoe_lora_a1b2c3d4``). A normal
import (``from axolotl.integrations.kernels...``) would create a
*separate* module instance whose kernel objects have empty
``.cache`` dicts because autotuning ran on the runtime copy.
We search ``sys.modules`` for any module whose name contains
``lora_ops`` and that has the ``_scatter2scatter_lora`` kernel
attribute — that is the runtime copy with populated caches.
"""
for name, module in sys.modules.items():
if (
module is not None
and "lora_ops" in name
and hasattr(module, "_scatter2scatter_lora")
):
return module
return None
def collect_autotune_configs() -> list[dict[str, Any]]:
"""Read autotune caches from the four scattermoe-lora kernels.
Returns a (possibly empty) list of dicts, each containing:
* ``kernel`` human-readable kernel name
* ``key`` dict with the ``M``/``N``/``K`` problem dimensions
* ``config`` dict with the selected tile sizes, ``num_warps``,
and ``num_stages``
Returns ``[]`` if the kernel module cannot be found or if no
autotune cache entries exist yet.
"""
lora_ops = _find_lora_ops_module()
if lora_ops is None:
LOG.debug(
"lora_ops module not found in sys.modules; skipping autotune collection"
)
return []
results: list[dict[str, Any]] = []
for friendly_name, attr_name in _KERNEL_REGISTRY:
kernel_fn = getattr(lora_ops, attr_name, None)
if kernel_fn is None:
continue
cache = getattr(kernel_fn, "cache", None)
if not cache:
continue
for key_tuple, config in cache.items():
config_dict = dict(config.kwargs)
config_dict["num_warps"] = config.num_warps
config_dict["num_stages"] = config.num_stages
if getattr(config, "num_ctas", None) is not None:
config_dict["num_ctas"] = config.num_ctas
results.append(
{
"kernel": friendly_name,
"key": _parse_key_tuple(key_tuple),
"config": config_dict,
}
)
return results

View File

@@ -220,6 +220,158 @@ def _unwrap_experts_lora(experts_module):
return base_experts, gup_lora, down_lora
# =============================================================================
# Routing helpers
# =============================================================================
def _softmax_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
):
"""Softmax→topk routing (Qwen, OLMoE, Mixtral, MiniMax).
Returns:
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
"""
router_logits = F.linear(hidden_states, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(hidden_states, gate_lora_delta)
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if getattr(base_gate, "norm_topk_prob", True):
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
return routing_weights, selected_experts, top_k, num_experts
def _sigmoid_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
):
"""Sigmoid→topk routing (GLM, DeepSeek V3, MiniMax M2).
Supports:
- ``e_score_correction_bias`` on gate or moe_block
- Group-based expert selection when ``n_group > 1``
- ``routed_scaling_factor`` applied to final weights
- Final weights gathered from original sigmoid probs (not bias-corrected)
Returns:
(routing_weights [T, K], selected_experts [T, K], top_k, num_experts)
"""
router_logits = F.linear(hidden_states.float(), gate_weight.float())
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states.float(), gate_lora_delta.float()
)
router_probs = router_logits.sigmoid() # [T, E]
top_k = getattr(moe_block, "top_k", getattr(base_gate, "top_k", None))
num_experts = getattr(moe_block, "n_routed_experts", gate_weight.shape[0])
# Bias-corrected scores for expert selection (not used for final weights).
# glm_moe_dsa/deepseek_v3 store the bias on gate; minimax_m2 on the block.
e_score_correction_bias = getattr(base_gate, "e_score_correction_bias", None)
if e_score_correction_bias is None:
e_score_correction_bias = getattr(moe_block, "e_score_correction_bias", None)
if e_score_correction_bias is not None:
scores_for_choice = router_probs + e_score_correction_bias
else:
scores_for_choice = router_probs
# Group-based selection: pick top groups, mask the rest
n_group = getattr(moe_block, "n_group", 1)
if n_group > 1:
group_scores = (
scores_for_choice.view(-1, n_group, num_experts // n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [T, n_group]
topk_group = getattr(moe_block, "topk_group", n_group)
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(-1, n_group, num_experts // n_group)
.reshape(-1, num_experts)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
# Final topk from (possibly masked) scores
topk_indices = torch.topk(scores_for_choice, k=top_k, dim=-1, sorted=False)[1]
# Gather weights from original sigmoid scores (not bias-corrected)
topk_weights = router_probs.gather(1, topk_indices)
# Optional renormalization + scaling
if getattr(moe_block, "norm_topk_prob", True):
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
topk_weights = topk_weights * routed_scaling_factor
return topk_weights, topk_indices, top_k, num_experts
def _route(moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta):
"""Dispatch to the correct routing strategy based on block attributes.
Detects sigmoid routing by the presence of ``e_score_correction_bias``
on either the gate or the moe_block.
"""
has_sigmoid = (
getattr(base_gate, "e_score_correction_bias", None) is not None
or getattr(moe_block, "e_score_correction_bias", None) is not None
)
if has_sigmoid:
return _sigmoid_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
)
return _softmax_topk_route(
moe_block, base_gate, hidden_states, gate_weight, gate_lora_delta
)
# =============================================================================
# Shared expert helpers
# =============================================================================
def _compute_shared_expert(moe_block, hidden_states_flat):
"""Compute shared expert output if the block has one.
Handles singular (qwen2_moe: ``shared_expert``), plural
(glm_moe_dsa/deepseek_v3: ``shared_experts``), and MLP
(hunyuan_v1_moe: ``shared_mlp``) attribute names.
peft wraps individual linear layers inside the shared expert with
standard LoRA — calling forward() handles this transparently.
"""
shared_expert = (
getattr(moe_block, "shared_expert", None)
or getattr(moe_block, "shared_experts", None)
or getattr(moe_block, "shared_mlp", None)
)
if shared_expert is None:
return None
shared_expert_output = shared_expert(hidden_states_flat)
# Optional sigmoid gate (Qwen2MoE pattern).
# shared_expert_gate may also be peft-wrapped (standard LoRA
# on nn.Linear), its forward() applies LoRA automatically.
shared_expert_gate = getattr(moe_block, "shared_expert_gate", None)
if shared_expert_gate is not None:
shared_expert_output = (
F.sigmoid(shared_expert_gate(hidden_states_flat)) * shared_expert_output
)
return shared_expert_output
# =============================================================================
# Layer classes
# =============================================================================
@@ -281,16 +433,18 @@ class ScatterMoEGatedMLP(nn.Module):
class HFScatterMoEGatedMLP(nn.Module):
"""
ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE).
ScatterMoE-accelerated forward pass for HF MoEs.
Used as a kernel layer via the HF ``kernels`` library. The ``forward``
method replaces the original ``OlmoeSparseMoeBlock.forward``.
method replaces the original SparseMoeBlock.forward.
Supports both full-parameter training and LoRA fine-tuning:
Supports:
* **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel)
* **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts
adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
"""
@staticmethod
@@ -302,7 +456,7 @@ class HFScatterMoEGatedMLP(nn.Module):
self: The MoeSparseMoeBlock module containing:
- self.gate: Router (or peft ParamWrapper wrapping it)
- self.experts: Experts module (or peft ParamWrapper chain)
- self.shared_expert: Optional shared expert (e.g. Qwen2MoE)
- self.shared_expert(s): Optional shared expert
- self.shared_expert_gate: Optional shared expert gate
layer_input: Input tensor [batch_size, seq_len, hidden_size]
@@ -313,38 +467,17 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat = layer_input.view(-1, hidden_dim)
# ====================================================================
# Shared Expert (if present, e.g. Qwen2MoE)
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
# ====================================================================
# peft wraps individual linear layers inside shared_expert with
# standard LoRA — calling forward() handles this transparently.
if hasattr(self, "shared_expert") and self.shared_expert is not None:
shared_expert_output = self.shared_expert(hidden_states_flat)
# shared_expert_gate may also be peft-wrapped (standard LoRA
# on nn.Linear), its forward() applies LoRA automatically.
shared_expert_gate_output = F.sigmoid(
self.shared_expert_gate(hidden_states_flat)
)
shared_expert_output = shared_expert_output * shared_expert_gate_output
else:
shared_expert_output = None
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
# ====================================================================
# Router Computation (with optional gate LoRA)
# ====================================================================
base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)
router_logits = F.linear(hidden_states_flat, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states_flat, gate_lora_delta
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if base_gate.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights, selected_experts, top_k, num_experts = _route(
self, base_gate, hidden_states_flat, gate_weight, gate_lora_delta
)
routing_weights = routing_weights.to(hidden_states_flat.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(

View File

@@ -110,6 +110,16 @@ class KernelsPlugin(BasePlugin):
}
)
def add_callbacks_pre_trainer(self, cfg, model):
callbacks = []
if cfg.use_scattermoe:
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
callbacks.append(AutotuneReportCallback())
return callbacks
def _kernelize_model(self, model_type: str):
from kernels import replace_kernel_forward_from_hub

View File

@@ -12,6 +12,7 @@ Tests verify correctness of:
3. Frozen weights: expert weight gradients are correctly skipped
4. Various configurations: top-k, grouped_in/out, with/without bias
5. Numerical stability: bf16/fp16 outputs within tolerance of fp32 reference
6. HFScatterMoEGatedMLP with sigmoid routing (GLM/DeepSeek/MiniMax M2)
Test strategy:
- Reference implementation uses pure PyTorch ops (no Triton)
@@ -19,6 +20,8 @@ Test strategy:
- Tolerances account for tf32 accumulation in Triton kernels
"""
from types import SimpleNamespace
import pytest
import torch
@@ -1476,3 +1479,347 @@ class TestCombinedOptimizations:
torch.testing.assert_close(X1.grad, X2.grad, atol=5e-2, rtol=5e-2)
torch.testing.assert_close(A1.grad, A2.grad, atol=5e-2, rtol=5e-2)
torch.testing.assert_close(B1.grad, B2.grad, atol=5e-2, rtol=5e-2)
# =============================================================================
# Test: HFScatterMoEGatedMLP with Sigmoid Routing
# =============================================================================
def _reference_moe_forward(
hidden_states,
gate_weight,
gate_up_proj,
down_proj,
act_fn,
routing_weights,
selected_experts,
num_experts,
):
"""Pure PyTorch reference for a full MoE forward pass.
Args:
hidden_states: [T, H]
gate_weight: [E, H]
gate_up_proj: [E, 2*FF, H]
down_proj: [E, H, FF]
act_fn: activation function (e.g. torch.nn.SiLU())
routing_weights: [T, K] routing weights
selected_experts: [T, K] expert indices
num_experts: int
Returns:
output: [T, H]
"""
T, H = hidden_states.shape
K = selected_experts.shape[1]
output = torch.zeros(T, H, device=hidden_states.device, dtype=hidden_states.dtype)
for t in range(T):
for j in range(K):
e = selected_experts[t, j].item()
w = routing_weights[t, j].item()
# gate_up projection
gup = hidden_states[t] @ gate_up_proj[e].T # [2*I]
I_dim = gup.shape[0] // 2
gates = gup[:I_dim]
up = gup[I_dim:]
# activation
h = act_fn(gates) * up
# down projection
out = h @ down_proj[e].T # [H]
output[t] += w * out
return output
def _make_mock_sigmoid_moe_block(
T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True
):
"""Create a mock MoE block with sigmoid routing for GPU testing."""
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
act_fn = torch.nn.SiLU()
experts = SimpleNamespace(
gate_up_proj=gate_up_proj,
down_proj=down_proj,
act_fn=act_fn,
num_experts=E,
)
if bias_on_gate:
gate = SimpleNamespace(
weight=torch.randn(E, H, device="cuda") * 0.1,
e_score_correction_bias=torch.zeros(E, device="cuda"),
)
moe_block = SimpleNamespace(
gate=gate,
experts=experts,
top_k=K,
n_routed_experts=E,
n_group=n_group,
topk_group=topk_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
else:
# minimax_m2 style
gate = SimpleNamespace(
weight=torch.randn(E, H, device="cuda") * 0.1,
top_k=K,
)
moe_block = SimpleNamespace(
gate=gate,
experts=experts,
top_k=K,
e_score_correction_bias=torch.zeros(E, device="cuda"),
)
return moe_block, T, H, FF, E, K
class TestHFScatterMoESigmoidRouting:
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
def test_forward_matches_reference_bias_on_gate(self):
"""Forward pass with sigmoid routing (bias on gate) matches reference."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
_sigmoid_topk_route,
)
moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block(
T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True
)
hidden = torch.randn(1, T, H, device="cuda")
# Get routing for reference
gate = moe_block.gate
hidden_flat = hidden.view(-1, H)
routing_weights, selected_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden_flat, gate.weight, None
)
# Reference output
ref_output = _reference_moe_forward(
hidden_flat,
gate.weight,
moe_block.experts.gate_up_proj,
moe_block.experts.down_proj,
moe_block.experts.act_fn,
routing_weights,
selected_experts,
E,
)
# Kernel output
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
kernel_output_flat = kernel_output.view(-1, H)
torch.testing.assert_close(
kernel_output_flat.float(),
ref_output.float(),
atol=5e-2,
rtol=5e-2,
)
def test_forward_matches_reference_bias_on_block(self):
"""Forward pass with sigmoid routing (minimax_m2 style, bias on block)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
_sigmoid_topk_route,
)
moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block(
T=16, H=64, FF=32, E=8, K=2, n_group=1, bias_on_gate=False
)
hidden = torch.randn(1, T, H, device="cuda")
hidden_flat = hidden.view(-1, H)
gate = moe_block.gate
routing_weights, selected_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden_flat, gate.weight, None
)
ref_output = _reference_moe_forward(
hidden_flat,
gate.weight,
moe_block.experts.gate_up_proj,
moe_block.experts.down_proj,
moe_block.experts.act_fn,
routing_weights,
selected_experts,
E,
)
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
kernel_output_flat = kernel_output.view(-1, H)
torch.testing.assert_close(
kernel_output_flat.float(),
ref_output.float(),
atol=5e-2,
rtol=5e-2,
)
def test_softmax_routing_still_works(self):
"""Verify softmax routing (Qwen/OLMoE) is not broken."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
_softmax_topk_route,
)
T, H, FF, E, K = 16, 64, 32, 4, 2
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
act_fn = torch.nn.SiLU()
experts = SimpleNamespace(
gate_up_proj=gate_up_proj,
down_proj=down_proj,
act_fn=act_fn,
num_experts=E,
)
gate = SimpleNamespace(
weight=torch.randn(E, H, device="cuda") * 0.1,
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
moe_block = SimpleNamespace(gate=gate, experts=experts)
hidden = torch.randn(1, T, H, device="cuda")
hidden_flat = hidden.view(-1, H)
routing_weights, selected_experts, _, _ = _softmax_topk_route(
moe_block, gate, hidden_flat, gate.weight, None
)
ref_output = _reference_moe_forward(
hidden_flat,
gate.weight,
gate_up_proj,
down_proj,
act_fn,
routing_weights,
selected_experts,
E,
)
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
kernel_output_flat = kernel_output.view(-1, H)
torch.testing.assert_close(
kernel_output_flat.float(),
ref_output.float(),
atol=5e-2,
rtol=5e-2,
)
class TestHFScatterMoESigmoidWithSharedExperts:
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""
def test_shared_experts_plural(self):
"""DeepSeek V3 style: shared_experts attribute (plural)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
)
T, H, FF, E, K = 8, 64, 32, 8, 2
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
act_fn = torch.nn.SiLU()
experts = SimpleNamespace(
gate_up_proj=gate_up_proj,
down_proj=down_proj,
act_fn=act_fn,
num_experts=E,
)
# Shared expert as a simple linear for testing
shared_W = torch.randn(H, H, device="cuda") * 0.01
shared_experts_fn = lambda x: x @ shared_W.T # noqa: E731
gate = SimpleNamespace(
weight=torch.randn(E, H, device="cuda") * 0.1,
e_score_correction_bias=torch.zeros(E, device="cuda"),
)
moe_block = SimpleNamespace(
gate=gate,
experts=experts,
shared_experts=shared_experts_fn,
top_k=K,
n_routed_experts=E,
n_group=1,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
hidden = torch.randn(1, T, H, device="cuda")
# Should not raise; output should include shared expert contribution
output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
assert output.shape == (1, T, H)
# Run without shared expert to verify it changes the output
moe_block_no_shared = SimpleNamespace(
gate=gate,
experts=experts,
top_k=K,
n_routed_experts=E,
n_group=1,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
output_no_shared = HFScatterMoEGatedMLP.forward(moe_block_no_shared, hidden)
assert not torch.equal(output, output_no_shared)
def test_shared_expert_with_gate(self):
"""Qwen2MoE style: shared_expert + shared_expert_gate."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
)
T, H, FF, E, K = 8, 64, 32, 4, 2
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
act_fn = torch.nn.SiLU()
experts = SimpleNamespace(
gate_up_proj=gate_up_proj,
down_proj=down_proj,
act_fn=act_fn,
num_experts=E,
)
shared_W = torch.randn(H, H, device="cuda") * 0.01
shared_expert_fn = lambda x: x @ shared_W.T # noqa: E731
# Gate that returns 0 -> sigmoid(0) = 0.5
gate_W = torch.zeros(H, H, device="cuda")
shared_expert_gate_fn = lambda x: x @ gate_W.T # noqa: E731
gate = SimpleNamespace(
weight=torch.randn(E, H, device="cuda") * 0.1,
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
moe_block = SimpleNamespace(
gate=gate,
experts=experts,
shared_expert=shared_expert_fn,
shared_expert_gate=shared_expert_gate_fn,
)
hidden = torch.randn(1, T, H, device="cuda")
output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
assert output.shape == (1, T, H)

View File

@@ -0,0 +1,474 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
Parity tests between scattermoe-lora and sonicmoe routing implementations.
These tests verify that both implementations produce numerically identical
results for the same inputs, ensuring safe centralization of the routing code.
ScatterMoE returns 2D tensors [T, K]; SonicMoE returns flattened 1D [T*K].
The core algorithm should be identical — only the output format differs.
"""
from types import SimpleNamespace
import pytest
import torch
def _require_triton():
pytest.importorskip("triton")
# ============================================================================
# Fixtures / helpers
# ============================================================================
def _make_softmax_block(T=8, H=16, E=4, K=2):
"""Qwen/OLMoE-style block usable by both implementations."""
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(T, H)
return moe_block, gate, hidden, T, H, E, K
def _make_sigmoid_block(
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
):
"""GLM/DeepSeek-style block usable by both implementations."""
if bias_on_gate:
gate = SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=n_group,
topk_group=topk_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
else:
# minimax_m2 style: bias on block
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
e_score_correction_bias=torch.zeros(E),
)
return moe_block, gate, hidden_states(T, H), T, H, E, K
def hidden_states(T, H):
return torch.randn(T, H)
# ============================================================================
# 1. Softmax routing parity
# ============================================================================
class TestSoftmaxRoutingParity:
"""Verify scattermoe and sonicmoe softmax routing produce identical results."""
@pytest.fixture(autouse=True)
def _require(self):
_require_triton()
def test_weights_match(self):
"""2D weights from scattermoe == reshaped 1D weights from sonicmoe."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
# ScatterMoE path (no LoRA delta)
sm_weights, sm_experts, sm_topk, sm_E = _softmax_topk_route(
moe_block, gate, hidden, gate.weight, None
)
# SonicMoE path
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = softmax_topk_routing(
hidden, moe_block
)
# ScatterMoE returns [T, K], SonicMoE returns [T*K] flattened
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
assert sm_topk == K
assert sm_E == E
# Both should select the same experts and produce the same weights
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
def test_logits_not_returned_by_scattermoe(self):
"""ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
_, _, _, logits = softmax_topk_routing(hidden, moe_block)
assert logits.shape == (T, E)
def test_no_renorm(self):
"""With norm_topk_prob=False, both should skip renormalization."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
gate.norm_topk_prob = False
sm_weights, sm_experts, _, _ = _softmax_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
def test_various_expert_counts(self):
"""Parity across different E and K values."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]:
moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K)
sm_weights, sm_experts, _, _ = _softmax_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)), (
f"Expert mismatch for E={E}, K={K}"
)
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6), (
f"Weight mismatch for E={E}, K={K}"
)
# ============================================================================
# 2. Sigmoid routing parity
# ============================================================================
class TestSigmoidRoutingParity:
"""Verify scattermoe and sonicmoe sigmoid routing produce identical results."""
@pytest.fixture(autouse=True)
def _require(self):
_require_triton()
def test_weights_match_with_groups(self):
"""Both implementations should produce identical weights with group selection."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
)
sm_weights, sm_experts, sm_topk, sm_E = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = sigmoid_topk_routing(
hidden, moe_block
)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
assert sm_topk == K
assert sm_E == E
# Sort experts within each token to handle different topk orderings
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
# Gather weights in sorted order for comparison
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
def test_weights_match_no_groups(self):
"""Both implementations match without group selection (n_group=1)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True
)
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
# Sort for comparison (topk with sorted=False may differ in order)
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
def test_bias_on_block_parity(self):
"""minimax_m2 style: bias on block, not gate."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=1, bias_on_gate=False
)
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
def test_scaling_factor_parity(self):
"""routed_scaling_factor applied identically by both."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
)
moe_block.routed_scaling_factor = 2.5
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
def test_no_renorm_parity(self):
"""norm_topk_prob=False produces same results in both."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
)
moe_block.norm_topk_prob = False
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
# ============================================================================
# 3. Shared expert parity
# ============================================================================
class TestSharedExpertParity:
"""Verify both _compute_shared_expert implementations behave identically."""
@pytest.fixture(autouse=True)
def _require(self):
_require_triton()
def _get_both_fns(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert as scatter_compute,
)
from axolotl.integrations.kernels.sonicmoe.patch import (
_compute_shared_expert as sonic_compute,
)
return scatter_compute, sonic_compute
def test_shared_expert_singular(self):
scatter_fn, sonic_fn = self._get_both_fns()
out = torch.randn(4, 8)
block = SimpleNamespace(shared_expert=lambda x: out)
hidden = torch.randn(4, 8)
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
def test_shared_experts_plural(self):
scatter_fn, sonic_fn = self._get_both_fns()
out = torch.randn(4, 8)
block = SimpleNamespace(shared_experts=lambda x: out)
hidden = torch.randn(4, 8)
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
def test_shared_mlp(self):
scatter_fn, sonic_fn = self._get_both_fns()
out = torch.randn(4, 8)
block = SimpleNamespace(shared_mlp=lambda x: out)
hidden = torch.randn(4, 8)
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
def test_no_shared_expert(self):
scatter_fn, sonic_fn = self._get_both_fns()
block = SimpleNamespace()
hidden = torch.randn(4, 8)
assert scatter_fn(block, hidden) is None
assert sonic_fn(block, hidden) is None
def test_shared_expert_gate_only_in_scattermoe(self):
"""ScatterMoE's _compute_shared_expert handles shared_expert_gate;
SonicMoE's patch.py handles it externally in the forward function.
This documents the known divergence: the scattermoe version applies
sigmoid gating inline, while sonicmoe applies it in the forward.
"""
scatter_fn, sonic_fn = self._get_both_fns()
H = 8
expert_out = torch.ones(4, H)
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731 # sigmoid(0) = 0.5
block = SimpleNamespace(
shared_expert=lambda x: expert_out,
shared_expert_gate=gate_fn,
)
hidden = torch.randn(4, H)
scatter_result = scatter_fn(block, hidden)
sonic_result = sonic_fn(block, hidden)
# ScatterMoE applies the gate: expert_out * sigmoid(0) = 0.5
expected_gated = expert_out * 0.5
assert torch.allclose(scatter_result, expected_gated, atol=1e-6)
# SonicMoE does NOT apply the gate here (it does it in the forward)
assert torch.equal(sonic_result, expert_out)
# ============================================================================
# 4. Route dispatcher parity
# ============================================================================
class TestRouteDispatcherParity:
"""Verify _route in scattermoe dispatches correctly and matches individual fns."""
@pytest.fixture(autouse=True)
def _require(self):
_require_triton()
def test_route_dispatches_softmax(self):
"""_route should use softmax when no e_score_correction_bias."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_route,
_softmax_topk_route,
)
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
route_w, route_e, route_k, route_E = _route(
moe_block, gate, hidden, gate.weight, None
)
direct_w, direct_e, direct_k, direct_E = _softmax_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert torch.equal(route_w, direct_w)
assert torch.equal(route_e, direct_e)
assert route_k == direct_k
assert route_E == direct_E
def test_route_dispatches_sigmoid(self):
"""_route should use sigmoid when e_score_correction_bias is present."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_route,
_sigmoid_topk_route,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
)
route_w, route_e, route_k, route_E = _route(
moe_block, gate, hidden, gate.weight, None
)
direct_w, direct_e, direct_k, direct_E = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert torch.equal(route_w, direct_w)
assert torch.equal(route_e, direct_e)
assert route_k == direct_k
assert route_E == direct_E

View File

@@ -0,0 +1,367 @@
"""Tests for scattermoe autotune telemetry integration.
These tests use mocking to verify the collection and reporting logic
without requiring Triton or CUDA.
"""
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
# Simulate the hash-suffixed module name that LocalLayerRepository creates.
_FAKE_MODULE_NAME = "scattermoe_lora_abc123.kernels.lora_ops"
def _make_mock_config(kwargs, num_warps=4, num_stages=3):
"""Create a mock triton.Config-like object."""
return SimpleNamespace(kwargs=kwargs, num_warps=num_warps, num_stages=num_stages)
def _make_mock_kernel(cache=None):
"""Create a mock autotuned kernel object with a ``.cache`` dict."""
kernel = SimpleNamespace()
kernel.cache = cache if cache is not None else {}
return kernel
def _make_mock_lora_ops(
fwd_cache=None, dx_cache=None, bwd_cache=None, fused_cache=None
):
"""Build a mock ``lora_ops`` module with the four kernel attributes."""
mod = SimpleNamespace(
_scatter2scatter_lora=_make_mock_kernel(fwd_cache),
_scatter2scatter_lora_dX=_make_mock_kernel(dx_cache),
_group_bwd_lora=_make_mock_kernel(bwd_cache),
_group_bwd_lora_fused=_make_mock_kernel(fused_cache),
)
return mod
# =========================================================================
# TestAutotuneCollector
# =========================================================================
class TestAutotuneCollector:
"""Test ``collect_autotune_configs`` with mocked kernel objects."""
def test_empty_cache_returns_empty_list(self):
"""When no kernel has been autotuned yet, return ``[]``."""
mock_lora_ops = _make_mock_lora_ops()
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert result == []
def test_populated_cache_returns_configs(self):
"""When a cache entry exists, it appears in the output."""
cfg = _make_mock_config(
{"BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4
)
mock_lora_ops = _make_mock_lora_ops(fwd_cache={(2048, 4096, 1024): cfg})
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert len(result) == 1
entry = result[0]
assert entry["kernel"] == "scatter2scatter_lora_fwd"
assert entry["key"] == {"M": 2048, "N": 4096, "K": 1024}
assert entry["config"]["BLOCK_N"] == 128
assert entry["config"]["BLOCK_K"] == 64
assert entry["config"]["num_warps"] == 8
assert entry["config"]["num_stages"] == 4
def test_multiple_kernels_and_keys(self):
"""Multiple cache entries across kernels are all returned."""
cfg_fwd = _make_mock_config({"BLOCK_N": 128, "BLOCK_K": 32})
cfg_dx = _make_mock_config({"BLOCK_K": 64, "BLOCK_N": 128}, num_warps=8)
mock_lora_ops = _make_mock_lora_ops(
fwd_cache={(16, 256, 128): cfg_fwd},
dx_cache={(16, 256, 128): cfg_dx},
)
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert len(result) == 2
names = {r["kernel"] for r in result}
assert "scatter2scatter_lora_fwd" in names
assert "scatter2scatter_lora_dX" in names
def test_extra_key_elements_stored(self):
"""Dtype or other extra elements in the cache key are captured."""
cfg = _make_mock_config({"BLOCK_N": 64, "BLOCK_K": 32})
cache_key = (512, 1024, 256, "float16", "float16")
mock_lora_ops = _make_mock_lora_ops(fwd_cache={cache_key: cfg})
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert len(result) == 1
key = result[0]["key"]
assert key["M"] == 512
assert key["N"] == 1024
assert key["K"] == 256
assert key["_extra"] == ["float16", "float16"]
def test_no_module_in_sys_modules_returns_empty(self):
"""If no lora_ops module is loaded, return ``[]``."""
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
# Don't inject anything — the real lora_ops isn't loaded either
# (no triton on this machine), so _find_lora_ops_module returns None.
result = collect_autotune_configs()
assert result == []
def test_finds_module_under_hash_suffixed_name(self):
"""Collector finds lora_ops regardless of the hash suffix."""
cfg = _make_mock_config({"BLOCK_N": 256, "BLOCK_K": 128})
mock_lora_ops = _make_mock_lora_ops(fwd_cache={(8, 512, 64): cfg})
# Use a different hash to prove it's not hardcoded.
alt_name = "scattermoe_lora_deadbeef.kernels.lora_ops"
with patch.dict(sys.modules, {alt_name: mock_lora_ops}):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert len(result) == 1
assert result[0]["config"]["BLOCK_N"] == 256
# =========================================================================
# TestAutotuneReportCallback
# =========================================================================
class TestAutotuneReportCallback:
"""Test the callback fires once and sends the correct event."""
def test_reports_once_on_first_step(self):
"""Callback should call ``send_event`` exactly once."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
mock_state = MagicMock()
mock_state.global_step = 1
fake_configs = [{"kernel": "test_fwd", "key": {}, "config": {}}]
with (
patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
return_value=fake_configs,
),
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
):
mock_tm = MagicMock()
mock_tm.enabled = True
mock_tm_cls.get_instance.return_value = mock_tm
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
assert mock_tm.send_event.call_count == 1
call_kwargs = mock_tm.send_event.call_args[1]
assert call_kwargs["event_type"] == "scattermoe-autotune"
assert call_kwargs["properties"]["kernel_count"] == 1
# Second call should NOT send again.
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
assert mock_tm.send_event.call_count == 1
def test_retries_until_step_5_then_gives_up(self):
"""If no configs found by step 5, stop retrying."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
with patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
return_value=[],
):
for step in range(1, 7):
mock_state = MagicMock()
mock_state.global_step = step
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
assert cb._reported is True
def test_reports_on_retry_when_data_arrives(self):
"""If step 1 has no data but step 2 does, report at step 2."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}]
call_count = 0
def _collector():
nonlocal call_count
call_count += 1
if call_count == 1:
return []
return fake_configs
with (
patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
side_effect=_collector,
),
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
):
mock_tm = MagicMock()
mock_tm.enabled = True
mock_tm_cls.get_instance.return_value = mock_tm
# Step 1 — empty, no report
s1 = MagicMock()
s1.global_step = 1
cb.on_step_end(args=MagicMock(), state=s1, control=MagicMock())
assert mock_tm.send_event.call_count == 0
# Step 2 — data arrives, report
s2 = MagicMock()
s2.global_step = 2
cb.on_step_end(args=MagicMock(), state=s2, control=MagicMock())
assert mock_tm.send_event.call_count == 1
def test_includes_gpu_info(self):
"""Event properties should include GPU identification."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
mock_state = MagicMock()
mock_state.global_step = 1
fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}]
fake_gpu = {
"gpu_name": "NVIDIA H100",
"gpu_compute_capability": "9.0",
"gpu_memory_bytes": 85899345920,
}
fake_smem = {"smem_capacity_bytes": 233472}
with (
patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
return_value=fake_configs,
),
patch(
"axolotl.integrations.kernels.autotune_callback._get_gpu_info",
return_value=fake_gpu,
),
patch(
"axolotl.integrations.kernels.autotune_callback._get_smem_capacity",
return_value=fake_smem,
),
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
):
mock_tm = MagicMock()
mock_tm.enabled = True
mock_tm_cls.get_instance.return_value = mock_tm
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
props = mock_tm.send_event.call_args[1]["properties"]
assert props["gpu_name"] == "NVIDIA H100"
assert props["gpu_compute_capability"] == "9.0"
assert props["gpu_memory_bytes"] == 85899345920
assert props["smem_capacity_bytes"] == 233472
def test_skips_send_when_telemetry_disabled(self):
"""If telemetry is disabled, no event is sent."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
mock_state = MagicMock()
mock_state.global_step = 1
with (
patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
return_value=[{"kernel": "fwd", "key": {}, "config": {}}],
),
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
):
mock_tm = MagicMock()
mock_tm.enabled = False
mock_tm_cls.get_instance.return_value = mock_tm
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
assert mock_tm.send_event.call_count == 0
# Should still mark as reported so we don't retry.
assert cb._reported is True
# =========================================================================
# TestKernelsPluginCallbackRegistration
# =========================================================================
class TestKernelsPluginCallbackRegistration:
"""Test that ``KernelsPlugin`` registers the callback correctly."""
def test_scattermoe_registers_callback(self):
"""When ``use_scattermoe=True``, plugin returns the callback."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
from axolotl.integrations.kernels.plugin import KernelsPlugin
plugin = KernelsPlugin()
cfg = MagicMock()
cfg.use_scattermoe = True
model = MagicMock()
callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
assert len(callbacks) == 1
assert isinstance(callbacks[0], AutotuneReportCallback)
def test_no_scattermoe_no_callback(self):
"""When ``use_scattermoe=False``, plugin returns empty list."""
from axolotl.integrations.kernels.plugin import KernelsPlugin
plugin = KernelsPlugin()
cfg = MagicMock()
cfg.use_scattermoe = False
model = MagicMock()
callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
assert callbacks == []

View File

@@ -3,17 +3,19 @@
# Licensed under the Apache License, Version 2.0
"""
Unit tests for scattermoe-lora code-review fixes.
Unit tests for scattermoe-lora.
Tests cover:
- KernelsArgs validator: disable_mlp_kernel
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
- ParallelExperts: scaling=0.0 not treated as falsy
- single2scatter: non-aligned K/N dimensions
- group_compileable: coeff=None accepted
- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract
- Routing strategy detection and sigmoid routing
- Generic shared expert handling
"""
from types import SimpleNamespace
from unittest.mock import patch
import pytest
@@ -321,3 +323,389 @@ class TestLayerReturnValues:
assert "Router logits" not in docstring, (
"Docstring should not mention 'Router logits' in Returns section"
)
# ============================================================================
# 7. Routing strategy detection and sigmoid routing
# ============================================================================
def _make_softmax_gate(E=4, H=16, K=2):
"""Create a mock softmax-style gate (Qwen/OLMoE)."""
return SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
def _make_sigmoid_gate_with_bias(E=16, H=16):
"""Create a mock sigmoid-style gate with e_score_correction_bias on gate."""
return SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
def _make_sigmoid_moe_block(
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
):
"""Create a mock GLM/DeepSeek-style MoE block for sigmoid routing tests."""
if bias_on_gate:
gate = SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=n_group,
topk_group=topk_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
else:
# minimax_m2 style: bias on block, not gate
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
e_score_correction_bias=torch.zeros(E),
)
return moe_block, T, H, E, K
def _skip_without_triton():
pytest.importorskip("triton")
class TestSigmoidRoutingInScatterMoE:
"""Test _sigmoid_topk_route from layers.py."""
@pytest.fixture(autouse=True)
def _require_triton(self):
_skip_without_triton()
def test_output_shapes(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block()
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, top_k, num_experts = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert experts.shape == (T, K)
assert top_k == K
assert num_experts == E
def test_weights_nonnegative(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block()
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert (weights >= 0).all()
def test_group_selection_restricts_experts(self):
"""With n_group=4, topk_group=1, experts should be from selected groups."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(
E=16, K=2, n_group=4, topk_group=1
)
gate = moe_block.gate
hidden = torch.randn(T, H)
_, expert_idx, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
# Each token's experts should fall within a single group (size E//n_group=4)
for t in range(T):
experts_t = expert_idx[t]
groups = experts_t // (E // moe_block.n_group)
assert (groups == groups[0]).all()
def test_scaling_factor_applied(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights_1x, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
moe_block.routed_scaling_factor = 2.0
weights_2x, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert torch.allclose(weights_2x, weights_1x * 2.0, atol=1e-5)
def test_bias_on_gate(self):
"""e_score_correction_bias on gate is found."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
def test_bias_on_block(self):
"""e_score_correction_bias on moe_block (not gate) is found."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
def test_gate_lora_delta_applied(self):
"""Gate LoRA delta should affect routing logits."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights_no_lora, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
# Large delta should change the results
delta = torch.randn(E, H) * 10.0
weights_with_lora, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, delta
)
assert not torch.equal(weights_no_lora, weights_with_lora)
def test_no_bias_does_not_crash(self):
"""Calling _sigmoid_topk_route with no e_score_correction_bias should not crash."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
T, H, E, K = 8, 16, 8, 2
gate = SimpleNamespace(weight=torch.randn(E, H))
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=1,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
hidden = torch.randn(T, H)
weights, experts, top_k, num_experts = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert experts.shape == (T, K)
# Without bias, scores_for_choice == sigmoid(logits) — all positive
assert (weights >= 0).all()
def test_missing_topk_group_defaults_to_n_group(self):
"""When topk_group is absent but n_group > 1, should default to n_group (no-op masking)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
T, H, E, K, n_group = 8, 16, 16, 2, 4
gate = SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
# Intentionally omit topk_group
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=n_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
hidden = torch.randn(T, H)
# Should not raise AttributeError; defaults topk_group to n_group
weights, experts, top_k_out, num_experts = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert experts.shape == (T, K)
class TestRoutingStrategyDetection:
"""Test that _route dispatches to the correct strategy."""
@pytest.fixture(autouse=True)
def _require_triton(self):
_skip_without_triton()
def test_softmax_for_qwen_style(self):
"""Block without e_score_correction_bias should use softmax."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
gate = _make_softmax_gate(E=4, H=16, K=2)
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(8, 16)
weights, experts, top_k, num_experts = _route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (8, 2)
assert experts.shape == (8, 2)
assert top_k == 2
assert num_experts == 4
per_token_sums = weights.sum(dim=-1)
assert torch.allclose(per_token_sums, torch.ones(8), atol=1e-5)
def test_sigmoid_for_glm_style(self):
"""Block with e_score_correction_bias on gate should use sigmoid."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True, n_group=1)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, top_k, num_experts = _route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert experts.shape == (T, K)
assert (weights >= 0).all()
def test_sigmoid_for_minimax_m2_style(self):
"""Block with e_score_correction_bias on block (not gate) should use sigmoid."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, top_k, num_experts = _route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert (weights >= 0).all()
# ============================================================================
# 8. Generic shared expert handling
# ============================================================================
class TestGenericSharedExpert:
"""Test _compute_shared_expert from layers.py."""
@pytest.fixture(autouse=True)
def _require_triton(self):
_skip_without_triton()
def test_shared_expert_singular(self):
"""shared_expert attribute (Qwen2MoE style)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
called = torch.randn(4, 8)
moe_block = SimpleNamespace(
shared_expert=lambda x: called,
)
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
assert torch.equal(result, called)
def test_shared_experts_plural(self):
"""shared_experts attribute (DeepSeek V3 style)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
called = torch.randn(4, 8)
moe_block = SimpleNamespace(
shared_experts=lambda x: called,
)
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
assert torch.equal(result, called)
def test_shared_mlp(self):
"""shared_mlp attribute (Hunyuan style)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
called = torch.randn(4, 8)
moe_block = SimpleNamespace(
shared_mlp=lambda x: called,
)
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
assert torch.equal(result, called)
def test_shared_expert_with_gate(self):
"""shared_expert + shared_expert_gate applies sigmoid gating."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
H = 8
expert_out = torch.ones(4, H)
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731
moe_block = SimpleNamespace(
shared_expert=lambda x: expert_out,
shared_expert_gate=gate_fn,
)
result = _compute_shared_expert(moe_block, torch.randn(4, H))
expected = expert_out * 0.5 # sigmoid(0) = 0.5
assert torch.allclose(result, expected, atol=1e-6)
def test_no_shared_expert(self):
"""No shared expert attributes returns None."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
moe_block = SimpleNamespace()
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
assert result is None