consolidate behavioud of routing in scattermoe kernels (#3475)

* consolidate behavioud of routing in scattermoe kernels

* collect telemetry on best chosen autotuned kernel

* properly collect data

* Fix property name and get smem too

* handle issues raised by coderabbit

* add tests for parity before refactoring
This commit is contained in:
Wing Lian
2026-03-16 23:47:40 -04:00
committed by GitHub
parent 830e9f7eaf
commit 8f3fb517b3
8 changed files with 1988 additions and 35 deletions

View File

@@ -3,17 +3,19 @@
# Licensed under the Apache License, Version 2.0
"""
Unit tests for scattermoe-lora code-review fixes.
Unit tests for scattermoe-lora.
Tests cover:
- KernelsArgs validator: disable_mlp_kernel
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
- ParallelExperts: scaling=0.0 not treated as falsy
- single2scatter: non-aligned K/N dimensions
- group_compileable: coeff=None accepted
- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract
- Routing strategy detection and sigmoid routing
- Generic shared expert handling
"""
from types import SimpleNamespace
from unittest.mock import patch
import pytest
@@ -321,3 +323,389 @@ class TestLayerReturnValues:
assert "Router logits" not in docstring, (
"Docstring should not mention 'Router logits' in Returns section"
)
# ============================================================================
# 7. Routing strategy detection and sigmoid routing
# ============================================================================
def _make_softmax_gate(E=4, H=16, K=2):
"""Create a mock softmax-style gate (Qwen/OLMoE)."""
return SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
def _make_sigmoid_gate_with_bias(E=16, H=16):
"""Create a mock sigmoid-style gate with e_score_correction_bias on gate."""
return SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
def _make_sigmoid_moe_block(
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
):
"""Create a mock GLM/DeepSeek-style MoE block for sigmoid routing tests."""
if bias_on_gate:
gate = SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=n_group,
topk_group=topk_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
else:
# minimax_m2 style: bias on block, not gate
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
e_score_correction_bias=torch.zeros(E),
)
return moe_block, T, H, E, K
def _skip_without_triton():
pytest.importorskip("triton")
class TestSigmoidRoutingInScatterMoE:
"""Test _sigmoid_topk_route from layers.py."""
@pytest.fixture(autouse=True)
def _require_triton(self):
_skip_without_triton()
def test_output_shapes(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block()
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, top_k, num_experts = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert experts.shape == (T, K)
assert top_k == K
assert num_experts == E
def test_weights_nonnegative(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block()
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert (weights >= 0).all()
def test_group_selection_restricts_experts(self):
"""With n_group=4, topk_group=1, experts should be from selected groups."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(
E=16, K=2, n_group=4, topk_group=1
)
gate = moe_block.gate
hidden = torch.randn(T, H)
_, expert_idx, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
# Each token's experts should fall within a single group (size E//n_group=4)
for t in range(T):
experts_t = expert_idx[t]
groups = experts_t // (E // moe_block.n_group)
assert (groups == groups[0]).all()
def test_scaling_factor_applied(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights_1x, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
moe_block.routed_scaling_factor = 2.0
weights_2x, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert torch.allclose(weights_2x, weights_1x * 2.0, atol=1e-5)
def test_bias_on_gate(self):
"""e_score_correction_bias on gate is found."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
def test_bias_on_block(self):
"""e_score_correction_bias on moe_block (not gate) is found."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
def test_gate_lora_delta_applied(self):
"""Gate LoRA delta should affect routing logits."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights_no_lora, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
# Large delta should change the results
delta = torch.randn(E, H) * 10.0
weights_with_lora, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, delta
)
assert not torch.equal(weights_no_lora, weights_with_lora)
def test_no_bias_does_not_crash(self):
"""Calling _sigmoid_topk_route with no e_score_correction_bias should not crash."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
T, H, E, K = 8, 16, 8, 2
gate = SimpleNamespace(weight=torch.randn(E, H))
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=1,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
hidden = torch.randn(T, H)
weights, experts, top_k, num_experts = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert experts.shape == (T, K)
# Without bias, scores_for_choice == sigmoid(logits) — all positive
assert (weights >= 0).all()
def test_missing_topk_group_defaults_to_n_group(self):
"""When topk_group is absent but n_group > 1, should default to n_group (no-op masking)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
T, H, E, K, n_group = 8, 16, 16, 2, 4
gate = SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
# Intentionally omit topk_group
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=n_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
hidden = torch.randn(T, H)
# Should not raise AttributeError; defaults topk_group to n_group
weights, experts, top_k_out, num_experts = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert experts.shape == (T, K)
class TestRoutingStrategyDetection:
"""Test that _route dispatches to the correct strategy."""
@pytest.fixture(autouse=True)
def _require_triton(self):
_skip_without_triton()
def test_softmax_for_qwen_style(self):
"""Block without e_score_correction_bias should use softmax."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
gate = _make_softmax_gate(E=4, H=16, K=2)
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(8, 16)
weights, experts, top_k, num_experts = _route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (8, 2)
assert experts.shape == (8, 2)
assert top_k == 2
assert num_experts == 4
per_token_sums = weights.sum(dim=-1)
assert torch.allclose(per_token_sums, torch.ones(8), atol=1e-5)
def test_sigmoid_for_glm_style(self):
"""Block with e_score_correction_bias on gate should use sigmoid."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True, n_group=1)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, top_k, num_experts = _route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert experts.shape == (T, K)
assert (weights >= 0).all()
def test_sigmoid_for_minimax_m2_style(self):
"""Block with e_score_correction_bias on block (not gate) should use sigmoid."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, top_k, num_experts = _route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert (weights >= 0).all()
# ============================================================================
# 8. Generic shared expert handling
# ============================================================================
class TestGenericSharedExpert:
"""Test _compute_shared_expert from layers.py."""
@pytest.fixture(autouse=True)
def _require_triton(self):
_skip_without_triton()
def test_shared_expert_singular(self):
"""shared_expert attribute (Qwen2MoE style)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
called = torch.randn(4, 8)
moe_block = SimpleNamespace(
shared_expert=lambda x: called,
)
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
assert torch.equal(result, called)
def test_shared_experts_plural(self):
"""shared_experts attribute (DeepSeek V3 style)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
called = torch.randn(4, 8)
moe_block = SimpleNamespace(
shared_experts=lambda x: called,
)
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
assert torch.equal(result, called)
def test_shared_mlp(self):
"""shared_mlp attribute (Hunyuan style)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
called = torch.randn(4, 8)
moe_block = SimpleNamespace(
shared_mlp=lambda x: called,
)
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
assert torch.equal(result, called)
def test_shared_expert_with_gate(self):
"""shared_expert + shared_expert_gate applies sigmoid gating."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
H = 8
expert_out = torch.ones(4, H)
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731
moe_block = SimpleNamespace(
shared_expert=lambda x: expert_out,
shared_expert_gate=gate_fn,
)
result = _compute_shared_expert(moe_block, torch.randn(4, H))
expected = expert_out * 0.5 # sigmoid(0) = 0.5
assert torch.allclose(result, expected, atol=1e-6)
def test_no_shared_expert(self):
"""No shared expert attributes returns None."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
moe_block = SimpleNamespace()
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
assert result is None