consolidate behavioud of routing in scattermoe kernels (#3475)
* consolidate behavioud of routing in scattermoe kernels * collect telemetry on best chosen autotuned kernel * properly collect data * Fix property name and get smem too * handle issues raised by coderabbit * add tests for parity before refactoring
This commit is contained in:
@@ -12,6 +12,7 @@ Tests verify correctness of:
|
||||
3. Frozen weights: expert weight gradients are correctly skipped
|
||||
4. Various configurations: top-k, grouped_in/out, with/without bias
|
||||
5. Numerical stability: bf16/fp16 outputs within tolerance of fp32 reference
|
||||
6. HFScatterMoEGatedMLP with sigmoid routing (GLM/DeepSeek/MiniMax M2)
|
||||
|
||||
Test strategy:
|
||||
- Reference implementation uses pure PyTorch ops (no Triton)
|
||||
@@ -19,6 +20,8 @@ Test strategy:
|
||||
- Tolerances account for tf32 accumulation in Triton kernels
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -1476,3 +1479,347 @@ class TestCombinedOptimizations:
|
||||
torch.testing.assert_close(X1.grad, X2.grad, atol=5e-2, rtol=5e-2)
|
||||
torch.testing.assert_close(A1.grad, A2.grad, atol=5e-2, rtol=5e-2)
|
||||
torch.testing.assert_close(B1.grad, B2.grad, atol=5e-2, rtol=5e-2)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: HFScatterMoEGatedMLP with Sigmoid Routing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _reference_moe_forward(
|
||||
hidden_states,
|
||||
gate_weight,
|
||||
gate_up_proj,
|
||||
down_proj,
|
||||
act_fn,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
num_experts,
|
||||
):
|
||||
"""Pure PyTorch reference for a full MoE forward pass.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H]
|
||||
gate_weight: [E, H]
|
||||
gate_up_proj: [E, 2*FF, H]
|
||||
down_proj: [E, H, FF]
|
||||
act_fn: activation function (e.g. torch.nn.SiLU())
|
||||
routing_weights: [T, K] routing weights
|
||||
selected_experts: [T, K] expert indices
|
||||
num_experts: int
|
||||
|
||||
Returns:
|
||||
output: [T, H]
|
||||
"""
|
||||
T, H = hidden_states.shape
|
||||
K = selected_experts.shape[1]
|
||||
output = torch.zeros(T, H, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
|
||||
for t in range(T):
|
||||
for j in range(K):
|
||||
e = selected_experts[t, j].item()
|
||||
w = routing_weights[t, j].item()
|
||||
|
||||
# gate_up projection
|
||||
gup = hidden_states[t] @ gate_up_proj[e].T # [2*I]
|
||||
I_dim = gup.shape[0] // 2
|
||||
gates = gup[:I_dim]
|
||||
up = gup[I_dim:]
|
||||
|
||||
# activation
|
||||
h = act_fn(gates) * up
|
||||
|
||||
# down projection
|
||||
out = h @ down_proj[e].T # [H]
|
||||
|
||||
output[t] += w * out
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _make_mock_sigmoid_moe_block(
|
||||
T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True
|
||||
):
|
||||
"""Create a mock MoE block with sigmoid routing for GPU testing."""
|
||||
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
|
||||
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
|
||||
act_fn = torch.nn.SiLU()
|
||||
|
||||
experts = SimpleNamespace(
|
||||
gate_up_proj=gate_up_proj,
|
||||
down_proj=down_proj,
|
||||
act_fn=act_fn,
|
||||
num_experts=E,
|
||||
)
|
||||
|
||||
if bias_on_gate:
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
e_score_correction_bias=torch.zeros(E, device="cuda"),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
else:
|
||||
# minimax_m2 style
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
top_k=K,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
top_k=K,
|
||||
e_score_correction_bias=torch.zeros(E, device="cuda"),
|
||||
)
|
||||
|
||||
return moe_block, T, H, FF, E, K
|
||||
|
||||
|
||||
class TestHFScatterMoESigmoidRouting:
|
||||
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
|
||||
|
||||
def test_forward_matches_reference_bias_on_gate(self):
|
||||
"""Forward pass with sigmoid routing (bias on gate) matches reference."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block(
|
||||
T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True
|
||||
)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
|
||||
# Get routing for reference
|
||||
gate = moe_block.gate
|
||||
hidden_flat = hidden.view(-1, H)
|
||||
routing_weights, selected_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden_flat, gate.weight, None
|
||||
)
|
||||
|
||||
# Reference output
|
||||
ref_output = _reference_moe_forward(
|
||||
hidden_flat,
|
||||
gate.weight,
|
||||
moe_block.experts.gate_up_proj,
|
||||
moe_block.experts.down_proj,
|
||||
moe_block.experts.act_fn,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
E,
|
||||
)
|
||||
|
||||
# Kernel output
|
||||
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
kernel_output_flat = kernel_output.view(-1, H)
|
||||
|
||||
torch.testing.assert_close(
|
||||
kernel_output_flat.float(),
|
||||
ref_output.float(),
|
||||
atol=5e-2,
|
||||
rtol=5e-2,
|
||||
)
|
||||
|
||||
def test_forward_matches_reference_bias_on_block(self):
|
||||
"""Forward pass with sigmoid routing (minimax_m2 style, bias on block)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block(
|
||||
T=16, H=64, FF=32, E=8, K=2, n_group=1, bias_on_gate=False
|
||||
)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
hidden_flat = hidden.view(-1, H)
|
||||
|
||||
gate = moe_block.gate
|
||||
routing_weights, selected_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden_flat, gate.weight, None
|
||||
)
|
||||
|
||||
ref_output = _reference_moe_forward(
|
||||
hidden_flat,
|
||||
gate.weight,
|
||||
moe_block.experts.gate_up_proj,
|
||||
moe_block.experts.down_proj,
|
||||
moe_block.experts.act_fn,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
E,
|
||||
)
|
||||
|
||||
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
kernel_output_flat = kernel_output.view(-1, H)
|
||||
|
||||
torch.testing.assert_close(
|
||||
kernel_output_flat.float(),
|
||||
ref_output.float(),
|
||||
atol=5e-2,
|
||||
rtol=5e-2,
|
||||
)
|
||||
|
||||
def test_softmax_routing_still_works(self):
|
||||
"""Verify softmax routing (Qwen/OLMoE) is not broken."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
_softmax_topk_route,
|
||||
)
|
||||
|
||||
T, H, FF, E, K = 16, 64, 32, 4, 2
|
||||
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
|
||||
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
|
||||
act_fn = torch.nn.SiLU()
|
||||
|
||||
experts = SimpleNamespace(
|
||||
gate_up_proj=gate_up_proj,
|
||||
down_proj=down_proj,
|
||||
act_fn=act_fn,
|
||||
num_experts=E,
|
||||
)
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
moe_block = SimpleNamespace(gate=gate, experts=experts)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
hidden_flat = hidden.view(-1, H)
|
||||
|
||||
routing_weights, selected_experts, _, _ = _softmax_topk_route(
|
||||
moe_block, gate, hidden_flat, gate.weight, None
|
||||
)
|
||||
|
||||
ref_output = _reference_moe_forward(
|
||||
hidden_flat,
|
||||
gate.weight,
|
||||
gate_up_proj,
|
||||
down_proj,
|
||||
act_fn,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
E,
|
||||
)
|
||||
|
||||
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
kernel_output_flat = kernel_output.view(-1, H)
|
||||
|
||||
torch.testing.assert_close(
|
||||
kernel_output_flat.float(),
|
||||
ref_output.float(),
|
||||
atol=5e-2,
|
||||
rtol=5e-2,
|
||||
)
|
||||
|
||||
|
||||
class TestHFScatterMoESigmoidWithSharedExperts:
|
||||
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""
|
||||
|
||||
def test_shared_experts_plural(self):
|
||||
"""DeepSeek V3 style: shared_experts attribute (plural)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
)
|
||||
|
||||
T, H, FF, E, K = 8, 64, 32, 8, 2
|
||||
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
|
||||
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
|
||||
act_fn = torch.nn.SiLU()
|
||||
|
||||
experts = SimpleNamespace(
|
||||
gate_up_proj=gate_up_proj,
|
||||
down_proj=down_proj,
|
||||
act_fn=act_fn,
|
||||
num_experts=E,
|
||||
)
|
||||
|
||||
# Shared expert as a simple linear for testing
|
||||
shared_W = torch.randn(H, H, device="cuda") * 0.01
|
||||
shared_experts_fn = lambda x: x @ shared_W.T # noqa: E731
|
||||
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
e_score_correction_bias=torch.zeros(E, device="cuda"),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
shared_experts=shared_experts_fn,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=1,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
|
||||
# Should not raise; output should include shared expert contribution
|
||||
output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
assert output.shape == (1, T, H)
|
||||
|
||||
# Run without shared expert to verify it changes the output
|
||||
moe_block_no_shared = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=1,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
output_no_shared = HFScatterMoEGatedMLP.forward(moe_block_no_shared, hidden)
|
||||
assert not torch.equal(output, output_no_shared)
|
||||
|
||||
def test_shared_expert_with_gate(self):
|
||||
"""Qwen2MoE style: shared_expert + shared_expert_gate."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
)
|
||||
|
||||
T, H, FF, E, K = 8, 64, 32, 4, 2
|
||||
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
|
||||
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
|
||||
act_fn = torch.nn.SiLU()
|
||||
|
||||
experts = SimpleNamespace(
|
||||
gate_up_proj=gate_up_proj,
|
||||
down_proj=down_proj,
|
||||
act_fn=act_fn,
|
||||
num_experts=E,
|
||||
)
|
||||
|
||||
shared_W = torch.randn(H, H, device="cuda") * 0.01
|
||||
shared_expert_fn = lambda x: x @ shared_W.T # noqa: E731
|
||||
# Gate that returns 0 -> sigmoid(0) = 0.5
|
||||
gate_W = torch.zeros(H, H, device="cuda")
|
||||
shared_expert_gate_fn = lambda x: x @ gate_W.T # noqa: E731
|
||||
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
shared_expert=shared_expert_fn,
|
||||
shared_expert_gate=shared_expert_gate_fn,
|
||||
)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
assert output.shape == (1, T, H)
|
||||
|
||||
Reference in New Issue
Block a user