feat: add custom routing support for ernie4_5_moe, and hunyuan_v1_moe (#3526)
* feat: add Ernie 4.5 and subsequently custom routing support * Update routing.py * chore: lint * fix minor nits * removed deepseek v2 * remove unneeded change --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -6,6 +6,11 @@ Used by both ScatterMoE and SonicMoE kernel paths.
|
||||
|
||||
Values can be a single class name (str) or a list of class names for models
|
||||
with multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker).
|
||||
|
||||
Models with custom routing (see sonicmoe/routing.py for implementations):
|
||||
- ernie4_5_moe: softmax→bias correction→topk (softmax_bias_topk_routing)
|
||||
- deepseek_v2: softmax→group_limited_greedy (softmax_group_limited_topk_routing)
|
||||
- hunyuan_v1_moe: softmax→topk via gate.wg (softmax_topk_wg_routing)
|
||||
"""
|
||||
|
||||
import importlib
|
||||
@@ -36,11 +41,15 @@ SPARSE_MOE_BLOCK = {
|
||||
"glm4v_moe": "Glm4vMoeTextMoE",
|
||||
# sigmoid -> topk routing (no group selection)
|
||||
"minimax_m2": "MiniMaxM2SparseMoeBlock",
|
||||
# Models below need custom routing (not yet implemented):
|
||||
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
|
||||
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
|
||||
# "hunyuan_v1_moe": "HunYuanMoEV1Moe", # softmax->topk, gate.wg (not gate.weight), scatter routing
|
||||
# "gpt_oss": "GptOssMLP", # topk->softmax, transposed layout [E,H,2*I], custom GLU, expert biases
|
||||
# softmax->topk, e_score_correction_bias between softmax and topk
|
||||
"ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock",
|
||||
# softmax->topk, group_limited_greedy, different attr names (num_group)
|
||||
"deepseek_v2": "DeepseekV2Moe",
|
||||
# softmax->topk, gate.wg (not gate.weight)
|
||||
"hunyuan_v1_moe": "HunYuanMoEV1Moe",
|
||||
# TODO: gpt_oss deferred — transposed weight layout [E,H,2*I], expert biases,
|
||||
# and custom GLU activation require a dedicated forward path in patch.py.
|
||||
# "gpt_oss": "GptOssMLP",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -3,9 +3,11 @@ Routing functions for SonicMoE integration.
|
||||
|
||||
Different MoE architectures use different routing strategies:
|
||||
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
|
||||
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
|
||||
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
|
||||
- mistral4: softmax -> group selection -> topk (with renormalization and scaling)
|
||||
- glm_moe_dsa / deepseek_v3 / minimax_m2: sigmoid -> topk (with group-based expert selection)
|
||||
- ernie4_5_moe: softmax -> bias correction -> topk -> gather (softmax_bias_topk_routing)
|
||||
- hunyuan_v1_moe: softmax -> topk via gate.wg (softmax_topk_wg_routing)
|
||||
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None) [NOT YET SUPPORTED]
|
||||
|
||||
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
||||
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
||||
@@ -57,17 +59,10 @@ def get_model_moe_config(model_type: str):
|
||||
"minimax_m2",
|
||||
):
|
||||
return sigmoid_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
# elif model_type in ("ernie4_5_moe",):
|
||||
# # Softmax→topk with e_score_correction_bias applied between softmax and topk.
|
||||
# return ..., ActivationType.SWIGLU, "gate"
|
||||
# elif model_type in ("deepseek_v2",):
|
||||
# # Softmax→topk with group_limited_greedy. Different attr names: num_group
|
||||
# # (not n_group), gate is nn.Linear (not a router class).
|
||||
# return ..., ActivationType.SWIGLU, "gate"
|
||||
# elif model_type in ("hunyuan_v1_moe",):
|
||||
# # Softmax→topk but gate structure differs: gate.wg (not gate.weight),
|
||||
# # top_k on block not gate, creates scatter routing matrix.
|
||||
# return ..., ActivationType.SWIGLU, "gate"
|
||||
elif model_type in ("ernie4_5_moe",):
|
||||
return softmax_bias_topk_routing, ActivationType.SWIGLU, "gate"
|
||||
elif model_type in ("hunyuan_v1_moe",):
|
||||
return softmax_topk_wg_routing, ActivationType.SWIGLU, "gate"
|
||||
# Fused topk -> softmax path (routing_fn=None):
|
||||
# elif model_type in ("gpt_oss",):
|
||||
# # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer
|
||||
@@ -94,7 +89,7 @@ def softmax_topk_routing(
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, H = hidden_states.shape
|
||||
T, _ = hidden_states.shape
|
||||
K = gate.top_k
|
||||
|
||||
# Compute router logits and softmax over all experts
|
||||
@@ -134,7 +129,7 @@ def softmax_group_topk_routing(
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
|
||||
gate = moe_block.gate
|
||||
T, H = hidden_states.shape
|
||||
T, _ = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
@@ -212,7 +207,7 @@ def sigmoid_topk_routing(
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, H = hidden_states.shape
|
||||
T, _ = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
||||
n_group = getattr(moe_block, "n_group", 1)
|
||||
@@ -276,3 +271,203 @@ def sigmoid_topk_routing(
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def softmax_bias_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Ernie 4.5 MoE routing: softmax → bias correction → topk → gather → renorm.
|
||||
|
||||
Differs from standard softmax_topk_routing in three ways:
|
||||
1. A learned e_score_correction_bias is added to softmax probs *before* topk
|
||||
(selection uses biased scores, but final weights use original probs).
|
||||
2. The bias is applied via gate.moe_statics module (not a raw tensor).
|
||||
3. Renormalization uses clamp(min=norm_min) instead of sum+epsilon.
|
||||
|
||||
Reference: Ernie4_5_MoeTopKRouter.forward in transformers.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.gate.*)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, _ = hidden_states.shape
|
||||
K = gate.top_k
|
||||
|
||||
# Compute router logits and softmax (force float32 for numerical stability)
|
||||
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Bias-corrected scores for expert selection (via moe_statics module)
|
||||
scores_for_choice = gate.moe_statics(router_probs) # [T, E]
|
||||
|
||||
# Select top-k experts using biased scores
|
||||
_, selected_experts = torch.topk(scores_for_choice, K, dim=-1) # [T, K]
|
||||
|
||||
# Gather weights from *original* (unbiased) softmax probs
|
||||
top_values = torch.gather(router_probs, dim=-1, index=selected_experts) # [T, K]
|
||||
|
||||
# Renormalize with clamp(min=norm_min) instead of sum+epsilon
|
||||
norm_min = getattr(gate, "norm_min", 1e-20)
|
||||
top_values = top_values / torch.clamp(
|
||||
top_values.sum(dim=-1, keepdim=True), min=norm_min
|
||||
)
|
||||
|
||||
# Flatten for moe_general_routing_inputs
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = selected_experts.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def softmax_group_limited_topk_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""DeepSeek V2 routing: softmax → group_limited_greedy/greedy → topk → scale.
|
||||
|
||||
Differs from softmax_group_topk_routing (Mistral4) in several ways:
|
||||
1. Uses ``num_group`` attribute (not ``n_group``).
|
||||
2. Group score = max per group (not sum of top-2).
|
||||
3. Supports ``greedy`` method (plain topk without groups).
|
||||
4. No renormalization — just ``topk_weight * routed_scaling_factor``.
|
||||
5. Gate is ``nn.Linear`` (access weight via ``gate.weight``).
|
||||
|
||||
Reference: DeepseekV2Moe.route_tokens_to_experts in transformers.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.gate, .num_group,
|
||||
.topk_group, .top_k, .topk_method, .routed_scaling_factor)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, _ = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
num_group = getattr(moe_block, "num_group", 1)
|
||||
num_experts = gate.weight.shape[0]
|
||||
topk_method = getattr(moe_block, "topk_method", "greedy")
|
||||
|
||||
# Compute logits in float32 and softmax
|
||||
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
if topk_method == "greedy" or num_group == 1:
|
||||
topk_weights, topk_indices = torch.topk(router_probs, k=K, dim=-1, sorted=False)
|
||||
elif topk_method == "group_limited_greedy":
|
||||
# Guard: selected groups must contain enough experts for topk
|
||||
group_size = num_experts // num_group
|
||||
if moe_block.topk_group * group_size < K:
|
||||
raise ValueError(
|
||||
f"DeepSeek V2: topk_group ({moe_block.topk_group}) * group_size "
|
||||
f"({group_size}) = {moe_block.topk_group * group_size} < top_k ({K}). "
|
||||
f"Not enough experts in selected groups for topk selection."
|
||||
)
|
||||
# Group selection: pick top groups by max score per group
|
||||
group_scores = (
|
||||
router_probs.view(T, num_group, num_experts // num_group).max(dim=-1).values
|
||||
) # [T, num_group]
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
|
||||
)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(T, num_group, num_experts // num_group)
|
||||
.reshape(T, -1)
|
||||
)
|
||||
tmp_scores = router_probs.masked_fill(~score_mask.bool(), 0.0)
|
||||
topk_weights, topk_indices = torch.topk(tmp_scores, k=K, dim=-1, sorted=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"DeepSeek V2: unsupported topk_method '{topk_method}'. "
|
||||
f"Expected 'greedy' or 'group_limited_greedy'."
|
||||
)
|
||||
|
||||
# Scale only — no renormalization (weights won't sum to 1.0 per token).
|
||||
# This matches the reference DeepseekV2Moe.route_tokens_to_experts behavior.
|
||||
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
|
||||
# Flatten for moe_general_routing_inputs
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
|
||||
def softmax_topk_wg_routing(
|
||||
hidden_states: torch.Tensor, moe_block
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""HunYuan V1 MoE routing: softmax → topk → renorm (gate weight via gate.wg).
|
||||
|
||||
Differs from standard softmax_topk_routing in:
|
||||
1. Gate weight lives at ``gate.wg.weight`` (not ``gate.weight``).
|
||||
2. ``top_k`` is on ``moe_block`` (not ``gate``).
|
||||
3. Always renormalizes (no ``norm_topk_prob`` flag).
|
||||
|
||||
Reference: HunYuanMoEV1Moe.route_tokens_to_experts and
|
||||
HunYuanMoEV1Gate.forward in transformers.
|
||||
|
||||
Args:
|
||||
hidden_states: [T, H] flattened token representations
|
||||
moe_block: MoE block module (accesses moe_block.gate.wg, moe_block.top_k)
|
||||
|
||||
Returns:
|
||||
router_scores: [T*K] flattened scores (float32)
|
||||
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
||||
expert_indices: [T*K] which expert (int32)
|
||||
router_logits: [T, E] original logits for aux loss
|
||||
"""
|
||||
gate = moe_block.gate
|
||||
T, _ = hidden_states.shape
|
||||
K = moe_block.top_k
|
||||
|
||||
# Gate computes logits via gate.wg (nn.Linear, float32)
|
||||
wg = gate.wg
|
||||
router_logits = F.linear(hidden_states.float(), wg.weight.float()) # [T, E]
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
||||
|
||||
# Select top-k experts
|
||||
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
|
||||
|
||||
# Always renormalize (HunYuan V1 has no norm_topk_prob flag)
|
||||
top_values = top_values / (top_values.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
|
||||
# Flatten for moe_general_routing_inputs
|
||||
token_indices = (
|
||||
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(T, K)
|
||||
)
|
||||
|
||||
flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K]
|
||||
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
||||
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
|
||||
|
||||
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
||||
|
||||
@@ -590,7 +590,8 @@ class PatchManager:
|
||||
def _apply_llama_flash_attn_patches(self, model):
|
||||
"""Apply LLaMA-specific flash attention patches."""
|
||||
if (
|
||||
self.model_config.model_type in ["llama", "llama4"]
|
||||
self.model_config.model_type
|
||||
in ["llama", "llama4", "ernie4_5", "ernie4_5_moe"]
|
||||
and not self.cfg.trust_remote_code
|
||||
and not self.cfg.gptq
|
||||
and self.cfg.flash_attention
|
||||
|
||||
@@ -52,6 +52,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"seed_oss",
|
||||
"lfm2",
|
||||
"lfm2_moe",
|
||||
"ernie4_5",
|
||||
"ernie4_5_moe",
|
||||
"olmo",
|
||||
"olmo2",
|
||||
"olmo3",
|
||||
|
||||
@@ -426,3 +426,363 @@ class TestMiniMaxM2SigmoidRouting:
|
||||
expert_idx_2d = expert_idx.reshape(T, K)
|
||||
for t in range(T):
|
||||
assert 0 in expert_idx_2d[t]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Ernie 4.5 MoE: softmax -> bias correction -> topk
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_ernie_moe_block(T=8, H=16, E=8, K=2, norm_min=1e-20):
|
||||
"""Create a mock Ernie 4.5 MoE block for routing tests.
|
||||
|
||||
Ernie 4.5 uses a gate.moe_statics module that adds bias to softmax probs
|
||||
before topk selection, then gathers from original probs.
|
||||
"""
|
||||
bias = torch.zeros(E)
|
||||
|
||||
class MockMoeStatics:
|
||||
def __init__(self, bias_tensor):
|
||||
self.e_score_correction_bias = bias_tensor
|
||||
|
||||
def __call__(self, probs):
|
||||
return probs + self.e_score_correction_bias
|
||||
|
||||
gate = SimpleNamespace(
|
||||
weight=torch.randn(E, H),
|
||||
top_k=K,
|
||||
moe_statics=MockMoeStatics(bias),
|
||||
norm_min=norm_min,
|
||||
)
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
return moe_block, bias, T, H, E, K
|
||||
|
||||
|
||||
class TestSoftmaxBiasTopkRouting:
|
||||
"""Tests for Ernie 4.5 MoE routing (softmax_bias_topk_routing)."""
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, _, T, H, E, K = _make_ernie_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, token_idx, expert_idx, logits = softmax_bias_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
assert scores.shape == (T * K,)
|
||||
assert token_idx.shape == (T * K,)
|
||||
assert expert_idx.shape == (T * K,)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, _, T, H, E, K = _make_ernie_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = softmax_bias_topk_routing(hidden, moe_block)
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, _, T, H, E, K = _make_ernie_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, token_idx, _, _ = softmax_bias_topk_routing(hidden, moe_block)
|
||||
diffs = token_idx[1:] - token_idx[:-1]
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, _, T, H, E, K = _make_ernie_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = softmax_bias_topk_routing(hidden, moe_block)
|
||||
assert (expert_idx >= 0).all()
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_renormalized_scores_sum_to_one(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, _, T, H, E, K = _make_ernie_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = softmax_bias_topk_routing(hidden, moe_block)
|
||||
per_token_sums = scores.reshape(T, K).sum(dim=-1)
|
||||
assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5)
|
||||
|
||||
def test_bias_affects_expert_selection(self):
|
||||
"""Large positive bias on expert 0 should make it always selected."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_bias_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, bias, T, H, E, K = _make_ernie_moe_block()
|
||||
bias[0] = 100.0 # mutate the bias to strongly favor expert 0
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = softmax_bias_topk_routing(hidden, moe_block)
|
||||
expert_idx_2d = expert_idx.reshape(T, K)
|
||||
for t in range(T):
|
||||
assert 0 in expert_idx_2d[t]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DeepSeek V2: softmax -> group_limited_greedy / greedy -> topk
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_deepseek_v2_moe_block(
|
||||
T=8, H=16, E=16, K=4, num_group=2, topk_group=1, topk_method="group_limited_greedy"
|
||||
):
|
||||
"""Create a mock DeepSeek V2 MoE block for routing tests.
|
||||
|
||||
DeepSeek V2 uses num_group (not n_group), gate is nn.Linear,
|
||||
and supports greedy / group_limited_greedy topk methods.
|
||||
"""
|
||||
gate = SimpleNamespace(weight=torch.randn(E, H))
|
||||
moe_block = SimpleNamespace(
|
||||
gate=gate,
|
||||
top_k=K,
|
||||
num_group=num_group,
|
||||
topk_group=topk_group,
|
||||
topk_method=topk_method,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
return moe_block, T, H, E, K
|
||||
|
||||
|
||||
class TestSoftmaxGroupLimitedTopkRouting:
|
||||
"""Tests for DeepSeek V2 routing (softmax_group_limited_topk_routing)."""
|
||||
|
||||
def test_output_shapes_group_limited(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(
|
||||
topk_method="group_limited_greedy"
|
||||
)
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, token_idx, expert_idx, logits = softmax_group_limited_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
assert scores.shape == (T * K,)
|
||||
assert token_idx.shape == (T * K,)
|
||||
assert expert_idx.shape == (T * K,)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_output_shapes_greedy(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="greedy")
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, token_idx, expert_idx, logits = softmax_group_limited_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
assert scores.shape == (T * K,)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, token_idx, _, _ = softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
diffs = token_idx[1:] - token_idx[:-1]
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
assert (expert_idx >= 0).all()
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_scaling_factor_applied(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="greedy")
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores_1x, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
|
||||
moe_block.routed_scaling_factor = 2.5
|
||||
scores_2x, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
|
||||
assert torch.allclose(scores_2x, scores_1x * 2.5, atol=1e-5)
|
||||
|
||||
def test_group_selection_restricts_experts(self):
|
||||
"""With num_group=4 and topk_group=1, experts should come from selected groups."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(
|
||||
E=16, K=2, num_group=4, topk_group=1, topk_method="group_limited_greedy"
|
||||
)
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
expert_idx_2d = expert_idx.reshape(T, K)
|
||||
group_size = E // moe_block.num_group
|
||||
for t in range(T):
|
||||
experts = expert_idx_2d[t]
|
||||
groups = experts // group_size
|
||||
# All selected experts should be from the same group
|
||||
assert (groups == groups[0]).all()
|
||||
|
||||
def test_unsupported_topk_method_raises(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_group_limited_topk_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="invalid")
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
with pytest.raises(ValueError, match="unsupported topk_method"):
|
||||
softmax_group_limited_topk_routing(hidden, moe_block)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HunYuan V1 MoE: softmax -> topk -> renorm (via gate.wg)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_hunyuan_moe_block(T=8, H=16, E=8, K=2):
|
||||
"""Create a mock HunYuan V1 MoE block for routing tests.
|
||||
|
||||
HunYuan V1 uses gate.wg (nn.Linear-like) instead of gate.weight,
|
||||
and top_k on the moe_block instead of the gate.
|
||||
"""
|
||||
wg = SimpleNamespace(weight=torch.randn(E, H))
|
||||
gate = SimpleNamespace(wg=wg)
|
||||
moe_block = SimpleNamespace(gate=gate, top_k=K)
|
||||
return moe_block, T, H, E, K
|
||||
|
||||
|
||||
class TestSoftmaxTopkWgRouting:
|
||||
"""Tests for HunYuan V1 MoE routing (softmax_topk_wg_routing)."""
|
||||
|
||||
def test_output_shapes(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_hunyuan_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, token_idx, expert_idx, logits = softmax_topk_wg_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
assert scores.shape == (T * K,)
|
||||
assert token_idx.shape == (T * K,)
|
||||
assert expert_idx.shape == (T * K,)
|
||||
assert logits.shape == (T, E)
|
||||
|
||||
def test_scores_are_float32(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_hunyuan_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = softmax_topk_wg_routing(hidden, moe_block)
|
||||
assert scores.dtype == torch.float32
|
||||
|
||||
def test_token_indices_sorted_ascending(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_hunyuan_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, token_idx, _, _ = softmax_topk_wg_routing(hidden, moe_block)
|
||||
diffs = token_idx[1:] - token_idx[:-1]
|
||||
assert (diffs >= 0).all()
|
||||
|
||||
def test_expert_indices_in_range(self):
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_hunyuan_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
_, _, expert_idx, _ = softmax_topk_wg_routing(hidden, moe_block)
|
||||
assert (expert_idx >= 0).all()
|
||||
assert (expert_idx < E).all()
|
||||
|
||||
def test_renormalized_scores_sum_to_one(self):
|
||||
"""HunYuan V1 always renormalizes (no norm_topk_prob flag)."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_hunyuan_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores, _, _, _ = softmax_topk_wg_routing(hidden, moe_block)
|
||||
per_token_sums = scores.reshape(T, K).sum(dim=-1)
|
||||
assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5)
|
||||
|
||||
def test_uses_gate_wg_weight(self):
|
||||
"""Verify that modifying gate.wg.weight changes the routing output."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import (
|
||||
softmax_topk_wg_routing,
|
||||
)
|
||||
|
||||
moe_block, T, H, E, K = _make_hunyuan_moe_block()
|
||||
hidden = torch.randn(T, H)
|
||||
|
||||
scores1, _, _, _ = softmax_topk_wg_routing(hidden, moe_block)
|
||||
|
||||
# Change the wg weight and verify scores change
|
||||
moe_block.gate.wg.weight = torch.randn(E, H)
|
||||
scores2, _, _, _ = softmax_topk_wg_routing(hidden, moe_block)
|
||||
|
||||
assert not torch.equal(scores1, scores2)
|
||||
|
||||
Reference in New Issue
Block a user