feat: add sonicmoe (#3411)

* feat: add sonicmoe

* feat: add torch compile for routing

* feat: add routing smoke test

* feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe

* fix: disable mlp kernel for sonicmoe too

* feat: update to sonicmoe release

* chore: update import following new sonicmoe changes

* feat: update handling for blackwell

* feat: add sonicmoe e2e test

* fix: installation for updated sonicmoe

* fix: git commit

* fix: ignore py req and fix metadata

* fix: increase min hidden size to match sonicmoe kernel min

* fix: attempt properly interleave and handle unpatch mid-test

* chore: refactor teardown better

* chore: refactor to re-use rearrange

* fix: add idempotency guard

* fix: address comments on CI memory and interleave

* fix: tests grad, param doublewrapped
This commit is contained in:
NanoCode012
2026-03-06 01:43:31 +07:00
committed by GitHub
parent 1eaf4d7418
commit 6a8baf8fa7
12 changed files with 1698 additions and 42 deletions

View File

@@ -0,0 +1,288 @@
"""
End-to-end gradient and convergence tests for SonicMoE integration.
Requires:
- H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90)
- sonicmoe package installed
- transformers with Qwen3MoE support
Usage:
pytest tests/e2e/integrations/test_sonicmoe.py -v -s
"""
import importlib.util
import math
import pytest
import torch
_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None
_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
pytestmark = [
pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"),
pytest.mark.skipif(
not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)"
),
pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"),
]
def _create_tiny_qwen3_config():
"""Create a minimal Qwen3MoE config for fast testing."""
from transformers import AutoConfig
config = AutoConfig.for_model("qwen3_moe")
config.hidden_size = 512
config.intermediate_size = 1024
config.moe_intermediate_size = 64
config.num_attention_heads = 16
config.num_key_value_heads = 2
config.head_dim = 32
config.num_hidden_layers = 2
config.num_experts = 8
config.num_experts_per_tok = 2
config.vocab_size = 1000
config.max_position_embeddings = 128
config.norm_topk_prob = True
config.torch_dtype = torch.bfloat16
return config
def _interleave_gate_up_weights(model):
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
interleave_gate_up,
)
with torch.no_grad():
for name, param in model.named_parameters():
if "gate_up_proj" in name:
param.copy_(interleave_gate_up(param))
def _unpatch_sonicmoe():
"""Restore original forward on the MoE block class if it was patched."""
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
for moe_cls in resolve_moe_block_classes("qwen3_moe"):
if hasattr(moe_cls, "_original_forward"):
moe_cls.forward = moe_cls._original_forward
del moe_cls._original_forward
class TestSonicMoEForwardCorrectness:
"""Verify SonicMoE-patched model produces same output as original."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_forward_output_matches(self):
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
# Original model
model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
with torch.no_grad():
out_orig = model_orig(input_ids)
# Patched model (same weights, interleaved for SonicMoE)
model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
model_patched.load_state_dict(model_orig.state_dict())
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model_patched)
with torch.no_grad():
out_patched = model_patched(input_ids)
max_diff = (out_orig.logits - out_patched.logits).abs().max().item()
assert torch.allclose(
out_orig.logits, out_patched.logits, atol=1e-1, rtol=1e-1
), f"Output mismatch: max diff={max_diff:.6f}"
class TestSonicMoEGradientCorrectness:
"""Compare gradients between original HuggingFace and SonicMoE-patched forward."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_gradients_match(self):
"""Verify all parameter gradients match between original and patched."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
deinterleave_gate_up,
)
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
# ---------- Original model ----------
model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
out_orig = model_orig(input_ids, labels=input_ids)
out_orig.loss.backward()
grads_orig = {
n: p.grad.float().clone()
for n, p in model_orig.named_parameters()
if p.grad is not None
}
loss_orig = out_orig.loss.item()
# ---------- SonicMoE-patched model (same weights, interleaved) ----------
model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
model_patched.load_state_dict(model_orig.state_dict())
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model_patched)
out_patched = model_patched(input_ids, labels=input_ids)
out_patched.loss.backward()
grads_patched = {}
for n, p in model_patched.named_parameters():
if p.grad is None:
continue
g = p.grad.float().clone()
# gate_up_proj grads are in interleaved layout, de-interleave to match orig
if "gate_up_proj" in n:
g = deinterleave_gate_up(g)
grads_patched[n] = g
loss_patched = out_patched.loss.item()
# ---------- Compare ----------
assert abs(loss_orig - loss_patched) < 0.5, (
f"Loss mismatch: orig={loss_orig:.4f}, patched={loss_patched:.4f}"
)
# All parameters with gradients in original should have them in patched
missing = set(grads_orig.keys()) - set(grads_patched.keys())
assert not missing, f"Missing gradients in patched model: {missing}"
# Compare gradient values
# bf16 with different GEMM impls (cuBLAS vs CUTLASS) can diverge,
# so use generous tolerance: flag only if both rel >10% AND abs >1e-2
mismatches = []
for name in grads_orig:
if name not in grads_patched:
continue
g_orig = grads_orig[name]
g_patched = grads_patched[name]
max_diff = (g_orig - g_patched).abs().max().item()
rel_diff = max_diff / (g_orig.abs().max().item() + 1e-8)
if rel_diff > 0.1 and max_diff > 1e-2:
mismatches.append(
f" {name}: max_abs_diff={max_diff:.6f}, rel_diff={rel_diff:.4f}"
)
assert not mismatches, (
"Gradient mismatches (rel_diff > 10% and abs_diff > 1e-2):\n"
+ "\n".join(mismatches)
)
def test_router_weights_receive_gradients(self):
"""Verify that router (gate) weights get non-zero gradients."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
out = model(input_ids, labels=input_ids)
out.loss.backward()
gate_grads_found = False
for name, param in model.named_parameters():
if "gate" in name and "weight" in name:
gate_grads_found = True
assert param.grad is not None, f"No gradient for router: {name}"
assert param.grad.abs().max() > 0, f"Zero gradient for router: {name}"
assert gate_grads_found, "No gate.weight parameters found in model"
class TestSonicMoETrainingConvergence:
"""Verify loss decreases during training with SonicMoE."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_loss_decreases(self):
"""Run 30 training steps, verify loss decreases and no NaN/Inf."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
losses = []
for step in range(30):
out = model(input_ids, labels=input_ids)
loss = out.loss
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
losses.append(loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
assert losses[-1] < losses[0], (
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
)
def test_expert_weights_update(self):
"""Verify expert weights change during training (not frozen)."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
# Snapshot expert weights before training
expert_weights_before = {}
for name, param in model.named_parameters():
if "experts" in name:
expert_weights_before[name] = param.data.clone()
assert expert_weights_before, "No expert parameters found"
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
for _ in range(5):
out = model(input_ids, labels=input_ids)
out.loss.backward()
optimizer.step()
optimizer.zero_grad()
# Check that expert weights changed
changed = 0
for name, param in model.named_parameters():
if name in expert_weights_before:
if not torch.equal(param.data, expert_weights_before[name]):
changed += 1
assert changed > 0, "No expert weights changed after 5 training steps"

View File

@@ -6,7 +6,7 @@
Unit tests for scattermoe-lora code-review fixes.
Tests cover:
- KernelsArgs validator: disable_mlp_kernel_scattermoe
- KernelsArgs validator: disable_mlp_kernel
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
- ParallelExperts: scaling=0.0 not treated as falsy
- single2scatter: non-aligned K/N dimensions
@@ -20,12 +20,12 @@ import pytest
import torch
# ============================================================================
# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator
# 1. KernelsArgs: disable_mlp_kernel validator
# ============================================================================
class TestKernelsArgsValidator:
"""Test that disable_mlp_kernel_scattermoe sets both flags correctly.
"""Test that disable_mlp_kernel sets both flags correctly.
These tests call the validator classmethod directly on raw dicts,
since lora_mlp_kernel / mlp_kernel are not declared model fields.
@@ -40,7 +40,7 @@ class TestKernelsArgsValidator:
"use_scattermoe": True,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is False
assert result["mlp_kernel"] is False
@@ -52,7 +52,7 @@ class TestKernelsArgsValidator:
"use_kernels": True,
"use_scattermoe": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
result = KernelsArgs.disable_mlp_kernel(data)
assert result["mlp_kernel"] is False
# lora_mlp_kernel was not in data, should not be added
assert "lora_mlp_kernel" not in result
@@ -66,7 +66,7 @@ class TestKernelsArgsValidator:
"use_scattermoe": True,
"lora_mlp_kernel": False,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is False
def test_no_change_when_scattermoe_disabled(self):
@@ -78,7 +78,7 @@ class TestKernelsArgsValidator:
"use_scattermoe": False,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is True

View File

@@ -0,0 +1,428 @@
"""Unit tests for the SonicMoE integration."""
from types import SimpleNamespace
import pytest
import torch
from axolotl.integrations.kernels.args import KernelsArgs
from axolotl.integrations.kernels.sonicmoe.routing import (
sigmoid_topk_routing,
softmax_topk_routing,
)
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
ConcatenatedToInterleaved,
InterleavedToConcatenated,
register_sonicmoe_weight_converter,
)
class TestKernelsArgs:
def test_mutual_exclusivity_raises(self):
with pytest.raises(ValueError, match="Cannot use both"):
KernelsArgs.model_validate({"use_scattermoe": True, "use_sonicmoe": True})
def test_sonicmoe_only(self):
result = KernelsArgs.model_validate({"use_sonicmoe": True})
assert result.use_sonicmoe is True
assert result.use_scattermoe is None
def test_scattermoe_only(self):
result = KernelsArgs.model_validate({"use_scattermoe": True})
assert result.use_scattermoe is True
assert result.use_sonicmoe is None
def test_neither_set(self):
result = KernelsArgs.model_validate({})
assert result.use_scattermoe is None
assert result.use_sonicmoe is None
def test_disables_mlp_kernel_when_sonicmoe(self):
data = {"use_sonicmoe": True, "lora_mlp_kernel": True}
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is False
assert result["mlp_kernel"] is False
class TestConcatenatedToInterleaved:
@pytest.fixture
def sample_tensor(self):
"""Create a test tensor [E=2, 2*I=4, H=3] with distinct gate/up values."""
E, I, H = 2, 2, 3 # noqa: E741
gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H)
up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H)
return torch.cat([gate, up], dim=1)
def test_interleave_rows_alternate(self, sample_tensor):
op = ConcatenatedToInterleaved(dim=1)
result = op.convert(
{"test": sample_tensor},
source_patterns=["test"],
target_patterns=["test"],
)
interleaved = result["test"]
# For expert 0: even rows should be gate, odd rows should be up
E, two_I, H = sample_tensor.shape
I = two_I // 2 # noqa: E741
gate_orig = sample_tensor[:, :I, :]
up_orig = sample_tensor[:, I:, :]
assert torch.equal(interleaved[:, 0::2, :], gate_orig)
assert torch.equal(interleaved[:, 1::2, :], up_orig)
def test_interleave_handles_list_input(self, sample_tensor):
op = ConcatenatedToInterleaved(dim=1)
result = op.convert(
{"test": [sample_tensor]},
source_patterns=["test"],
target_patterns=["test"],
)
assert result["test"].shape == sample_tensor.shape
def test_reverse_op_type(self):
op = ConcatenatedToInterleaved(dim=1)
assert isinstance(op.reverse_op, InterleavedToConcatenated)
assert op.reverse_op.dim == 1
class TestInterleavedToConcatenated:
@pytest.fixture
def interleaved_tensor(self):
"""Create an interleaved tensor [E=2, 2*I=4, H=3]."""
E, I, H = 2, 2, 3 # noqa: E741
gate = torch.arange(1, E * I * H + 1, dtype=torch.float32).reshape(E, I, H)
up = torch.arange(100, 100 + E * I * H, dtype=torch.float32).reshape(E, I, H)
interleaved = torch.empty(E, 2 * I, H)
interleaved[:, 0::2, :] = gate
interleaved[:, 1::2, :] = up
return interleaved
def test_deinterleave_gate_up_separated(self, interleaved_tensor):
op = InterleavedToConcatenated(dim=1)
result = op.convert(
{"test": interleaved_tensor},
source_patterns=["test"],
target_patterns=["test"],
)
concatenated = result["test"]
E, two_I, H = concatenated.shape
I = two_I // 2 # noqa: E741
# First half should be gate (even rows from interleaved)
assert torch.equal(concatenated[:, :I, :], interleaved_tensor[:, 0::2, :])
# Second half should be up (odd rows from interleaved)
assert torch.equal(concatenated[:, I:, :], interleaved_tensor[:, 1::2, :])
def test_reverse_op_type(self):
op = InterleavedToConcatenated(dim=1)
assert isinstance(op.reverse_op, ConcatenatedToInterleaved)
assert op.reverse_op.dim == 1
class TestRoundTrip:
@pytest.fixture
def concat_tensor(self):
E, I, H = 4, 8, 16 # noqa: E741
gate = torch.randn(E, I, H)
up = torch.randn(E, I, H)
return torch.cat([gate, up], dim=1)
def test_interleave_then_deinterleave_is_identity(self, concat_tensor):
fwd = ConcatenatedToInterleaved(dim=1)
rev = InterleavedToConcatenated(dim=1)
interleaved = fwd.convert(
{"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"]
)["k"]
recovered = rev.convert(
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
)["k"]
assert torch.equal(concat_tensor, recovered)
def test_reverse_op_chain_is_identity(self, concat_tensor):
"""Verify that op.reverse_op produces an exact inverse."""
op = ConcatenatedToInterleaved(dim=1)
rev = op.reverse_op
interleaved = op.convert(
{"k": concat_tensor}, source_patterns=["k"], target_patterns=["k"]
)["k"]
recovered = rev.convert(
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
)["k"]
assert torch.equal(concat_tensor, recovered)
def test_various_shapes(self):
"""Test with different expert counts and dimensions."""
fwd = ConcatenatedToInterleaved(dim=1)
rev = InterleavedToConcatenated(dim=1)
for E, I, H in [(1, 4, 8), (8, 16, 32), (16, 128, 256)]: # noqa: E741
concat = torch.randn(E, 2 * I, H)
interleaved = fwd.convert(
{"k": concat}, source_patterns=["k"], target_patterns=["k"]
)["k"]
recovered = rev.convert(
{"k": interleaved}, source_patterns=["k"], target_patterns=["k"]
)["k"]
assert torch.equal(concat, recovered), (
f"Failed for shape ({E}, {2 * I}, {H})"
)
class TestWeightConverterRegistration:
def test_register_appends_interleave_op(self):
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
register_sonicmoe_weight_converter("qwen3_moe")
modified = get_checkpoint_conversion_mapping("qwen3_moe")
# Find the gate_up_proj converter
gate_up_converter = None
for conv in modified:
if hasattr(conv, "operations") and any(
"gate_up_proj" in pat for pat in conv.target_patterns
):
gate_up_converter = conv
break
assert gate_up_converter is not None
assert isinstance(gate_up_converter.operations[-1], ConcatenatedToInterleaved)
def test_double_registration_is_idempotent(self):
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
register_sonicmoe_weight_converter("qwen3_moe")
register_sonicmoe_weight_converter("qwen3_moe")
modified = get_checkpoint_conversion_mapping("qwen3_moe")
for conv in modified:
if hasattr(conv, "operations") and any(
"gate_up_proj" in pat for pat in conv.target_patterns
):
interleave_count = sum(
isinstance(op, ConcatenatedToInterleaved) for op in conv.operations
)
assert interleave_count == 1, (
f"Expected 1 ConcatenatedToInterleaved op, got {interleave_count}"
)
break
def test_register_unsupported_model_type_warns(self):
# A model type with no conversion mapping should warn but not raise
register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
def _make_qwen_moe_block(T=8, H=16, E=4, K=2):
"""Create a mock qwen-style MoE block for routing tests."""
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
num_experts=E,
norm_topk_prob=True,
)
return SimpleNamespace(gate=gate), T, H, E, K
def _make_glm_moe_block(T=8, H=16, E=16, K=4, n_group=2, topk_group=1):
"""Create a mock GLM5-style MoE block for routing tests."""
gate = SimpleNamespace(
weight=torch.randn(E, H),
e_score_correction_bias=torch.zeros(E),
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
n_routed_experts=E,
n_group=n_group,
topk_group=topk_group,
norm_topk_prob=True,
routed_scaling_factor=1.0,
)
return moe_block, T, H, E, K
def _make_minimax_m2_moe_block(T=8, H=16, E=16, K=4):
"""Create a mock minimax_m2-style MoE block for routing tests.
minimax_m2 uses sigmoid->topk WITHOUT group selection:
- e_score_correction_bias is on the moe_block (not on gate)
- No n_group / topk_group attributes
- Always normalizes (norm_topk_prob defaults to True)
- No routed_scaling_factor (defaults to 1.0)
"""
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
)
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
e_score_correction_bias=torch.zeros(E),
)
return moe_block, T, H, E, K
class TestSoftmaxTopkRouting:
def test_output_shapes(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = softmax_topk_routing(hidden, moe_block)
assert scores.shape == (T * K,)
assert token_idx.shape == (T * K,)
assert expert_idx.shape == (T * K,)
assert logits.shape == (T, E)
def test_scores_are_float32(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
_, token_idx, _, _ = softmax_topk_routing(hidden, moe_block)
# Token indices must be sorted ascending (SonicMoE requirement)
diffs = token_idx[1:] - token_idx[:-1]
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
_, _, expert_idx, _ = softmax_topk_routing(hidden, moe_block)
assert (expert_idx >= 0).all()
assert (expert_idx < E).all()
def test_renormalized_scores_sum_to_one(self):
moe_block, T, H, E, K = _make_qwen_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
per_token_sums = scores.reshape(T, K).sum(dim=-1)
assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5)
class TestSigmoidTopkRouting:
def test_output_shapes(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block)
assert scores.shape == (T * K,)
assert token_idx.shape == (T * K,)
assert expert_idx.shape == (T * K,)
assert logits.shape == (T, E)
def test_scores_are_float32(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
_, token_idx, _, _ = sigmoid_topk_routing(hidden, moe_block)
diffs = token_idx[1:] - token_idx[:-1]
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
assert (expert_idx >= 0).all()
assert (expert_idx < E).all()
def test_scores_are_nonnegative(self):
"""Sigmoid outputs are in [0, 1], so scores should be non-negative."""
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
assert (scores >= 0).all()
def test_scaling_factor_applied(self):
moe_block, T, H, E, K = _make_glm_moe_block()
hidden = torch.randn(T, H)
# Get scores with scaling_factor=1.0
scores_1x, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
# Get scores with scaling_factor=2.0
moe_block.routed_scaling_factor = 2.0
scores_2x, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
assert torch.allclose(scores_2x, scores_1x * 2.0, atol=1e-5)
def test_group_selection_restricts_experts(self):
"""With n_group=4 and topk_group=1, only 1/4 of experts should be selectable."""
moe_block, T, H, E, K = _make_glm_moe_block(E=16, K=2, n_group=4, topk_group=1)
hidden = torch.randn(T, H)
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
# Each token's experts should all fall within a single group (size E//n_group=4)
expert_idx_2d = expert_idx.reshape(T, K)
for t in range(T):
experts = expert_idx_2d[t]
groups = experts // (E // moe_block.n_group)
# All selected experts should be from the same group
assert (groups == groups[0]).all()
class TestMiniMaxM2SigmoidRouting:
"""Tests for minimax_m2 routing: sigmoid->topk without group selection."""
def test_output_shapes(self):
"""Validates getattr defaults work: n_group=1, E from gate.weight.shape[0]."""
moe_block, T, H, E, K = _make_minimax_m2_moe_block()
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = sigmoid_topk_routing(hidden, moe_block)
assert scores.shape == (T * K,)
assert token_idx.shape == (T * K,)
assert expert_idx.shape == (T * K,)
assert logits.shape == (T, E)
def test_bias_on_block_not_gate(self):
"""Verify that e_score_correction_bias on the block (not gate) is used."""
T, H, E, K = 8, 16, 8, 2
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
)
# Large positive bias on expert 0 should make it selected more often
bias = torch.zeros(E)
bias[0] = 100.0
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
e_score_correction_bias=bias,
)
hidden = torch.randn(T, H)
_, _, expert_idx, _ = sigmoid_topk_routing(hidden, moe_block)
# Expert 0 should appear for every token due to the large bias
expert_idx_2d = expert_idx.reshape(T, K)
for t in range(T):
assert 0 in expert_idx_2d[t]

View File

@@ -0,0 +1,158 @@
"""
Gradient correctness tests for SonicMoE routing functions (CPU-only).
Uses torch.autograd.gradcheck with float32 inputs to match the production
code path where routing happens in float32.
"""
import torch
from axolotl.integrations.kernels.sonicmoe.routing import (
sigmoid_topk_routing,
softmax_topk_routing,
)
_GC_EPS = 1e-3
_GC_ATOL = 1e-3
_GC_RTOL = 1e-3
def _make_softmax_moe_block(weight):
gate = torch.nn.Module()
gate.weight = weight
gate.top_k = 2
gate.norm_topk_prob = True
moe_block = torch.nn.Module()
moe_block.gate = gate
return moe_block
def _make_sigmoid_moe_block(weight, bias):
gate = torch.nn.Module()
gate.weight = weight
gate.e_score_correction_bias = bias
moe_block = torch.nn.Module()
moe_block.gate = gate
moe_block.top_k = 2
moe_block.n_routed_experts = weight.shape[0]
moe_block.n_group = 1
moe_block.norm_topk_prob = True
moe_block.routed_scaling_factor = 1.0
return moe_block
class TestSoftmaxTopkRoutingGradcheck:
"""Numerical gradient verification for softmax_topk_routing."""
def test_gradcheck_wrt_gate_weight(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_hidden_states(self):
T, H, E = 4, 8, 4
weight = torch.randn(E, H, dtype=torch.float32)
moe_block = _make_softmax_moe_block(weight)
def fn(hidden):
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_router_logits(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
_, _, _, router_logits = softmax_topk_routing(hidden, moe_block)
return router_logits
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_no_norm_variant(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
moe_block.gate.norm_topk_prob = False
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
class TestSigmoidTopkRoutingGradcheck:
"""Numerical gradient verification for sigmoid_topk_routing."""
def test_gradcheck_wrt_gate_weight(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
bias = torch.zeros(E, dtype=torch.float32)
def fn(weight):
moe_block = _make_sigmoid_moe_block(weight, bias)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_hidden_states(self):
T, H, E = 4, 8, 4
weight = torch.randn(E, H, dtype=torch.float32)
bias = torch.zeros(E, dtype=torch.float32)
moe_block = _make_sigmoid_moe_block(weight, bias)
def fn(hidden):
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_bias(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
weight = torch.randn(E, H, dtype=torch.float32)
def fn(bias):
moe_block = _make_sigmoid_moe_block(weight, bias)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
bias = torch.zeros(E, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(fn, (bias,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL)