Files
axolotl/tests/integrations/test_sonicmoe_gradients.py
NanoCode012 842fa039dd feat: add sonicmoe fused lora support (#3519)
* feat: add sonicmoe fused lora support

* fix: forgot to add file

* feat: add test

* feat: add lora support for other routes

* fix: add int8 lora support

* fix: add qwen35_moe interleave support

* fix: qwen3_5_moe loss

* chore: lint

* address some pr comments

* fix test imports

* add support matrix for moe kernels [skip ci]

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-02 08:53:48 -04:00

159 lines
4.8 KiB
Python

"""
Gradient correctness tests for SonicMoE routing functions (CPU-only).
Uses torch.autograd.gradcheck with float32 inputs to match the production
code path where routing happens in float32.
"""
import torch
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
softmax_topk_routing,
)
_GC_EPS = 1e-3
_GC_ATOL = 1e-3
_GC_RTOL = 1e-3
def _make_softmax_moe_block(weight):
gate = torch.nn.Module()
gate.weight = weight
gate.top_k = 2
gate.norm_topk_prob = True
moe_block = torch.nn.Module()
moe_block.gate = gate
return moe_block
def _make_sigmoid_moe_block(weight, bias):
gate = torch.nn.Module()
gate.weight = weight
gate.e_score_correction_bias = bias
moe_block = torch.nn.Module()
moe_block.gate = gate
moe_block.top_k = 2
moe_block.n_routed_experts = weight.shape[0]
moe_block.n_group = 1
moe_block.norm_topk_prob = True
moe_block.routed_scaling_factor = 1.0
return moe_block
class TestSoftmaxTopkRoutingGradcheck:
"""Numerical gradient verification for softmax_topk_routing."""
def test_gradcheck_wrt_gate_weight(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_hidden_states(self):
T, H, E = 4, 8, 4
weight = torch.randn(E, H, dtype=torch.float32)
moe_block = _make_softmax_moe_block(weight)
def fn(hidden):
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_router_logits(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
_, _, _, router_logits = softmax_topk_routing(hidden, moe_block)
return router_logits
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_no_norm_variant(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
def fn(weight):
moe_block = _make_softmax_moe_block(weight)
moe_block.gate.norm_topk_prob = False
scores, _, _, _ = softmax_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
class TestSigmoidTopkRoutingGradcheck:
"""Numerical gradient verification for sigmoid_topk_routing."""
def test_gradcheck_wrt_gate_weight(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
bias = torch.zeros(E, dtype=torch.float32)
def fn(weight):
moe_block = _make_sigmoid_moe_block(weight, bias)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
weight = torch.randn(E, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (weight,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_hidden_states(self):
T, H, E = 4, 8, 4
weight = torch.randn(E, H, dtype=torch.float32)
bias = torch.zeros(E, dtype=torch.float32)
moe_block = _make_sigmoid_moe_block(weight, bias)
def fn(hidden):
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
hidden = torch.randn(T, H, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(
fn, (hidden,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL
)
def test_gradcheck_wrt_bias(self):
T, H, E = 4, 8, 4
hidden = torch.randn(T, H, dtype=torch.float32)
weight = torch.randn(E, H, dtype=torch.float32)
def fn(bias):
moe_block = _make_sigmoid_moe_block(weight, bias)
scores, _, _, _ = sigmoid_topk_routing(hidden, moe_block)
return scores
bias = torch.zeros(E, dtype=torch.float32, requires_grad=True)
torch.autograd.gradcheck(fn, (bias,), eps=_GC_EPS, atol=_GC_ATOL, rtol=_GC_RTOL)