|
|
|
|
@@ -3,9 +3,11 @@ Routing functions for SonicMoE integration.
|
|
|
|
|
|
|
|
|
|
Different MoE architectures use different routing strategies:
|
|
|
|
|
- qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization)
|
|
|
|
|
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None)
|
|
|
|
|
- glm_moe_dsa: sigmoid -> topk (with group-based expert selection)
|
|
|
|
|
- mistral4: softmax -> group selection -> topk (with renormalization and scaling)
|
|
|
|
|
- glm_moe_dsa / deepseek_v3 / minimax_m2: sigmoid -> topk (with group-based expert selection)
|
|
|
|
|
- ernie4_5_moe: softmax -> bias correction -> topk -> gather (softmax_bias_topk_routing)
|
|
|
|
|
- hunyuan_v1_moe: softmax -> topk via gate.wg (softmax_topk_wg_routing)
|
|
|
|
|
- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None) [NOT YET SUPPORTED]
|
|
|
|
|
|
|
|
|
|
Each model type maps to a (routing_fn, activation_type, router_attr) triple.
|
|
|
|
|
When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used.
|
|
|
|
|
@@ -57,17 +59,10 @@ def get_model_moe_config(model_type: str):
|
|
|
|
|
"minimax_m2",
|
|
|
|
|
):
|
|
|
|
|
return sigmoid_topk_routing, ActivationType.SWIGLU, "gate"
|
|
|
|
|
# elif model_type in ("ernie4_5_moe",):
|
|
|
|
|
# # Softmax→topk with e_score_correction_bias applied between softmax and topk.
|
|
|
|
|
# return ..., ActivationType.SWIGLU, "gate"
|
|
|
|
|
# elif model_type in ("deepseek_v2",):
|
|
|
|
|
# # Softmax→topk with group_limited_greedy. Different attr names: num_group
|
|
|
|
|
# # (not n_group), gate is nn.Linear (not a router class).
|
|
|
|
|
# return ..., ActivationType.SWIGLU, "gate"
|
|
|
|
|
# elif model_type in ("hunyuan_v1_moe",):
|
|
|
|
|
# # Softmax→topk but gate structure differs: gate.wg (not gate.weight),
|
|
|
|
|
# # top_k on block not gate, creates scatter routing matrix.
|
|
|
|
|
# return ..., ActivationType.SWIGLU, "gate"
|
|
|
|
|
elif model_type in ("ernie4_5_moe",):
|
|
|
|
|
return softmax_bias_topk_routing, ActivationType.SWIGLU, "gate"
|
|
|
|
|
elif model_type in ("hunyuan_v1_moe",):
|
|
|
|
|
return softmax_topk_wg_routing, ActivationType.SWIGLU, "gate"
|
|
|
|
|
# Fused topk -> softmax path (routing_fn=None):
|
|
|
|
|
# elif model_type in ("gpt_oss",):
|
|
|
|
|
# # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer
|
|
|
|
|
@@ -94,7 +89,7 @@ def softmax_topk_routing(
|
|
|
|
|
router_logits: [T, E] original logits for aux loss
|
|
|
|
|
"""
|
|
|
|
|
gate = moe_block.gate
|
|
|
|
|
T, H = hidden_states.shape
|
|
|
|
|
T, _ = hidden_states.shape
|
|
|
|
|
K = gate.top_k
|
|
|
|
|
|
|
|
|
|
# Compute router logits and softmax over all experts
|
|
|
|
|
@@ -134,7 +129,7 @@ def softmax_group_topk_routing(
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale."""
|
|
|
|
|
gate = moe_block.gate
|
|
|
|
|
T, H = hidden_states.shape
|
|
|
|
|
T, _ = hidden_states.shape
|
|
|
|
|
K = moe_block.top_k
|
|
|
|
|
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
|
|
|
|
n_group = getattr(moe_block, "n_group", 1)
|
|
|
|
|
@@ -212,7 +207,7 @@ def sigmoid_topk_routing(
|
|
|
|
|
router_logits: [T, E] original logits for aux loss
|
|
|
|
|
"""
|
|
|
|
|
gate = moe_block.gate
|
|
|
|
|
T, H = hidden_states.shape
|
|
|
|
|
T, _ = hidden_states.shape
|
|
|
|
|
K = moe_block.top_k
|
|
|
|
|
E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0])
|
|
|
|
|
n_group = getattr(moe_block, "n_group", 1)
|
|
|
|
|
@@ -276,3 +271,203 @@ def sigmoid_topk_routing(
|
|
|
|
|
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
|
|
|
|
|
|
|
|
|
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def softmax_bias_topk_routing(
|
|
|
|
|
hidden_states: torch.Tensor, moe_block
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""Ernie 4.5 MoE routing: softmax → bias correction → topk → gather → renorm.
|
|
|
|
|
|
|
|
|
|
Differs from standard softmax_topk_routing in three ways:
|
|
|
|
|
1. A learned e_score_correction_bias is added to softmax probs *before* topk
|
|
|
|
|
(selection uses biased scores, but final weights use original probs).
|
|
|
|
|
2. The bias is applied via gate.moe_statics module (not a raw tensor).
|
|
|
|
|
3. Renormalization uses clamp(min=norm_min) instead of sum+epsilon.
|
|
|
|
|
|
|
|
|
|
Reference: Ernie4_5_MoeTopKRouter.forward in transformers.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
hidden_states: [T, H] flattened token representations
|
|
|
|
|
moe_block: MoE block module (accesses moe_block.gate.*)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
router_scores: [T*K] flattened scores (float32)
|
|
|
|
|
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
|
|
|
|
expert_indices: [T*K] which expert (int32)
|
|
|
|
|
router_logits: [T, E] original logits for aux loss
|
|
|
|
|
"""
|
|
|
|
|
gate = moe_block.gate
|
|
|
|
|
T, _ = hidden_states.shape
|
|
|
|
|
K = gate.top_k
|
|
|
|
|
|
|
|
|
|
# Compute router logits and softmax (force float32 for numerical stability)
|
|
|
|
|
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
|
|
|
|
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
|
|
|
|
|
|
|
|
|
# Bias-corrected scores for expert selection (via moe_statics module)
|
|
|
|
|
scores_for_choice = gate.moe_statics(router_probs) # [T, E]
|
|
|
|
|
|
|
|
|
|
# Select top-k experts using biased scores
|
|
|
|
|
_, selected_experts = torch.topk(scores_for_choice, K, dim=-1) # [T, K]
|
|
|
|
|
|
|
|
|
|
# Gather weights from *original* (unbiased) softmax probs
|
|
|
|
|
top_values = torch.gather(router_probs, dim=-1, index=selected_experts) # [T, K]
|
|
|
|
|
|
|
|
|
|
# Renormalize with clamp(min=norm_min) instead of sum+epsilon
|
|
|
|
|
norm_min = getattr(gate, "norm_min", 1e-20)
|
|
|
|
|
top_values = top_values / torch.clamp(
|
|
|
|
|
top_values.sum(dim=-1, keepdim=True), min=norm_min
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Flatten for moe_general_routing_inputs
|
|
|
|
|
token_indices = (
|
|
|
|
|
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
|
|
|
|
.unsqueeze(1)
|
|
|
|
|
.expand(T, K)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K]
|
|
|
|
|
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
|
|
|
|
flat_expert_idx = selected_experts.to(torch.int32).reshape(-1) # [T*K]
|
|
|
|
|
|
|
|
|
|
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def softmax_group_limited_topk_routing(
|
|
|
|
|
hidden_states: torch.Tensor, moe_block
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""DeepSeek V2 routing: softmax → group_limited_greedy/greedy → topk → scale.
|
|
|
|
|
|
|
|
|
|
Differs from softmax_group_topk_routing (Mistral4) in several ways:
|
|
|
|
|
1. Uses ``num_group`` attribute (not ``n_group``).
|
|
|
|
|
2. Group score = max per group (not sum of top-2).
|
|
|
|
|
3. Supports ``greedy`` method (plain topk without groups).
|
|
|
|
|
4. No renormalization — just ``topk_weight * routed_scaling_factor``.
|
|
|
|
|
5. Gate is ``nn.Linear`` (access weight via ``gate.weight``).
|
|
|
|
|
|
|
|
|
|
Reference: DeepseekV2Moe.route_tokens_to_experts in transformers.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
hidden_states: [T, H] flattened token representations
|
|
|
|
|
moe_block: MoE block module (accesses moe_block.gate, .num_group,
|
|
|
|
|
.topk_group, .top_k, .topk_method, .routed_scaling_factor)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
router_scores: [T*K] flattened scores (float32)
|
|
|
|
|
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
|
|
|
|
expert_indices: [T*K] which expert (int32)
|
|
|
|
|
router_logits: [T, E] original logits for aux loss
|
|
|
|
|
"""
|
|
|
|
|
gate = moe_block.gate
|
|
|
|
|
T, _ = hidden_states.shape
|
|
|
|
|
K = moe_block.top_k
|
|
|
|
|
num_group = getattr(moe_block, "num_group", 1)
|
|
|
|
|
num_experts = gate.weight.shape[0]
|
|
|
|
|
topk_method = getattr(moe_block, "topk_method", "greedy")
|
|
|
|
|
|
|
|
|
|
# Compute logits in float32 and softmax
|
|
|
|
|
router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E]
|
|
|
|
|
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
|
|
|
|
|
|
|
|
|
if topk_method == "greedy" or num_group == 1:
|
|
|
|
|
topk_weights, topk_indices = torch.topk(router_probs, k=K, dim=-1, sorted=False)
|
|
|
|
|
elif topk_method == "group_limited_greedy":
|
|
|
|
|
# Guard: selected groups must contain enough experts for topk
|
|
|
|
|
group_size = num_experts // num_group
|
|
|
|
|
if moe_block.topk_group * group_size < K:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"DeepSeek V2: topk_group ({moe_block.topk_group}) * group_size "
|
|
|
|
|
f"({group_size}) = {moe_block.topk_group * group_size} < top_k ({K}). "
|
|
|
|
|
f"Not enough experts in selected groups for topk selection."
|
|
|
|
|
)
|
|
|
|
|
# Group selection: pick top groups by max score per group
|
|
|
|
|
group_scores = (
|
|
|
|
|
router_probs.view(T, num_group, num_experts // num_group).max(dim=-1).values
|
|
|
|
|
) # [T, num_group]
|
|
|
|
|
group_idx = torch.topk(
|
|
|
|
|
group_scores, k=moe_block.topk_group, dim=-1, sorted=False
|
|
|
|
|
)[1]
|
|
|
|
|
group_mask = torch.zeros_like(group_scores)
|
|
|
|
|
group_mask.scatter_(1, group_idx, 1)
|
|
|
|
|
score_mask = (
|
|
|
|
|
group_mask.unsqueeze(-1)
|
|
|
|
|
.expand(T, num_group, num_experts // num_group)
|
|
|
|
|
.reshape(T, -1)
|
|
|
|
|
)
|
|
|
|
|
tmp_scores = router_probs.masked_fill(~score_mask.bool(), 0.0)
|
|
|
|
|
topk_weights, topk_indices = torch.topk(tmp_scores, k=K, dim=-1, sorted=False)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"DeepSeek V2: unsupported topk_method '{topk_method}'. "
|
|
|
|
|
f"Expected 'greedy' or 'group_limited_greedy'."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Scale only — no renormalization (weights won't sum to 1.0 per token).
|
|
|
|
|
# This matches the reference DeepseekV2Moe.route_tokens_to_experts behavior.
|
|
|
|
|
routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0)
|
|
|
|
|
topk_weights = topk_weights * routed_scaling_factor
|
|
|
|
|
|
|
|
|
|
# Flatten for moe_general_routing_inputs
|
|
|
|
|
token_indices = (
|
|
|
|
|
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
|
|
|
|
.unsqueeze(1)
|
|
|
|
|
.expand(T, K)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K]
|
|
|
|
|
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
|
|
|
|
flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K]
|
|
|
|
|
|
|
|
|
|
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def softmax_topk_wg_routing(
|
|
|
|
|
hidden_states: torch.Tensor, moe_block
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""HunYuan V1 MoE routing: softmax → topk → renorm (gate weight via gate.wg).
|
|
|
|
|
|
|
|
|
|
Differs from standard softmax_topk_routing in:
|
|
|
|
|
1. Gate weight lives at ``gate.wg.weight`` (not ``gate.weight``).
|
|
|
|
|
2. ``top_k`` is on ``moe_block`` (not ``gate``).
|
|
|
|
|
3. Always renormalizes (no ``norm_topk_prob`` flag).
|
|
|
|
|
|
|
|
|
|
Reference: HunYuanMoEV1Moe.route_tokens_to_experts and
|
|
|
|
|
HunYuanMoEV1Gate.forward in transformers.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
hidden_states: [T, H] flattened token representations
|
|
|
|
|
moe_block: MoE block module (accesses moe_block.gate.wg, moe_block.top_k)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
router_scores: [T*K] flattened scores (float32)
|
|
|
|
|
token_indices: [T*K] which token each entry belongs to (int32), sorted ascending
|
|
|
|
|
expert_indices: [T*K] which expert (int32)
|
|
|
|
|
router_logits: [T, E] original logits for aux loss
|
|
|
|
|
"""
|
|
|
|
|
gate = moe_block.gate
|
|
|
|
|
T, _ = hidden_states.shape
|
|
|
|
|
K = moe_block.top_k
|
|
|
|
|
|
|
|
|
|
# Gate computes logits via gate.wg (nn.Linear, float32)
|
|
|
|
|
wg = gate.wg
|
|
|
|
|
router_logits = F.linear(hidden_states.float(), wg.weight.float()) # [T, E]
|
|
|
|
|
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E]
|
|
|
|
|
|
|
|
|
|
# Select top-k experts
|
|
|
|
|
top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each
|
|
|
|
|
|
|
|
|
|
# Always renormalize (HunYuan V1 has no norm_topk_prob flag)
|
|
|
|
|
top_values = top_values / (top_values.sum(dim=-1, keepdim=True) + 1e-20)
|
|
|
|
|
|
|
|
|
|
# Flatten for moe_general_routing_inputs
|
|
|
|
|
token_indices = (
|
|
|
|
|
torch.arange(T, device=hidden_states.device, dtype=torch.int32)
|
|
|
|
|
.unsqueeze(1)
|
|
|
|
|
.expand(T, K)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K]
|
|
|
|
|
flat_token_idx = token_indices.reshape(-1) # [T*K]
|
|
|
|
|
flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K]
|
|
|
|
|
|
|
|
|
|
return flat_scores, flat_token_idx, flat_expert_idx, router_logits
|
|
|
|
|
|