Files
axolotl/tests/integrations/test_sonicmoe_gradients.py
NanoCode012 6a8baf8fa7 feat: add sonicmoe (#3411)
* feat: add sonicmoe

* feat: add torch compile for routing

* feat: add routing smoke test

* feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe

* fix: disable mlp kernel for sonicmoe too

* feat: update to sonicmoe release

* chore: update import following new sonicmoe changes

* feat: update handling for blackwell

* feat: add sonicmoe e2e test

* fix: installation for updated sonicmoe

* fix: git commit

* fix: ignore py req and fix metadata

* fix: increase min hidden size to match sonicmoe kernel min

* fix: attempt properly interleave and handle unpatch mid-test

* chore: refactor teardown better

* chore: refactor to re-use rearrange

* fix: add idempotency guard

* fix: address comments on CI memory and interleave

* fix: tests grad, param doublewrapped
2026-03-05 13:43:31 -05:00

159 lines
4.8 KiB
Python

"""
Gradient correctness tests for SonicMoE routing functions (CPU-only).
Uses torch.autograd.gradcheck with float32 inputs to match the production
code path where routing happens in float32.
"""
import torch
from axolotl.integrations.kernels.sonicmoe.routing import (
sigmoid_topk_routing,
softmax_topk_routing,
)
_GC_EPS = 1e-3
_GC_ATOL = 1e-3
_GC_RTOL = 1e-3
def _make_softmax_moe_block(weight):
gate = torch.nn.Module()
gate.weight = weight
gate.top_k = 2
gate.norm_topk_prob = True
moe_block = torch.nn.Module()
moe_block.gate = gate
return moe_block
def _make_sigmoid_moe_block(weight, bias):
gate = torch.nn.Module()
gate.weight = weight
gate.e_score_correction_bias = bias
moe_block = torch.nn.Module()
moe_block.gate = gate
moe_block.top_k = 2
moe_block.n_routed_experts = weight.shape[0]
moe_block.n_group = 1
moe_block.norm_topk_prob = True
moe_block.routed_scaling_factor = 1.0
return moe_block
class TestSoftmaxTopkRoutingGradcheck:
"""Numerical gradient verification for softmax_topk_routing."""
def test_gradcheck_wrt_gate_weight(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_hidden_states(self):
T, H, E = 4, 8, 4
weight = torch.randn(E, H, dtype=torch.float32)
moe_block = _make_softmax_moe_block(weight)
def fn(hidden):
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_router_logits(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
_, _, _, router_logits = softmax_topk_routing(hidden, moe_block)
return router_logits
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_no_norm_variant(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
moe_block.gate.norm_topk_prob = False
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
class TestSigmoidTopkRoutingGradcheck:
"""Numerical gradient verification for sigmoid_topk_routing."""
def test_gradcheck_wrt_gate_weight(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
bias = torch.zeros(E, dtype=torch.float32)
def fn(weight):
moe_block = _make_sigmoid_moe_block(weight, bias)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_hidden_states(self):
T, H, E = 4, 8, 4
weight = torch.randn(E, H, dtype=torch.float32)
bias = torch.zeros(E, dtype=torch.float32)
moe_block = _make_sigmoid_moe_block(weight, bias)
def fn(hidden):
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_bias(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
weight = torch.randn(E, H, dtype=torch.float32)
def fn(bias):
moe_block = _make_sigmoid_moe_block(weight, bias)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
bias = torch.zeros(E, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(fn, (bias,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL)