This commit is contained in:
Dan Saunders
2025-09-18 11:23:15 -04:00
parent 01b6792c2e
commit 7d867de9b2

View File

@@ -114,19 +114,26 @@ def moe_ffn_forward_grouped(
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
assignments = torch.bincount(flat_idx, minlength=num_experts)
perm = torch.argsort(flat_idx, stable=True)
sorted_experts = flat_idx[perm]
assignments = torch.bincount(sorted_experts, minlength=num_experts)
if assignments.sum() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
perm = torch.argsort(flat_idx, stable=True)
token_indices_sorted = perm // top_k
scores_sorted = topk_weight.view(-1)[perm]
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
routed_input = torch.gather(x_flat, 0, gather_index).contiguous()
offsets = torch.cumsum(assignments.to(device=device, dtype=torch.int32), dim=0)
active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1)
if active_idx.numel() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
counts_active = assignments[active_idx]
offsets = torch.cumsum(counts_active.to(device=device, dtype=torch.int32), dim=0)
if offsets[-1].item() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
@@ -135,9 +142,9 @@ def moe_ffn_forward_grouped(
w_gate = w13[..., :mid]
w_up = w13[..., mid:]
w_gate_t = w_gate.transpose(-2, -1).contiguous()
w_up_t = w_up.transpose(-2, -1).contiguous()
w2_t = w2.transpose(-2, -1).contiguous()
w_gate_t = w_gate[active_idx].transpose(-2, -1).contiguous()
w_up_t = w_up[active_idx].transpose(-2, -1).contiguous()
w2_t = w2[active_idx].transpose(-2, -1).contiguous()
routed_in = routed_input.to(expert_dtype)
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)