This commit is contained in:
Dan Saunders
2025-09-15 19:04:58 -04:00
parent cfefad1eea
commit 5b19a1ea9c

View File

@@ -122,29 +122,28 @@ def moe_ffn_forward_stub(
)
W1 = experts_module._stacked_w1
W3 = experts_module._stacked_w3
# compute gathered inputs X_g according to gather_idx via matmul_ogs gather
# First matmul for w1: gather happens inside kernel using gather_indx
Y1 = handles.matmul_ogs.matmul_ogs(
# Fused up+gate: single matmul on concatenated weights [E, H, 2I]
W13 = getattr(experts_module, "_stacked_w13", None)
if (
W13 is None
or W13.device != dev
or W13.dtype != dt
or W13.shape[-1] != (W1.shape[-1] + W3.shape[-1])
):
W13 = torch.cat([W1, W3], dim=-1).contiguous()
experts_module._stacked_w13 = W13
Y13 = handles.matmul_ogs.matmul_ogs(
flat,
W1,
W13,
None,
routing_data=routing_data,
gather_indx=gather_idx,
scatter_indx=None,
precision_config=handles.matmul_ogs.PrecisionConfig(),
)
# Second matmul for w3 on the same gathered order
Y3 = handles.matmul_ogs.matmul_ogs(
flat,
W3,
None,
routing_data=routing_data,
gather_indx=gather_idx,
scatter_indx=None,
precision_config=handles.matmul_ogs.PrecisionConfig(),
)
# SwiGLU: silu(Y1) * Y3
Hidden = F.silu(Y1) * Y3
# Use kernels hub SwiGLU for optimal MoE launch
sw_pc = handles.swiglu.PrecisionConfig(limit=1.0)
Hidden = handles.swiglu.swiglu(Y13, 1.0, sw_pc, routing_data)
# Down projection weights [E, inter, hidden]
W2 = experts_module._stacked_w2
# Down matmul with fused scatter back using scatter_indx