fix
This commit is contained in:
@@ -71,7 +71,7 @@ def _call_grouped_mm(
|
||||
lengths = torch.tensor(
|
||||
[a.shape[0] for a in As2], device=device, dtype=torch.int32
|
||||
)
|
||||
offsets = torch.cumsum(lengths, dim=0)
|
||||
offsets = torch.cumsum(lengths, dim=0).to(torch.int32)
|
||||
Y_cat = torch._grouped_mm(
|
||||
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user