consolidate behavioud of routing in scattermoe kernels (#3475)

* consolidate behavioud of routing in scattermoe kernels

* collect telemetry on best chosen autotuned kernel

* properly collect data

* Fix property name and get smem too

* handle issues raised by coderabbit

* add tests for parity before refactoring
This commit is contained in:
Wing Lian
2026-03-16 23:47:40 -04:00
committed by GitHub
parent 830e9f7eaf
commit 8f3fb517b3
8 changed files with 1988 additions and 35 deletions

View File

@@ -12,6 +12,7 @@ Tests verify correctness of:
3. Frozen weights: expert weight gradients are correctly skipped
4. Various configurations: top-k, grouped_in/out, with/without bias
5. Numerical stability: bf16/fp16 outputs within tolerance of fp32 reference
6. HFScatterMoEGatedMLP with sigmoid routing (GLM/DeepSeek/MiniMax M2)
Test strategy:
- Reference implementation uses pure PyTorch ops (no Triton)
@@ -19,6 +20,8 @@ Test strategy:
- Tolerances account for tf32 accumulation in Triton kernels
"""
from types import SimpleNamespace
import pytest
import torch
@@ -1476,3 +1479,347 @@ class TestCombinedOptimizations:
torch.testing.assert_close(X1.grad, X2.grad, atol=5e-2, rtol=5e-2)
torch.testing.assert_close(A1.grad, A2.grad, atol=5e-2, rtol=5e-2)
torch.testing.assert_close(B1.grad, B2.grad, atol=5e-2, rtol=5e-2)
# =============================================================================
# Test: HFScatterMoEGatedMLP with Sigmoid Routing
# =============================================================================
def _reference_moe_forward(
hidden_states,
gate_weight,
gate_up_proj,
down_proj,
act_fn,
routing_weights,
selected_experts,
num_experts,
):
"""Pure PyTorch reference for a full MoE forward pass.
Args:
hidden_states: [T, H]
gate_weight: [E, H]
gate_up_proj: [E, 2*FF, H]
down_proj: [E, H, FF]
act_fn: activation function (e.g. torch.nn.SiLU())
routing_weights: [T, K] routing weights
selected_experts: [T, K] expert indices
num_experts: int
Returns:
output: [T, H]
"""
T, H = hidden_states.shape
K = selected_experts.shape[1]
output = torch.zeros(T, H, device=hidden_states.device, dtype=hidden_states.dtype)
for t in range(T):
for j in range(K):
e = selected_experts[t, j].item()
w = routing_weights[t, j].item()
# gate_up projection
gup = hidden_states[t] @ gate_up_proj[e].T # [2*I]
I_dim = gup.shape[0] // 2
gates = gup[:I_dim]
up = gup[I_dim:]
# activation
h = act_fn(gates) * up
# down projection
out = h @ down_proj[e].T # [H]
output[t] += w * out
return output
def _make_mock_sigmoid_moe_block(
T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True
):
"""Create a mock MoE block with sigmoid routing for GPU testing."""
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
act_fn = torch.nn.SiLU()
experts = SimpleNamespace(
gate_up_proj=gate_up_proj,
down_proj=down_proj,
act_fn=act_fn,
num_experts=E,
)
if bias_on_gate:
gate = SimpleNamespace(
weight=torch.randn(E, H, device="cuda") * 0.1,
e_score_correction_bias=torch.zeros(E, device="cuda"),
)
moe_block = SimpleNamespace(
gate=gate,
experts=experts,
top_k=K,
n_routed_experts=E,
n_group=n_group,
topk_group=topk_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
else:
# minimax_m2 style
gate = SimpleNamespace(
weight=torch.randn(E, H, device="cuda") * 0.1,
top_k=K,
)
moe_block = SimpleNamespace(
gate=gate,
experts=experts,
top_k=K,
e_score_correction_bias=torch.zeros(E, device="cuda"),
)
return moe_block, T, H, FF, E, K
class TestHFScatterMoESigmoidRouting:
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
def test_forward_matches_reference_bias_on_gate(self):
"""Forward pass with sigmoid routing (bias on gate) matches reference."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
_sigmoid_topk_route,
)
moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block(
T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True
)
hidden = torch.randn(1, T, H, device="cuda")
# Get routing for reference
gate = moe_block.gate
hidden_flat = hidden.view(-1, H)
routing_weights, selected_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden_flat, gate.weight, None
)
# Reference output
ref_output = _reference_moe_forward(
hidden_flat,
gate.weight,
moe_block.experts.gate_up_proj,
moe_block.experts.down_proj,
moe_block.experts.act_fn,
routing_weights,
selected_experts,
E,
)
# Kernel output
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
kernel_output_flat = kernel_output.view(-1, H)
torch.testing.assert_close(
kernel_output_flat.float(),
ref_output.float(),
atol=5e-2,
rtol=5e-2,
)
def test_forward_matches_reference_bias_on_block(self):
"""Forward pass with sigmoid routing (minimax_m2 style, bias on block)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
_sigmoid_topk_route,
)
moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block(
T=16, H=64, FF=32, E=8, K=2, n_group=1, bias_on_gate=False
)
hidden = torch.randn(1, T, H, device="cuda")
hidden_flat = hidden.view(-1, H)
gate = moe_block.gate
routing_weights, selected_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden_flat, gate.weight, None
)
ref_output = _reference_moe_forward(
hidden_flat,
gate.weight,
moe_block.experts.gate_up_proj,
moe_block.experts.down_proj,
moe_block.experts.act_fn,
routing_weights,
selected_experts,
E,
)
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
kernel_output_flat = kernel_output.view(-1, H)
torch.testing.assert_close(
kernel_output_flat.float(),
ref_output.float(),
atol=5e-2,
rtol=5e-2,
)
def test_softmax_routing_still_works(self):
"""Verify softmax routing (Qwen/OLMoE) is not broken."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
_softmax_topk_route,
)
T, H, FF, E, K = 16, 64, 32, 4, 2
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
act_fn = torch.nn.SiLU()
experts = SimpleNamespace(
gate_up_proj=gate_up_proj,
down_proj=down_proj,
act_fn=act_fn,
num_experts=E,
)
gate = SimpleNamespace(
weight=torch.randn(E, H, device="cuda") * 0.1,
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
moe_block = SimpleNamespace(gate=gate, experts=experts)
hidden = torch.randn(1, T, H, device="cuda")
hidden_flat = hidden.view(-1, H)
routing_weights, selected_experts, _, _ = _softmax_topk_route(
moe_block, gate, hidden_flat, gate.weight, None
)
ref_output = _reference_moe_forward(
hidden_flat,
gate.weight,
gate_up_proj,
down_proj,
act_fn,
routing_weights,
selected_experts,
E,
)
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
kernel_output_flat = kernel_output.view(-1, H)
torch.testing.assert_close(
kernel_output_flat.float(),
ref_output.float(),
atol=5e-2,
rtol=5e-2,
)
class TestHFScatterMoESigmoidWithSharedExperts:
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""
def test_shared_experts_plural(self):
"""DeepSeek V3 style: shared_experts attribute (plural)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
)
T, H, FF, E, K = 8, 64, 32, 8, 2
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
act_fn = torch.nn.SiLU()
experts = SimpleNamespace(
gate_up_proj=gate_up_proj,
down_proj=down_proj,
act_fn=act_fn,
num_experts=E,
)
# Shared expert as a simple linear for testing
shared_W = torch.randn(H, H, device="cuda") * 0.01
shared_experts_fn = lambda x: x @ shared_W.T # noqa: E731
gate = SimpleNamespace(
weight=torch.randn(E, H, device="cuda") * 0.1,
e_score_correction_bias=torch.zeros(E, device="cuda"),
)
moe_block = SimpleNamespace(
gate=gate,
experts=experts,
shared_experts=shared_experts_fn,
top_k=K,
n_routed_experts=E,
n_group=1,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
hidden = torch.randn(1, T, H, device="cuda")
# Should not raise; output should include shared expert contribution
output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
assert output.shape == (1, T, H)
# Run without shared expert to verify it changes the output
moe_block_no_shared = SimpleNamespace(
gate=gate,
experts=experts,
top_k=K,
n_routed_experts=E,
n_group=1,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
output_no_shared = HFScatterMoEGatedMLP.forward(moe_block_no_shared, hidden)
assert not torch.equal(output, output_no_shared)
def test_shared_expert_with_gate(self):
"""Qwen2MoE style: shared_expert + shared_expert_gate."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
)
T, H, FF, E, K = 8, 64, 32, 4, 2
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
act_fn = torch.nn.SiLU()
experts = SimpleNamespace(
gate_up_proj=gate_up_proj,
down_proj=down_proj,
act_fn=act_fn,
num_experts=E,
)
shared_W = torch.randn(H, H, device="cuda") * 0.01
shared_expert_fn = lambda x: x @ shared_W.T # noqa: E731
# Gate that returns 0 -> sigmoid(0) = 0.5
gate_W = torch.zeros(H, H, device="cuda")
shared_expert_gate_fn = lambda x: x @ gate_W.T # noqa: E731
gate = SimpleNamespace(
weight=torch.randn(E, H, device="cuda") * 0.1,
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
moe_block = SimpleNamespace(
gate=gate,
experts=experts,
shared_expert=shared_expert_fn,
shared_expert_gate=shared_expert_gate_fn,
)
hidden = torch.randn(1, T, H, device="cuda")
output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
assert output.shape == (1, T, H)