consolidate behavioud of routing in scattermoe kernels (#3475)
* consolidate behavioud of routing in scattermoe kernels * collect telemetry on best chosen autotuned kernel * properly collect data * Fix property name and get smem too * handle issues raised by coderabbit * add tests for parity before refactoring
This commit is contained in:
@@ -3,17 +3,19 @@
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
Unit tests for scattermoe-lora code-review fixes.
|
||||
Unit tests for scattermoe-lora.
|
||||
|
||||
Tests cover:
|
||||
- KernelsArgs validator: disable_mlp_kernel
|
||||
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
|
||||
- ParallelExperts: scaling=0.0 not treated as falsy
|
||||
- single2scatter: non-aligned K/N dimensions
|
||||
- group_compileable: coeff=None accepted
|
||||
- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract
|
||||
- Routing strategy detection and sigmoid routing
|
||||
- Generic shared expert handling
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -321,3 +323,389 @@ class TestLayerReturnValues:
|
||||
assert "Router logits" not in docstring, (
|
||||
"Docstring should not mention 'Router logits' in Returns section"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 7. Routing strategy detection and sigmoid routing
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_softmax_gate(E=4, H=16, K=2):
|
||||
"""Create a mock softmax-style gate (Qwen/OLMoE)."""
|
||||
return SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
|
||||
|
||||
def _make_sigmoid_gate_with_bias(E=16, H=16):
|
||||
"""Create a mock sigmoid-style gate with e_score_correction_bias on gate."""
|
||||
return SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
|
||||
|
||||
def _make_sigmoid_moe_block(
|
||||
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
||||
):
|
||||
"""Create a mock GLM/DeepSeek-style MoE block for sigmoid routing tests."""
|
||||
if bias_on_gate:
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
else:
|
||||
# minimax_m2 style: bias on block, not gate
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
return moe_block, T, H, E, K
|
||||
|
||||
|
||||
def _skip_without_triton():
|
||||
pytest.importorskip("triton")
|
||||
|
||||
|
||||
class TestSigmoidRoutingInScatterMoE:
|
||||
"""Test _sigmoid_topk_route from layers.py."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_triton(self):
|
||||
_skip_without_triton()
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block()
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, top_k, num_experts = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert weights.shape == (T, K)
|
||||
assert experts.shape == (T, K)
|
||||
assert top_k == K
|
||||
assert num_experts == E
|
||||
|
||||
def test_weights_nonnegative(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block()
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert (weights >= 0).all()
|
||||
|
||||
def test_group_selection_restricts_experts(self):
|
||||
"""With n_group=4, topk_group=1, experts should be from selected groups."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(
|
||||
E=16, K=2, n_group=4, topk_group=1
|
||||
)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, expert_idx, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
# Each token's experts should fall within a single group (size E//n_group=4)
|
||||
for t in range(T):
|
||||
experts_t = expert_idx[t]
|
||||
groups = experts_t // (E // moe_block.n_group)
|
||||
assert (groups == groups[0]).all()
|
||||
|
||||
def test_scaling_factor_applied(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights_1x, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
moe_block.routed_scaling_factor = 2.0
|
||||
weights_2x, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert torch.allclose(weights_2x, weights_1x * 2.0, atol=1e-5)
|
||||
|
||||
def test_bias_on_gate(self):
|
||||
"""e_score_correction_bias on gate is found."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert weights.shape == (T, K)
|
||||
|
||||
def test_bias_on_block(self):
|
||||
"""e_score_correction_bias on moe_block (not gate) is found."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert weights.shape == (T, K)
|
||||
|
||||
def test_gate_lora_delta_applied(self):
|
||||
"""Gate LoRA delta should affect routing logits."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights_no_lora, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
# Large delta should change the results
|
||||
delta = torch.randn(E, H) * 10.0
|
||||
weights_with_lora, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, delta
|
||||
)
|
||||
|
||||
assert not torch.equal(weights_no_lora, weights_with_lora)
|
||||
|
||||
def test_no_bias_does_not_crash(self):
|
||||
"""Calling _sigmoid_topk_route with no e_score_correction_bias should not crash."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
T, H, E, K = 8, 16, 8, 2
|
||||
gate = SimpleNamespace(weight=torch.randn(E, H))
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=1,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, top_k, num_experts = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert weights.shape == (T, K)
|
||||
assert experts.shape == (T, K)
|
||||
# Without bias, scores_for_choice == sigmoid(logits) — all positive
|
||||
assert (weights >= 0).all()
|
||||
|
||||
def test_missing_topk_group_defaults_to_n_group(self):
|
||||
"""When topk_group is absent but n_group > 1, should default to n_group (no-op masking)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
T, H, E, K, n_group = 8, 16, 16, 2, 4
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
# Intentionally omit topk_group
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
# Should not raise AttributeError; defaults topk_group to n_group
|
||||
weights, experts, top_k_out, num_experts = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert weights.shape == (T, K)
|
||||
assert experts.shape == (T, K)
|
||||
|
||||
|
||||
class TestRoutingStrategyDetection:
|
||||
"""Test that _route dispatches to the correct strategy."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_triton(self):
|
||||
_skip_without_triton()
|
||||
|
||||
def test_softmax_for_qwen_style(self):
|
||||
"""Block without e_score_correction_bias should use softmax."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
|
||||
|
||||
gate = _make_softmax_gate(E=4, H=16, K=2)
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(8, 16)
|
||||
|
||||
weights, experts, top_k, num_experts = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert weights.shape == (8, 2)
|
||||
assert experts.shape == (8, 2)
|
||||
assert top_k == 2
|
||||
assert num_experts == 4
|
||||
per_token_sums = weights.sum(dim=-1)
|
||||
assert torch.allclose(per_token_sums, torch.ones(8), atol=1e-5)
|
||||
|
||||
def test_sigmoid_for_glm_style(self):
|
||||
"""Block with e_score_correction_bias on gate should use sigmoid."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True, n_group=1)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, top_k, num_experts = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert weights.shape == (T, K)
|
||||
assert experts.shape == (T, K)
|
||||
assert (weights >= 0).all()
|
||||
|
||||
def test_sigmoid_for_minimax_m2_style(self):
|
||||
"""Block with e_score_correction_bias on block (not gate) should use sigmoid."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, top_k, num_experts = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert weights.shape == (T, K)
|
||||
assert (weights >= 0).all()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 8. Generic shared expert handling
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestGenericSharedExpert:
|
||||
"""Test _compute_shared_expert from layers.py."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_triton(self):
|
||||
_skip_without_triton()
|
||||
|
||||
def test_shared_expert_singular(self):
|
||||
"""shared_expert attribute (Qwen2MoE style)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
called = torch.randn(4, 8)
|
||||
moe_block = SimpleNamespace(
|
||||
shared_expert=lambda x: called,
|
||||
)
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
||||
assert torch.equal(result, called)
|
||||
|
||||
def test_shared_experts_plural(self):
|
||||
"""shared_experts attribute (DeepSeek V3 style)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
called = torch.randn(4, 8)
|
||||
moe_block = SimpleNamespace(
|
||||
shared_experts=lambda x: called,
|
||||
)
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
||||
assert torch.equal(result, called)
|
||||
|
||||
def test_shared_mlp(self):
|
||||
"""shared_mlp attribute (Hunyuan style)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
called = torch.randn(4, 8)
|
||||
moe_block = SimpleNamespace(
|
||||
shared_mlp=lambda x: called,
|
||||
)
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
||||
assert torch.equal(result, called)
|
||||
|
||||
def test_shared_expert_with_gate(self):
|
||||
"""shared_expert + shared_expert_gate applies sigmoid gating."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
H = 8
|
||||
expert_out = torch.ones(4, H)
|
||||
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731
|
||||
|
||||
moe_block = SimpleNamespace(
|
||||
shared_expert=lambda x: expert_out,
|
||||
shared_expert_gate=gate_fn,
|
||||
)
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, H))
|
||||
expected = expert_out * 0.5 # sigmoid(0) = 0.5
|
||||
assert torch.allclose(result, expected, atol=1e-6)
|
||||
|
||||
def test_no_shared_expert(self):
|
||||
"""No shared expert attributes returns None."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
moe_block = SimpleNamespace()
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
||||
assert result is None
|
||||
|
||||
Reference in New Issue
Block a user