improve
This commit is contained in:
@@ -122,29 +122,28 @@ def moe_ffn_forward_stub(
|
|||||||
)
|
)
|
||||||
W1 = experts_module._stacked_w1
|
W1 = experts_module._stacked_w1
|
||||||
W3 = experts_module._stacked_w3
|
W3 = experts_module._stacked_w3
|
||||||
# compute gathered inputs X_g according to gather_idx via matmul_ogs gather
|
# Fused up+gate: single matmul on concatenated weights [E, H, 2I]
|
||||||
# First matmul for w1: gather happens inside kernel using gather_indx
|
W13 = getattr(experts_module, "_stacked_w13", None)
|
||||||
Y1 = handles.matmul_ogs.matmul_ogs(
|
if (
|
||||||
|
W13 is None
|
||||||
|
or W13.device != dev
|
||||||
|
or W13.dtype != dt
|
||||||
|
or W13.shape[-1] != (W1.shape[-1] + W3.shape[-1])
|
||||||
|
):
|
||||||
|
W13 = torch.cat([W1, W3], dim=-1).contiguous()
|
||||||
|
experts_module._stacked_w13 = W13
|
||||||
|
Y13 = handles.matmul_ogs.matmul_ogs(
|
||||||
flat,
|
flat,
|
||||||
W1,
|
W13,
|
||||||
None,
|
None,
|
||||||
routing_data=routing_data,
|
routing_data=routing_data,
|
||||||
gather_indx=gather_idx,
|
gather_indx=gather_idx,
|
||||||
scatter_indx=None,
|
scatter_indx=None,
|
||||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
||||||
)
|
)
|
||||||
# Second matmul for w3 on the same gathered order
|
# Use kernels hub SwiGLU for optimal MoE launch
|
||||||
Y3 = handles.matmul_ogs.matmul_ogs(
|
sw_pc = handles.swiglu.PrecisionConfig(limit=1.0)
|
||||||
flat,
|
Hidden = handles.swiglu.swiglu(Y13, 1.0, sw_pc, routing_data)
|
||||||
W3,
|
|
||||||
None,
|
|
||||||
routing_data=routing_data,
|
|
||||||
gather_indx=gather_idx,
|
|
||||||
scatter_indx=None,
|
|
||||||
precision_config=handles.matmul_ogs.PrecisionConfig(),
|
|
||||||
)
|
|
||||||
# SwiGLU: silu(Y1) * Y3
|
|
||||||
Hidden = F.silu(Y1) * Y3
|
|
||||||
# Down projection weights [E, inter, hidden]
|
# Down projection weights [E, inter, hidden]
|
||||||
W2 = experts_module._stacked_w2
|
W2 = experts_module._stacked_w2
|
||||||
# Down matmul with fused scatter back using scatter_indx
|
# Down matmul with fused scatter back using scatter_indx
|
||||||
|
|||||||
Reference in New Issue
Block a user