yet another refactor

This commit is contained in:
Dan Saunders
2025-09-18 13:03:28 -04:00
parent 7500641601
commit efcd032fce

View File

@@ -271,8 +271,8 @@ def moe_ffn_forward_grouped(
sample_mod = expert_impls[0] sample_mod = expert_impls[0]
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod) storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
w_gate = storage.gate w_gate = storage.gate
w_up = storage.up
w2 = storage.down w2 = storage.down
w_gate_up = storage.fused_gate_up
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype) x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
router_logits = gate_linear(x_flat.to(routing_dtype)) router_logits = gate_linear(x_flat.to(routing_dtype))
@@ -296,33 +296,30 @@ def moe_ffn_forward_grouped(
token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous() token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous()
scores_sorted = topk_weight.reshape(-1).index_select(0, perm) scores_sorted = topk_weight.reshape(-1).index_select(0, perm)
routed_input = x_flat.index_select(0, token_indices_sorted).contiguous() gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
routed_input = torch.gather(x_flat, 0, gather_index)
active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1).contiguous() counts_i32 = assignments.to(device=device, dtype=torch.int32)
if active_idx.numel() == 0: offsets = torch.cumsum(counts_i32, dim=0)
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
counts_active = assignments[active_idx]
counts_active_i32 = counts_active.to(device=device, dtype=torch.int32)
offsets = torch.cumsum(counts_active_i32, dim=0)
if offsets[-1].item() == 0: if offsets[-1].item() == 0:
zero = torch.zeros_like(x_flat) zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits return zero.view(bsz, seqlen, hdim), router_logits
w_gate_up_t = w_gate_up.index_select(0, active_idx).transpose(-2, -1) mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
w2_t = w2.index_select(0, active_idx).transpose(-2, -1) routed_in = routed_input.to(mm_dtype)
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
w_up_t = w_up.transpose(-2, -1).to(mm_dtype)
w2_t = w2.transpose(-2, -1).to(mm_dtype)
routed_in = routed_input.to(expert_dtype) gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
gate_up_out = torch._grouped_mm(routed_in, w_gate_up_t, offs=offsets) torch.ops.aten.silu_(gate_out)
inter_dim = w_gate.shape[1] up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
gate_out = torch.ops.aten.silu_(gate_up_out[..., :inter_dim]) gate_out.mul_(up_out)
gate_out.mul_(gate_up_out[..., inter_dim:]) down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype)
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets)
weights = scores_sorted.unsqueeze(-1).to(expert_dtype) weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
down_out.mul_(weights) down_out.mul_(weights)
combined = torch.zeros_like(x_flat) combined = torch.zeros_like(x_flat)
combined.index_add_(0, token_indices_sorted, down_out) combined.scatter_add_(0, gather_index, down_out)
return combined.view(bsz, seqlen, hdim), router_logits return combined.view(bsz, seqlen, hdim), router_logits