This commit is contained in:
Dan Saunders
2025-09-17 18:53:07 -04:00
parent d57b9c67c2
commit e62979d11d

View File

@@ -71,7 +71,7 @@ def _call_grouped_mm(
lengths = torch.tensor(
[a.shape[0] for a in As2], device=device, dtype=torch.int32
)
offsets = torch.cumsum(lengths, dim=0)
offsets = torch.cumsum(lengths, dim=0).to(torch.int32)
Y_cat = torch._grouped_mm(
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets
)