diff --git a/src/axolotl/integrations/kernels/constants.py b/src/axolotl/integrations/kernels/constants.py index 8002b3f79..a03761484 100644 --- a/src/axolotl/integrations/kernels/constants.py +++ b/src/axolotl/integrations/kernels/constants.py @@ -6,6 +6,11 @@ Used by both ScatterMoE and SonicMoE kernel paths. Values can be a single class name (str) or a list of class names for models with multiple MoE block types (e.g. qwen3_omni_moe has Thinker + Talker). + +Models with custom routing (see sonicmoe/routing.py for implementations): +- ernie4_5_moe: softmax→bias correction→topk (softmax_bias_topk_routing) +- deepseek_v2: softmax→group_limited_greedy (softmax_group_limited_topk_routing) +- hunyuan_v1_moe: softmax→topk via gate.wg (softmax_topk_wg_routing) """ import importlib @@ -36,11 +41,15 @@ SPARSE_MOE_BLOCK = { "glm4v_moe": "Glm4vMoeTextMoE", # sigmoid -> topk routing (no group selection) "minimax_m2": "MiniMaxM2SparseMoeBlock", - # Models below need custom routing (not yet implemented): - # "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk - # "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group) - # "hunyuan_v1_moe": "HunYuanMoEV1Moe", # softmax->topk, gate.wg (not gate.weight), scatter routing - # "gpt_oss": "GptOssMLP", # topk->softmax, transposed layout [E,H,2*I], custom GLU, expert biases + # softmax->topk, e_score_correction_bias between softmax and topk + "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", + # softmax->topk, group_limited_greedy, different attr names (num_group) + "deepseek_v2": "DeepseekV2Moe", + # softmax->topk, gate.wg (not gate.weight) + "hunyuan_v1_moe": "HunYuanMoEV1Moe", + # TODO: gpt_oss deferred — transposed weight layout [E,H,2*I], expert biases, + # and custom GLU activation require a dedicated forward path in patch.py. + # "gpt_oss": "GptOssMLP", } diff --git a/src/axolotl/integrations/kernels/sonicmoe/routing.py b/src/axolotl/integrations/kernels/sonicmoe/routing.py index fe2d12092..09bffc742 100644 --- a/src/axolotl/integrations/kernels/sonicmoe/routing.py +++ b/src/axolotl/integrations/kernels/sonicmoe/routing.py @@ -3,9 +3,11 @@ Routing functions for SonicMoE integration. Different MoE architectures use different routing strategies: - qwen3_moe / qwen2_moe / qwen3_5_moe / qwen3_vl_moe / qwen3_omni_moe: softmax -> topk (with optional renormalization) -- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None) -- glm_moe_dsa: sigmoid -> topk (with group-based expert selection) - mistral4: softmax -> group selection -> topk (with renormalization and scaling) +- glm_moe_dsa / deepseek_v3 / minimax_m2: sigmoid -> topk (with group-based expert selection) +- ernie4_5_moe: softmax -> bias correction -> topk -> gather (softmax_bias_topk_routing) +- hunyuan_v1_moe: softmax -> topk via gate.wg (softmax_topk_wg_routing) +- gpt_oss: topk -> softmax (uses fused moe_TC_softmax_topk_layer, routing_fn=None) [NOT YET SUPPORTED] Each model type maps to a (routing_fn, activation_type, router_attr) triple. When routing_fn is None, the fused moe_TC_softmax_topk_layer path is used. @@ -57,17 +59,10 @@ def get_model_moe_config(model_type: str): "minimax_m2", ): return sigmoid_topk_routing, ActivationType.SWIGLU, "gate" - # elif model_type in ("ernie4_5_moe",): - # # Softmax→topk with e_score_correction_bias applied between softmax and topk. - # return ..., ActivationType.SWIGLU, "gate" - # elif model_type in ("deepseek_v2",): - # # Softmax→topk with group_limited_greedy. Different attr names: num_group - # # (not n_group), gate is nn.Linear (not a router class). - # return ..., ActivationType.SWIGLU, "gate" - # elif model_type in ("hunyuan_v1_moe",): - # # Softmax→topk but gate structure differs: gate.wg (not gate.weight), - # # top_k on block not gate, creates scatter routing matrix. - # return ..., ActivationType.SWIGLU, "gate" + elif model_type in ("ernie4_5_moe",): + return softmax_bias_topk_routing, ActivationType.SWIGLU, "gate" + elif model_type in ("hunyuan_v1_moe",): + return softmax_topk_wg_routing, ActivationType.SWIGLU, "gate" # Fused topk -> softmax path (routing_fn=None): # elif model_type in ("gpt_oss",): # # NOTE: gpt_oss has a router bias which moe_TC_softmax_topk_layer @@ -94,7 +89,7 @@ def softmax_topk_routing( router_logits: [T, E] original logits for aux loss """ gate = moe_block.gate - T, H = hidden_states.shape + T, _ = hidden_states.shape K = gate.top_k # Compute router logits and softmax over all experts @@ -134,7 +129,7 @@ def softmax_group_topk_routing( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Mistral4-style routing: softmax -> group selection -> topk -> renorm -> scale.""" gate = moe_block.gate - T, H = hidden_states.shape + T, _ = hidden_states.shape K = moe_block.top_k E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0]) n_group = getattr(moe_block, "n_group", 1) @@ -212,7 +207,7 @@ def sigmoid_topk_routing( router_logits: [T, E] original logits for aux loss """ gate = moe_block.gate - T, H = hidden_states.shape + T, _ = hidden_states.shape K = moe_block.top_k E = getattr(moe_block, "n_routed_experts", gate.weight.shape[0]) n_group = getattr(moe_block, "n_group", 1) @@ -276,3 +271,203 @@ def sigmoid_topk_routing( flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K] return flat_scores, flat_token_idx, flat_expert_idx, router_logits + + +def softmax_bias_topk_routing( + hidden_states: torch.Tensor, moe_block +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Ernie 4.5 MoE routing: softmax → bias correction → topk → gather → renorm. + + Differs from standard softmax_topk_routing in three ways: + 1. A learned e_score_correction_bias is added to softmax probs *before* topk + (selection uses biased scores, but final weights use original probs). + 2. The bias is applied via gate.moe_statics module (not a raw tensor). + 3. Renormalization uses clamp(min=norm_min) instead of sum+epsilon. + + Reference: Ernie4_5_MoeTopKRouter.forward in transformers. + + Args: + hidden_states: [T, H] flattened token representations + moe_block: MoE block module (accesses moe_block.gate.*) + + Returns: + router_scores: [T*K] flattened scores (float32) + token_indices: [T*K] which token each entry belongs to (int32), sorted ascending + expert_indices: [T*K] which expert (int32) + router_logits: [T, E] original logits for aux loss + """ + gate = moe_block.gate + T, _ = hidden_states.shape + K = gate.top_k + + # Compute router logits and softmax (force float32 for numerical stability) + router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E] + router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] + + # Bias-corrected scores for expert selection (via moe_statics module) + scores_for_choice = gate.moe_statics(router_probs) # [T, E] + + # Select top-k experts using biased scores + _, selected_experts = torch.topk(scores_for_choice, K, dim=-1) # [T, K] + + # Gather weights from *original* (unbiased) softmax probs + top_values = torch.gather(router_probs, dim=-1, index=selected_experts) # [T, K] + + # Renormalize with clamp(min=norm_min) instead of sum+epsilon + norm_min = getattr(gate, "norm_min", 1e-20) + top_values = top_values / torch.clamp( + top_values.sum(dim=-1, keepdim=True), min=norm_min + ) + + # Flatten for moe_general_routing_inputs + token_indices = ( + torch.arange(T, device=hidden_states.device, dtype=torch.int32) + .unsqueeze(1) + .expand(T, K) + ) + + flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K] + flat_token_idx = token_indices.reshape(-1) # [T*K] + flat_expert_idx = selected_experts.to(torch.int32).reshape(-1) # [T*K] + + return flat_scores, flat_token_idx, flat_expert_idx, router_logits + + +def softmax_group_limited_topk_routing( + hidden_states: torch.Tensor, moe_block +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """DeepSeek V2 routing: softmax → group_limited_greedy/greedy → topk → scale. + + Differs from softmax_group_topk_routing (Mistral4) in several ways: + 1. Uses ``num_group`` attribute (not ``n_group``). + 2. Group score = max per group (not sum of top-2). + 3. Supports ``greedy`` method (plain topk without groups). + 4. No renormalization — just ``topk_weight * routed_scaling_factor``. + 5. Gate is ``nn.Linear`` (access weight via ``gate.weight``). + + Reference: DeepseekV2Moe.route_tokens_to_experts in transformers. + + Args: + hidden_states: [T, H] flattened token representations + moe_block: MoE block module (accesses moe_block.gate, .num_group, + .topk_group, .top_k, .topk_method, .routed_scaling_factor) + + Returns: + router_scores: [T*K] flattened scores (float32) + token_indices: [T*K] which token each entry belongs to (int32), sorted ascending + expert_indices: [T*K] which expert (int32) + router_logits: [T, E] original logits for aux loss + """ + gate = moe_block.gate + T, _ = hidden_states.shape + K = moe_block.top_k + num_group = getattr(moe_block, "num_group", 1) + num_experts = gate.weight.shape[0] + topk_method = getattr(moe_block, "topk_method", "greedy") + + # Compute logits in float32 and softmax + router_logits = F.linear(hidden_states.float(), gate.weight.float()) # [T, E] + router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] + + if topk_method == "greedy" or num_group == 1: + topk_weights, topk_indices = torch.topk(router_probs, k=K, dim=-1, sorted=False) + elif topk_method == "group_limited_greedy": + # Guard: selected groups must contain enough experts for topk + group_size = num_experts // num_group + if moe_block.topk_group * group_size < K: + raise ValueError( + f"DeepSeek V2: topk_group ({moe_block.topk_group}) * group_size " + f"({group_size}) = {moe_block.topk_group * group_size} < top_k ({K}). " + f"Not enough experts in selected groups for topk selection." + ) + # Group selection: pick top groups by max score per group + group_scores = ( + router_probs.view(T, num_group, num_experts // num_group).max(dim=-1).values + ) # [T, num_group] + group_idx = torch.topk( + group_scores, k=moe_block.topk_group, dim=-1, sorted=False + )[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(T, num_group, num_experts // num_group) + .reshape(T, -1) + ) + tmp_scores = router_probs.masked_fill(~score_mask.bool(), 0.0) + topk_weights, topk_indices = torch.topk(tmp_scores, k=K, dim=-1, sorted=False) + else: + raise ValueError( + f"DeepSeek V2: unsupported topk_method '{topk_method}'. " + f"Expected 'greedy' or 'group_limited_greedy'." + ) + + # Scale only — no renormalization (weights won't sum to 1.0 per token). + # This matches the reference DeepseekV2Moe.route_tokens_to_experts behavior. + routed_scaling_factor = getattr(moe_block, "routed_scaling_factor", 1.0) + topk_weights = topk_weights * routed_scaling_factor + + # Flatten for moe_general_routing_inputs + token_indices = ( + torch.arange(T, device=hidden_states.device, dtype=torch.int32) + .unsqueeze(1) + .expand(T, K) + ) + + flat_scores = topk_weights.to(torch.float32).reshape(-1) # [T*K] + flat_token_idx = token_indices.reshape(-1) # [T*K] + flat_expert_idx = topk_indices.to(torch.int32).reshape(-1) # [T*K] + + return flat_scores, flat_token_idx, flat_expert_idx, router_logits + + +def softmax_topk_wg_routing( + hidden_states: torch.Tensor, moe_block +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """HunYuan V1 MoE routing: softmax → topk → renorm (gate weight via gate.wg). + + Differs from standard softmax_topk_routing in: + 1. Gate weight lives at ``gate.wg.weight`` (not ``gate.weight``). + 2. ``top_k`` is on ``moe_block`` (not ``gate``). + 3. Always renormalizes (no ``norm_topk_prob`` flag). + + Reference: HunYuanMoEV1Moe.route_tokens_to_experts and + HunYuanMoEV1Gate.forward in transformers. + + Args: + hidden_states: [T, H] flattened token representations + moe_block: MoE block module (accesses moe_block.gate.wg, moe_block.top_k) + + Returns: + router_scores: [T*K] flattened scores (float32) + token_indices: [T*K] which token each entry belongs to (int32), sorted ascending + expert_indices: [T*K] which expert (int32) + router_logits: [T, E] original logits for aux loss + """ + gate = moe_block.gate + T, _ = hidden_states.shape + K = moe_block.top_k + + # Gate computes logits via gate.wg (nn.Linear, float32) + wg = gate.wg + router_logits = F.linear(hidden_states.float(), wg.weight.float()) # [T, E] + router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # [T, E] + + # Select top-k experts + top_values, top_indices = torch.topk(router_probs, K, dim=-1) # [T, K] each + + # Always renormalize (HunYuan V1 has no norm_topk_prob flag) + top_values = top_values / (top_values.sum(dim=-1, keepdim=True) + 1e-20) + + # Flatten for moe_general_routing_inputs + token_indices = ( + torch.arange(T, device=hidden_states.device, dtype=torch.int32) + .unsqueeze(1) + .expand(T, K) + ) + + flat_scores = top_values.to(torch.float32).reshape(-1) # [T*K] + flat_token_idx = token_indices.reshape(-1) # [T*K] + flat_expert_idx = top_indices.to(torch.int32).reshape(-1) # [T*K] + + return flat_scores, flat_token_idx, flat_expert_idx, router_logits diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 38cc198d3..756eef886 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -590,7 +590,8 @@ class PatchManager: def _apply_llama_flash_attn_patches(self, model): """Apply LLaMA-specific flash attention patches.""" if ( - self.model_config.model_type in ["llama", "llama4"] + self.model_config.model_type + in ["llama", "llama4", "ernie4_5", "ernie4_5_moe"] and not self.cfg.trust_remote_code and not self.cfg.gptq and self.cfg.flash_attention diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 78087acbc..8566af526 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -52,6 +52,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "seed_oss", "lfm2", "lfm2_moe", + "ernie4_5", + "ernie4_5_moe", "olmo", "olmo2", "olmo3", diff --git a/tests/integrations/test_sonicmoe.py b/tests/integrations/test_sonicmoe.py index e6294f564..7d26d9d93 100644 --- a/tests/integrations/test_sonicmoe.py +++ b/tests/integrations/test_sonicmoe.py @@ -426,3 +426,363 @@ class TestMiniMaxM2SigmoidRouting: expert_idx_2d = expert_idx.reshape(T, K) for t in range(T): assert 0 in expert_idx_2d[t] + + +# ============================================================================ +# Ernie 4.5 MoE: softmax -> bias correction -> topk +# ============================================================================ + + +def _make_ernie_moe_block(T=8, H=16, E=8, K=2, norm_min=1e-20): + """Create a mock Ernie 4.5 MoE block for routing tests. + + Ernie 4.5 uses a gate.moe_statics module that adds bias to softmax probs + before topk selection, then gathers from original probs. + """ + bias = torch.zeros(E) + + class MockMoeStatics: + def __init__(self, bias_tensor): + self.e_score_correction_bias = bias_tensor + + def __call__(self, probs): + return probs + self.e_score_correction_bias + + gate = SimpleNamespace( + weight=torch.randn(E, H), + top_k=K, + moe_statics=MockMoeStatics(bias), + norm_min=norm_min, + ) + moe_block = SimpleNamespace(gate=gate) + return moe_block, bias, T, H, E, K + + +class TestSoftmaxBiasTopkRouting: + """Tests for Ernie 4.5 MoE routing (softmax_bias_topk_routing).""" + + def test_output_shapes(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_bias_topk_routing, + ) + + moe_block, _, T, H, E, K = _make_ernie_moe_block() + hidden = torch.randn(T, H) + + scores, token_idx, expert_idx, logits = softmax_bias_topk_routing( + hidden, moe_block + ) + + assert scores.shape == (T * K,) + assert token_idx.shape == (T * K,) + assert expert_idx.shape == (T * K,) + assert logits.shape == (T, E) + + def test_scores_are_float32(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_bias_topk_routing, + ) + + moe_block, _, T, H, E, K = _make_ernie_moe_block() + hidden = torch.randn(T, H) + + scores, _, _, _ = softmax_bias_topk_routing(hidden, moe_block) + assert scores.dtype == torch.float32 + + def test_token_indices_sorted_ascending(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_bias_topk_routing, + ) + + moe_block, _, T, H, E, K = _make_ernie_moe_block() + hidden = torch.randn(T, H) + + _, token_idx, _, _ = softmax_bias_topk_routing(hidden, moe_block) + diffs = token_idx[1:] - token_idx[:-1] + assert (diffs >= 0).all() + + def test_expert_indices_in_range(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_bias_topk_routing, + ) + + moe_block, _, T, H, E, K = _make_ernie_moe_block() + hidden = torch.randn(T, H) + + _, _, expert_idx, _ = softmax_bias_topk_routing(hidden, moe_block) + assert (expert_idx >= 0).all() + assert (expert_idx < E).all() + + def test_renormalized_scores_sum_to_one(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_bias_topk_routing, + ) + + moe_block, _, T, H, E, K = _make_ernie_moe_block() + hidden = torch.randn(T, H) + + scores, _, _, _ = softmax_bias_topk_routing(hidden, moe_block) + per_token_sums = scores.reshape(T, K).sum(dim=-1) + assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5) + + def test_bias_affects_expert_selection(self): + """Large positive bias on expert 0 should make it always selected.""" + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_bias_topk_routing, + ) + + moe_block, bias, T, H, E, K = _make_ernie_moe_block() + bias[0] = 100.0 # mutate the bias to strongly favor expert 0 + hidden = torch.randn(T, H) + + _, _, expert_idx, _ = softmax_bias_topk_routing(hidden, moe_block) + expert_idx_2d = expert_idx.reshape(T, K) + for t in range(T): + assert 0 in expert_idx_2d[t] + + +# ============================================================================ +# DeepSeek V2: softmax -> group_limited_greedy / greedy -> topk +# ============================================================================ + + +def _make_deepseek_v2_moe_block( + T=8, H=16, E=16, K=4, num_group=2, topk_group=1, topk_method="group_limited_greedy" +): + """Create a mock DeepSeek V2 MoE block for routing tests. + + DeepSeek V2 uses num_group (not n_group), gate is nn.Linear, + and supports greedy / group_limited_greedy topk methods. + """ + gate = SimpleNamespace(weight=torch.randn(E, H)) + moe_block = SimpleNamespace( + gate=gate, + top_k=K, + num_group=num_group, + topk_group=topk_group, + topk_method=topk_method, + routed_scaling_factor=1.0, + ) + return moe_block, T, H, E, K + + +class TestSoftmaxGroupLimitedTopkRouting: + """Tests for DeepSeek V2 routing (softmax_group_limited_topk_routing).""" + + def test_output_shapes_group_limited(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_group_limited_topk_routing, + ) + + moe_block, T, H, E, K = _make_deepseek_v2_moe_block( + topk_method="group_limited_greedy" + ) + hidden = torch.randn(T, H) + + scores, token_idx, expert_idx, logits = softmax_group_limited_topk_routing( + hidden, moe_block + ) + + assert scores.shape == (T * K,) + assert token_idx.shape == (T * K,) + assert expert_idx.shape == (T * K,) + assert logits.shape == (T, E) + + def test_output_shapes_greedy(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_group_limited_topk_routing, + ) + + moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="greedy") + hidden = torch.randn(T, H) + + scores, token_idx, expert_idx, logits = softmax_group_limited_topk_routing( + hidden, moe_block + ) + + assert scores.shape == (T * K,) + assert logits.shape == (T, E) + + def test_scores_are_float32(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_group_limited_topk_routing, + ) + + moe_block, T, H, E, K = _make_deepseek_v2_moe_block() + hidden = torch.randn(T, H) + + scores, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block) + assert scores.dtype == torch.float32 + + def test_token_indices_sorted_ascending(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_group_limited_topk_routing, + ) + + moe_block, T, H, E, K = _make_deepseek_v2_moe_block() + hidden = torch.randn(T, H) + + _, token_idx, _, _ = softmax_group_limited_topk_routing(hidden, moe_block) + diffs = token_idx[1:] - token_idx[:-1] + assert (diffs >= 0).all() + + def test_expert_indices_in_range(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_group_limited_topk_routing, + ) + + moe_block, T, H, E, K = _make_deepseek_v2_moe_block() + hidden = torch.randn(T, H) + + _, _, expert_idx, _ = softmax_group_limited_topk_routing(hidden, moe_block) + assert (expert_idx >= 0).all() + assert (expert_idx < E).all() + + def test_scaling_factor_applied(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_group_limited_topk_routing, + ) + + moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="greedy") + hidden = torch.randn(T, H) + + scores_1x, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block) + + moe_block.routed_scaling_factor = 2.5 + scores_2x, _, _, _ = softmax_group_limited_topk_routing(hidden, moe_block) + + assert torch.allclose(scores_2x, scores_1x * 2.5, atol=1e-5) + + def test_group_selection_restricts_experts(self): + """With num_group=4 and topk_group=1, experts should come from selected groups.""" + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_group_limited_topk_routing, + ) + + moe_block, T, H, E, K = _make_deepseek_v2_moe_block( + E=16, K=2, num_group=4, topk_group=1, topk_method="group_limited_greedy" + ) + hidden = torch.randn(T, H) + + _, _, expert_idx, _ = softmax_group_limited_topk_routing(hidden, moe_block) + expert_idx_2d = expert_idx.reshape(T, K) + group_size = E // moe_block.num_group + for t in range(T): + experts = expert_idx_2d[t] + groups = experts // group_size + # All selected experts should be from the same group + assert (groups == groups[0]).all() + + def test_unsupported_topk_method_raises(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_group_limited_topk_routing, + ) + + moe_block, T, H, E, K = _make_deepseek_v2_moe_block(topk_method="invalid") + hidden = torch.randn(T, H) + + with pytest.raises(ValueError, match="unsupported topk_method"): + softmax_group_limited_topk_routing(hidden, moe_block) + + +# ============================================================================ +# HunYuan V1 MoE: softmax -> topk -> renorm (via gate.wg) +# ============================================================================ + + +def _make_hunyuan_moe_block(T=8, H=16, E=8, K=2): + """Create a mock HunYuan V1 MoE block for routing tests. + + HunYuan V1 uses gate.wg (nn.Linear-like) instead of gate.weight, + and top_k on the moe_block instead of the gate. + """ + wg = SimpleNamespace(weight=torch.randn(E, H)) + gate = SimpleNamespace(wg=wg) + moe_block = SimpleNamespace(gate=gate, top_k=K) + return moe_block, T, H, E, K + + +class TestSoftmaxTopkWgRouting: + """Tests for HunYuan V1 MoE routing (softmax_topk_wg_routing).""" + + def test_output_shapes(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_topk_wg_routing, + ) + + moe_block, T, H, E, K = _make_hunyuan_moe_block() + hidden = torch.randn(T, H) + + scores, token_idx, expert_idx, logits = softmax_topk_wg_routing( + hidden, moe_block + ) + + assert scores.shape == (T * K,) + assert token_idx.shape == (T * K,) + assert expert_idx.shape == (T * K,) + assert logits.shape == (T, E) + + def test_scores_are_float32(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_topk_wg_routing, + ) + + moe_block, T, H, E, K = _make_hunyuan_moe_block() + hidden = torch.randn(T, H) + + scores, _, _, _ = softmax_topk_wg_routing(hidden, moe_block) + assert scores.dtype == torch.float32 + + def test_token_indices_sorted_ascending(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_topk_wg_routing, + ) + + moe_block, T, H, E, K = _make_hunyuan_moe_block() + hidden = torch.randn(T, H) + + _, token_idx, _, _ = softmax_topk_wg_routing(hidden, moe_block) + diffs = token_idx[1:] - token_idx[:-1] + assert (diffs >= 0).all() + + def test_expert_indices_in_range(self): + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_topk_wg_routing, + ) + + moe_block, T, H, E, K = _make_hunyuan_moe_block() + hidden = torch.randn(T, H) + + _, _, expert_idx, _ = softmax_topk_wg_routing(hidden, moe_block) + assert (expert_idx >= 0).all() + assert (expert_idx < E).all() + + def test_renormalized_scores_sum_to_one(self): + """HunYuan V1 always renormalizes (no norm_topk_prob flag).""" + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_topk_wg_routing, + ) + + moe_block, T, H, E, K = _make_hunyuan_moe_block() + hidden = torch.randn(T, H) + + scores, _, _, _ = softmax_topk_wg_routing(hidden, moe_block) + per_token_sums = scores.reshape(T, K).sum(dim=-1) + assert torch.allclose(per_token_sums, torch.ones(T), atol=1e-5) + + def test_uses_gate_wg_weight(self): + """Verify that modifying gate.wg.weight changes the routing output.""" + from axolotl.integrations.kernels.sonicmoe.routing import ( + softmax_topk_wg_routing, + ) + + moe_block, T, H, E, K = _make_hunyuan_moe_block() + hidden = torch.randn(T, H) + + scores1, _, _, _ = softmax_topk_wg_routing(hidden, moe_block) + + # Change the wg weight and verify scores change + moe_block.gate.wg.weight = torch.randn(E, H) + scores2, _, _, _ = softmax_topk_wg_routing(hidden, moe_block) + + assert not torch.equal(scores1, scores2)