consolidate behavioud of routing in scattermoe kernels (#3475)
* consolidate behavioud of routing in scattermoe kernels * collect telemetry on best chosen autotuned kernel * properly collect data * Fix property name and get smem too * handle issues raised by coderabbit * add tests for parity before refactoring
This commit is contained in:
@@ -12,6 +12,7 @@ Tests verify correctness of:
|
||||
3. Frozen weights: expert weight gradients are correctly skipped
|
||||
4. Various configurations: top-k, grouped_in/out, with/without bias
|
||||
5. Numerical stability: bf16/fp16 outputs within tolerance of fp32 reference
|
||||
6. HFScatterMoEGatedMLP with sigmoid routing (GLM/DeepSeek/MiniMax M2)
|
||||
|
||||
Test strategy:
|
||||
- Reference implementation uses pure PyTorch ops (no Triton)
|
||||
@@ -19,6 +20,8 @@ Test strategy:
|
||||
- Tolerances account for tf32 accumulation in Triton kernels
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -1476,3 +1479,347 @@ class TestCombinedOptimizations:
|
||||
torch.testing.assert_close(X1.grad, X2.grad, atol=5e-2, rtol=5e-2)
|
||||
torch.testing.assert_close(A1.grad, A2.grad, atol=5e-2, rtol=5e-2)
|
||||
torch.testing.assert_close(B1.grad, B2.grad, atol=5e-2, rtol=5e-2)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test: HFScatterMoEGatedMLP with Sigmoid Routing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _reference_moe_forward(
|
||||
hidden_states,
|
||||
gate_weight,
|
||||
gate_up_proj,
|
||||
down_proj,
|
||||
act_fn,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
num_experts,
|
||||
):
|
||||
"""Pure PyTorch reference for a full MoE forward pass.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H]
|
||||
gate_weight: [E, H]
|
||||
gate_up_proj: [E, 2*FF, H]
|
||||
down_proj: [E, H, FF]
|
||||
act_fn: activation function (e.g. torch.nn.SiLU())
|
||||
routing_weights: [T, K] routing weights
|
||||
selected_experts: [T, K] expert indices
|
||||
num_experts: int
|
||||
|
||||
Returns:
|
||||
output: [T, H]
|
||||
"""
|
||||
T, H = hidden_states.shape
|
||||
K = selected_experts.shape[1]
|
||||
output = torch.zeros(T, H, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
|
||||
for t in range(T):
|
||||
for j in range(K):
|
||||
e = selected_experts[t, j].item()
|
||||
w = routing_weights[t, j].item()
|
||||
|
||||
# gate_up projection
|
||||
gup = hidden_states[t] @ gate_up_proj[e].T # [2*I]
|
||||
I_dim = gup.shape[0] // 2
|
||||
gates = gup[:I_dim]
|
||||
up = gup[I_dim:]
|
||||
|
||||
# activation
|
||||
h = act_fn(gates) * up
|
||||
|
||||
# down projection
|
||||
out = h @ down_proj[e].T # [H]
|
||||
|
||||
output[t] += w * out
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _make_mock_sigmoid_moe_block(
|
||||
T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True
|
||||
):
|
||||
"""Create a mock MoE block with sigmoid routing for GPU testing."""
|
||||
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
|
||||
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
|
||||
act_fn = torch.nn.SiLU()
|
||||
|
||||
experts = SimpleNamespace(
|
||||
gate_up_proj=gate_up_proj,
|
||||
down_proj=down_proj,
|
||||
act_fn=act_fn,
|
||||
num_experts=E,
|
||||
)
|
||||
|
||||
if bias_on_gate:
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
e_score_correction_bias=torch.zeros(E, device="cuda"),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
else:
|
||||
# minimax_m2 style
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
top_k=K,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
top_k=K,
|
||||
e_score_correction_bias=torch.zeros(E, device="cuda"),
|
||||
)
|
||||
|
||||
return moe_block, T, H, FF, E, K
|
||||
|
||||
|
||||
class TestHFScatterMoESigmoidRouting:
|
||||
"""Test HFScatterMoEGatedMLP forward with sigmoid routing on GPU."""
|
||||
|
||||
def test_forward_matches_reference_bias_on_gate(self):
|
||||
"""Forward pass with sigmoid routing (bias on gate) matches reference."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block(
|
||||
T=16, H=64, FF=32, E=8, K=2, n_group=2, topk_group=1, bias_on_gate=True
|
||||
)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
|
||||
# Get routing for reference
|
||||
gate = moe_block.gate
|
||||
hidden_flat = hidden.view(-1, H)
|
||||
routing_weights, selected_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden_flat, gate.weight, None
|
||||
)
|
||||
|
||||
# Reference output
|
||||
ref_output = _reference_moe_forward(
|
||||
hidden_flat,
|
||||
gate.weight,
|
||||
moe_block.experts.gate_up_proj,
|
||||
moe_block.experts.down_proj,
|
||||
moe_block.experts.act_fn,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
E,
|
||||
)
|
||||
|
||||
# Kernel output
|
||||
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
kernel_output_flat = kernel_output.view(-1, H)
|
||||
|
||||
torch.testing.assert_close(
|
||||
kernel_output_flat.float(),
|
||||
ref_output.float(),
|
||||
atol=5e-2,
|
||||
rtol=5e-2,
|
||||
)
|
||||
|
||||
def test_forward_matches_reference_bias_on_block(self):
|
||||
"""Forward pass with sigmoid routing (minimax_m2 style, bias on block)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, FF, E, K = _make_mock_sigmoid_moe_block(
|
||||
T=16, H=64, FF=32, E=8, K=2, n_group=1, bias_on_gate=False
|
||||
)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
hidden_flat = hidden.view(-1, H)
|
||||
|
||||
gate = moe_block.gate
|
||||
routing_weights, selected_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden_flat, gate.weight, None
|
||||
)
|
||||
|
||||
ref_output = _reference_moe_forward(
|
||||
hidden_flat,
|
||||
gate.weight,
|
||||
moe_block.experts.gate_up_proj,
|
||||
moe_block.experts.down_proj,
|
||||
moe_block.experts.act_fn,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
E,
|
||||
)
|
||||
|
||||
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
kernel_output_flat = kernel_output.view(-1, H)
|
||||
|
||||
torch.testing.assert_close(
|
||||
kernel_output_flat.float(),
|
||||
ref_output.float(),
|
||||
atol=5e-2,
|
||||
rtol=5e-2,
|
||||
)
|
||||
|
||||
def test_softmax_routing_still_works(self):
|
||||
"""Verify softmax routing (Qwen/OLMoE) is not broken."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
_softmax_topk_route,
|
||||
)
|
||||
|
||||
T, H, FF, E, K = 16, 64, 32, 4, 2
|
||||
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
|
||||
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
|
||||
act_fn = torch.nn.SiLU()
|
||||
|
||||
experts = SimpleNamespace(
|
||||
gate_up_proj=gate_up_proj,
|
||||
down_proj=down_proj,
|
||||
act_fn=act_fn,
|
||||
num_experts=E,
|
||||
)
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
moe_block = SimpleNamespace(gate=gate, experts=experts)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
hidden_flat = hidden.view(-1, H)
|
||||
|
||||
routing_weights, selected_experts, _, _ = _softmax_topk_route(
|
||||
moe_block, gate, hidden_flat, gate.weight, None
|
||||
)
|
||||
|
||||
ref_output = _reference_moe_forward(
|
||||
hidden_flat,
|
||||
gate.weight,
|
||||
gate_up_proj,
|
||||
down_proj,
|
||||
act_fn,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
E,
|
||||
)
|
||||
|
||||
kernel_output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
kernel_output_flat = kernel_output.view(-1, H)
|
||||
|
||||
torch.testing.assert_close(
|
||||
kernel_output_flat.float(),
|
||||
ref_output.float(),
|
||||
atol=5e-2,
|
||||
rtol=5e-2,
|
||||
)
|
||||
|
||||
|
||||
class TestHFScatterMoESigmoidWithSharedExperts:
|
||||
"""Test HFScatterMoEGatedMLP with sigmoid routing + shared experts."""
|
||||
|
||||
def test_shared_experts_plural(self):
|
||||
"""DeepSeek V3 style: shared_experts attribute (plural)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
)
|
||||
|
||||
T, H, FF, E, K = 8, 64, 32, 8, 2
|
||||
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
|
||||
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
|
||||
act_fn = torch.nn.SiLU()
|
||||
|
||||
experts = SimpleNamespace(
|
||||
gate_up_proj=gate_up_proj,
|
||||
down_proj=down_proj,
|
||||
act_fn=act_fn,
|
||||
num_experts=E,
|
||||
)
|
||||
|
||||
# Shared expert as a simple linear for testing
|
||||
shared_W = torch.randn(H, H, device="cuda") * 0.01
|
||||
shared_experts_fn = lambda x: x @ shared_W.T # noqa: E731
|
||||
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
e_score_correction_bias=torch.zeros(E, device="cuda"),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
shared_experts=shared_experts_fn,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=1,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
|
||||
# Should not raise; output should include shared expert contribution
|
||||
output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
assert output.shape == (1, T, H)
|
||||
|
||||
# Run without shared expert to verify it changes the output
|
||||
moe_block_no_shared = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=1,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
output_no_shared = HFScatterMoEGatedMLP.forward(moe_block_no_shared, hidden)
|
||||
assert not torch.equal(output, output_no_shared)
|
||||
|
||||
def test_shared_expert_with_gate(self):
|
||||
"""Qwen2MoE style: shared_expert + shared_expert_gate."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
)
|
||||
|
||||
T, H, FF, E, K = 8, 64, 32, 4, 2
|
||||
gate_up_proj = torch.randn(E, 2 * FF, H, device="cuda") * 0.02
|
||||
down_proj = torch.randn(E, H, FF, device="cuda") * 0.02
|
||||
act_fn = torch.nn.SiLU()
|
||||
|
||||
experts = SimpleNamespace(
|
||||
gate_up_proj=gate_up_proj,
|
||||
down_proj=down_proj,
|
||||
act_fn=act_fn,
|
||||
num_experts=E,
|
||||
)
|
||||
|
||||
shared_W = torch.randn(H, H, device="cuda") * 0.01
|
||||
shared_expert_fn = lambda x: x @ shared_W.T # noqa: E731
|
||||
# Gate that returns 0 -> sigmoid(0) = 0.5
|
||||
gate_W = torch.zeros(H, H, device="cuda")
|
||||
shared_expert_gate_fn = lambda x: x @ gate_W.T # noqa: E731
|
||||
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H, device="cuda") * 0.1,
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
experts=experts,
|
||||
shared_expert=shared_expert_fn,
|
||||
shared_expert_gate=shared_expert_gate_fn,
|
||||
)
|
||||
|
||||
hidden = torch.randn(1, T, H, device="cuda")
|
||||
output = HFScatterMoEGatedMLP.forward(moe_block, hidden)
|
||||
assert output.shape == (1, T, H)
|
||||
|
||||
474
tests/integrations/test_routing_parity.py
Normal file
474
tests/integrations/test_routing_parity.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
Parity tests between scattermoe-lora and sonicmoe routing implementations.
|
||||
|
||||
These tests verify that both implementations produce numerically identical
|
||||
results for the same inputs, ensuring safe centralization of the routing code.
|
||||
|
||||
ScatterMoE returns 2D tensors [T, K]; SonicMoE returns flattened 1D [T*K].
|
||||
The core algorithm should be identical — only the output format differs.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def _require_triton():
|
||||
pytest.importorskip("triton")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures / helpers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_softmax_block(T=8, H=16, E=4, K=2):
|
||||
"""Qwen/OLMoE-style block usable by both implementations."""
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(T, H)
|
||||
return moe_block, gate, hidden, T, H, E, K
|
||||
|
||||
|
||||
def _make_sigmoid_block(
|
||||
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
||||
):
|
||||
"""GLM/DeepSeek-style block usable by both implementations."""
|
||||
if bias_on_gate:
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
else:
|
||||
# minimax_m2 style: bias on block
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
return moe_block, gate, hidden_states(T, H), T, H, E, K
|
||||
|
||||
|
||||
def hidden_states(T, H):
|
||||
return torch.randn(T, H)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 1. Softmax routing parity
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSoftmaxRoutingParity:
|
||||
"""Verify scattermoe and sonicmoe softmax routing produce identical results."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require(self):
|
||||
_require_triton()
|
||||
|
||||
def test_weights_match(self):
|
||||
"""2D weights from scattermoe == reshaped 1D weights from sonicmoe."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
|
||||
# ScatterMoE path (no LoRA delta)
|
||||
sm_weights, sm_experts, sm_topk, sm_E = _softmax_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
# SonicMoE path
|
||||
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = softmax_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
# ScatterMoE returns [T, K], SonicMoE returns [T*K] flattened
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
assert sm_topk == K
|
||||
assert sm_E == E
|
||||
|
||||
# Both should select the same experts and produce the same weights
|
||||
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
|
||||
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
|
||||
|
||||
def test_logits_not_returned_by_scattermoe(self):
|
||||
"""ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
_, _, _, logits = softmax_topk_routing(hidden, moe_block)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_no_renorm(self):
|
||||
"""With norm_topk_prob=False, both should skip renormalization."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
gate.norm_topk_prob = False
|
||||
|
||||
sm_weights, sm_experts, _, _ = _softmax_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype))
|
||||
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6)
|
||||
|
||||
def test_various_expert_counts(self):
|
||||
"""Parity across different E and K values."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]:
|
||||
moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K)
|
||||
|
||||
sm_weights, sm_experts, _, _ = _softmax_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
assert torch.equal(sm_experts, sonic_experts_2d.to(sm_experts.dtype)), (
|
||||
f"Expert mismatch for E={E}, K={K}"
|
||||
)
|
||||
assert torch.allclose(sm_weights, sonic_weights_2d, atol=1e-6), (
|
||||
f"Weight mismatch for E={E}, K={K}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 2. Sigmoid routing parity
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSigmoidRoutingParity:
|
||||
"""Verify scattermoe and sonicmoe sigmoid routing produce identical results."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require(self):
|
||||
_require_triton()
|
||||
|
||||
def test_weights_match_with_groups(self):
|
||||
"""Both implementations should produce identical weights with group selection."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
||||
)
|
||||
|
||||
sm_weights, sm_experts, sm_topk, sm_E = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
sonic_scores, sonic_tok_idx, sonic_exp_idx, sonic_logits = sigmoid_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
assert sm_topk == K
|
||||
assert sm_E == E
|
||||
|
||||
# Sort experts within each token to handle different topk orderings
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
|
||||
# Gather weights in sorted order for comparison
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
def test_weights_match_no_groups(self):
|
||||
"""Both implementations match without group selection (n_group=1)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True
|
||||
)
|
||||
|
||||
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
# Sort for comparison (topk with sorted=False may differ in order)
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
def test_bias_on_block_parity(self):
|
||||
"""minimax_m2 style: bias on block, not gate."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
E=16, K=4, n_group=1, bias_on_gate=False
|
||||
)
|
||||
|
||||
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
def test_scaling_factor_parity(self):
|
||||
"""routed_scaling_factor applied identically by both."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
)
|
||||
moe_block.routed_scaling_factor = 2.5
|
||||
|
||||
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
def test_no_renorm_parity(self):
|
||||
"""norm_topk_prob=False produces same results in both."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
)
|
||||
moe_block.norm_topk_prob = False
|
||||
|
||||
sm_weights, sm_experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
sonic_scores, _, sonic_exp_idx, _ = sigmoid_topk_routing(hidden, moe_block)
|
||||
|
||||
sonic_weights_2d = sonic_scores.reshape(T, K)
|
||||
sonic_experts_2d = sonic_exp_idx.reshape(T, K)
|
||||
|
||||
sm_sorted, sm_order = sm_experts.sort(dim=-1)
|
||||
sonic_sorted, sonic_order = sonic_experts_2d.to(sm_experts.dtype).sort(dim=-1)
|
||||
|
||||
assert torch.equal(sm_sorted, sonic_sorted)
|
||||
sm_weights_sorted = sm_weights.gather(1, sm_order)
|
||||
sonic_weights_sorted = sonic_weights_2d.gather(1, sonic_order)
|
||||
assert torch.allclose(sm_weights_sorted, sonic_weights_sorted, atol=1e-6)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 3. Shared expert parity
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSharedExpertParity:
|
||||
"""Verify both _compute_shared_expert implementations behave identically."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require(self):
|
||||
_require_triton()
|
||||
|
||||
def _get_both_fns(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert as scatter_compute,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.patch import (
|
||||
_compute_shared_expert as sonic_compute,
|
||||
)
|
||||
|
||||
return scatter_compute, sonic_compute
|
||||
|
||||
def test_shared_expert_singular(self):
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
out = torch.randn(4, 8)
|
||||
block = SimpleNamespace(shared_expert=lambda x: out)
|
||||
hidden = torch.randn(4, 8)
|
||||
|
||||
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
|
||||
|
||||
def test_shared_experts_plural(self):
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
out = torch.randn(4, 8)
|
||||
block = SimpleNamespace(shared_experts=lambda x: out)
|
||||
hidden = torch.randn(4, 8)
|
||||
|
||||
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
|
||||
|
||||
def test_shared_mlp(self):
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
out = torch.randn(4, 8)
|
||||
block = SimpleNamespace(shared_mlp=lambda x: out)
|
||||
hidden = torch.randn(4, 8)
|
||||
|
||||
assert torch.equal(scatter_fn(block, hidden), sonic_fn(block, hidden))
|
||||
|
||||
def test_no_shared_expert(self):
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
block = SimpleNamespace()
|
||||
hidden = torch.randn(4, 8)
|
||||
|
||||
assert scatter_fn(block, hidden) is None
|
||||
assert sonic_fn(block, hidden) is None
|
||||
|
||||
def test_shared_expert_gate_only_in_scattermoe(self):
|
||||
"""ScatterMoE's _compute_shared_expert handles shared_expert_gate;
|
||||
SonicMoE's patch.py handles it externally in the forward function.
|
||||
|
||||
This documents the known divergence: the scattermoe version applies
|
||||
sigmoid gating inline, while sonicmoe applies it in the forward.
|
||||
"""
|
||||
scatter_fn, sonic_fn = self._get_both_fns()
|
||||
|
||||
H = 8
|
||||
expert_out = torch.ones(4, H)
|
||||
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731 # sigmoid(0) = 0.5
|
||||
|
||||
block = SimpleNamespace(
|
||||
shared_expert=lambda x: expert_out,
|
||||
shared_expert_gate=gate_fn,
|
||||
)
|
||||
hidden = torch.randn(4, H)
|
||||
|
||||
scatter_result = scatter_fn(block, hidden)
|
||||
sonic_result = sonic_fn(block, hidden)
|
||||
|
||||
# ScatterMoE applies the gate: expert_out * sigmoid(0) = 0.5
|
||||
expected_gated = expert_out * 0.5
|
||||
assert torch.allclose(scatter_result, expected_gated, atol=1e-6)
|
||||
|
||||
# SonicMoE does NOT apply the gate here (it does it in the forward)
|
||||
assert torch.equal(sonic_result, expert_out)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 4. Route dispatcher parity
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestRouteDispatcherParity:
|
||||
"""Verify _route in scattermoe dispatches correctly and matches individual fns."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require(self):
|
||||
_require_triton()
|
||||
|
||||
def test_route_dispatches_softmax(self):
|
||||
"""_route should use softmax when no e_score_correction_bias."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_route,
|
||||
_softmax_topk_route,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
|
||||
|
||||
route_w, route_e, route_k, route_E = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
direct_w, direct_e, direct_k, direct_E = _softmax_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert torch.equal(route_w, direct_w)
|
||||
assert torch.equal(route_e, direct_e)
|
||||
assert route_k == direct_k
|
||||
assert route_E == direct_E
|
||||
|
||||
def test_route_dispatches_sigmoid(self):
|
||||
"""_route should use sigmoid when e_score_correction_bias is present."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_route,
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
|
||||
n_group=1, bias_on_gate=True
|
||||
)
|
||||
|
||||
route_w, route_e, route_k, route_E = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
direct_w, direct_e, direct_k, direct_E = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert torch.equal(route_w, direct_w)
|
||||
assert torch.equal(route_e, direct_e)
|
||||
assert route_k == direct_k
|
||||
assert route_E == direct_E
|
||||
367
tests/integrations/test_scattermoe_autotune_telemetry.py
Normal file
367
tests/integrations/test_scattermoe_autotune_telemetry.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""Tests for scattermoe autotune telemetry integration.
|
||||
|
||||
These tests use mocking to verify the collection and reporting logic
|
||||
without requiring Triton or CUDA.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Simulate the hash-suffixed module name that LocalLayerRepository creates.
|
||||
_FAKE_MODULE_NAME = "scattermoe_lora_abc123.kernels.lora_ops"
|
||||
|
||||
|
||||
def _make_mock_config(kwargs, num_warps=4, num_stages=3):
|
||||
"""Create a mock triton.Config-like object."""
|
||||
return SimpleNamespace(kwargs=kwargs, num_warps=num_warps, num_stages=num_stages)
|
||||
|
||||
|
||||
def _make_mock_kernel(cache=None):
|
||||
"""Create a mock autotuned kernel object with a ``.cache`` dict."""
|
||||
kernel = SimpleNamespace()
|
||||
kernel.cache = cache if cache is not None else {}
|
||||
return kernel
|
||||
|
||||
|
||||
def _make_mock_lora_ops(
|
||||
fwd_cache=None, dx_cache=None, bwd_cache=None, fused_cache=None
|
||||
):
|
||||
"""Build a mock ``lora_ops`` module with the four kernel attributes."""
|
||||
mod = SimpleNamespace(
|
||||
_scatter2scatter_lora=_make_mock_kernel(fwd_cache),
|
||||
_scatter2scatter_lora_dX=_make_mock_kernel(dx_cache),
|
||||
_group_bwd_lora=_make_mock_kernel(bwd_cache),
|
||||
_group_bwd_lora_fused=_make_mock_kernel(fused_cache),
|
||||
)
|
||||
return mod
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# TestAutotuneCollector
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestAutotuneCollector:
|
||||
"""Test ``collect_autotune_configs`` with mocked kernel objects."""
|
||||
|
||||
def test_empty_cache_returns_empty_list(self):
|
||||
"""When no kernel has been autotuned yet, return ``[]``."""
|
||||
mock_lora_ops = _make_mock_lora_ops()
|
||||
|
||||
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
result = collect_autotune_configs()
|
||||
assert result == []
|
||||
|
||||
def test_populated_cache_returns_configs(self):
|
||||
"""When a cache entry exists, it appears in the output."""
|
||||
cfg = _make_mock_config(
|
||||
{"BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=4
|
||||
)
|
||||
mock_lora_ops = _make_mock_lora_ops(fwd_cache={(2048, 4096, 1024): cfg})
|
||||
|
||||
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
result = collect_autotune_configs()
|
||||
|
||||
assert len(result) == 1
|
||||
entry = result[0]
|
||||
assert entry["kernel"] == "scatter2scatter_lora_fwd"
|
||||
assert entry["key"] == {"M": 2048, "N": 4096, "K": 1024}
|
||||
assert entry["config"]["BLOCK_N"] == 128
|
||||
assert entry["config"]["BLOCK_K"] == 64
|
||||
assert entry["config"]["num_warps"] == 8
|
||||
assert entry["config"]["num_stages"] == 4
|
||||
|
||||
def test_multiple_kernels_and_keys(self):
|
||||
"""Multiple cache entries across kernels are all returned."""
|
||||
cfg_fwd = _make_mock_config({"BLOCK_N": 128, "BLOCK_K": 32})
|
||||
cfg_dx = _make_mock_config({"BLOCK_K": 64, "BLOCK_N": 128}, num_warps=8)
|
||||
|
||||
mock_lora_ops = _make_mock_lora_ops(
|
||||
fwd_cache={(16, 256, 128): cfg_fwd},
|
||||
dx_cache={(16, 256, 128): cfg_dx},
|
||||
)
|
||||
|
||||
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
result = collect_autotune_configs()
|
||||
|
||||
assert len(result) == 2
|
||||
names = {r["kernel"] for r in result}
|
||||
assert "scatter2scatter_lora_fwd" in names
|
||||
assert "scatter2scatter_lora_dX" in names
|
||||
|
||||
def test_extra_key_elements_stored(self):
|
||||
"""Dtype or other extra elements in the cache key are captured."""
|
||||
cfg = _make_mock_config({"BLOCK_N": 64, "BLOCK_K": 32})
|
||||
cache_key = (512, 1024, 256, "float16", "float16")
|
||||
|
||||
mock_lora_ops = _make_mock_lora_ops(fwd_cache={cache_key: cfg})
|
||||
|
||||
with patch.dict(sys.modules, {_FAKE_MODULE_NAME: mock_lora_ops}):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
result = collect_autotune_configs()
|
||||
|
||||
assert len(result) == 1
|
||||
key = result[0]["key"]
|
||||
assert key["M"] == 512
|
||||
assert key["N"] == 1024
|
||||
assert key["K"] == 256
|
||||
assert key["_extra"] == ["float16", "float16"]
|
||||
|
||||
def test_no_module_in_sys_modules_returns_empty(self):
|
||||
"""If no lora_ops module is loaded, return ``[]``."""
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
# Don't inject anything — the real lora_ops isn't loaded either
|
||||
# (no triton on this machine), so _find_lora_ops_module returns None.
|
||||
result = collect_autotune_configs()
|
||||
assert result == []
|
||||
|
||||
def test_finds_module_under_hash_suffixed_name(self):
|
||||
"""Collector finds lora_ops regardless of the hash suffix."""
|
||||
cfg = _make_mock_config({"BLOCK_N": 256, "BLOCK_K": 128})
|
||||
mock_lora_ops = _make_mock_lora_ops(fwd_cache={(8, 512, 64): cfg})
|
||||
|
||||
# Use a different hash to prove it's not hardcoded.
|
||||
alt_name = "scattermoe_lora_deadbeef.kernels.lora_ops"
|
||||
with patch.dict(sys.modules, {alt_name: mock_lora_ops}):
|
||||
from axolotl.integrations.kernels.autotune_collector import (
|
||||
collect_autotune_configs,
|
||||
)
|
||||
|
||||
result = collect_autotune_configs()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["config"]["BLOCK_N"] == 256
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# TestAutotuneReportCallback
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestAutotuneReportCallback:
|
||||
"""Test the callback fires once and sends the correct event."""
|
||||
|
||||
def test_reports_once_on_first_step(self):
|
||||
"""Callback should call ``send_event`` exactly once."""
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
cb = AutotuneReportCallback()
|
||||
mock_state = MagicMock()
|
||||
mock_state.global_step = 1
|
||||
|
||||
fake_configs = [{"kernel": "test_fwd", "key": {}, "config": {}}]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
|
||||
return_value=fake_configs,
|
||||
),
|
||||
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
|
||||
):
|
||||
mock_tm = MagicMock()
|
||||
mock_tm.enabled = True
|
||||
mock_tm_cls.get_instance.return_value = mock_tm
|
||||
|
||||
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
|
||||
assert mock_tm.send_event.call_count == 1
|
||||
|
||||
call_kwargs = mock_tm.send_event.call_args[1]
|
||||
assert call_kwargs["event_type"] == "scattermoe-autotune"
|
||||
assert call_kwargs["properties"]["kernel_count"] == 1
|
||||
|
||||
# Second call should NOT send again.
|
||||
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
|
||||
assert mock_tm.send_event.call_count == 1
|
||||
|
||||
def test_retries_until_step_5_then_gives_up(self):
|
||||
"""If no configs found by step 5, stop retrying."""
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
cb = AutotuneReportCallback()
|
||||
|
||||
with patch(
|
||||
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
|
||||
return_value=[],
|
||||
):
|
||||
for step in range(1, 7):
|
||||
mock_state = MagicMock()
|
||||
mock_state.global_step = step
|
||||
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
|
||||
|
||||
assert cb._reported is True
|
||||
|
||||
def test_reports_on_retry_when_data_arrives(self):
|
||||
"""If step 1 has no data but step 2 does, report at step 2."""
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
cb = AutotuneReportCallback()
|
||||
fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}]
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _collector():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return []
|
||||
return fake_configs
|
||||
|
||||
with (
|
||||
patch(
|
||||
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
|
||||
side_effect=_collector,
|
||||
),
|
||||
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
|
||||
):
|
||||
mock_tm = MagicMock()
|
||||
mock_tm.enabled = True
|
||||
mock_tm_cls.get_instance.return_value = mock_tm
|
||||
|
||||
# Step 1 — empty, no report
|
||||
s1 = MagicMock()
|
||||
s1.global_step = 1
|
||||
cb.on_step_end(args=MagicMock(), state=s1, control=MagicMock())
|
||||
assert mock_tm.send_event.call_count == 0
|
||||
|
||||
# Step 2 — data arrives, report
|
||||
s2 = MagicMock()
|
||||
s2.global_step = 2
|
||||
cb.on_step_end(args=MagicMock(), state=s2, control=MagicMock())
|
||||
assert mock_tm.send_event.call_count == 1
|
||||
|
||||
def test_includes_gpu_info(self):
|
||||
"""Event properties should include GPU identification."""
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
cb = AutotuneReportCallback()
|
||||
mock_state = MagicMock()
|
||||
mock_state.global_step = 1
|
||||
|
||||
fake_configs = [{"kernel": "fwd", "key": {}, "config": {}}]
|
||||
fake_gpu = {
|
||||
"gpu_name": "NVIDIA H100",
|
||||
"gpu_compute_capability": "9.0",
|
||||
"gpu_memory_bytes": 85899345920,
|
||||
}
|
||||
|
||||
fake_smem = {"smem_capacity_bytes": 233472}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
|
||||
return_value=fake_configs,
|
||||
),
|
||||
patch(
|
||||
"axolotl.integrations.kernels.autotune_callback._get_gpu_info",
|
||||
return_value=fake_gpu,
|
||||
),
|
||||
patch(
|
||||
"axolotl.integrations.kernels.autotune_callback._get_smem_capacity",
|
||||
return_value=fake_smem,
|
||||
),
|
||||
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
|
||||
):
|
||||
mock_tm = MagicMock()
|
||||
mock_tm.enabled = True
|
||||
mock_tm_cls.get_instance.return_value = mock_tm
|
||||
|
||||
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
|
||||
props = mock_tm.send_event.call_args[1]["properties"]
|
||||
assert props["gpu_name"] == "NVIDIA H100"
|
||||
assert props["gpu_compute_capability"] == "9.0"
|
||||
assert props["gpu_memory_bytes"] == 85899345920
|
||||
assert props["smem_capacity_bytes"] == 233472
|
||||
|
||||
def test_skips_send_when_telemetry_disabled(self):
|
||||
"""If telemetry is disabled, no event is sent."""
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
|
||||
cb = AutotuneReportCallback()
|
||||
mock_state = MagicMock()
|
||||
mock_state.global_step = 1
|
||||
|
||||
with (
|
||||
patch(
|
||||
"axolotl.integrations.kernels.autotune_collector.collect_autotune_configs",
|
||||
return_value=[{"kernel": "fwd", "key": {}, "config": {}}],
|
||||
),
|
||||
patch("axolotl.telemetry.manager.TelemetryManager") as mock_tm_cls,
|
||||
):
|
||||
mock_tm = MagicMock()
|
||||
mock_tm.enabled = False
|
||||
mock_tm_cls.get_instance.return_value = mock_tm
|
||||
|
||||
cb.on_step_end(args=MagicMock(), state=mock_state, control=MagicMock())
|
||||
assert mock_tm.send_event.call_count == 0
|
||||
# Should still mark as reported so we don't retry.
|
||||
assert cb._reported is True
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# TestKernelsPluginCallbackRegistration
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestKernelsPluginCallbackRegistration:
|
||||
"""Test that ``KernelsPlugin`` registers the callback correctly."""
|
||||
|
||||
def test_scattermoe_registers_callback(self):
|
||||
"""When ``use_scattermoe=True``, plugin returns the callback."""
|
||||
from axolotl.integrations.kernels.autotune_callback import (
|
||||
AutotuneReportCallback,
|
||||
)
|
||||
from axolotl.integrations.kernels.plugin import KernelsPlugin
|
||||
|
||||
plugin = KernelsPlugin()
|
||||
cfg = MagicMock()
|
||||
cfg.use_scattermoe = True
|
||||
model = MagicMock()
|
||||
|
||||
callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
|
||||
assert len(callbacks) == 1
|
||||
assert isinstance(callbacks[0], AutotuneReportCallback)
|
||||
|
||||
def test_no_scattermoe_no_callback(self):
|
||||
"""When ``use_scattermoe=False``, plugin returns empty list."""
|
||||
from axolotl.integrations.kernels.plugin import KernelsPlugin
|
||||
|
||||
plugin = KernelsPlugin()
|
||||
cfg = MagicMock()
|
||||
cfg.use_scattermoe = False
|
||||
model = MagicMock()
|
||||
|
||||
callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
|
||||
assert callbacks == []
|
||||
@@ -3,17 +3,19 @@
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
Unit tests for scattermoe-lora code-review fixes.
|
||||
Unit tests for scattermoe-lora.
|
||||
|
||||
Tests cover:
|
||||
- KernelsArgs validator: disable_mlp_kernel
|
||||
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
|
||||
- ParallelExperts: scaling=0.0 not treated as falsy
|
||||
- single2scatter: non-aligned K/N dimensions
|
||||
- group_compileable: coeff=None accepted
|
||||
- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract
|
||||
- Routing strategy detection and sigmoid routing
|
||||
- Generic shared expert handling
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -321,3 +323,389 @@ class TestLayerReturnValues:
|
||||
assert "Router logits" not in docstring, (
|
||||
"Docstring should not mention 'Router logits' in Returns section"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 7. Routing strategy detection and sigmoid routing
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_softmax_gate(E=4, H=16, K=2):
|
||||
"""Create a mock softmax-style gate (Qwen/OLMoE)."""
|
||||
return SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
num_experts=E,
|
||||
norm_topk_prob=True,
|
||||
)
|
||||
|
||||
|
||||
def _make_sigmoid_gate_with_bias(E=16, H=16):
|
||||
"""Create a mock sigmoid-style gate with e_score_correction_bias on gate."""
|
||||
return SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
|
||||
|
||||
def _make_sigmoid_moe_block(
|
||||
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
||||
):
|
||||
"""Create a mock GLM/DeepSeek-style MoE block for sigmoid routing tests."""
|
||||
if bias_on_gate:
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
else:
|
||||
# minimax_m2 style: bias on block, not gate
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
)
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
return moe_block, T, H, E, K
|
||||
|
||||
|
||||
def _skip_without_triton():
|
||||
pytest.importorskip("triton")
|
||||
|
||||
|
||||
class TestSigmoidRoutingInScatterMoE:
|
||||
"""Test _sigmoid_topk_route from layers.py."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_triton(self):
|
||||
_skip_without_triton()
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block()
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, top_k, num_experts = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert weights.shape == (T, K)
|
||||
assert experts.shape == (T, K)
|
||||
assert top_k == K
|
||||
assert num_experts == E
|
||||
|
||||
def test_weights_nonnegative(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block()
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert (weights >= 0).all()
|
||||
|
||||
def test_group_selection_restricts_experts(self):
|
||||
"""With n_group=4, topk_group=1, experts should be from selected groups."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(
|
||||
E=16, K=2, n_group=4, topk_group=1
|
||||
)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, expert_idx, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
# Each token's experts should fall within a single group (size E//n_group=4)
|
||||
for t in range(T):
|
||||
experts_t = expert_idx[t]
|
||||
groups = experts_t // (E // moe_block.n_group)
|
||||
assert (groups == groups[0]).all()
|
||||
|
||||
def test_scaling_factor_applied(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights_1x, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
moe_block.routed_scaling_factor = 2.0
|
||||
weights_2x, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert torch.allclose(weights_2x, weights_1x * 2.0, atol=1e-5)
|
||||
|
||||
def test_bias_on_gate(self):
|
||||
"""e_score_correction_bias on gate is found."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert weights.shape == (T, K)
|
||||
|
||||
def test_bias_on_block(self):
|
||||
"""e_score_correction_bias on moe_block (not gate) is found."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert weights.shape == (T, K)
|
||||
|
||||
def test_gate_lora_delta_applied(self):
|
||||
"""Gate LoRA delta should affect routing logits."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights_no_lora, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
# Large delta should change the results
|
||||
delta = torch.randn(E, H) * 10.0
|
||||
weights_with_lora, _, _, _ = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, delta
|
||||
)
|
||||
|
||||
assert not torch.equal(weights_no_lora, weights_with_lora)
|
||||
|
||||
def test_no_bias_does_not_crash(self):
|
||||
"""Calling _sigmoid_topk_route with no e_score_correction_bias should not crash."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
T, H, E, K = 8, 16, 8, 2
|
||||
gate = SimpleNamespace(weight=torch.randn(E, H))
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=1,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, top_k, num_experts = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert weights.shape == (T, K)
|
||||
assert experts.shape == (T, K)
|
||||
# Without bias, scores_for_choice == sigmoid(logits) — all positive
|
||||
assert (weights >= 0).all()
|
||||
|
||||
def test_missing_topk_group_defaults_to_n_group(self):
|
||||
"""When topk_group is absent but n_group > 1, should default to n_group (no-op masking)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_sigmoid_topk_route,
|
||||
)
|
||||
|
||||
T, H, E, K, n_group = 8, 16, 16, 2, 4
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
e_score_correction_bias=torch.zeros(E),
|
||||
)
|
||||
# Intentionally omit topk_group
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
n_routed_experts=E,
|
||||
n_group=n_group,
|
||||
norm_topk_prob=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
# Should not raise AttributeError; defaults topk_group to n_group
|
||||
weights, experts, top_k_out, num_experts = _sigmoid_topk_route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
assert weights.shape == (T, K)
|
||||
assert experts.shape == (T, K)
|
||||
|
||||
|
||||
class TestRoutingStrategyDetection:
|
||||
"""Test that _route dispatches to the correct strategy."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_triton(self):
|
||||
_skip_without_triton()
|
||||
|
||||
def test_softmax_for_qwen_style(self):
|
||||
"""Block without e_score_correction_bias should use softmax."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
|
||||
|
||||
gate = _make_softmax_gate(E=4, H=16, K=2)
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(8, 16)
|
||||
|
||||
weights, experts, top_k, num_experts = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert weights.shape == (8, 2)
|
||||
assert experts.shape == (8, 2)
|
||||
assert top_k == 2
|
||||
assert num_experts == 4
|
||||
per_token_sums = weights.sum(dim=-1)
|
||||
assert torch.allclose(per_token_sums, torch.ones(8), atol=1e-5)
|
||||
|
||||
def test_sigmoid_for_glm_style(self):
|
||||
"""Block with e_score_correction_bias on gate should use sigmoid."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True, n_group=1)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, top_k, num_experts = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert weights.shape == (T, K)
|
||||
assert experts.shape == (T, K)
|
||||
assert (weights >= 0).all()
|
||||
|
||||
def test_sigmoid_for_minimax_m2_style(self):
|
||||
"""Block with e_score_correction_bias on block (not gate) should use sigmoid."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
|
||||
|
||||
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
|
||||
gate = moe_block.gate
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
weights, experts, top_k, num_experts = _route(
|
||||
moe_block, gate, hidden, gate.weight, None
|
||||
)
|
||||
|
||||
assert weights.shape == (T, K)
|
||||
assert (weights >= 0).all()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 8. Generic shared expert handling
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestGenericSharedExpert:
|
||||
"""Test _compute_shared_expert from layers.py."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_triton(self):
|
||||
_skip_without_triton()
|
||||
|
||||
def test_shared_expert_singular(self):
|
||||
"""shared_expert attribute (Qwen2MoE style)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
called = torch.randn(4, 8)
|
||||
moe_block = SimpleNamespace(
|
||||
shared_expert=lambda x: called,
|
||||
)
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
||||
assert torch.equal(result, called)
|
||||
|
||||
def test_shared_experts_plural(self):
|
||||
"""shared_experts attribute (DeepSeek V3 style)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
called = torch.randn(4, 8)
|
||||
moe_block = SimpleNamespace(
|
||||
shared_experts=lambda x: called,
|
||||
)
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
||||
assert torch.equal(result, called)
|
||||
|
||||
def test_shared_mlp(self):
|
||||
"""shared_mlp attribute (Hunyuan style)."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
called = torch.randn(4, 8)
|
||||
moe_block = SimpleNamespace(
|
||||
shared_mlp=lambda x: called,
|
||||
)
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
||||
assert torch.equal(result, called)
|
||||
|
||||
def test_shared_expert_with_gate(self):
|
||||
"""shared_expert + shared_expert_gate applies sigmoid gating."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
H = 8
|
||||
expert_out = torch.ones(4, H)
|
||||
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731
|
||||
|
||||
moe_block = SimpleNamespace(
|
||||
shared_expert=lambda x: expert_out,
|
||||
shared_expert_gate=gate_fn,
|
||||
)
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, H))
|
||||
expected = expert_out * 0.5 # sigmoid(0) = 0.5
|
||||
assert torch.allclose(result, expected, atol=1e-6)
|
||||
|
||||
def test_no_shared_expert(self):
|
||||
"""No shared expert attributes returns None."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_compute_shared_expert,
|
||||
)
|
||||
|
||||
moe_block = SimpleNamespace()
|
||||
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
||||
assert result is None
|
||||
|
||||
Reference in New Issue
Block a user