From 0f8b92139910795d0176fd8c6f90da23456683cf Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 19 Sep 2025 12:47:53 -0400 Subject: [PATCH] contig --- src/axolotl/kernels/moe/torch_grouped.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index baceb31d4..3ea7b9045 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -344,10 +344,15 @@ def moe_ffn_forward_grouped( w_up_t = w_up.transpose(-2, -1).to(mm_dtype) w2_t = w2.transpose(-2, -1).to(mm_dtype) + routed_in = routed_in.contiguous() + w_gate_t = w_gate_t.contiguous() gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets) torch.ops.aten.silu_(gate_out) + w_up_t = w_up_t.contiguous() up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets) gate_out.mul_(up_out) + gate_out = gate_out.contiguous() + w2_t = w2_t.contiguous() down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype) weights = scores_sorted.unsqueeze(-1).to(expert_dtype)