feat: add custom routing support for ernie4_5_moe, and hunyuan_v1_moe (#3526)

* feat: add Ernie 4.5 and subsequently custom routing support

* Update routing.py

* chore: lint

* fix minor nits

* removed deepseek v2

* remove unneeded change

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Avaya Aggarwal
2026-03-25 18:10:31 +05:30
committed by GitHub
parent 678ebb1bb2
commit ff0f67c730
5 changed files with 589 additions and 22 deletions

View File

@@ -6,6 +6,11 @@ Used by both ScatterMoE and SonicMoE kernel paths.
Values can be a single class name (str) or a list of class names for models
with multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker).
Models with custom routing (see sonicmoe/routing.py for implementations):
- ernie4_5_moe: softmax→bias correction→topk (softmax_bias_topk_routing)
- deepseek_v2: softmax→group_limited_greedy (softmax_group_limited_topk_routing)
- hunyuan_v1_moe: softmax→topk via gate.wg (softmax_topk_wg_routing)
"""
import importlib
@@ -36,11 +41,15 @@ SPARSE_MOE_BLOCK = {
"glm4v_moe": "Glm4vMoeTextMoE",
# sigmoid -> topk routing (no group selection)
"minimax_m2": "MiniMaxM2SparseMoeBlock",
# Models below need custom routing (not yet implemented):
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
# "hunyuan_v1_moe": "HunYuanMoEV1Moe", # softmax->topk, gate.wg (not gate.weight), scatter routing
# "gpt_oss": "GptOssMLP", # topk->softmax, transposed layout [E,H,2*I], custom GLU, expert biases
# softmax->topk, e_score_correction_bias between softmax and topk
"ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock",
# softmax->topk, group_limited_greedy, different attr names (num_group)
"deepseek_v2": "DeepseekV2Moe",
# softmax->topk, gate.wg (not gate.weight)
"hunyuan_v1_moe": "HunYuanMoEV1Moe",
# TODO: gpt_oss deferred — transposed weight layout [E,H,2*I], expert biases,
# and custom GLU activation require a dedicated forward path in patch.py.
# "gpt_oss": "GptOssMLP",
}

View File

@@ -3,9 +3,11 @@ Routing functions for SonicMoE integration.
Different MoE architectures use different routing strategies:
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
- mistral4: softmax -> group selection -> topk (with renormalization and scaling)
- glm_moe_dsa / deepseek_v3 / minimax_m2: sigmoid -> topk (with group-based expert selection)
- ernie4_5_moe: softmax -> bias correction -> topk -> gather (softmax_bias_topk_routing)
- hunyuan_v1_moe: softmax -> topk via gate.wg (softmax_topk_wg_routing)
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None) [NOT YET SUPPORTED]
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
@@ -57,17 +59,10 @@ def get_model_moe_config(model_type: str):
"minimax_m2",
):
return sigmoid_topk_routing, ActivationType.SWIGLU, "gate"
# elif model_type in ("ernie4_5_moe",):
# # Softmax→topk with e_score_correction_bias applied between softmax and topk.
# return ..., ActivationType.SWIGLU, "gate"
# elif model_type in ("deepseek_v2",):
# # Softmax→topk with group_limited_greedy. Different attr names: num_group
# # (not n_group), gate is nn.Linear (not a router class).
# return ..., ActivationType.SWIGLU, "gate"
# elif model_type in ("hunyuan_v1_moe",):
# # Softmax→topk but gate structure differs: gate.wg (not gate.weight),
# # top_k on block not gate, creates scatter routing matrix.
# return ..., ActivationType.SWIGLU, "gate"
elif model_type in ("ernie4_5_moe",):
return softmax_bias_topk_routing, ActivationType.SWIGLU, "gate"
elif model_type in ("hunyuan_v1_moe",):
return softmax_topk_wg_routing, ActivationType.SWIGLU, "gate"
# Fused topk -> softmax path (routing_fn=None):
# elif model_type in ("gpt_oss",):
# # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer
@@ -94,7 +89,7 @@ def softmax_topk_routing(
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, H = hidden_states.shape
T, _ = hidden_states.shape
K = gate.top_k
# Compute router logits and softmax over all experts
@@ -134,7 +129,7 @@ def softmax_group_topk_routing(
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
gate = moe_block.gate
T, H = hidden_states.shape
T, _ = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
@@ -212,7 +207,7 @@ def sigmoid_topk_routing(
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, H = hidden_states.shape
T, _ = hidden_states.shape
K = moe_block.top_k
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
n_group = getattr(moe_block, "n_group", 1)
@@ -276,3 +271,203 @@ def sigmoid_topk_routing(
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def softmax_bias_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Ernie 4.5 MoE routing: softmax → bias correction → topk → gather → renorm.
Differs from standard softmax_topk_routing in three ways:
1. A learned e_score_correction_bias is added to softmax probs *before* topk
(selection uses biased scores, but final weights use original probs).
2. The bias is applied via gate.moe_statics module (not a raw tensor).
3. Renormalization uses clamp(min=norm_min) instead of sum+epsilon.
Reference: Ernie4_5_MoeTopKRouter.forward in transformers.
Args:
hidden_states: [T, H] flattened token representations
moe_block: MoE block module (accesses moe_block.gate.*)
Returns:
router_scores: [T*K] flattened scores (float32)
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, _ = hidden_states.shape
K = gate.top_k
# Compute router logits and softmax (force float32 for numerical stability)
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Bias-corrected scores for expert selection (via moe_statics module)
scores_for_choice = gate.moe_statics(router_probs) # [T, E]
# Select top-k experts using biased scores
_, selected_experts = torch.topk(scores_for_choice, K, dim=-1) # [T, K]
# Gather weights from *original* (unbiased) softmax probs
top_values = torch.gather(router_probs, dim=-1, index=selected_experts) # [T, K]
# Renormalize with clamp(min=norm_min) instead of sum+epsilon
norm_min = getattr(gate, "norm_min", 1e-20)
top_values = top_values / torch.clamp(
top_values.sum(dim=-1, keepdim=True), min=norm_min
)
# Flatten for moe_general_routing_inputs
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = selected_experts.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def softmax_group_limited_topk_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""DeepSeek V2 routing: softmax → group_limited_greedy/greedy → topk → scale.
Differs from softmax_group_topk_routing (Mistral4) in several ways:
1. Uses ``num_group`` attribute (not ``n_group``).
2. Group score = max per group (not sum of top-2).
3. Supports ``greedy`` method (plain topk without groups).
4. No renormalization — just ``topk_weight * routed_scaling_factor``.
5. Gate is ``nn.Linear`` (access weight via ``gate.weight``).
Reference: DeepseekV2Moe.route_tokens_to_experts in transformers.
Args:
hidden_states: [T, H] flattened token representations
moe_block: MoE block module (accesses moe_block.gate, .num_group,
.topk_group, .top_k, .topk_method, .routed_scaling_factor)
Returns:
router_scores: [T*K] flattened scores (float32)
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, _ = hidden_states.shape
K = moe_block.top_k
num_group = getattr(moe_block, "num_group", 1)
num_experts = gate.weight.shape[0]
topk_method = getattr(moe_block, "topk_method", "greedy")
# Compute logits in float32 and softmax
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
if topk_method == "greedy" or num_group == 1:
topk_weights, topk_indices = torch.topk(router_probs, k=K, dim=-1, sorted=False)
elif topk_method == "group_limited_greedy":
# Guard: selected groups must contain enough experts for topk
group_size = num_experts // num_group
if moe_block.topk_group * group_size < K:
raise ValueError(
f"DeepSeek V2: topk_group ({moe_block.topk_group}) * group_size "
f"({group_size}) = {moe_block.topk_group * group_size} < top_k ({K}). "
f"Not enough experts in selected groups for topk selection."
)
# Group selection: pick top groups by max score per group
group_scores = (
router_probs.view(T, num_group, num_experts // num_group).max(dim=-1).values
) # [T, num_group]
group_idx = torch.topk(
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(T, num_group, num_experts // num_group)
.reshape(T, -1)
)
tmp_scores = router_probs.masked_fill(~score_mask.bool(), 0.0)
topk_weights, topk_indices = torch.topk(tmp_scores, k=K, dim=-1, sorted=False)
else:
raise ValueError(
f"DeepSeek V2: unsupported topk_method '{topk_method}'. "
f"Expected 'greedy' or 'group_limited_greedy'."
)
# Scale only — no renormalization (weights won't sum to 1.0 per token).
# This matches the reference DeepseekV2Moe.route_tokens_to_experts behavior.
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
topk_weights = topk_weights * routed_scaling_factor
# Flatten for moe_general_routing_inputs
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
def softmax_topk_wg_routing(
hidden_states: torch.Tensor, moe_block
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""HunYuan V1 MoE routing: softmax → topk → renorm (gate weight via gate.wg).
Differs from standard softmax_topk_routing in:
1. Gate weight lives at ``gate.wg.weight`` (not ``gate.weight``).
2. ``top_k`` is on ``moe_block`` (not ``gate``).
3. Always renormalizes (no ``norm_topk_prob`` flag).
Reference: HunYuanMoEV1Moe.route_tokens_to_experts and
HunYuanMoEV1Gate.forward in transformers.
Args:
hidden_states: [T, H] flattened token representations
moe_block: MoE block module (accesses moe_block.gate.wg, moe_block.top_k)
Returns:
router_scores: [T*K] flattened scores (float32)
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
expert_indices: [T*K] which expert (int32)
router_logits: [T, E] original logits for aux loss
"""
gate = moe_block.gate
T, _ = hidden_states.shape
K = moe_block.top_k
# Gate computes logits via gate.wg (nn.Linear, float32)
wg = gate.wg
router_logits = F.linear(hidden_states.float(), wg.weight.float()) # [T, E]
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
# Select top-k experts
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
# Always renormalize (HunYuan V1 has no norm_topk_prob flag)
top_values = top_values / (top_values.sum(dim=-1, keepdim=True) + 1e-20)
# Flatten for moe_general_routing_inputs
token_indices = (
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
.unsqueeze(1)
.expand(T, K)
)
flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K]
flat_token_idx = token_indices.reshape(-1) # [T*K]
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
return flat_scores, flat_token_idx, flat_expert_idx, router_logits

View File

@@ -590,7 +590,8 @@ class PatchManager:
def _apply_llama_flash_attn_patches(self, model):
"""Apply LLaMA-specific flash attention patches."""
if (
self.model_config.model_type in ["llama", "llama4"]
self.model_config.model_type
in ["llama", "llama4", "ernie4_5", "ernie4_5_moe"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
and self.cfg.flash_attention

View File

@@ -52,6 +52,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"seed_oss",
"lfm2",
"lfm2_moe",
"ernie4_5",
"ernie4_5_moe",
"olmo",
"olmo2",
"olmo3",

View File

@@ -426,3 +426,363 @@ class TestMiniMaxM2SigmoidRouting:
expert_idx_2d = expert_idx.reshape(T, K)
for t in range(T):
assert 0 in expert_idx_2d[t]
# ============================================================================
# Ernie 4.5 MoE: softmax -> bias correction -> topk
# ============================================================================
def _make_ernie_moe_block(T=8, H=16, E=8, K=2, norm_min=1e-20):
"""Create a mock Ernie 4.5 MoE block for routing tests.
Ernie 4.5 uses a gate.moe_statics module that adds bias to softmax probs
before topk selection, then gathers from original probs.
"""
bias = torch.zeros(E)
class MockMoeStatics:
def __init__(self, bias_tensor):
self.e_score_correction_bias = bias_tensor
def __call__(self, probs):
return probs + self.e_score_correction_bias
gate = SimpleNamespace(
weight=torch.randn(E, H),
top_k=K,
moe_statics=MockMoeStatics(bias),
norm_min=norm_min,
)
moe_block = SimpleNamespace(gate=gate)
return moe_block, bias, T, H, E, K
class TestSoftmaxBiasTopkRouting:
"""Tests for Ernie 4.5 MoE routing (softmax_bias_topk_routing)."""
def test_output_shapes(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_bias_topk_routing,
)
moe_block, _, T, H, E, K = _make_ernie_moe_block()
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = softmax_bias_topk_routing(
hidden, moe_block
)
assert scores.shape == (T * K,)
assert token_idx.shape == (T * K,)
assert expert_idx.shape == (T * K,)
assert logits.shape == (T, E)
def test_scores_are_float32(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_bias_topk_routing,
)
moe_block, _, T, H, E, K = _make_ernie_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = softmax_bias_topk_routing(hidden, moe_block)
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_bias_topk_routing,
)
moe_block, _, T, H, E, K = _make_ernie_moe_block()
hidden = torch.randn(T, H)
_, token_idx, _, _ = softmax_bias_topk_routing(hidden, moe_block)
diffs = token_idx[1:] - token_idx[:-1]
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_bias_topk_routing,
)
moe_block, _, T, H, E, K = _make_ernie_moe_block()
hidden = torch.randn(T, H)
_, _, expert_idx, _ = softmax_bias_topk_routing(hidden, moe_block)
assert (expert_idx >= 0).all()
assert (expert_idx < E).all()
def test_renormalized_scores_sum_to_one(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_bias_topk_routing,
)
moe_block, _, T, H, E, K = _make_ernie_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = softmax_bias_topk_routing(hidden, moe_block)
per_token_sums = scores.reshape(T, K).sum(dim=-1)
assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5)
def test_bias_affects_expert_selection(self):
"""Large positive bias on expert 0 should make it always selected."""
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_bias_topk_routing,
)
moe_block, bias, T, H, E, K = _make_ernie_moe_block()
bias[0] = 100.0 # mutate the bias to strongly favor expert 0
hidden = torch.randn(T, H)
_, _, expert_idx, _ = softmax_bias_topk_routing(hidden, moe_block)
expert_idx_2d = expert_idx.reshape(T, K)
for t in range(T):
assert 0 in expert_idx_2d[t]
# ============================================================================
# DeepSeek V2: softmax -> group_limited_greedy / greedy -> topk
# ============================================================================
def _make_deepseek_v2_moe_block(
T=8, H=16, E=16, K=4, num_group=2, topk_group=1, topk_method="group_limited_greedy"
):
"""Create a mock DeepSeek V2 MoE block for routing tests.
DeepSeek V2 uses num_group (not n_group), gate is nn.Linear,
and supports greedy / group_limited_greedy topk methods.
"""
gate = SimpleNamespace(weight=torch.randn(E, H))
moe_block = SimpleNamespace(
gate=gate,
top_k=K,
num_group=num_group,
topk_group=topk_group,
topk_method=topk_method,
routed_scaling_factor=1.0,
)
return moe_block, T, H, E, K
class TestSoftmaxGroupLimitedTopkRouting:
"""Tests for DeepSeek V2 routing (softmax_group_limited_topk_routing)."""
def test_output_shapes_group_limited(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(
topk_method="group_limited_greedy"
)
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = softmax_group_limited_topk_routing(
hidden, moe_block
)
assert scores.shape == (T * K,)
assert token_idx.shape == (T * K,)
assert expert_idx.shape == (T * K,)
assert logits.shape == (T, E)
def test_output_shapes_greedy(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="greedy")
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = softmax_group_limited_topk_routing(
hidden, moe_block
)
assert scores.shape == (T * K,)
assert logits.shape == (T, E)
def test_scores_are_float32(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
moe_block, T, H, E, K = _make_deepseek_v2_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block)
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
moe_block, T, H, E, K = _make_deepseek_v2_moe_block()
hidden = torch.randn(T, H)
_, token_idx, _, _ = softmax_group_limited_topk_routing(hidden, moe_block)
diffs = token_idx[1:] - token_idx[:-1]
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
moe_block, T, H, E, K = _make_deepseek_v2_moe_block()
hidden = torch.randn(T, H)
_, _, expert_idx, _ = softmax_group_limited_topk_routing(hidden, moe_block)
assert (expert_idx >= 0).all()
assert (expert_idx < E).all()
def test_scaling_factor_applied(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="greedy")
hidden = torch.randn(T, H)
scores_1x, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block)
moe_block.routed_scaling_factor = 2.5
scores_2x, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block)
assert torch.allclose(scores_2x, scores_1x * 2.5, atol=1e-5)
def test_group_selection_restricts_experts(self):
"""With num_group=4 and topk_group=1, experts should come from selected groups."""
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(
E=16, K=2, num_group=4, topk_group=1, topk_method="group_limited_greedy"
)
hidden = torch.randn(T, H)
_, _, expert_idx, _ = softmax_group_limited_topk_routing(hidden, moe_block)
expert_idx_2d = expert_idx.reshape(T, K)
group_size = E // moe_block.num_group
for t in range(T):
experts = expert_idx_2d[t]
groups = experts // group_size
# All selected experts should be from the same group
assert (groups == groups[0]).all()
def test_unsupported_topk_method_raises(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_group_limited_topk_routing,
)
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="invalid")
hidden = torch.randn(T, H)
with pytest.raises(ValueError, match="unsupported topk_method"):
softmax_group_limited_topk_routing(hidden, moe_block)
# ============================================================================
# HunYuan V1 MoE: softmax -> topk -> renorm (via gate.wg)
# ============================================================================
def _make_hunyuan_moe_block(T=8, H=16, E=8, K=2):
"""Create a mock HunYuan V1 MoE block for routing tests.
HunYuan V1 uses gate.wg (nn.Linear-like) instead of gate.weight,
and top_k on the moe_block instead of the gate.
"""
wg = SimpleNamespace(weight=torch.randn(E, H))
gate = SimpleNamespace(wg=wg)
moe_block = SimpleNamespace(gate=gate, top_k=K)
return moe_block, T, H, E, K
class TestSoftmaxTopkWgRouting:
"""Tests for HunYuan V1 MoE routing (softmax_topk_wg_routing)."""
def test_output_shapes(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_topk_wg_routing,
)
moe_block, T, H, E, K = _make_hunyuan_moe_block()
hidden = torch.randn(T, H)
scores, token_idx, expert_idx, logits = softmax_topk_wg_routing(
hidden, moe_block
)
assert scores.shape == (T * K,)
assert token_idx.shape == (T * K,)
assert expert_idx.shape == (T * K,)
assert logits.shape == (T, E)
def test_scores_are_float32(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_topk_wg_routing,
)
moe_block, T, H, E, K = _make_hunyuan_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = softmax_topk_wg_routing(hidden, moe_block)
assert scores.dtype == torch.float32
def test_token_indices_sorted_ascending(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_topk_wg_routing,
)
moe_block, T, H, E, K = _make_hunyuan_moe_block()
hidden = torch.randn(T, H)
_, token_idx, _, _ = softmax_topk_wg_routing(hidden, moe_block)
diffs = token_idx[1:] - token_idx[:-1]
assert (diffs >= 0).all()
def test_expert_indices_in_range(self):
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_topk_wg_routing,
)
moe_block, T, H, E, K = _make_hunyuan_moe_block()
hidden = torch.randn(T, H)
_, _, expert_idx, _ = softmax_topk_wg_routing(hidden, moe_block)
assert (expert_idx >= 0).all()
assert (expert_idx < E).all()
def test_renormalized_scores_sum_to_one(self):
"""HunYuan V1 always renormalizes (no norm_topk_prob flag)."""
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_topk_wg_routing,
)
moe_block, T, H, E, K = _make_hunyuan_moe_block()
hidden = torch.randn(T, H)
scores, _, _, _ = softmax_topk_wg_routing(hidden, moe_block)
per_token_sums = scores.reshape(T, K).sum(dim=-1)
assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5)
def test_uses_gate_wg_weight(self):
"""Verify that modifying gate.wg.weight changes the routing output."""
from axolotl.integrations.kernels.sonicmoe.routing import (
softmax_topk_wg_routing,
)
moe_block, T, H, E, K = _make_hunyuan_moe_block()
hidden = torch.randn(T, H)
scores1, _, _, _ = softmax_topk_wg_routing(hidden, moe_block)
# Change the wg weight and verify scores change
moe_block.gate.wg.weight = torch.randn(E, H)
scores2, _, _, _ = softmax_topk_wg_routing(hidden, moe_block)
assert not torch.equal(scores1, scores2)