* consolidate behavioud of routing in scattermoe kernels * collect telemetry on best chosen autotuned kernel * properly collect data * Fix property name and get smem too * handle issues raised by coderabbit * add tests for parity before refactoring
475 lines
17 KiB
Python
475 lines
17 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# Copyright (c) Axolotl AI
|
|
# Licensed under the Apache License, Version 2.0
|
|
|
|
"""
|
|
Parity tests between scattermoe-lora and sonicmoe routing implementations.
|
|
|
|
These tests verify that both implementations produce numerically identical
|
|
results for the same inputs, ensuring safe centralization of the routing code.
|
|
|
|
ScatterMoE returns 2D tensors [T, K]; SonicMoE returns flattened 1D [T*K].
|
|
The core algorithm should be identical — only the output format differs.
|
|
"""
|
|
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
|
|
def _require_triton():
|
|
pytest.importorskip("triton")
|
|
|
|
|
|
# ============================================================================
|
|
# Fixtures / helpers
|
|
# ============================================================================
|
|
|
|
|
|
def _make_softmax_block(T=8, H=16, E=4, K=2):
|
|
"""Qwen/OLMoE-style block usable by both implementations."""
|
|
gate = SimpleNamespace(
|
|
weight=torch.randn(E, H),
|
|
top_k=K,
|
|
num_experts=E,
|
|
norm_topk_prob=True,
|
|
)
|
|
moe_block = SimpleNamespace(gate=gate)
|
|
hidden = torch.randn(T, H)
|
|
return moe_block, gate, hidden, T, H, E, K
|
|
|
|
|
|
def _make_sigmoid_block(
|
|
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
|
):
|
|
"""GLM/DeepSeek-style block usable by both implementations."""
|
|
if bias_on_gate:
|
|
gate = SimpleNamespace(
|
|
weight=torch.randn(E, H),
|
|
e_score_correction_bias=torch.zeros(E),
|
|
)
|
|
moe_block = SimpleNamespace(
|
|
gate=gate,
|
|
top_k=K,
|
|
n_routed_experts=E,
|
|
n_group=n_group,
|
|
topk_group=topk_group,
|
|
norm_topk_prob=True,
|
|
routed_scaling_factor=1.0,
|
|
)
|
|
else:
|
|
# minimax_m2 style: bias on block
|
|
gate = SimpleNamespace(
|
|
weight=torch.randn(E, H),
|
|
top_k=K,
|
|
)
|
|
moe_block = SimpleNamespace(
|
|
gate=gate,
|
|
top_k=K,
|
|
e_score_correction_bias=torch.zeros(E),
|
|
)
|
|
return moe_block, gate, hidden_states(T, H), T, H, E, K
|
|
|
|
|
|
def hidden_states(T, H):
|
|
return torch.randn(T, H)
|
|
|
|
|
|
# ============================================================================
|
|
# 1. Softmax routing parity
|
|
# ============================================================================
|
|
|
|
|
|
class TestSoftmaxRoutingParity:
|
|
"""Verify scattermoe and sonicmoe softmax routing produce identical results."""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _require(self):
|
|
_require_triton()
|
|
|
|
def test_weights_match(self):
|
|
"""2D weights from scattermoe == reshaped 1D weights from sonicmoe."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_softmax_topk_route,
|
|
)
|
|
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
|
|
|
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
|
|
|
# ScatterMoE path (no LoRA delta)
|
|
sm_weights, sm_experts, sm_topk, sm_E = _softmax_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
|
|
# SonicMoE path
|
|
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = softmax_topk_routing(
|
|
hidden, moe_block
|
|
)
|
|
|
|
# ScatterMoE returns [T, K], SonicMoE returns [T*K] flattened
|
|
sonic_weights_2d = sonic_scores.reshape(T, K)
|
|
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
|
|
|
assert sm_topk == K
|
|
assert sm_E == E
|
|
|
|
# Both should select the same experts and produce the same weights
|
|
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
|
|
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
|
|
|
|
def test_logits_not_returned_by_scattermoe(self):
|
|
"""ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape."""
|
|
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
|
|
|
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
|
_, _, _, logits = softmax_topk_routing(hidden, moe_block)
|
|
assert logits.shape == (T, E)
|
|
|
|
def test_no_renorm(self):
|
|
"""With norm_topk_prob=False, both should skip renormalization."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_softmax_topk_route,
|
|
)
|
|
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
|
|
|
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
|
gate.norm_topk_prob = False
|
|
|
|
sm_weights, sm_experts, _, _ = _softmax_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
|
|
|
|
sonic_weights_2d = sonic_scores.reshape(T, K)
|
|
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
|
|
|
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
|
|
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
|
|
|
|
def test_various_expert_counts(self):
|
|
"""Parity across different E and K values."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_softmax_topk_route,
|
|
)
|
|
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
|
|
|
for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]:
|
|
moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K)
|
|
|
|
sm_weights, sm_experts, _, _ = _softmax_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
|
|
|
|
sonic_weights_2d = sonic_scores.reshape(T, K)
|
|
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
|
|
|
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)), (
|
|
f"Expert mismatch for E={E}, K={K}"
|
|
)
|
|
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6), (
|
|
f"Weight mismatch for E={E}, K={K}"
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# 2. Sigmoid routing parity
|
|
# ============================================================================
|
|
|
|
|
|
class TestSigmoidRoutingParity:
|
|
"""Verify scattermoe and sonicmoe sigmoid routing produce identical results."""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _require(self):
|
|
_require_triton()
|
|
|
|
def test_weights_match_with_groups(self):
|
|
"""Both implementations should produce identical weights with group selection."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
|
|
|
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
|
E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
|
)
|
|
|
|
sm_weights, sm_experts, sm_topk, sm_E = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
|
|
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = sigmoid_topk_routing(
|
|
hidden, moe_block
|
|
)
|
|
|
|
sonic_weights_2d = sonic_scores.reshape(T, K)
|
|
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
|
|
|
assert sm_topk == K
|
|
assert sm_E == E
|
|
|
|
# Sort experts within each token to handle different topk orderings
|
|
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
|
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
|
|
|
assert torch.equal(sm_sorted, sonic_sorted)
|
|
|
|
# Gather weights in sorted order for comparison
|
|
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
|
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
|
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
|
|
|
def test_weights_match_no_groups(self):
|
|
"""Both implementations match without group selection (n_group=1)."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
|
|
|
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
|
E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True
|
|
)
|
|
|
|
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
|
|
|
sonic_weights_2d = sonic_scores.reshape(T, K)
|
|
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
|
|
|
# Sort for comparison (topk with sorted=False may differ in order)
|
|
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
|
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
|
|
|
assert torch.equal(sm_sorted, sonic_sorted)
|
|
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
|
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
|
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
|
|
|
def test_bias_on_block_parity(self):
|
|
"""minimax_m2 style: bias on block, not gate."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
|
|
|
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
|
E=16, K=4, n_group=1, bias_on_gate=False
|
|
)
|
|
|
|
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
|
|
|
sonic_weights_2d = sonic_scores.reshape(T, K)
|
|
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
|
|
|
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
|
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
|
|
|
assert torch.equal(sm_sorted, sonic_sorted)
|
|
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
|
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
|
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
|
|
|
def test_scaling_factor_parity(self):
|
|
"""routed_scaling_factor applied identically by both."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
|
|
|
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
|
n_group=1, bias_on_gate=True
|
|
)
|
|
moe_block.routed_scaling_factor = 2.5
|
|
|
|
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
|
|
|
sonic_weights_2d = sonic_scores.reshape(T, K)
|
|
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
|
|
|
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
|
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
|
|
|
assert torch.equal(sm_sorted, sonic_sorted)
|
|
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
|
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
|
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
|
|
|
def test_no_renorm_parity(self):
|
|
"""norm_topk_prob=False produces same results in both."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
|
|
|
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
|
n_group=1, bias_on_gate=True
|
|
)
|
|
moe_block.norm_topk_prob = False
|
|
|
|
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
|
|
|
sonic_weights_2d = sonic_scores.reshape(T, K)
|
|
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
|
|
|
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
|
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
|
|
|
assert torch.equal(sm_sorted, sonic_sorted)
|
|
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
|
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
|
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
|
|
|
|
|
# ============================================================================
|
|
# 3. Shared expert parity
|
|
# ============================================================================
|
|
|
|
|
|
class TestSharedExpertParity:
|
|
"""Verify both _compute_shared_expert implementations behave identically."""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _require(self):
|
|
_require_triton()
|
|
|
|
def _get_both_fns(self):
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_compute_shared_expert as scatter_compute,
|
|
)
|
|
from axolotl.integrations.kernels.sonicmoe.patch import (
|
|
_compute_shared_expert as sonic_compute,
|
|
)
|
|
|
|
return scatter_compute, sonic_compute
|
|
|
|
def test_shared_expert_singular(self):
|
|
scatter_fn, sonic_fn = self._get_both_fns()
|
|
out = torch.randn(4, 8)
|
|
block = SimpleNamespace(shared_expert=lambda x: out)
|
|
hidden = torch.randn(4, 8)
|
|
|
|
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
|
|
|
|
def test_shared_experts_plural(self):
|
|
scatter_fn, sonic_fn = self._get_both_fns()
|
|
out = torch.randn(4, 8)
|
|
block = SimpleNamespace(shared_experts=lambda x: out)
|
|
hidden = torch.randn(4, 8)
|
|
|
|
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
|
|
|
|
def test_shared_mlp(self):
|
|
scatter_fn, sonic_fn = self._get_both_fns()
|
|
out = torch.randn(4, 8)
|
|
block = SimpleNamespace(shared_mlp=lambda x: out)
|
|
hidden = torch.randn(4, 8)
|
|
|
|
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
|
|
|
|
def test_no_shared_expert(self):
|
|
scatter_fn, sonic_fn = self._get_both_fns()
|
|
block = SimpleNamespace()
|
|
hidden = torch.randn(4, 8)
|
|
|
|
assert scatter_fn(block, hidden) is None
|
|
assert sonic_fn(block, hidden) is None
|
|
|
|
def test_shared_expert_gate_only_in_scattermoe(self):
|
|
"""ScatterMoE's _compute_shared_expert handles shared_expert_gate;
|
|
SonicMoE's patch.py handles it externally in the forward function.
|
|
|
|
This documents the known divergence: the scattermoe version applies
|
|
sigmoid gating inline, while sonicmoe applies it in the forward.
|
|
"""
|
|
scatter_fn, sonic_fn = self._get_both_fns()
|
|
|
|
H = 8
|
|
expert_out = torch.ones(4, H)
|
|
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731 # sigmoid(0) = 0.5
|
|
|
|
block = SimpleNamespace(
|
|
shared_expert=lambda x: expert_out,
|
|
shared_expert_gate=gate_fn,
|
|
)
|
|
hidden = torch.randn(4, H)
|
|
|
|
scatter_result = scatter_fn(block, hidden)
|
|
sonic_result = sonic_fn(block, hidden)
|
|
|
|
# ScatterMoE applies the gate: expert_out * sigmoid(0) = 0.5
|
|
expected_gated = expert_out * 0.5
|
|
assert torch.allclose(scatter_result, expected_gated, atol=1e-6)
|
|
|
|
# SonicMoE does NOT apply the gate here (it does it in the forward)
|
|
assert torch.equal(sonic_result, expert_out)
|
|
|
|
|
|
# ============================================================================
|
|
# 4. Route dispatcher parity
|
|
# ============================================================================
|
|
|
|
|
|
class TestRouteDispatcherParity:
|
|
"""Verify _route in scattermoe dispatches correctly and matches individual fns."""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _require(self):
|
|
_require_triton()
|
|
|
|
def test_route_dispatches_softmax(self):
|
|
"""_route should use softmax when no e_score_correction_bias."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_route,
|
|
_softmax_topk_route,
|
|
)
|
|
|
|
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
|
|
|
route_w, route_e, route_k, route_E = _route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
direct_w, direct_e, direct_k, direct_E = _softmax_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
|
|
assert torch.equal(route_w, direct_w)
|
|
assert torch.equal(route_e, direct_e)
|
|
assert route_k == direct_k
|
|
assert route_E == direct_E
|
|
|
|
def test_route_dispatches_sigmoid(self):
|
|
"""_route should use sigmoid when e_score_correction_bias is present."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_route,
|
|
_sigmoid_topk_route,
|
|
)
|
|
|
|
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
|
n_group=1, bias_on_gate=True
|
|
)
|
|
|
|
route_w, route_e, route_k, route_E = _route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
direct_w, direct_e, direct_k, direct_E = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
|
|
assert torch.equal(route_w, direct_w)
|
|
assert torch.equal(route_e, direct_e)
|
|
assert route_k == direct_k
|
|
assert route_E == direct_E
|