feat: add sonicmoe fused lora support (#3519)
* feat: add sonicmoe fused lora support * fix: forgot to add file * feat: add test * feat: add lora support for other routes * fix: add int8 lora support * fix: add qwen35_moe interleave support * fix: qwen3_5_moe loss * chore: lint * address some pr comments * fix test imports * add support matrix for moe kernels [skip ci] --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -6,11 +6,11 @@ import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.args import KernelsArgs
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
sigmoid_topk_routing,
|
||||
softmax_topk_routing,
|
||||
)
|
||||
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
|
||||
ConcatenatedToInterleaved,
|
||||
InterleavedToConcatenated,
|
||||
register_sonicmoe_weight_converter,
|
||||
@@ -212,9 +212,40 @@ class TestWeightConverterRegistration:
|
||||
)
|
||||
break
|
||||
|
||||
def test_register_unsupported_model_type_warns(self):
|
||||
# A model type with no conversion mapping should warn but not raise
|
||||
register_sonicmoe_weight_converter("nonexistent_model_type_xyz")
|
||||
def test_register_adds_same_key_converter(self):
|
||||
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
|
||||
|
||||
register_sonicmoe_weight_converter("qwen3_moe")
|
||||
|
||||
modified = get_checkpoint_conversion_mapping("qwen3_moe")
|
||||
# Should have a same-key converter for already-fused checkpoints
|
||||
same_key = [
|
||||
c
|
||||
for c in modified
|
||||
if hasattr(c, "source_patterns")
|
||||
and c.source_patterns == ["mlp.experts.gate_up_proj"]
|
||||
and c.target_patterns == ["mlp.experts.gate_up_proj"]
|
||||
]
|
||||
assert len(same_key) == 1
|
||||
assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved)
|
||||
|
||||
def test_register_creates_mapping_when_none(self):
|
||||
from transformers.conversion_mapping import get_checkpoint_conversion_mapping
|
||||
|
||||
# qwen3_5_moe has no conversion mapping in transformers
|
||||
register_sonicmoe_weight_converter("qwen3_5_moe")
|
||||
|
||||
mapping = get_checkpoint_conversion_mapping("qwen3_5_moe")
|
||||
assert mapping is not None
|
||||
same_key = [
|
||||
c
|
||||
for c in mapping
|
||||
if hasattr(c, "source_patterns")
|
||||
and c.source_patterns == ["mlp.experts.gate_up_proj"]
|
||||
and c.target_patterns == ["mlp.experts.gate_up_proj"]
|
||||
]
|
||||
assert len(same_key) == 1
|
||||
assert isinstance(same_key[0].operations[0], ConcatenatedToInterleaved)
|
||||
|
||||
|
||||
def _make_qwen_moe_block(T=8, H=16, E=4, K=2):
|
||||
@@ -462,7 +493,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
"""Tests for Ernie 4.5 MoE routing (softmax_bias_topk_routing)."""
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -479,7 +510,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -490,7 +521,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -502,7 +533,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -514,7 +545,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_renormalized_scores_sum_to_one(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -527,7 +558,7 @@ class TestSoftmaxBiasTopkRouting:
|
||||
|
||||
def test_bias_affects_expert_selection(self):
|
||||
"""Large positive bias on expert 0 should make it always selected."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
@@ -570,7 +601,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
"""Tests for DeepSeek V2 routing (softmax_group_limited_topk_routing)."""
|
||||
|
||||
def test_output_shapes_group_limited(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -589,7 +620,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_output_shapes_greedy(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -604,7 +635,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -615,7 +646,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -627,7 +658,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -639,7 +670,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_scaling_factor_applied(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -655,7 +686,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
|
||||
def test_group_selection_restricts_experts(self):
|
||||
"""With num_group=4 and topk_group=1, experts should come from selected groups."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -674,7 +705,7 @@ class TestSoftmaxGroupLimitedTopkRouting:
|
||||
assert (groups == groups[0]).all()
|
||||
|
||||
def test_unsupported_topk_method_raises(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
@@ -706,7 +737,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
"""Tests for HunYuan V1 MoE routing (softmax_topk_wg_routing)."""
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -723,7 +754,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -734,7 +765,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -746,7 +777,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -759,7 +790,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
|
||||
def test_renormalized_scores_sum_to_one(self):
|
||||
"""HunYuan V1 always renormalizes (no norm_topk_prob flag)."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
@@ -772,7 +803,7 @@ class TestSoftmaxTopkWgRouting:
|
||||
|
||||
def test_uses_gate_wg_weight(self):
|
||||
"""Verify that modifying gate.wg.weight changes the routing output."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
from axolotl.integrations.kernels.libs.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user