From eaaf16aa0037a742438d868f79c0a3815147c357 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 17 Sep 2025 18:45:15 -0400 Subject: [PATCH] cumulative offsets --- src/axolotl/kernels/moe/torch_grouped.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index c442a0e9e..44447d83e 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -64,15 +64,18 @@ def _call_grouped_mm( As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As] Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs] device = As2[0].device - offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32) + lengths = torch.tensor( + [a.shape[0] for a in As2], device=device, dtype=torch.int32 + ) + offsets = torch.cumsum(lengths, dim=0) Y_cat = torch.ops.aten._grouped_mm( - torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offs + torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets ) outs: List[torch.Tensor] = [] start = 0 - for m in offs.tolist(): - outs.append(Y_cat[start : start + m]) - start += m + for size in lengths.tolist(): + outs.append(Y_cat[start : start + size]) + start += size return outs except RuntimeError as err: _LOGGER.warning("torch_grouped: _grouped_mm failed (%s)", err)