improve
This commit is contained in:
@@ -122,29 +122,28 @@ def moe_ffn_forward_stub(
|
||||
)
|
||||
W1 = experts_module._stacked_w1
|
||||
W3 = experts_module._stacked_w3
|
||||
# compute gathered inputs X_g according to gather_idx via matmul_ogs gather
|
||||
# First matmul for w1: gather happens inside kernel using gather_indx
|
||||
Y1 = handles.matmul_ogs.matmul_ogs(
|
||||
# Fused up+gate: single matmul on concatenated weights [E, H, 2I]
|
||||
W13 = getattr(experts_module, "_stacked_w13", None)
|
||||
if (
|
||||
W13 is None
|
||||
or W13.device != dev
|
||||
or W13.dtype != dt
|
||||
or W13.shape[-1] != (W1.shape[-1] + W3.shape[-1])
|
||||
):
|
||||
W13 = torch.cat([W1, W3], dim=-1).contiguous()
|
||||
experts_module._stacked_w13 = W13
|
||||
Y13 = handles.matmul_ogs.matmul_ogs(
|
||||
flat,
|
||||
W1,
|
||||
W13,
|
||||
None,
|
||||
routing_data=routing_data,
|
||||
gather_indx=gather_idx,
|
||||
scatter_indx=None,
|
||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||
)
|
||||
# Second matmul for w3 on the same gathered order
|
||||
Y3 = handles.matmul_ogs.matmul_ogs(
|
||||
flat,
|
||||
W3,
|
||||
None,
|
||||
routing_data=routing_data,
|
||||
gather_indx=gather_idx,
|
||||
scatter_indx=None,
|
||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||
)
|
||||
# SwiGLU: silu(Y1) * Y3
|
||||
Hidden = F.silu(Y1) * Y3
|
||||
# Use kernels hub SwiGLU for optimal MoE launch
|
||||
sw_pc = handles.swiglu.PrecisionConfig(limit=1.0)
|
||||
Hidden = handles.swiglu.swiglu(Y13, 1.0, sw_pc, routing_data)
|
||||
# Down projection weights [E, inter, hidden]
|
||||
W2 = experts_module._stacked_w2
|
||||
# Down matmul with fused scatter back using scatter_indx
|
||||
|
||||
Reference in New Issue
Block a user