consolidate behavioud of routing in scattermoe kernels (#3475)
* consolidate behavioud of routing in scattermoe kernels * collect telemetry on best chosen autotuned kernel * properly collect data * Fix property name and get smem too * handle issues raised by coderabbit * add tests for parity before refactoring
This commit is contained in:
474
tests/integrations/test_routing_parity.py
Normal file
474
tests/integrations/test_routing_parity.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
Parity tests between scattermoe-lora and sonicmoe routing implementations.
|
||||
|
||||
These tests verify that both implementations produce numerically identical
|
||||
results for the same inputs, ensuring safe centralization of the routing code.
|
||||
|
||||
ScatterMoE returns 2D tensors [T, K]; SonicMoE returns flattened 1D [T*K].
|
||||
The core algorithm should be identical — only the output format differs.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def _require_triton():
|
||||
pytest.importorskip("triton")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures / helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_softmax_block(T=8, H=16, E=4, K=2):
|
||||
"""Qwen/OLMoE-style block usable by both implementations."""
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(T, H)
|
||||
return moe_block, gate, hidden, T, H, E, K
|
||||
|
||||
|
||||
def _make_sigmoid_block(
|
||||
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
||||
):
|
||||
"""GLM/DeepSeek-style block usable by both implementations."""
|
||||
if bias_on_gate:
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
else:
|
||||
# minimax_m2 style: bias on block
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
return moe_block, gate, hidden_states(T, H), T, H, E, K
|
||||
|
||||
|
||||
def hidden_states(T, H):
|
||||
return torch.randn(T, H)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 1. Softmax routing parity
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSoftmaxRoutingParity:
|
||||
"""Verify scattermoe and sonicmoe softmax routing produce identical results."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require(self):
|
||||
_require_triton()
|
||||
|
||||
def test_weights_match(self):
|
||||
"""2D weights from scattermoe == reshaped 1D weights from sonicmoe."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
|
||||
# ScatterMoE path (no LoRA delta)
|
||||
sm_weights, sm_experts, sm_topk, sm_E = _softmax_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
# SonicMoE path
|
||||
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = softmax_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
# ScatterMoE returns [T, K], SonicMoE returns [T*K] flattened
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
assert sm_topk == K
|
||||
assert sm_E == E
|
||||
|
||||
# Both should select the same experts and produce the same weights
|
||||
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
|
||||
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
|
||||
|
||||
def test_logits_not_returned_by_scattermoe(self):
|
||||
"""ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
_, _, _, logits = softmax_topk_routing(hidden, moe_block)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_no_renorm(self):
|
||||
"""With norm_topk_prob=False, both should skip renormalization."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
gate.norm_topk_prob = False
|
||||
|
||||
sm_weights, sm_experts, _, _ = _softmax_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
|
||||
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
|
||||
|
||||
def test_various_expert_counts(self):
|
||||
"""Parity across different E and K values."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]:
|
||||
moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K)
|
||||
|
||||
sm_weights, sm_experts, _, _ = _softmax_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)), (
|
||||
f"Expert mismatch for E={E}, K={K}"
|
||||
)
|
||||
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6), (
|
||||
f"Weight mismatch for E={E}, K={K}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 2. Sigmoid routing parity
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSigmoidRoutingParity:
|
||||
"""Verify scattermoe and sonicmoe sigmoid routing produce identical results."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require(self):
|
||||
_require_triton()
|
||||
|
||||
def test_weights_match_with_groups(self):
|
||||
"""Both implementations should produce identical weights with group selection."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
||||
)
|
||||
|
||||
sm_weights, sm_experts, sm_topk, sm_E = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = sigmoid_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
assert sm_topk == K
|
||||
assert sm_E == E
|
||||
|
||||
# Sort experts within each token to handle different topk orderings
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
|
||||
# Gather weights in sorted order for comparison
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
def test_weights_match_no_groups(self):
|
||||
"""Both implementations match without group selection (n_group=1)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True
|
||||
)
|
||||
|
||||
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
# Sort for comparison (topk with sorted=False may differ in order)
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
def test_bias_on_block_parity(self):
|
||||
"""minimax_m2 style: bias on block, not gate."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=1, bias_on_gate=False
|
||||
)
|
||||
|
||||
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
def test_scaling_factor_parity(self):
|
||||
"""routed_scaling_factor applied identically by both."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
)
|
||||
moe_block.routed_scaling_factor = 2.5
|
||||
|
||||
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
def test_no_renorm_parity(self):
|
||||
"""norm_topk_prob=False produces same results in both."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
)
|
||||
moe_block.norm_topk_prob = False
|
||||
|
||||
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 3. Shared expert parity
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSharedExpertParity:
|
||||
"""Verify both _compute_shared_expert implementations behave identically."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require(self):
|
||||
_require_triton()
|
||||
|
||||
def _get_both_fns(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert as scatter_compute,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import (
|
||||
_compute_shared_expert as sonic_compute,
|
||||
)
|
||||
|
||||
return scatter_compute, sonic_compute
|
||||
|
||||
def test_shared_expert_singular(self):
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
out = torch.randn(4, 8)
|
||||
block = SimpleNamespace(shared_expert=lambda x: out)
|
||||
hidden = torch.randn(4, 8)
|
||||
|
||||
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
|
||||
|
||||
def test_shared_experts_plural(self):
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
out = torch.randn(4, 8)
|
||||
block = SimpleNamespace(shared_experts=lambda x: out)
|
||||
hidden = torch.randn(4, 8)
|
||||
|
||||
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
|
||||
|
||||
def test_shared_mlp(self):
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
out = torch.randn(4, 8)
|
||||
block = SimpleNamespace(shared_mlp=lambda x: out)
|
||||
hidden = torch.randn(4, 8)
|
||||
|
||||
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
|
||||
|
||||
def test_no_shared_expert(self):
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
block = SimpleNamespace()
|
||||
hidden = torch.randn(4, 8)
|
||||
|
||||
assert scatter_fn(block, hidden) is None
|
||||
assert sonic_fn(block, hidden) is None
|
||||
|
||||
def test_shared_expert_gate_only_in_scattermoe(self):
|
||||
"""ScatterMoE's _compute_shared_expert handles shared_expert_gate;
|
||||
SonicMoE's patch.py handles it externally in the forward function.
|
||||
|
||||
This documents the known divergence: the scattermoe version applies
|
||||
sigmoid gating inline, while sonicmoe applies it in the forward.
|
||||
"""
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
|
||||
H = 8
|
||||
expert_out = torch.ones(4, H)
|
||||
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731 # sigmoid(0) = 0.5
|
||||
|
||||
block = SimpleNamespace(
|
||||
shared_expert=lambda x: expert_out,
|
||||
shared_expert_gate=gate_fn,
|
||||
)
|
||||
hidden = torch.randn(4, H)
|
||||
|
||||
scatter_result = scatter_fn(block, hidden)
|
||||
sonic_result = sonic_fn(block, hidden)
|
||||
|
||||
# ScatterMoE applies the gate: expert_out * sigmoid(0) = 0.5
|
||||
expected_gated = expert_out * 0.5
|
||||
assert torch.allclose(scatter_result, expected_gated, atol=1e-6)
|
||||
|
||||
# SonicMoE does NOT apply the gate here (it does it in the forward)
|
||||
assert torch.equal(sonic_result, expert_out)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 4. Route dispatcher parity
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestRouteDispatcherParity:
|
||||
"""Verify _route in scattermoe dispatches correctly and matches individual fns."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require(self):
|
||||
_require_triton()
|
||||
|
||||
def test_route_dispatches_softmax(self):
|
||||
"""_route should use softmax when no e_score_correction_bias."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_route,
|
||||
_softmax_topk_route,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
|
||||
route_w, route_e, route_k, route_E = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
direct_w, direct_e, direct_k, direct_E = _softmax_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert torch.equal(route_w, direct_w)
|
||||
assert torch.equal(route_e, direct_e)
|
||||
assert route_k == direct_k
|
||||
assert route_E == direct_E
|
||||
|
||||
def test_route_dispatches_sigmoid(self):
|
||||
"""_route should use sigmoid when e_score_correction_bias is present."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_route,
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
)
|
||||
|
||||
route_w, route_e, route_k, route_E = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
direct_w, direct_e, direct_k, direct_E = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert torch.equal(route_w, direct_w)
|
||||
assert torch.equal(route_e, direct_e)
|
||||
assert route_k == direct_k
|
||||
assert route_E == direct_E
|
||||
Reference in New Issue
Block a user