feat: add sonicmoe fused lora support (#3519)

* feat: add sonicmoe fused lora support

* fix: forgot to add file

* feat: add test

* feat: add lora support for other routes

* fix: add int8 lora support

* fix: add qwen35_moe interleave support

* fix: qwen3_5_moe loss

* chore: lint

* address some pr comments

* fix test imports

* add support matrix for moe kernels [skip ci]

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2026-04-02 19:53:48 +07:00
committed by GitHub
parent 16e32232fb
commit 842fa039dd
16 changed files with 1249 additions and 126 deletions

View File

@@ -93,7 +93,9 @@ class TestSoftmaxRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
@@ -120,7 +122,9 @@ class TestSoftmaxRoutingParity:
def test_logits_not_returned_by_scattermoe(self):
"""ScatterMoE doesn't return logits; SonicMoE does — verify SonicMoE logits shape."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
_, _, _, logits = softmax_topk_routing(hidden, moe_block)
@@ -131,7 +135,9 @@ class TestSoftmaxRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_softmax_block()
gate.norm_topk_prob = False
@@ -152,7 +158,9 @@ class TestSoftmaxRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_routing,
)
for E, K in [(2, 1), (8, 2), (16, 4), (32, 8)]:
moe_block, gate, hidden, T, H, _, _ = _make_softmax_block(E=E, K=K)
@@ -190,7 +198,9 @@ class TestSigmoidRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=2, topk_group=1, bias_on_gate=True
@@ -226,7 +236,9 @@ class TestSigmoidRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=1, topk_group=1, bias_on_gate=True
@@ -254,7 +266,9 @@ class TestSigmoidRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
E=16, K=4, n_group=1, bias_on_gate=False
@@ -281,7 +295,9 @@ class TestSigmoidRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
@@ -309,7 +325,9 @@ class TestSigmoidRoutingParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_sigmoid_topk_route,
)
from axolotl.integrations.kernels.sonicmoe.routing import sigmoid_topk_routing
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
)
moe_block, gate, hidden, T, H, E, K = _make_sigmoid_block(
n_group=1, bias_on_gate=True
@@ -349,7 +367,7 @@ class TestSharedExpertParity:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_compute_shared_expert as scatter_compute,
)
from axolotl.integrations.kernels.sonicmoe.patch import (
from axolotl.integrations.kernels.libs.sonicmoe.patch import (
_compute_shared_expert as sonic_compute,
)