yet another refactor
This commit is contained in:
@@ -271,8 +271,8 @@ def moe_ffn_forward_grouped(
|
|||||||
sample_mod = expert_impls[0]
|
sample_mod = expert_impls[0]
|
||||||
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
|
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
|
||||||
w_gate = storage.gate
|
w_gate = storage.gate
|
||||||
|
w_up = storage.up
|
||||||
w2 = storage.down
|
w2 = storage.down
|
||||||
w_gate_up = storage.fused_gate_up
|
|
||||||
|
|
||||||
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
||||||
router_logits = gate_linear(x_flat.to(routing_dtype))
|
router_logits = gate_linear(x_flat.to(routing_dtype))
|
||||||
@@ -296,33 +296,30 @@ def moe_ffn_forward_grouped(
|
|||||||
token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous()
|
token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous()
|
||||||
scores_sorted = topk_weight.reshape(-1).index_select(0, perm)
|
scores_sorted = topk_weight.reshape(-1).index_select(0, perm)
|
||||||
|
|
||||||
routed_input = x_flat.index_select(0, token_indices_sorted).contiguous()
|
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
|
||||||
|
routed_input = torch.gather(x_flat, 0, gather_index)
|
||||||
|
|
||||||
active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1).contiguous()
|
counts_i32 = assignments.to(device=device, dtype=torch.int32)
|
||||||
if active_idx.numel() == 0:
|
offsets = torch.cumsum(counts_i32, dim=0)
|
||||||
zero = torch.zeros_like(x_flat)
|
|
||||||
return zero.view(bsz, seqlen, hdim), router_logits
|
|
||||||
|
|
||||||
counts_active = assignments[active_idx]
|
|
||||||
counts_active_i32 = counts_active.to(device=device, dtype=torch.int32)
|
|
||||||
offsets = torch.cumsum(counts_active_i32, dim=0)
|
|
||||||
if offsets[-1].item() == 0:
|
if offsets[-1].item() == 0:
|
||||||
zero = torch.zeros_like(x_flat)
|
zero = torch.zeros_like(x_flat)
|
||||||
return zero.view(bsz, seqlen, hdim), router_logits
|
return zero.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|
||||||
w_gate_up_t = w_gate_up.index_select(0, active_idx).transpose(-2, -1)
|
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
|
||||||
w2_t = w2.index_select(0, active_idx).transpose(-2, -1)
|
routed_in = routed_input.to(mm_dtype)
|
||||||
|
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
|
||||||
|
w_up_t = w_up.transpose(-2, -1).to(mm_dtype)
|
||||||
|
w2_t = w2.transpose(-2, -1).to(mm_dtype)
|
||||||
|
|
||||||
routed_in = routed_input.to(expert_dtype)
|
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
||||||
gate_up_out = torch._grouped_mm(routed_in, w_gate_up_t, offs=offsets)
|
torch.ops.aten.silu_(gate_out)
|
||||||
inter_dim = w_gate.shape[1]
|
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
|
||||||
gate_out = torch.ops.aten.silu_(gate_up_out[..., :inter_dim])
|
gate_out.mul_(up_out)
|
||||||
gate_out.mul_(gate_up_out[..., inter_dim:])
|
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype)
|
||||||
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets)
|
|
||||||
|
|
||||||
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
|
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
|
||||||
down_out.mul_(weights)
|
down_out.mul_(weights)
|
||||||
|
|
||||||
combined = torch.zeros_like(x_flat)
|
combined = torch.zeros_like(x_flat)
|
||||||
combined.index_add_(0, token_indices_sorted, down_out)
|
combined.scatter_add_(0, gather_index, down_out)
|
||||||
return combined.view(bsz, seqlen, hdim), router_logits
|
return combined.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|||||||
Reference in New Issue
Block a user