yet another refactor
This commit is contained in:
@@ -271,8 +271,8 @@ def moe_ffn_forward_grouped(
|
||||
sample_mod = expert_impls[0]
|
||||
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
|
||||
w_gate = storage.gate
|
||||
w_up = storage.up
|
||||
w2 = storage.down
|
||||
w_gate_up = storage.fused_gate_up
|
||||
|
||||
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
||||
router_logits = gate_linear(x_flat.to(routing_dtype))
|
||||
@@ -296,33 +296,30 @@ def moe_ffn_forward_grouped(
|
||||
token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous()
|
||||
scores_sorted = topk_weight.reshape(-1).index_select(0, perm)
|
||||
|
||||
routed_input = x_flat.index_select(0, token_indices_sorted).contiguous()
|
||||
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
|
||||
routed_input = torch.gather(x_flat, 0, gather_index)
|
||||
|
||||
active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1).contiguous()
|
||||
if active_idx.numel() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
counts_active = assignments[active_idx]
|
||||
counts_active_i32 = counts_active.to(device=device, dtype=torch.int32)
|
||||
offsets = torch.cumsum(counts_active_i32, dim=0)
|
||||
counts_i32 = assignments.to(device=device, dtype=torch.int32)
|
||||
offsets = torch.cumsum(counts_i32, dim=0)
|
||||
if offsets[-1].item() == 0:
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
w_gate_up_t = w_gate_up.index_select(0, active_idx).transpose(-2, -1)
|
||||
w2_t = w2.index_select(0, active_idx).transpose(-2, -1)
|
||||
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
|
||||
routed_in = routed_input.to(mm_dtype)
|
||||
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
|
||||
w_up_t = w_up.transpose(-2, -1).to(mm_dtype)
|
||||
w2_t = w2.transpose(-2, -1).to(mm_dtype)
|
||||
|
||||
routed_in = routed_input.to(expert_dtype)
|
||||
gate_up_out = torch._grouped_mm(routed_in, w_gate_up_t, offs=offsets)
|
||||
inter_dim = w_gate.shape[1]
|
||||
gate_out = torch.ops.aten.silu_(gate_up_out[..., :inter_dim])
|
||||
gate_out.mul_(gate_up_out[..., inter_dim:])
|
||||
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets)
|
||||
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
||||
torch.ops.aten.silu_(gate_out)
|
||||
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
|
||||
gate_out.mul_(up_out)
|
||||
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype)
|
||||
|
||||
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
|
||||
down_out.mul_(weights)
|
||||
|
||||
combined = torch.zeros_like(x_flat)
|
||||
combined.index_add_(0, token_indices_sorted, down_out)
|
||||
combined.scatter_add_(0, gather_index, down_out)
|
||||
return combined.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
Reference in New Issue
Block a user