cumulative offsets

This commit is contained in:
Dan Saunders
2025-09-17 18:45:15 -04:00
parent f3b953e222
commit eaaf16aa00

View File

@@ -64,15 +64,18 @@ def _call_grouped_mm(
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
device = As2[0].device
offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
lengths = torch.tensor(
[a.shape[0] for a in As2], device=device, dtype=torch.int32
)
offsets = torch.cumsum(lengths, dim=0)
Y_cat = torch.ops.aten._grouped_mm(
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offs
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets
)
outs: List[torch.Tensor] = []
start = 0
for m in offs.tolist():
outs.append(Y_cat[start : start + m])
start += m
for size in lengths.tolist():
outs.append(Y_cat[start : start + size])
start += size
return outs
except RuntimeError as err:
_LOGGER.warning("torch_grouped: _grouped_mm failed (%s)", err)