From 7d867de9b23e7b9b58d2b2890f4f932aa6e362cd Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 18 Sep 2025 11:23:15 -0400 Subject: [PATCH] refactor --- src/axolotl/kernels/moe/torch_grouped.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 99d861551..9201d17a3 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -114,19 +114,26 @@ def moe_ffn_forward_grouped( zero = torch.zeros_like(x_flat) return zero.view(bsz, seqlen, hdim), router_logits - assignments = torch.bincount(flat_idx, minlength=num_experts) + perm = torch.argsort(flat_idx, stable=True) + sorted_experts = flat_idx[perm] + assignments = torch.bincount(sorted_experts, minlength=num_experts) if assignments.sum() == 0: zero = torch.zeros_like(x_flat) return zero.view(bsz, seqlen, hdim), router_logits - perm = torch.argsort(flat_idx, stable=True) token_indices_sorted = perm // top_k scores_sorted = topk_weight.view(-1)[perm] gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim) routed_input = torch.gather(x_flat, 0, gather_index).contiguous() - offsets = torch.cumsum(assignments.to(device=device, dtype=torch.int32), dim=0) + active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1) + if active_idx.numel() == 0: + zero = torch.zeros_like(x_flat) + return zero.view(bsz, seqlen, hdim), router_logits + + counts_active = assignments[active_idx] + offsets = torch.cumsum(counts_active.to(device=device, dtype=torch.int32), dim=0) if offsets[-1].item() == 0: zero = torch.zeros_like(x_flat) return zero.view(bsz, seqlen, hdim), router_logits @@ -135,9 +142,9 @@ def moe_ffn_forward_grouped( w_gate = w13[..., :mid] w_up = w13[..., mid:] - w_gate_t = w_gate.transpose(-2, -1).contiguous() - w_up_t = w_up.transpose(-2, -1).contiguous() - w2_t = w2.transpose(-2, -1).contiguous() + w_gate_t = w_gate[active_idx].transpose(-2, -1).contiguous() + w_up_t = w_up[active_idx].transpose(-2, -1).contiguous() + w2_t = w2[active_idx].transpose(-2, -1).contiguous() routed_in = routed_input.to(expert_dtype) gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)