diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index ac9d19895..c2a33e455 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -294,9 +294,12 @@ def moe_ffn_forward_grouped( w2_t = w2[active_idx].transpose(-2, -1) routed_in = routed_input.to(expert_dtype) - gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets) - gate_out = torch.ops.aten.silu_(gate_out) - gate_out.mul_(torch._grouped_mm(routed_in, w_up_t, offs=offsets)) + gate_up_out = torch._grouped_mm( + routed_in, torch.cat((w_gate_t, w_up_t), dim=-1), offs=offsets + ) + inter_dim = w_gate_t.shape[-1] + gate_out = torch.ops.aten.silu_(gate_up_out[..., :inter_dim]) + gate_out.mul_(gate_up_out[..., inter_dim:]) down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets) weights = scores_sorted.unsqueeze(-1).to(expert_dtype)