diff --git a/src/axolotl/kernels/moe/hf_triton.py b/src/axolotl/kernels/moe/hf_triton.py index ec42e81e4..1fb3d084d 100644 --- a/src/axolotl/kernels/moe/hf_triton.py +++ b/src/axolotl/kernels/moe/hf_triton.py @@ -122,29 +122,28 @@ def moe_ffn_forward_stub( ) W1 = experts_module._stacked_w1 W3 = experts_module._stacked_w3 - # compute gathered inputs X_g according to gather_idx via matmul_ogs gather - # First matmul for w1: gather happens inside kernel using gather_indx - Y1 = handles.matmul_ogs.matmul_ogs( + # Fused up+gate: single matmul on concatenated weights [E, H, 2I] + W13 = getattr(experts_module, "_stacked_w13", None) + if ( + W13 is None + or W13.device != dev + or W13.dtype != dt + or W13.shape[-1] != (W1.shape[-1] + W3.shape[-1]) + ): + W13 = torch.cat([W1, W3], dim=-1).contiguous() + experts_module._stacked_w13 = W13 + Y13 = handles.matmul_ogs.matmul_ogs( flat, - W1, + W13, None, routing_data=routing_data, gather_indx=gather_idx, scatter_indx=None, precision_config=handles.matmul_ogs.PrecisionConfig(), ) - # Second matmul for w3 on the same gathered order - Y3 = handles.matmul_ogs.matmul_ogs( - flat, - W3, - None, - routing_data=routing_data, - gather_indx=gather_idx, - scatter_indx=None, - precision_config=handles.matmul_ogs.PrecisionConfig(), - ) - # SwiGLU: silu(Y1) * Y3 - Hidden = F.silu(Y1) * Y3 + # Use kernels hub SwiGLU for optimal MoE launch + sw_pc = handles.swiglu.PrecisionConfig(limit=1.0) + Hidden = handles.swiglu.swiglu(Y13, 1.0, sw_pc, routing_data) # Down projection weights [E, inter, hidden] W2 = experts_module._stacked_w2 # Down matmul with fused scatter back using scatter_indx