Files
axolotl/tests/integrations/test_routing_parity.py
Wing Lian 8f3fb517b3 consolidate behavioud of routing in scattermoe kernels (#3475)
* consolidate behavioud of routing in scattermoe kernels

* collect telemetry on best chosen autotuned kernel

* properly collect data

* Fix property name and get smem too

* handle issues raised by coderabbit

* add tests for parity before refactoring
2026-03-16 23:47:40 -04:00

475 lines
17 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
Parity tests between scattermoe-lora and sonicmoe routing implementations.
These tests verify that both implementations produce numerically identical
results for the same inputs, ensuring safe centralization of the routing code.
ScatterMoE returns 2D tensors [T, K]; SonicMoE returns flattened 1D [T*K].
The core algorithm should be identical — only the output format differs.
"""
from types import SimpleNamespace
import pytest
import torch
def _require_triton():
pytest.importorskip("triton")
# ============================================================================
# Fixtures / helpers
# ============================================================================
def _make_softmax_block(T=8, H=16, E=4, K=2):
"""Qwen/OLMoE-style block usable by both implementations."""
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(T, H)
return moe_block, gate, hidden, T, H, E, K
def _make_sigmoid_block(
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
):
"""GLM/DeepSeek-style block usable by both implementations."""
if bias_on_gate:
gate = SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=n_group,
topk_group=topk_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
else:
# minimax_m2 style: bias on block
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
e_score_correction_bias=torch.zeros(E),
)
return moe_block, gate, hidden_states(T, H), T, H, E, K
def hidden_states(T, H):
return torch.randn(T, H)
# ============================================================================
# 1. Softmax routing parity
# ============================================================================
class TestSoftmaxRoutingParity:
"""Verify scattermoe and sonicmoe softmax routing produce identical results."""
@pytest.fixture(autouse=True)
def _require(self):
_require_triton()
def test_weights_match(self):
"""2D weights from scattermoe == reshaped 1D weights from sonicmoe."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
# ScatterMoE path (no LoRA delta)
sm_weights, sm_experts, sm_topk, sm_E = _softmax_topk_route(
moe_block, gate, hidden, gate.weight, None
)
# SonicMoE path
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = softmax_topk_routing(
hidden, moe_block
)
# ScatterMoE returns [T, K], SonicMoE returns [T*K] flattened
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
assert sm_topk == K
assert sm_E == E
# Both should select the same experts and produce the same weights
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
def test_logits_not_returned_by_scattermoe(self):
"""ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
_, _, _, logits = softmax_topk_routing(hidden, moe_block)
assert logits.shape == (T, E)
def test_no_renorm(self):
"""With norm_topk_prob=False, both should skip renormalization."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
gate.norm_topk_prob = False
sm_weights, sm_experts, _, _ = _softmax_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
def test_various_expert_counts(self):
"""Parity across different E and K values."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]:
moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K)
sm_weights, sm_experts, _, _ = _softmax_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)), (
f"Expert mismatch for E={E}, K={K}"
)
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6), (
f"Weight mismatch for E={E}, K={K}"
)
# ============================================================================
# 2. Sigmoid routing parity
# ============================================================================
class TestSigmoidRoutingParity:
"""Verify scattermoe and sonicmoe sigmoid routing produce identical results."""
@pytest.fixture(autouse=True)
def _require(self):
_require_triton()
def test_weights_match_with_groups(self):
"""Both implementations should produce identical weights with group selection."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
)
sm_weights, sm_experts, sm_topk, sm_E = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = sigmoid_topk_routing(
hidden, moe_block
)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
assert sm_topk == K
assert sm_E == E
# Sort experts within each token to handle different topk orderings
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
# Gather weights in sorted order for comparison
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
def test_weights_match_no_groups(self):
"""Both implementations match without group selection (n_group=1)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True
)
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
# Sort for comparison (topk with sorted=False may differ in order)
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
def test_bias_on_block_parity(self):
"""minimax_m2 style: bias on block, not gate."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=1, bias_on_gate=False
)
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
def test_scaling_factor_parity(self):
"""routed_scaling_factor applied identically by both."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
)
moe_block.routed_scaling_factor = 2.5
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
def test_no_renorm_parity(self):
"""norm_topk_prob=False produces same results in both."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
)
moe_block.norm_topk_prob = False
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
# ============================================================================
# 3. Shared expert parity
# ============================================================================
class TestSharedExpertParity:
"""Verify both _compute_shared_expert implementations behave identically."""
@pytest.fixture(autouse=True)
def _require(self):
_require_triton()
def _get_both_fns(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert as scatter_compute,
)
from axolotl.integrations.kernels.sonicmoe.patch import (
_compute_shared_expert as sonic_compute,
)
return scatter_compute, sonic_compute
def test_shared_expert_singular(self):
scatter_fn, sonic_fn = self._get_both_fns()
out = torch.randn(4, 8)
block = SimpleNamespace(shared_expert=lambda x: out)
hidden = torch.randn(4, 8)
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
def test_shared_experts_plural(self):
scatter_fn, sonic_fn = self._get_both_fns()
out = torch.randn(4, 8)
block = SimpleNamespace(shared_experts=lambda x: out)
hidden = torch.randn(4, 8)
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
def test_shared_mlp(self):
scatter_fn, sonic_fn = self._get_both_fns()
out = torch.randn(4, 8)
block = SimpleNamespace(shared_mlp=lambda x: out)
hidden = torch.randn(4, 8)
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
def test_no_shared_expert(self):
scatter_fn, sonic_fn = self._get_both_fns()
block = SimpleNamespace()
hidden = torch.randn(4, 8)
assert scatter_fn(block, hidden) is None
assert sonic_fn(block, hidden) is None
def test_shared_expert_gate_only_in_scattermoe(self):
"""ScatterMoE's _compute_shared_expert handles shared_expert_gate;
SonicMoE's patch.py handles it externally in the forward function.
This documents the known divergence: the scattermoe version applies
sigmoid gating inline, while sonicmoe applies it in the forward.
"""
scatter_fn, sonic_fn = self._get_both_fns()
H = 8
expert_out = torch.ones(4, H)
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731 # sigmoid(0) = 0.5
block = SimpleNamespace(
shared_expert=lambda x: expert_out,
shared_expert_gate=gate_fn,
)
hidden = torch.randn(4, H)
scatter_result = scatter_fn(block, hidden)
sonic_result = sonic_fn(block, hidden)
# ScatterMoE applies the gate: expert_out * sigmoid(0) = 0.5
expected_gated = expert_out * 0.5
assert torch.allclose(scatter_result, expected_gated, atol=1e-6)
# SonicMoE does NOT apply the gate here (it does it in the forward)
assert torch.equal(sonic_result, expert_out)
# ============================================================================
# 4. Route dispatcher parity
# ============================================================================
class TestRouteDispatcherParity:
"""Verify _route in scattermoe dispatches correctly and matches individual fns."""
@pytest.fixture(autouse=True)
def _require(self):
_require_triton()
def test_route_dispatches_softmax(self):
"""_route should use softmax when no e_score_correction_bias."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_route,
_softmax_topk_route,
)
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
route_w, route_e, route_k, route_E = _route(
moe_block, gate, hidden, gate.weight, None
)
direct_w, direct_e, direct_k, direct_E = _softmax_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert torch.equal(route_w, direct_w)
assert torch.equal(route_e, direct_e)
assert route_k == direct_k
assert route_E == direct_E
def test_route_dispatches_sigmoid(self):
"""_route should use sigmoid when e_score_correction_bias is present."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_route,
_sigmoid_topk_route,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
)
route_w, route_e, route_k, route_E = _route(
moe_block, gate, hidden, gate.weight, None
)
direct_w, direct_e, direct_k, direct_E = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert torch.equal(route_w, direct_w)
assert torch.equal(route_e, direct_e)
assert route_k == direct_k
assert route_E == direct_E