feat: add custom routing support for ernie4_5_moe, and hunyuan_v1_moe (#3526)
* feat: add Ernie 4.5 and subsequently custom routing support * Update routing.py * chore: lint * fix minor nits * removed deepseek v2 * remove unneeded change --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -426,3 +426,363 @@ class TestMiniMaxM2SigmoidRouting:
|
||||
expert_idx_2d = expert_idx.reshape(T, K)
|
||||
for t in range(T):
|
||||
assert 0 in expert_idx_2d[t]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Ernie 4.5 MoE: softmax -> bias correction -> topk
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_ernie_moe_block(T=8, H=16, E=8, K=2, norm_min=1e-20):
|
||||
"""Create a mock Ernie 4.5 MoE block for routing tests.
|
||||
|
||||
Ernie 4.5 uses a gate.moe_statics module that adds bias to softmax probs
|
||||
before topk selection, then gathers from original probs.
|
||||
"""
|
||||
bias = torch.zeros(E)
|
||||
|
||||
class MockMoeStatics:
|
||||
def __init__(self, bias_tensor):
|
||||
self.e_score_correction_bias = bias_tensor
|
||||
|
||||
def __call__(self, probs):
|
||||
return probs + self.e_score_correction_bias
|
||||
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
moe_statics=MockMoeStatics(bias),
|
||||
norm_min=norm_min,
|
||||
)
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
return moe_block, bias, T, H, E, K
|
||||
|
||||
|
||||
class TestSoftmaxBiasTopkRouting:
|
||||
"""Tests for Ernie 4.5 MoE routing (softmax_bias_topk_routing)."""
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, _, T, H, E, K = _make_ernie_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, token_idx, expert_idx, logits = softmax_bias_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
assert scores.shape == (T * K,)
|
||||
assert token_idx.shape == (T * K,)
|
||||
assert expert_idx.shape == (T * K,)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, _, T, H, E, K = _make_ernie_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = softmax_bias_topk_routing(hidden, moe_block)
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, _, T, H, E, K = _make_ernie_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, token_idx, _, _ = softmax_bias_topk_routing(hidden, moe_block)
|
||||
diffs = token_idx[1:] - token_idx[:-1]
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, _, T, H, E, K = _make_ernie_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = softmax_bias_topk_routing(hidden, moe_block)
|
||||
assert (expert_idx >= 0).all()
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_renormalized_scores_sum_to_one(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, _, T, H, E, K = _make_ernie_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = softmax_bias_topk_routing(hidden, moe_block)
|
||||
per_token_sums = scores.reshape(T, K).sum(dim=-1)
|
||||
assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5)
|
||||
|
||||
def test_bias_affects_expert_selection(self):
|
||||
"""Large positive bias on expert 0 should make it always selected."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, bias, T, H, E, K = _make_ernie_moe_block()
|
||||
bias[0] = 100.0 # mutate the bias to strongly favor expert 0
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = softmax_bias_topk_routing(hidden, moe_block)
|
||||
expert_idx_2d = expert_idx.reshape(T, K)
|
||||
for t in range(T):
|
||||
assert 0 in expert_idx_2d[t]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DeepSeek V2: softmax -> group_limited_greedy / greedy -> topk
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_deepseek_v2_moe_block(
|
||||
T=8, H=16, E=16, K=4, num_group=2, topk_group=1, topk_method="group_limited_greedy"
|
||||
):
|
||||
"""Create a mock DeepSeek V2 MoE block for routing tests.
|
||||
|
||||
DeepSeek V2 uses num_group (not n_group), gate is nn.Linear,
|
||||
and supports greedy / group_limited_greedy topk methods.
|
||||
"""
|
||||
gate = SimpleNamespace(weight=torch.randn(E, H))
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
num_group=num_group,
|
||||
topk_group=topk_group,
|
||||
topk_method=topk_method,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
return moe_block, T, H, E, K
|
||||
|
||||
|
||||
class TestSoftmaxGroupLimitedTopkRouting:
|
||||
"""Tests for DeepSeek V2 routing (softmax_group_limited_topk_routing)."""
|
||||
|
||||
def test_output_shapes_group_limited(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(
|
||||
topk_method="group_limited_greedy"
|
||||
)
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, token_idx, expert_idx, logits = softmax_group_limited_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
assert scores.shape == (T * K,)
|
||||
assert token_idx.shape == (T * K,)
|
||||
assert expert_idx.shape == (T * K,)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_output_shapes_greedy(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="greedy")
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, token_idx, expert_idx, logits = softmax_group_limited_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
assert scores.shape == (T * K,)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, token_idx, _, _ = softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
diffs = token_idx[1:] - token_idx[:-1]
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
assert (expert_idx >= 0).all()
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_scaling_factor_applied(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="greedy")
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores_1x, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
|
||||
moe_block.routed_scaling_factor = 2.5
|
||||
scores_2x, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
|
||||
assert torch.allclose(scores_2x, scores_1x * 2.5, atol=1e-5)
|
||||
|
||||
def test_group_selection_restricts_experts(self):
|
||||
"""With num_group=4 and topk_group=1, experts should come from selected groups."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(
|
||||
E=16, K=2, num_group=4, topk_group=1, topk_method="group_limited_greedy"
|
||||
)
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
expert_idx_2d = expert_idx.reshape(T, K)
|
||||
group_size = E // moe_block.num_group
|
||||
for t in range(T):
|
||||
experts = expert_idx_2d[t]
|
||||
groups = experts // group_size
|
||||
# All selected experts should be from the same group
|
||||
assert (groups == groups[0]).all()
|
||||
|
||||
def test_unsupported_topk_method_raises(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="invalid")
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
with pytest.raises(ValueError, match="unsupported topk_method"):
|
||||
softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HunYuan V1 MoE: softmax -> topk -> renorm (via gate.wg)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_hunyuan_moe_block(T=8, H=16, E=8, K=2):
|
||||
"""Create a mock HunYuan V1 MoE block for routing tests.
|
||||
|
||||
HunYuan V1 uses gate.wg (nn.Linear-like) instead of gate.weight,
|
||||
and top_k on the moe_block instead of the gate.
|
||||
"""
|
||||
wg = SimpleNamespace(weight=torch.randn(E, H))
|
||||
gate = SimpleNamespace(wg=wg)
|
||||
moe_block = SimpleNamespace(gate=gate, top_k=K)
|
||||
return moe_block, T, H, E, K
|
||||
|
||||
|
||||
class TestSoftmaxTopkWgRouting:
|
||||
"""Tests for HunYuan V1 MoE routing (softmax_topk_wg_routing)."""
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_hunyuan_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, token_idx, expert_idx, logits = softmax_topk_wg_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
assert scores.shape == (T * K,)
|
||||
assert token_idx.shape == (T * K,)
|
||||
assert expert_idx.shape == (T * K,)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_hunyuan_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = softmax_topk_wg_routing(hidden, moe_block)
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_hunyuan_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, token_idx, _, _ = softmax_topk_wg_routing(hidden, moe_block)
|
||||
diffs = token_idx[1:] - token_idx[:-1]
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_hunyuan_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = softmax_topk_wg_routing(hidden, moe_block)
|
||||
assert (expert_idx >= 0).all()
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_renormalized_scores_sum_to_one(self):
|
||||
"""HunYuan V1 always renormalizes (no norm_topk_prob flag)."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_hunyuan_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = softmax_topk_wg_routing(hidden, moe_block)
|
||||
per_token_sums = scores.reshape(T, K).sum(dim=-1)
|
||||
assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5)
|
||||
|
||||
def test_uses_gate_wg_weight(self):
|
||||
"""Verify that modifying gate.wg.weight changes the routing output."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_hunyuan_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores1, _, _, _ = softmax_topk_wg_routing(hidden, moe_block)
|
||||
|
||||
# Change the wg weight and verify scores change
|
||||
moe_block.gate.wg.weight = torch.randn(E, H)
|
||||
scores2, _, _, _ = softmax_topk_wg_routing(hidden, moe_block)
|
||||
|
||||
assert not torch.equal(scores1, scores2)
|
||||
|
||||
Reference in New Issue
Block a user