This commit is contained in:
Dan Saunders
2025-09-15 19:04:58 -04:00
parent cfefad1eea
commit 5b19a1ea9c

View File

@@ -122,29 +122,28 @@ def moe_ffn_forward_stub(
) )
W1 = experts_module._stacked_w1 W1 = experts_module._stacked_w1
W3 = experts_module._stacked_w3 W3 = experts_module._stacked_w3
# compute gathered inputs X_g according to gather_idx via matmul_ogs gather # Fused up+gate: single matmul on concatenated weights [E, H, 2I]
# First matmul for w1: gather happens inside kernel using gather_indx W13 = getattr(experts_module, "_stacked_w13", None)
Y1 = handles.matmul_ogs.matmul_ogs( if (
W13 is None
or W13.device != dev
or W13.dtype != dt
or W13.shape[-1] != (W1.shape[-1] + W3.shape[-1])
):
W13 = torch.cat([W1, W3], dim=-1).contiguous()
experts_module._stacked_w13 = W13
Y13 = handles.matmul_ogs.matmul_ogs(
flat, flat,
W1, W13,
None, None,
routing_data=routing_data, routing_data=routing_data,
gather_indx=gather_idx, gather_indx=gather_idx,
scatter_indx=None, scatter_indx=None,
precision_config=handles.matmul_ogs.PrecisionConfig(), precision_config=handles.matmul_ogs.PrecisionConfig(),
) )
# Second matmul for w3 on the same gathered order # Use kernels hub SwiGLU for optimal MoE launch
Y3 = handles.matmul_ogs.matmul_ogs( sw_pc = handles.swiglu.PrecisionConfig(limit=1.0)
flat, Hidden = handles.swiglu.swiglu(Y13, 1.0, sw_pc, routing_data)
W3,
None,
routing_data=routing_data,
gather_indx=gather_idx,
scatter_indx=None,
precision_config=handles.matmul_ogs.PrecisionConfig(),
)
# SwiGLU: silu(Y1) * Y3
Hidden = F.silu(Y1) * Y3
# Down projection weights [E, inter, hidden] # Down projection weights [E, inter, hidden]
W2 = experts_module._stacked_w2 W2 = experts_module._stacked_w2
# Down matmul with fused scatter back using scatter_indx # Down matmul with fused scatter back using scatter_indx