This commit is contained in:
Dan Saunders
2025-09-17 18:53:07 -04:00
parent d57b9c67c2
commit e62979d11d

View File

@@ -71,7 +71,7 @@ def _call_grouped_mm(
lengths = torch.tensor( lengths = torch.tensor(
[a.shape[0] for a in As2], device=device, dtype=torch.int32 [a.shape[0] for a in As2], device=device, dtype=torch.int32
) )
offsets = torch.cumsum(lengths, dim=0) offsets = torch.cumsum(lengths, dim=0).to(torch.int32)
Y_cat = torch._grouped_mm( Y_cat = torch._grouped_mm(
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets
) )