refactor
This commit is contained in:
@@ -114,19 +114,26 @@ def moe_ffn_forward_grouped(
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
assignments = torch.bincount(flat_idx, minlength=num_experts)
|
||||
perm = torch.argsort(flat_idx, stable=True)
|
||||
sorted_experts = flat_idx[perm]
|
||||
assignments = torch.bincount(sorted_experts, minlength=num_experts)
|
||||
if assignments.sum() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
perm = torch.argsort(flat_idx, stable=True)
|
||||
token_indices_sorted = perm // top_k
|
||||
scores_sorted = topk_weight.view(-1)[perm]
|
||||
|
||||
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
|
||||
routed_input = torch.gather(x_flat, 0, gather_index).contiguous()
|
||||
|
||||
offsets = torch.cumsum(assignments.to(device=device, dtype=torch.int32), dim=0)
|
||||
active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1)
|
||||
if active_idx.numel() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
counts_active = assignments[active_idx]
|
||||
offsets = torch.cumsum(counts_active.to(device=device, dtype=torch.int32), dim=0)
|
||||
if offsets[-1].item() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
@@ -135,9 +142,9 @@ def moe_ffn_forward_grouped(
|
||||
w_gate = w13[..., :mid]
|
||||
w_up = w13[..., mid:]
|
||||
|
||||
w_gate_t = w_gate.transpose(-2, -1).contiguous()
|
||||
w_up_t = w_up.transpose(-2, -1).contiguous()
|
||||
w2_t = w2.transpose(-2, -1).contiguous()
|
||||
w_gate_t = w_gate[active_idx].transpose(-2, -1).contiguous()
|
||||
w_up_t = w_up[active_idx].transpose(-2, -1).contiguous()
|
||||
w2_t = w2[active_idx].transpose(-2, -1).contiguous()
|
||||
|
||||
routed_in = routed_input.to(expert_dtype)
|
||||
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
||||
|
||||
Reference in New Issue
Block a user