* consolidate behavioud of routing in scattermoe kernels * collect telemetry on best chosen autotuned kernel * properly collect data * Fix property name and get smem too * handle issues raised by coderabbit * add tests for parity before refactoring
712 lines
25 KiB
Python
712 lines
25 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# Copyright (c) Axolotl AI
|
|
# Licensed under the Apache License, Version 2.0
|
|
|
|
"""
|
|
Unit tests for scattermoe-lora.
|
|
|
|
Tests cover:
|
|
- KernelsArgs validator: disable_mlp_kernel
|
|
- ParallelExperts: scaling=0.0 not treated as falsy
|
|
- single2scatter: non-aligned K/N dimensions
|
|
- group_compileable: coeff=None accepted
|
|
- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract
|
|
- Routing strategy detection and sigmoid routing
|
|
- Generic shared expert handling
|
|
"""
|
|
|
|
from types import SimpleNamespace
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
# ============================================================================
|
|
# 1. KernelsArgs: disable_mlp_kernel validator
|
|
# ============================================================================
|
|
|
|
|
|
class TestKernelsArgsValidator:
|
|
"""Test that disable_mlp_kernel sets both flags correctly.
|
|
|
|
These tests call the validator classmethod directly on raw dicts,
|
|
since lora_mlp_kernel / mlp_kernel are not declared model fields.
|
|
"""
|
|
|
|
def test_disables_lora_mlp_kernel_when_scattermoe(self):
|
|
"""lora_mlp_kernel=True gets set to False when use_scattermoe=True."""
|
|
from axolotl.integrations.kernels.args import KernelsArgs
|
|
|
|
data = {
|
|
"use_kernels": True,
|
|
"use_scattermoe": True,
|
|
"lora_mlp_kernel": True,
|
|
}
|
|
result = KernelsArgs.disable_mlp_kernel(data)
|
|
assert result["lora_mlp_kernel"] is False
|
|
assert result["mlp_kernel"] is False
|
|
|
|
def test_mlp_kernel_disabled_without_lora(self):
|
|
"""Even without lora_mlp_kernel, mlp_kernel should be disabled."""
|
|
from axolotl.integrations.kernels.args import KernelsArgs
|
|
|
|
data = {
|
|
"use_kernels": True,
|
|
"use_scattermoe": True,
|
|
}
|
|
result = KernelsArgs.disable_mlp_kernel(data)
|
|
assert result["mlp_kernel"] is False
|
|
# lora_mlp_kernel was not in data, should not be added
|
|
assert "lora_mlp_kernel" not in result
|
|
|
|
def test_lora_mlp_kernel_false_unchanged(self):
|
|
"""lora_mlp_kernel=False should stay False (no warning, no change)."""
|
|
from axolotl.integrations.kernels.args import KernelsArgs
|
|
|
|
data = {
|
|
"use_kernels": True,
|
|
"use_scattermoe": True,
|
|
"lora_mlp_kernel": False,
|
|
}
|
|
result = KernelsArgs.disable_mlp_kernel(data)
|
|
assert result["lora_mlp_kernel"] is False
|
|
|
|
def test_no_change_when_scattermoe_disabled(self):
|
|
"""When use_scattermoe is not True, nothing should be changed."""
|
|
from axolotl.integrations.kernels.args import KernelsArgs
|
|
|
|
data = {
|
|
"use_kernels": True,
|
|
"use_scattermoe": False,
|
|
"lora_mlp_kernel": True,
|
|
}
|
|
result = KernelsArgs.disable_mlp_kernel(data)
|
|
assert result["lora_mlp_kernel"] is True
|
|
|
|
|
|
class TestParallelExpertsScaling:
|
|
"""Test that scaling=0.0 is preserved and not overridden to 1.0."""
|
|
|
|
def test_scaling_zero_preserved(self):
|
|
"""scaling=0.0 should be passed as 0.0, not replaced with 1.0."""
|
|
pytest.importorskip("triton")
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
|
|
ParallelExperts,
|
|
)
|
|
|
|
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
|
|
pe.set_lora(
|
|
lora_A=torch.randn(4, 4),
|
|
lora_B=torch.randn(4, 4),
|
|
scaling=0.0,
|
|
)
|
|
assert pe._lora_scaling == 0.0
|
|
|
|
# Patch parallel_linear_lora to capture the scaling arg
|
|
with patch(
|
|
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
|
|
) as mock_pll:
|
|
mock_pll.return_value = torch.randn(4, 4)
|
|
# Create dummy routing tensors
|
|
pe.forward(
|
|
inputs=torch.randn(2, 4),
|
|
k=1,
|
|
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
|
|
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
|
|
expert_offsets=torch.tensor([2, 4]),
|
|
)
|
|
# Check that scaling=0.0 was passed, not 1.0
|
|
call_kwargs = mock_pll.call_args
|
|
assert (
|
|
call_kwargs.kwargs.get("scaling") == 0.0
|
|
or call_kwargs[1].get("scaling") == 0.0
|
|
), f"Expected scaling=0.0 but got {call_kwargs}"
|
|
|
|
def test_scaling_none_defaults_to_one(self):
|
|
"""scaling=None (no LoRA attached) should default to 1.0."""
|
|
pytest.importorskip("triton")
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
|
|
ParallelExperts,
|
|
)
|
|
|
|
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
|
|
# No set_lora called, so _lora_scaling is None
|
|
|
|
with patch(
|
|
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
|
|
) as mock_pll:
|
|
mock_pll.return_value = torch.randn(4, 4)
|
|
pe.forward(
|
|
inputs=torch.randn(2, 4),
|
|
k=1,
|
|
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
|
|
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
|
|
expert_offsets=torch.tensor([2, 4]),
|
|
)
|
|
call_kwargs = mock_pll.call_args
|
|
scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get(
|
|
"scaling"
|
|
)
|
|
assert scaling_val == 1.0, (
|
|
f"Expected scaling=1.0 for None but got {scaling_val}"
|
|
)
|
|
|
|
def test_scaling_positive_preserved(self):
|
|
"""Normal positive scaling should be preserved."""
|
|
pytest.importorskip("triton")
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
|
|
ParallelExperts,
|
|
)
|
|
|
|
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
|
|
pe.set_lora(
|
|
lora_A=torch.randn(4, 4),
|
|
lora_B=torch.randn(4, 4),
|
|
scaling=0.5,
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
|
|
) as mock_pll:
|
|
mock_pll.return_value = torch.randn(4, 4)
|
|
pe.forward(
|
|
inputs=torch.randn(2, 4),
|
|
k=1,
|
|
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
|
|
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
|
|
expert_offsets=torch.tensor([2, 4]),
|
|
)
|
|
call_kwargs = mock_pll.call_args
|
|
scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get(
|
|
"scaling"
|
|
)
|
|
assert scaling_val == 0.5
|
|
|
|
|
|
# ============================================================================
|
|
# 4. single2scatter: non-aligned K/N dimensions (GPU only)
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
class TestSingle2ScatterBounds:
|
|
"""Test single2scatter with non-aligned dimensions."""
|
|
|
|
def test_non_aligned_k(self):
|
|
"""K not a multiple of BLOCK_K should produce correct results."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
|
|
single2scatter,
|
|
)
|
|
|
|
E, K, N = 2, 100, 128 # K=100 not a multiple of 128
|
|
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
|
|
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
|
|
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
|
|
|
|
Y = single2scatter(X, W, expert_idxs)
|
|
assert Y.shape == (2, N)
|
|
|
|
# Verify against manual computation
|
|
Y_ref_0 = X[0] @ W[0]
|
|
Y_ref_1 = X[0] @ W[1]
|
|
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
|
|
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
|
|
|
|
def test_non_aligned_n(self):
|
|
"""N not a multiple of BLOCK_N should produce correct results."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
|
|
single2scatter,
|
|
)
|
|
|
|
E, K, N = 2, 128, 100 # N=100 not a multiple of 128
|
|
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
|
|
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
|
|
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
|
|
|
|
Y = single2scatter(X, W, expert_idxs)
|
|
assert Y.shape == (2, N)
|
|
|
|
Y_ref_0 = X[0] @ W[0]
|
|
Y_ref_1 = X[0] @ W[1]
|
|
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
|
|
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
|
|
|
|
def test_non_aligned_both(self):
|
|
"""Both K and N not aligned should produce correct results."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
|
|
single2scatter,
|
|
)
|
|
|
|
E, K, N = 2, 100, 100 # Neither aligned to 128
|
|
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
|
|
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
|
|
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
|
|
|
|
Y = single2scatter(X, W, expert_idxs)
|
|
assert Y.shape == (2, N)
|
|
|
|
Y_ref_0 = X[0] @ W[0]
|
|
Y_ref_1 = X[0] @ W[1]
|
|
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
|
|
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
|
|
|
|
|
|
# ============================================================================
|
|
# 5. group_compileable: coeff=None accepted
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
class TestGroupCoeffNone:
|
|
"""Test that group() works with coeff=None."""
|
|
|
|
def test_group_with_none_coeff(self):
|
|
"""group() should accept coeff=None without errors."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group
|
|
|
|
M, K = 4, 32
|
|
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
|
|
sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long)
|
|
|
|
# This should not raise a TypeError
|
|
Y = group(A, sorted_expert_idxs, coeff=None, fan_out=1)
|
|
assert Y.shape == (M, K)
|
|
|
|
def test_group_with_coeff(self):
|
|
"""group() should also work with actual coeff values."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group
|
|
|
|
M, K = 4, 32
|
|
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
|
|
sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long)
|
|
coeff = torch.ones(M, device="cuda", dtype=torch.float32) * 0.5
|
|
|
|
Y = group(A, sorted_expert_idxs, coeff=coeff, fan_out=1)
|
|
assert Y.shape == (M, K)
|
|
|
|
|
|
# ============================================================================
|
|
# 6. Layer return value contracts
|
|
# ============================================================================
|
|
|
|
|
|
class TestLayerReturnValues:
|
|
"""Test that layer forward methods return the correct types."""
|
|
|
|
def test_hf_scatter_moe_returns_single_tensor(self):
|
|
"""HFScatterMoEGatedMLP.forward should return a single tensor, not a tuple."""
|
|
pytest.importorskip("triton")
|
|
# Verify the forward method signature and return annotation
|
|
import inspect
|
|
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
HFScatterMoEGatedMLP,
|
|
)
|
|
|
|
sig = inspect.signature(HFScatterMoEGatedMLP.forward)
|
|
# It's a staticmethod taking (self, layer_input)
|
|
params = list(sig.parameters.keys())
|
|
assert "self" in params
|
|
assert "layer_input" in params
|
|
|
|
def test_scatter_moe_gated_mlp_docstring_no_router_logits(self):
|
|
"""ScatterMoEGatedMLP.forward docstring should not mention router logits as return."""
|
|
pytest.importorskip("triton")
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
ScatterMoEGatedMLP,
|
|
)
|
|
|
|
docstring = ScatterMoEGatedMLP.forward.__doc__
|
|
assert docstring is not None
|
|
# The docstring should mention output tensor but NOT router logits
|
|
assert "Output tensor" in docstring or "output tensor" in docstring.lower()
|
|
assert "Router logits" not in docstring, (
|
|
"Docstring should not mention 'Router logits' in Returns section"
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# 7. Routing strategy detection and sigmoid routing
|
|
# ============================================================================
|
|
|
|
|
|
def _make_softmax_gate(E=4, H=16, K=2):
|
|
"""Create a mock softmax-style gate (Qwen/OLMoE)."""
|
|
return SimpleNamespace(
|
|
weight=torch.randn(E, H),
|
|
top_k=K,
|
|
num_experts=E,
|
|
norm_topk_prob=True,
|
|
)
|
|
|
|
|
|
def _make_sigmoid_gate_with_bias(E=16, H=16):
|
|
"""Create a mock sigmoid-style gate with e_score_correction_bias on gate."""
|
|
return SimpleNamespace(
|
|
weight=torch.randn(E, H),
|
|
e_score_correction_bias=torch.zeros(E),
|
|
)
|
|
|
|
|
|
def _make_sigmoid_moe_block(
|
|
T=8, H=16, E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
|
|
):
|
|
"""Create a mock GLM/DeepSeek-style MoE block for sigmoid routing tests."""
|
|
if bias_on_gate:
|
|
gate = SimpleNamespace(
|
|
weight=torch.randn(E, H),
|
|
e_score_correction_bias=torch.zeros(E),
|
|
)
|
|
moe_block = SimpleNamespace(
|
|
gate=gate,
|
|
top_k=K,
|
|
n_routed_experts=E,
|
|
n_group=n_group,
|
|
topk_group=topk_group,
|
|
norm_topk_prob=True,
|
|
routed_scaling_factor=1.0,
|
|
)
|
|
else:
|
|
# minimax_m2 style: bias on block, not gate
|
|
gate = SimpleNamespace(
|
|
weight=torch.randn(E, H),
|
|
top_k=K,
|
|
)
|
|
moe_block = SimpleNamespace(
|
|
gate=gate,
|
|
top_k=K,
|
|
e_score_correction_bias=torch.zeros(E),
|
|
)
|
|
return moe_block, T, H, E, K
|
|
|
|
|
|
def _skip_without_triton():
|
|
pytest.importorskip("triton")
|
|
|
|
|
|
class TestSigmoidRoutingInScatterMoE:
|
|
"""Test _sigmoid_topk_route from layers.py."""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _require_triton(self):
|
|
_skip_without_triton()
|
|
|
|
def test_output_shapes(self):
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
|
|
moe_block, T, H, E, K = _make_sigmoid_moe_block()
|
|
gate = moe_block.gate
|
|
hidden = torch.randn(T, H)
|
|
|
|
weights, experts, top_k, num_experts = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
|
|
assert weights.shape == (T, K)
|
|
assert experts.shape == (T, K)
|
|
assert top_k == K
|
|
assert num_experts == E
|
|
|
|
def test_weights_nonnegative(self):
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
|
|
moe_block, T, H, E, K = _make_sigmoid_moe_block()
|
|
gate = moe_block.gate
|
|
hidden = torch.randn(T, H)
|
|
|
|
weights, _, _, _ = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
assert (weights >= 0).all()
|
|
|
|
def test_group_selection_restricts_experts(self):
|
|
"""With n_group=4, topk_group=1, experts should be from selected groups."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
|
|
moe_block, T, H, E, K = _make_sigmoid_moe_block(
|
|
E=16, K=2, n_group=4, topk_group=1
|
|
)
|
|
gate = moe_block.gate
|
|
hidden = torch.randn(T, H)
|
|
|
|
_, expert_idx, _, _ = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
|
|
# Each token's experts should fall within a single group (size E//n_group=4)
|
|
for t in range(T):
|
|
experts_t = expert_idx[t]
|
|
groups = experts_t // (E // moe_block.n_group)
|
|
assert (groups == groups[0]).all()
|
|
|
|
def test_scaling_factor_applied(self):
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
|
|
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
|
|
gate = moe_block.gate
|
|
hidden = torch.randn(T, H)
|
|
|
|
weights_1x, _, _, _ = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
|
|
moe_block.routed_scaling_factor = 2.0
|
|
weights_2x, _, _, _ = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
|
|
assert torch.allclose(weights_2x, weights_1x * 2.0, atol=1e-5)
|
|
|
|
def test_bias_on_gate(self):
|
|
"""e_score_correction_bias on gate is found."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
|
|
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True)
|
|
gate = moe_block.gate
|
|
hidden = torch.randn(T, H)
|
|
|
|
weights, experts, _, _ = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
assert weights.shape == (T, K)
|
|
|
|
def test_bias_on_block(self):
|
|
"""e_score_correction_bias on moe_block (not gate) is found."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
|
|
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
|
|
gate = moe_block.gate
|
|
hidden = torch.randn(T, H)
|
|
|
|
weights, experts, _, _ = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
assert weights.shape == (T, K)
|
|
|
|
def test_gate_lora_delta_applied(self):
|
|
"""Gate LoRA delta should affect routing logits."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
|
|
moe_block, T, H, E, K = _make_sigmoid_moe_block(n_group=1)
|
|
gate = moe_block.gate
|
|
hidden = torch.randn(T, H)
|
|
|
|
weights_no_lora, _, _, _ = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
|
|
# Large delta should change the results
|
|
delta = torch.randn(E, H) * 10.0
|
|
weights_with_lora, _, _, _ = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, delta
|
|
)
|
|
|
|
assert not torch.equal(weights_no_lora, weights_with_lora)
|
|
|
|
def test_no_bias_does_not_crash(self):
|
|
"""Calling _sigmoid_topk_route with no e_score_correction_bias should not crash."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
|
|
T, H, E, K = 8, 16, 8, 2
|
|
gate = SimpleNamespace(weight=torch.randn(E, H))
|
|
moe_block = SimpleNamespace(
|
|
gate=gate,
|
|
top_k=K,
|
|
n_routed_experts=E,
|
|
n_group=1,
|
|
norm_topk_prob=True,
|
|
routed_scaling_factor=1.0,
|
|
)
|
|
hidden = torch.randn(T, H)
|
|
|
|
weights, experts, top_k, num_experts = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
assert weights.shape == (T, K)
|
|
assert experts.shape == (T, K)
|
|
# Without bias, scores_for_choice == sigmoid(logits) — all positive
|
|
assert (weights >= 0).all()
|
|
|
|
def test_missing_topk_group_defaults_to_n_group(self):
|
|
"""When topk_group is absent but n_group > 1, should default to n_group (no-op masking)."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_sigmoid_topk_route,
|
|
)
|
|
|
|
T, H, E, K, n_group = 8, 16, 16, 2, 4
|
|
gate = SimpleNamespace(
|
|
weight=torch.randn(E, H),
|
|
e_score_correction_bias=torch.zeros(E),
|
|
)
|
|
# Intentionally omit topk_group
|
|
moe_block = SimpleNamespace(
|
|
gate=gate,
|
|
top_k=K,
|
|
n_routed_experts=E,
|
|
n_group=n_group,
|
|
norm_topk_prob=True,
|
|
routed_scaling_factor=1.0,
|
|
)
|
|
hidden = torch.randn(T, H)
|
|
|
|
# Should not raise AttributeError; defaults topk_group to n_group
|
|
weights, experts, top_k_out, num_experts = _sigmoid_topk_route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
assert weights.shape == (T, K)
|
|
assert experts.shape == (T, K)
|
|
|
|
|
|
class TestRoutingStrategyDetection:
|
|
"""Test that _route dispatches to the correct strategy."""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _require_triton(self):
|
|
_skip_without_triton()
|
|
|
|
def test_softmax_for_qwen_style(self):
|
|
"""Block without e_score_correction_bias should use softmax."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
|
|
|
|
gate = _make_softmax_gate(E=4, H=16, K=2)
|
|
moe_block = SimpleNamespace(gate=gate)
|
|
hidden = torch.randn(8, 16)
|
|
|
|
weights, experts, top_k, num_experts = _route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
|
|
assert weights.shape == (8, 2)
|
|
assert experts.shape == (8, 2)
|
|
assert top_k == 2
|
|
assert num_experts == 4
|
|
per_token_sums = weights.sum(dim=-1)
|
|
assert torch.allclose(per_token_sums, torch.ones(8), atol=1e-5)
|
|
|
|
def test_sigmoid_for_glm_style(self):
|
|
"""Block with e_score_correction_bias on gate should use sigmoid."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
|
|
|
|
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=True, n_group=1)
|
|
gate = moe_block.gate
|
|
hidden = torch.randn(T, H)
|
|
|
|
weights, experts, top_k, num_experts = _route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
|
|
assert weights.shape == (T, K)
|
|
assert experts.shape == (T, K)
|
|
assert (weights >= 0).all()
|
|
|
|
def test_sigmoid_for_minimax_m2_style(self):
|
|
"""Block with e_score_correction_bias on block (not gate) should use sigmoid."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import _route
|
|
|
|
moe_block, T, H, E, K = _make_sigmoid_moe_block(bias_on_gate=False)
|
|
gate = moe_block.gate
|
|
hidden = torch.randn(T, H)
|
|
|
|
weights, experts, top_k, num_experts = _route(
|
|
moe_block, gate, hidden, gate.weight, None
|
|
)
|
|
|
|
assert weights.shape == (T, K)
|
|
assert (weights >= 0).all()
|
|
|
|
|
|
# ============================================================================
|
|
# 8. Generic shared expert handling
|
|
# ============================================================================
|
|
|
|
|
|
class TestGenericSharedExpert:
|
|
"""Test _compute_shared_expert from layers.py."""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _require_triton(self):
|
|
_skip_without_triton()
|
|
|
|
def test_shared_expert_singular(self):
|
|
"""shared_expert attribute (Qwen2MoE style)."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_compute_shared_expert,
|
|
)
|
|
|
|
called = torch.randn(4, 8)
|
|
moe_block = SimpleNamespace(
|
|
shared_expert=lambda x: called,
|
|
)
|
|
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
|
assert torch.equal(result, called)
|
|
|
|
def test_shared_experts_plural(self):
|
|
"""shared_experts attribute (DeepSeek V3 style)."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_compute_shared_expert,
|
|
)
|
|
|
|
called = torch.randn(4, 8)
|
|
moe_block = SimpleNamespace(
|
|
shared_experts=lambda x: called,
|
|
)
|
|
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
|
assert torch.equal(result, called)
|
|
|
|
def test_shared_mlp(self):
|
|
"""shared_mlp attribute (Hunyuan style)."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_compute_shared_expert,
|
|
)
|
|
|
|
called = torch.randn(4, 8)
|
|
moe_block = SimpleNamespace(
|
|
shared_mlp=lambda x: called,
|
|
)
|
|
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
|
assert torch.equal(result, called)
|
|
|
|
def test_shared_expert_with_gate(self):
|
|
"""shared_expert + shared_expert_gate applies sigmoid gating."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_compute_shared_expert,
|
|
)
|
|
|
|
H = 8
|
|
expert_out = torch.ones(4, H)
|
|
gate_fn = lambda x: torch.zeros(4, H) # noqa: E731
|
|
|
|
moe_block = SimpleNamespace(
|
|
shared_expert=lambda x: expert_out,
|
|
shared_expert_gate=gate_fn,
|
|
)
|
|
result = _compute_shared_expert(moe_block, torch.randn(4, H))
|
|
expected = expert_out * 0.5 # sigmoid(0) = 0.5
|
|
assert torch.allclose(result, expected, atol=1e-6)
|
|
|
|
def test_no_shared_expert(self):
|
|
"""No shared expert attributes returns None."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_compute_shared_expert,
|
|
)
|
|
|
|
moe_block = SimpleNamespace()
|
|
result = _compute_shared_expert(moe_block, torch.randn(4, 8))
|
|
assert result is None
|