consolidate behavioud of routing in scattermoe kernels (#3475)

* consolidate behavioud of routing in scattermoe kernels

* collect telemetry on best chosen autotuned kernel

* properly collect data

* Fix property name and get smem too

* handle issues raised by coderabbit

* add tests for parity before refactoring
This commit is contained in:
Wing Lian
2026-03-16 23:47:40 -04:00
committed by GitHub
parent 830e9f7eaf
commit 8f3fb517b3
8 changed files with 1988 additions and 35 deletions

View File

@@ -0,0 +1,474 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
Parity tests between scattermoe-lora and sonicmoe routing implementations.
These tests verify that both implementations produce numerically identical
results for the same inputs, ensuring safe centralization of the routing code.
ScatterMoE returns 2D tensors [T, K]; SonicMoE returns flattened 1D [T*K].
The core algorithm should be identical — only the output format differs.
"""
from types import SimpleNamespace
import pytest
import torch
def _require_triton():
pytest.importorskip("triton")
# ============================================================================
# Fixtures / helpers
# ============================================================================
def _make_softmax_block(T=8, H=16, E=4, K=2):
"""Qwen/OLMoE-style block usable by both implementations."""
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(T, H)
return moe_block, gate, hidden, T, H, E, K
def _make_sigmoid_block(
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
):
"""GLM/DeepSeek-style block usable by both implementations."""
if bias_on_gate:
gate = SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=n_group,
topk_group=topk_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
else:
# minimax_m2 style: bias on block
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
e_score_correction_bias=torch.zeros(E),
)
return moe_block, gate, hidden_states(T, H), T, H, E, K
def hidden_states(T, H):
return torch.randn(T, H)
# ============================================================================
# 1. Softmax routing parity
# ============================================================================
class TestSoftmaxRoutingParity:
"""Verify scattermoe and sonicmoe softmax routing produce identical results."""
@pytest.fixture(autouse=True)
def _require(self):
_require_triton()
def test_weights_match(self):
"""2D weights from scattermoe == reshaped 1D weights from sonicmoe."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
# ScatterMoE path (no LoRA delta)
sm_weights, sm_experts, sm_topk, sm_E = _softmax_topk_route(
moe_block, gate, hidden, gate.weight, None
)
# SonicMoE path
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = softmax_topk_routing(
hidden, moe_block
)
# ScatterMoE returns [T, K], SonicMoE returns [T*K] flattened
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
assert sm_topk == K
assert sm_E == E
# Both should select the same experts and produce the same weights
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
def test_logits_not_returned_by_scattermoe(self):
"""ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
_, _, _, logits = softmax_topk_routing(hidden, moe_block)
assert logits.shape == (T, E)
def test_no_renorm(self):
"""With norm_topk_prob=False, both should skip renormalization."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
gate.norm_topk_prob = False
sm_weights, sm_experts, _, _ = _softmax_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
def test_various_expert_counts(self):
"""Parity across different E and K values."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]:
moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K)
sm_weights, sm_experts, _, _ = _softmax_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)), (
f"Expert mismatch for E={E}, K={K}"
)
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6), (
f"Weight mismatch for E={E}, K={K}"
)
# ============================================================================
# 2. Sigmoid routing parity
# ============================================================================
class TestSigmoidRoutingParity:
"""Verify scattermoe and sonicmoe sigmoid routing produce identical results."""
@pytest.fixture(autouse=True)
def _require(self):
_require_triton()
def test_weights_match_with_groups(self):
"""Both implementations should produce identical weights with group selection."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
)
sm_weights, sm_experts, sm_topk, sm_E = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = sigmoid_topk_routing(
hidden, moe_block
)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
assert sm_topk == K
assert sm_E == E
# Sort experts within each token to handle different topk orderings
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
# Gather weights in sorted order for comparison
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
def test_weights_match_no_groups(self):
"""Both implementations match without group selection (n_group=1)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True
)
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
# Sort for comparison (topk with sorted=False may differ in order)
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
def test_bias_on_block_parity(self):
"""minimax_m2 style: bias on block, not gate."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=1, bias_on_gate=False
)
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
def test_scaling_factor_parity(self):
"""routed_scaling_factor applied identically by both."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
)
moe_block.routed_scaling_factor = 2.5
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
def test_no_renorm_parity(self):
"""norm_topk_prob=False produces same results in both."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
)
moe_block.norm_topk_prob = False
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
sonic_weights_2d = sonic_scores.reshape(T, K)
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
sm_sorted, sm_order = sm_experts.sort(dim=-1)
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
assert torch.equal(sm_sorted, sonic_sorted)
sm_weights_sorted = sm_weights.gather(1, sm_order)
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
# ============================================================================
# 3. Shared expert parity
# ============================================================================
class TestSharedExpertParity:
"""Verify both _compute_shared_expert implementations behave identically."""
@pytest.fixture(autouse=True)
def _require(self):
_require_triton()
def _get_both_fns(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert as scatter_compute,
)
from axolotl.integrations.kernels.sonicmoe.patch import (
_compute_shared_expert as sonic_compute,
)
return scatter_compute, sonic_compute
def test_shared_expert_singular(self):
scatter_fn, sonic_fn = self._get_both_fns()
out = torch.randn(4, 8)
block = SimpleNamespace(shared_expert=lambda x: out)
hidden = torch.randn(4, 8)
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
def test_shared_experts_plural(self):
scatter_fn, sonic_fn = self._get_both_fns()
out = torch.randn(4, 8)
block = SimpleNamespace(shared_experts=lambda x: out)
hidden = torch.randn(4, 8)
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
def test_shared_mlp(self):
scatter_fn, sonic_fn = self._get_both_fns()
out = torch.randn(4, 8)
block = SimpleNamespace(shared_mlp=lambda x: out)
hidden = torch.randn(4, 8)
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
def test_no_shared_expert(self):
scatter_fn, sonic_fn = self._get_both_fns()
block = SimpleNamespace()
hidden = torch.randn(4, 8)
assert scatter_fn(block, hidden) is None
assert sonic_fn(block, hidden) is None
def test_shared_expert_gate_only_in_scattermoe(self):
"""ScatterMoE's _compute_shared_expert handles shared_expert_gate;
SonicMoE's patch.py handles it externally in the forward function.
This documents the known divergence: the scattermoe version applies
sigmoid gating inline, while sonicmoe applies it in the forward.
"""
scatter_fn, sonic_fn = self._get_both_fns()
H = 8
expert_out = torch.ones(4, H)
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731 # sigmoid(0) = 0.5
block = SimpleNamespace(
shared_expert=lambda x: expert_out,
shared_expert_gate=gate_fn,
)
hidden = torch.randn(4, H)
scatter_result = scatter_fn(block, hidden)
sonic_result = sonic_fn(block, hidden)
# ScatterMoE applies the gate: expert_out * sigmoid(0) = 0.5
expected_gated = expert_out * 0.5
assert torch.allclose(scatter_result, expected_gated, atol=1e-6)
# SonicMoE does NOT apply the gate here (it does it in the forward)
assert torch.equal(sonic_result, expert_out)
# ============================================================================
# 4. Route dispatcher parity
# ============================================================================
class TestRouteDispatcherParity:
"""Verify _route in scattermoe dispatches correctly and matches individual fns."""
@pytest.fixture(autouse=True)
def _require(self):
_require_triton()
def test_route_dispatches_softmax(self):
"""_route should use softmax when no e_score_correction_bias."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_route,
_softmax_topk_route,
)
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
route_w, route_e, route_k, route_E = _route(
moe_block, gate, hidden, gate.weight, None
)
direct_w, direct_e, direct_k, direct_E = _softmax_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert torch.equal(route_w, direct_w)
assert torch.equal(route_e, direct_e)
assert route_k == direct_k
assert route_E == direct_E
def test_route_dispatches_sigmoid(self):
"""_route should use sigmoid when e_score_correction_bias is present."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_route,
_sigmoid_topk_route,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
)
route_w, route_e, route_k, route_E = _route(
moe_block, gate, hidden, gate.weight, None
)
direct_w, direct_e, direct_k, direct_E = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert torch.equal(route_w, direct_w)
assert torch.equal(route_e, direct_e)
assert route_k == direct_k
assert route_E == direct_E

View File

@@ -0,0 +1,367 @@
"""Tests for scattermoe autotune telemetry integration.
These tests use mocking to verify the collection and reporting logic
without requiring Triton or CUDA.
"""
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
# Simulate the hash-suffixed module name that LocalLayerRepository creates.
_FAKE_MODULE_NAME = "scattermoe_lora_abc123.kernels.lora_ops"
def _make_mock_config(kwargs, num_warps=4, num_stages=3):
"""Create a mock triton.Config-like object."""
return SimpleNamespace(kwargs=kwargs, num_warps=num_warps, num_stages=num_stages)
def _make_mock_kernel(cache=None):
"""Create a mock autotuned kernel object with a ``.cache`` dict."""
kernel = SimpleNamespace()
kernel.cache = cache if cache is not None else {}
return kernel
def _make_mock_lora_ops(
fwd_cache=None, dx_cache=None, bwd_cache=None, fused_cache=None
):
"""Build a mock ``lora_ops`` module with the four kernel attributes."""
mod = SimpleNamespace(
_scatter2scatter_lora=_make_mock_kernel(fwd_cache),
_scatter2scatter_lora_dX=_make_mock_kernel(dx_cache),
_group_bwd_lora=_make_mock_kernel(bwd_cache),
_group_bwd_lora_fused=_make_mock_kernel(fused_cache),
)
return mod
# =========================================================================
# TestAutotuneCollector
# =========================================================================
class TestAutotuneCollector:
"""Test ``collect_autotune_configs`` with mocked kernel objects."""
def test_empty_cache_returns_empty_list(self):
"""When no kernel has been autotuned yet, return ``[]``."""
mock_lora_ops = _make_mock_lora_ops()
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert result == []
def test_populated_cache_returns_configs(self):
"""When a cache entry exists, it appears in the output."""
cfg = _make_mock_config(
{"BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4
)
mock_lora_ops = _make_mock_lora_ops(fwd_cache={(2048, 4096, 1024): cfg})
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert len(result) == 1
entry = result[0]
assert entry["kernel"] == "scatter2scatter_lora_fwd"
assert entry["key"] == {"M": 2048, "N": 4096, "K": 1024}
assert entry["config"]["BLOCK_N"] == 128
assert entry["config"]["BLOCK_K"] == 64
assert entry["config"]["num_warps"] == 8
assert entry["config"]["num_stages"] == 4
def test_multiple_kernels_and_keys(self):
"""Multiple cache entries across kernels are all returned."""
cfg_fwd = _make_mock_config({"BLOCK_N": 128, "BLOCK_K": 32})
cfg_dx = _make_mock_config({"BLOCK_K": 64, "BLOCK_N": 128}, num_warps=8)
mock_lora_ops = _make_mock_lora_ops(
fwd_cache={(16, 256, 128): cfg_fwd},
dx_cache={(16, 256, 128): cfg_dx},
)
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert len(result) == 2
names = {r["kernel"] for r in result}
assert "scatter2scatter_lora_fwd" in names
assert "scatter2scatter_lora_dX" in names
def test_extra_key_elements_stored(self):
"""Dtype or other extra elements in the cache key are captured."""
cfg = _make_mock_config({"BLOCK_N": 64, "BLOCK_K": 32})
cache_key = (512, 1024, 256, "float16", "float16")
mock_lora_ops = _make_mock_lora_ops(fwd_cache={cache_key: cfg})
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert len(result) == 1
key = result[0]["key"]
assert key["M"] == 512
assert key["N"] == 1024
assert key["K"] == 256
assert key["_extra"] == ["float16", "float16"]
def test_no_module_in_sys_modules_returns_empty(self):
"""If no lora_ops module is loaded, return ``[]``."""
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
# Don't inject anything — the real lora_ops isn't loaded either
# (no triton on this machine), so _find_lora_ops_module returns None.
result = collect_autotune_configs()
assert result == []
def test_finds_module_under_hash_suffixed_name(self):
"""Collector finds lora_ops regardless of the hash suffix."""
cfg = _make_mock_config({"BLOCK_N": 256, "BLOCK_K": 128})
mock_lora_ops = _make_mock_lora_ops(fwd_cache={(8, 512, 64): cfg})
# Use a different hash to prove it's not hardcoded.
alt_name = "scattermoe_lora_deadbeef.kernels.lora_ops"
with patch.dict(sys.modules, {alt_name: mock_lora_ops}):
from axolotl.integrations.kernels.autotune_collector import (
collect_autotune_configs,
)
result = collect_autotune_configs()
assert len(result) == 1
assert result[0]["config"]["BLOCK_N"] == 256
# =========================================================================
# TestAutotuneReportCallback
# =========================================================================
class TestAutotuneReportCallback:
"""Test the callback fires once and sends the correct event."""
def test_reports_once_on_first_step(self):
"""Callback should call ``send_event`` exactly once."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
mock_state = MagicMock()
mock_state.global_step = 1
fake_configs = [{"kernel": "test_fwd", "key": {}, "config": {}}]
with (
patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
return_value=fake_configs,
),
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
):
mock_tm = MagicMock()
mock_tm.enabled = True
mock_tm_cls.get_instance.return_value = mock_tm
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
assert mock_tm.send_event.call_count == 1
call_kwargs = mock_tm.send_event.call_args[1]
assert call_kwargs["event_type"] == "scattermoe-autotune"
assert call_kwargs["properties"]["kernel_count"] == 1
# Second call should NOT send again.
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
assert mock_tm.send_event.call_count == 1
def test_retries_until_step_5_then_gives_up(self):
"""If no configs found by step 5, stop retrying."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
with patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
return_value=[],
):
for step in range(1, 7):
mock_state = MagicMock()
mock_state.global_step = step
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
assert cb._reported is True
def test_reports_on_retry_when_data_arrives(self):
"""If step 1 has no data but step 2 does, report at step 2."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}]
call_count = 0
def _collector():
nonlocal call_count
call_count += 1
if call_count == 1:
return []
return fake_configs
with (
patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
side_effect=_collector,
),
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
):
mock_tm = MagicMock()
mock_tm.enabled = True
mock_tm_cls.get_instance.return_value = mock_tm
# Step 1 — empty, no report
s1 = MagicMock()
s1.global_step = 1
cb.on_step_end(args=MagicMock(), state=s1, control=MagicMock())
assert mock_tm.send_event.call_count == 0
# Step 2 — data arrives, report
s2 = MagicMock()
s2.global_step = 2
cb.on_step_end(args=MagicMock(), state=s2, control=MagicMock())
assert mock_tm.send_event.call_count == 1
def test_includes_gpu_info(self):
"""Event properties should include GPU identification."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
mock_state = MagicMock()
mock_state.global_step = 1
fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}]
fake_gpu = {
"gpu_name": "NVIDIA H100",
"gpu_compute_capability": "9.0",
"gpu_memory_bytes": 85899345920,
}
fake_smem = {"smem_capacity_bytes": 233472}
with (
patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
return_value=fake_configs,
),
patch(
"axolotl.integrations.kernels.autotune_callback._get_gpu_info",
return_value=fake_gpu,
),
patch(
"axolotl.integrations.kernels.autotune_callback._get_smem_capacity",
return_value=fake_smem,
),
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
):
mock_tm = MagicMock()
mock_tm.enabled = True
mock_tm_cls.get_instance.return_value = mock_tm
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
props = mock_tm.send_event.call_args[1]["properties"]
assert props["gpu_name"] == "NVIDIA H100"
assert props["gpu_compute_capability"] == "9.0"
assert props["gpu_memory_bytes"] == 85899345920
assert props["smem_capacity_bytes"] == 233472
def test_skips_send_when_telemetry_disabled(self):
"""If telemetry is disabled, no event is sent."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
cb = AutotuneReportCallback()
mock_state = MagicMock()
mock_state.global_step = 1
with (
patch(
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
return_value=[{"kernel": "fwd", "key": {}, "config": {}}],
),
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
):
mock_tm = MagicMock()
mock_tm.enabled = False
mock_tm_cls.get_instance.return_value = mock_tm
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
assert mock_tm.send_event.call_count == 0
# Should still mark as reported so we don't retry.
assert cb._reported is True
# =========================================================================
# TestKernelsPluginCallbackRegistration
# =========================================================================
class TestKernelsPluginCallbackRegistration:
"""Test that ``KernelsPlugin`` registers the callback correctly."""
def test_scattermoe_registers_callback(self):
"""When ``use_scattermoe=True``, plugin returns the callback."""
from axolotl.integrations.kernels.autotune_callback import (
AutotuneReportCallback,
)
from axolotl.integrations.kernels.plugin import KernelsPlugin
plugin = KernelsPlugin()
cfg = MagicMock()
cfg.use_scattermoe = True
model = MagicMock()
callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
assert len(callbacks) == 1
assert isinstance(callbacks[0], AutotuneReportCallback)
def test_no_scattermoe_no_callback(self):
"""When ``use_scattermoe=False``, plugin returns empty list."""
from axolotl.integrations.kernels.plugin import KernelsPlugin
plugin = KernelsPlugin()
cfg = MagicMock()
cfg.use_scattermoe = False
model = MagicMock()
callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
assert callbacks == []

View File

@@ -3,17 +3,19 @@
# Licensed under the Apache License, Version 2.0
"""
Unit tests for scattermoe-lora code-review fixes.
Unit tests for scattermoe-lora.
Tests cover:
- KernelsArgs validator: disable_mlp_kernel
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
- ParallelExperts: scaling=0.0 not treated as falsy
- single2scatter: non-aligned K/N dimensions
- group_compileable: coeff=None accepted
- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract
- Routing strategy detection and sigmoid routing
- Generic shared expert handling
"""
from types import SimpleNamespace
from unittest.mock import patch
import pytest
@@ -321,3 +323,389 @@ class TestLayerReturnValues:
assert "Router logits" not in docstring, (
"Docstring should not mention 'Router logits' in Returns section"
)
# ============================================================================
# 7. Routing strategy detection and sigmoid routing
# ============================================================================
def _make_softmax_gate(E=4, H=16, K=2):
"""Create a mock softmax-style gate (Qwen/OLMoE)."""
return SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
def _make_sigmoid_gate_with_bias(E=16, H=16):
"""Create a mock sigmoid-style gate with e_score_correction_bias on gate."""
return SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
def _make_sigmoid_moe_block(
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
):
"""Create a mock GLM/DeepSeek-style MoE block for sigmoid routing tests."""
if bias_on_gate:
gate = SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=n_group,
topk_group=topk_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
else:
# minimax_m2 style: bias on block, not gate
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
e_score_correction_bias=torch.zeros(E),
)
return moe_block, T, H, E, K
def _skip_without_triton():
pytest.importorskip("triton")
class TestSigmoidRoutingInScatterMoE:
"""Test _sigmoid_topk_route from layers.py."""
@pytest.fixture(autouse=True)
def _require_triton(self):
_skip_without_triton()
def test_output_shapes(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block()
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, top_k, num_experts = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert experts.shape == (T, K)
assert top_k == K
assert num_experts == E
def test_weights_nonnegative(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block()
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert (weights >= 0).all()
def test_group_selection_restricts_experts(self):
"""With n_group=4, topk_group=1, experts should be from selected groups."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(
E=16, K=2, n_group=4, topk_group=1
)
gate = moe_block.gate
hidden = torch.randn(T, H)
_, expert_idx, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
# Each token's experts should fall within a single group (size E//n_group=4)
for t in range(T):
experts_t = expert_idx[t]
groups = experts_t // (E // moe_block.n_group)
assert (groups == groups[0]).all()
def test_scaling_factor_applied(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights_1x, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
moe_block.routed_scaling_factor = 2.0
weights_2x, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert torch.allclose(weights_2x, weights_1x * 2.0, atol=1e-5)
def test_bias_on_gate(self):
"""e_score_correction_bias on gate is found."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
def test_bias_on_block(self):
"""e_score_correction_bias on moe_block (not gate) is found."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
def test_gate_lora_delta_applied(self):
"""Gate LoRA delta should affect routing logits."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights_no_lora, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
# Large delta should change the results
delta = torch.randn(E, H) * 10.0
weights_with_lora, _, _, _ = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, delta
)
assert not torch.equal(weights_no_lora, weights_with_lora)
def test_no_bias_does_not_crash(self):
"""Calling _sigmoid_topk_route with no e_score_correction_bias should not crash."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
T, H, E, K = 8, 16, 8, 2
gate = SimpleNamespace(weight=torch.randn(E, H))
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=1,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
hidden = torch.randn(T, H)
weights, experts, top_k, num_experts = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert experts.shape == (T, K)
# Without bias, scores_for_choice == sigmoid(logits) — all positive
assert (weights >= 0).all()
def test_missing_topk_group_defaults_to_n_group(self):
"""When topk_group is absent but n_group > 1, should default to n_group (no-op masking)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
T, H, E, K, n_group = 8, 16, 16, 2, 4
gate = SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
# Intentionally omit topk_group
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=n_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
hidden = torch.randn(T, H)
# Should not raise AttributeError; defaults topk_group to n_group
weights, experts, top_k_out, num_experts = _sigmoid_topk_route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert experts.shape == (T, K)
class TestRoutingStrategyDetection:
"""Test that _route dispatches to the correct strategy."""
@pytest.fixture(autouse=True)
def _require_triton(self):
_skip_without_triton()
def test_softmax_for_qwen_style(self):
"""Block without e_score_correction_bias should use softmax."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
gate = _make_softmax_gate(E=4, H=16, K=2)
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(8, 16)
weights, experts, top_k, num_experts = _route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (8, 2)
assert experts.shape == (8, 2)
assert top_k == 2
assert num_experts == 4
per_token_sums = weights.sum(dim=-1)
assert torch.allclose(per_token_sums, torch.ones(8), atol=1e-5)
def test_sigmoid_for_glm_style(self):
"""Block with e_score_correction_bias on gate should use sigmoid."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True, n_group=1)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, top_k, num_experts = _route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert experts.shape == (T, K)
assert (weights >= 0).all()
def test_sigmoid_for_minimax_m2_style(self):
"""Block with e_score_correction_bias on block (not gate) should use sigmoid."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
gate = moe_block.gate
hidden = torch.randn(T, H)
weights, experts, top_k, num_experts = _route(
moe_block, gate, hidden, gate.weight, None
)
assert weights.shape == (T, K)
assert (weights >= 0).all()
# ============================================================================
# 8. Generic shared expert handling
# ============================================================================
class TestGenericSharedExpert:
"""Test _compute_shared_expert from layers.py."""
@pytest.fixture(autouse=True)
def _require_triton(self):
_skip_without_triton()
def test_shared_expert_singular(self):
"""shared_expert attribute (Qwen2MoE style)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
called = torch.randn(4, 8)
moe_block = SimpleNamespace(
shared_expert=lambda x: called,
)
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
assert torch.equal(result, called)
def test_shared_experts_plural(self):
"""shared_experts attribute (DeepSeek V3 style)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
called = torch.randn(4, 8)
moe_block = SimpleNamespace(
shared_experts=lambda x: called,
)
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
assert torch.equal(result, called)
def test_shared_mlp(self):
"""shared_mlp attribute (Hunyuan style)."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
called = torch.randn(4, 8)
moe_block = SimpleNamespace(
shared_mlp=lambda x: called,
)
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
assert torch.equal(result, called)
def test_shared_expert_with_gate(self):
"""shared_expert + shared_expert_gate applies sigmoid gating."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
H = 8
expert_out = torch.ones(4, H)
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731
moe_block = SimpleNamespace(
shared_expert=lambda x: expert_out,
shared_expert_gate=gate_fn,
)
result = _compute_shared_expert(moe_block, torch.randn(4, H))
expected = expert_out * 0.5 # sigmoid(0) = 0.5
assert torch.allclose(result, expected, atol=1e-6)
def test_no_shared_expert(self):
"""No shared expert attributes returns None."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert,
)
moe_block = SimpleNamespace()
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
assert result is None