diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index c6d671241..4e2c6c196 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -295,13 +295,13 @@ def moe_ffn_forward_grouped( routed_in = routed_input.to(expert_dtype) gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets) - up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets) - activated = F.silu(gate_out) * up_out - down_out = torch._grouped_mm(activated, w2_t, offs=offsets) + gate_out = torch.ops.aten.silu_(gate_out) + gate_out.mul_(torch._grouped_mm(routed_in, w_up_t, offs=offsets)) + down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets) - weights_fp32 = scores_sorted.unsqueeze(-1).to(torch.float32) - weighted = (down_out.to(torch.float32) * weights_fp32).to(expert_dtype) + weights = scores_sorted.unsqueeze(-1).to(expert_dtype) + down_out.mul_(weights) combined = torch.zeros_like(x_flat) - combined.scatter_add_(0, gather_index, weighted) + combined.scatter_add_(0, gather_index, down_out) return combined.view(bsz, seqlen, hdim), router_logits