cumulative offsets
This commit is contained in:
@@ -64,15 +64,18 @@ def _call_grouped_mm(
|
|||||||
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
||||||
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
||||||
device = As2[0].device
|
device = As2[0].device
|
||||||
offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
lengths = torch.tensor(
|
||||||
|
[a.shape[0] for a in As2], device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
offsets = torch.cumsum(lengths, dim=0)
|
||||||
Y_cat = torch.ops.aten._grouped_mm(
|
Y_cat = torch.ops.aten._grouped_mm(
|
||||||
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offs
|
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets
|
||||||
)
|
)
|
||||||
outs: List[torch.Tensor] = []
|
outs: List[torch.Tensor] = []
|
||||||
start = 0
|
start = 0
|
||||||
for m in offs.tolist():
|
for size in lengths.tolist():
|
||||||
outs.append(Y_cat[start : start + m])
|
outs.append(Y_cat[start : start + size])
|
||||||
start += m
|
start += size
|
||||||
return outs
|
return outs
|
||||||
except RuntimeError as err:
|
except RuntimeError as err:
|
||||||
_LOGGER.warning("torch_grouped: _grouped_mm failed (%s)", err)
|
_LOGGER.warning("torch_grouped: _grouped_mm failed (%s)", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user