fix
This commit is contained in:
@@ -71,7 +71,7 @@ def _call_grouped_mm(
|
|||||||
lengths = torch.tensor(
|
lengths = torch.tensor(
|
||||||
[a.shape[0] for a in As2], device=device, dtype=torch.int32
|
[a.shape[0] for a in As2], device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
offsets = torch.cumsum(lengths, dim=0)
|
offsets = torch.cumsum(lengths, dim=0).to(torch.int32)
|
||||||
Y_cat = torch._grouped_mm(
|
Y_cat = torch._grouped_mm(
|
||||||
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets
|
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user