feat: add sonicmoe fused lora support (#3519)

* feat: add sonicmoe fused lora support

* fix: forgot to add file

* feat: add test

* feat: add lora support for other routes

* fix: add int8 lora support

* fix: add qwen35_moe interleave support

* fix: qwen3_5_moe loss

* chore: lint

* address some pr comments

* fix test imports

* add support matrix for moe kernels [skip ci]

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2026-04-02 19:53:48 +07:00
committed by GitHub
parent 16e32232fb
commit 842fa039dd
16 changed files with 1249 additions and 126 deletions

View File

@@ -6,11 +6,11 @@ import pytest
import torch
from axolotl.integrations.kernels.args import KernelsArgs
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
sigmoid_topk_routing,
softmax_topk_routing,
)
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
ConcatenatedToInterleaved,
InterleavedToConcatenated,
register_sonicmoe_weight_converter,
@@ -212,9 +212,40 @@ class TestWeightConverterRegistration:
)
break
def test_register_unsupported_model_type_warns(self):
# A model type with no conversion mapping should warn but not raise
register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
def test_register_adds_same_key_converter(self):
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
register_sonicmoe_weight_converter("qwen3_moe")
modified = get_checkpoint_conversion_mapping("qwen3_moe")
# Should have a same-key converter for already-fused checkpoints
same_key = [
c
for c in modified
if hasattr(c, "source_patterns")
and c.source_patterns == ["mlp.experts.gate_up_proj"]
and c.target_patterns == ["mlp.experts.gate_up_proj"]
]
assert len(same_key) == 1
assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved)
def test_register_creates_mapping_when_none(self):
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
# qwen3_5_moe has no conversion mapping in transformers
register_sonicmoe_weight_converter("qwen3_5_moe")
mapping = get_checkpoint_conversion_mapping("qwen3_5_moe")
assert mapping is not None
same_key = [
c
for c in mapping
if hasattr(c, "source_patterns")
and c.source_patterns == ["mlp.experts.gate_up_proj"]
and c.target_patterns == ["mlp.experts.gate_up_proj"]
]
assert len(same_key) == 1
assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved)
def _make_qwen_moe_block(T=8, H=16, E=4, K=2):
@@ -462,7 +493,7 @@ class TestSoftmaxBiasTopkRouting:
"""Tests for Ernie 4.5 MoE routing (softmax_bias_topk_routing)."""
def test_output_shapes(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_bias_topk_routing,
)
@@ -479,7 +510,7 @@ class TestSoftmaxBiasTopkRouting:
assert logits.shape == (T, E)
def test_scores_are_float32(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_bias_topk_routing,
)
@@ -490,7 +521,7 @@ class TestSoftmaxBiasTopkRouting:
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_bias_topk_routing,
)
@@ -502,7 +533,7 @@ class TestSoftmaxBiasTopkRouting:
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_bias_topk_routing,
)
@@ -514,7 +545,7 @@ class TestSoftmaxBiasTopkRouting:
assert (expert_idx < E).all()
def test_renormalized_scores_sum_to_one(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_bias_topk_routing,
)
@@ -527,7 +558,7 @@ class TestSoftmaxBiasTopkRouting:
def test_bias_affects_expert_selection(self):
"""Large positive bias on expert 0 should make it always selected."""
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_bias_topk_routing,
)
@@ -570,7 +601,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
"""Tests for DeepSeek V2 routing (softmax_group_limited_topk_routing)."""
def test_output_shapes_group_limited(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -589,7 +620,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
assert logits.shape == (T, E)
def test_output_shapes_greedy(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -604,7 +635,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
assert logits.shape == (T, E)
def test_scores_are_float32(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -615,7 +646,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -627,7 +658,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -639,7 +670,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
assert (expert_idx < E).all()
def test_scaling_factor_applied(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -655,7 +686,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
def test_group_selection_restricts_experts(self):
"""With num_group=4 and topk_group=1, experts should come from selected groups."""
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -674,7 +705,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
assert (groups == groups[0]).all()
def test_unsupported_topk_method_raises(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
@@ -706,7 +737,7 @@ class TestSoftmaxTopkWgRouting:
"""Tests for HunYuan V1 MoE routing (softmax_topk_wg_routing)."""
def test_output_shapes(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_wg_routing,
)
@@ -723,7 +754,7 @@ class TestSoftmaxTopkWgRouting:
assert logits.shape == (T, E)
def test_scores_are_float32(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_wg_routing,
)
@@ -734,7 +765,7 @@ class TestSoftmaxTopkWgRouting:
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_wg_routing,
)
@@ -746,7 +777,7 @@ class TestSoftmaxTopkWgRouting:
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_wg_routing,
)
@@ -759,7 +790,7 @@ class TestSoftmaxTopkWgRouting:
def test_renormalized_scores_sum_to_one(self):
"""HunYuan V1 always renormalizes (no norm_topk_prob flag)."""
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_wg_routing,
)
@@ -772,7 +803,7 @@ class TestSoftmaxTopkWgRouting:
def test_uses_gate_wg_weight(self):
"""Verify that modifying gate.wg.weight changes the routing output."""
from axolotl.integrations.kernels.sonicmoe.routing import (
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
softmax_topk_wg_routing,
)