yet another refactor

This commit is contained in:
Dan Saunders
2025-09-18 13:03:28 -04:00
parent 7500641601
commit efcd032fce

View File

@@ -271,8 +271,8 @@ def moe_ffn_forward_grouped(
sample_mod = expert_impls[0]
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
w_gate = storage.gate
w_up = storage.up
w2 = storage.down
w_gate_up = storage.fused_gate_up
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
router_logits = gate_linear(x_flat.to(routing_dtype))
@@ -296,33 +296,30 @@ def moe_ffn_forward_grouped(
token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous()
scores_sorted = topk_weight.reshape(-1).index_select(0, perm)
routed_input = x_flat.index_select(0, token_indices_sorted).contiguous()
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
routed_input = torch.gather(x_flat, 0, gather_index)
active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1).contiguous()
if active_idx.numel() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
counts_active = assignments[active_idx]
counts_active_i32 = counts_active.to(device=device, dtype=torch.int32)
offsets = torch.cumsum(counts_active_i32, dim=0)
counts_i32 = assignments.to(device=device, dtype=torch.int32)
offsets = torch.cumsum(counts_i32, dim=0)
if offsets[-1].item() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
w_gate_up_t = w_gate_up.index_select(0, active_idx).transpose(-2, -1)
w2_t = w2.index_select(0, active_idx).transpose(-2, -1)
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
routed_in = routed_input.to(mm_dtype)
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
w_up_t = w_up.transpose(-2, -1).to(mm_dtype)
w2_t = w2.transpose(-2, -1).to(mm_dtype)
routed_in = routed_input.to(expert_dtype)
gate_up_out = torch._grouped_mm(routed_in, w_gate_up_t, offs=offsets)
inter_dim = w_gate.shape[1]
gate_out = torch.ops.aten.silu_(gate_up_out[..., :inter_dim])
gate_out.mul_(gate_up_out[..., inter_dim:])
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets)
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
torch.ops.aten.silu_(gate_out)
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
gate_out.mul_(up_out)
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype)
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
down_out.mul_(weights)
combined = torch.zeros_like(x_flat)
combined.index_add_(0, token_indices_sorted, down_out)
combined.scatter_add_(0, gather_index, down_out)
return combined.view(bsz, seqlen, hdim), router_logits