refactor
This commit is contained in:
@@ -114,19 +114,26 @@ def moe_ffn_forward_grouped(
|
|||||||
zero = torch.zeros_like(x_flat)
|
zero = torch.zeros_like(x_flat)
|
||||||
return zero.view(bsz, seqlen, hdim), router_logits
|
return zero.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|
||||||
assignments = torch.bincount(flat_idx, minlength=num_experts)
|
perm = torch.argsort(flat_idx, stable=True)
|
||||||
|
sorted_experts = flat_idx[perm]
|
||||||
|
assignments = torch.bincount(sorted_experts, minlength=num_experts)
|
||||||
if assignments.sum() == 0:
|
if assignments.sum() == 0:
|
||||||
zero = torch.zeros_like(x_flat)
|
zero = torch.zeros_like(x_flat)
|
||||||
return zero.view(bsz, seqlen, hdim), router_logits
|
return zero.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|
||||||
perm = torch.argsort(flat_idx, stable=True)
|
|
||||||
token_indices_sorted = perm // top_k
|
token_indices_sorted = perm // top_k
|
||||||
scores_sorted = topk_weight.view(-1)[perm]
|
scores_sorted = topk_weight.view(-1)[perm]
|
||||||
|
|
||||||
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
|
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
|
||||||
routed_input = torch.gather(x_flat, 0, gather_index).contiguous()
|
routed_input = torch.gather(x_flat, 0, gather_index).contiguous()
|
||||||
|
|
||||||
offsets = torch.cumsum(assignments.to(device=device, dtype=torch.int32), dim=0)
|
active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1)
|
||||||
|
if active_idx.numel() == 0:
|
||||||
|
zero = torch.zeros_like(x_flat)
|
||||||
|
return zero.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|
||||||
|
counts_active = assignments[active_idx]
|
||||||
|
offsets = torch.cumsum(counts_active.to(device=device, dtype=torch.int32), dim=0)
|
||||||
if offsets[-1].item() == 0:
|
if offsets[-1].item() == 0:
|
||||||
zero = torch.zeros_like(x_flat)
|
zero = torch.zeros_like(x_flat)
|
||||||
return zero.view(bsz, seqlen, hdim), router_logits
|
return zero.view(bsz, seqlen, hdim), router_logits
|
||||||
@@ -135,9 +142,9 @@ def moe_ffn_forward_grouped(
|
|||||||
w_gate = w13[..., :mid]
|
w_gate = w13[..., :mid]
|
||||||
w_up = w13[..., mid:]
|
w_up = w13[..., mid:]
|
||||||
|
|
||||||
w_gate_t = w_gate.transpose(-2, -1).contiguous()
|
w_gate_t = w_gate[active_idx].transpose(-2, -1).contiguous()
|
||||||
w_up_t = w_up.transpose(-2, -1).contiguous()
|
w_up_t = w_up[active_idx].transpose(-2, -1).contiguous()
|
||||||
w2_t = w2.transpose(-2, -1).contiguous()
|
w2_t = w2[active_idx].transpose(-2, -1).contiguous()
|
||||||
|
|
||||||
routed_in = routed_input.to(expert_dtype)
|
routed_in = routed_input.to(expert_dtype)
|
||||||
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
||||||
|
|||||||
Reference in New Issue
Block a user