From e62979d11dcbf40b0e093cd785070492089e5d03 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 17 Sep 2025 18:53:07 -0400 Subject: [PATCH] fix --- src/axolotl/kernels/moe/torch_grouped.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index b0f308497..50f1c4d8b 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -71,7 +71,7 @@ def _call_grouped_mm( lengths = torch.tensor( [a.shape[0] for a in As2], device=device, dtype=torch.int32 ) - offsets = torch.cumsum(lengths, dim=0) + offsets = torch.cumsum(lengths, dim=0).to(torch.int32) Y_cat = torch._grouped_mm( torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets )