fix
This commit is contained in:
@@ -32,26 +32,43 @@ def _call_grouped_mm(
|
||||
As: List[torch.Tensor], Bs: List[torch.Tensor]
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
"""
|
||||
Try calling the appropriate grouped mm op available in this torch build.
|
||||
Returns list of outputs or None on failure.
|
||||
Call grouped mm using aten._grouped_mm with packed representation.
|
||||
|
||||
- A_cat: concat As along rows -> [sum_i Mi, K]
|
||||
- B_stk: stack Bs per group -> [G, K, N]
|
||||
- offs: lengths per group Mi -> [G] int32
|
||||
Returns list of per-group outputs split from concatenated result.
|
||||
"""
|
||||
global LAST_ERROR
|
||||
try:
|
||||
# Ensure 2D contiguous inputs
|
||||
As = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
||||
Bs = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
||||
As2 = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
||||
Bs2 = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
||||
|
||||
if not As2:
|
||||
return []
|
||||
device = As2[0].device
|
||||
A_cat = torch.cat(As2, dim=0)
|
||||
B_stk = torch.stack(Bs2, dim=0)
|
||||
offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
||||
|
||||
if hasattr(torch.ops.aten, "_grouped_mm"):
|
||||
try:
|
||||
# Some builds expect tuples rather than lists
|
||||
return torch.ops.aten._grouped_mm(tuple(As), tuple(Bs)) # type: ignore[attr-defined]
|
||||
Y_cat = torch.ops.aten._grouped_mm(A_cat, B_stk, offs) # type: ignore[attr-defined]
|
||||
outs: List[torch.Tensor] = []
|
||||
start = 0
|
||||
for m in offs.tolist():
|
||||
outs.append(Y_cat[start : start + m, :])
|
||||
start += m
|
||||
return outs
|
||||
except Exception as e:
|
||||
LAST_ERROR = f"_grouped_mm failed: {e}"
|
||||
# Avoid _scaled_grouped_mm for now; its signature requires packed inputs.
|
||||
return None
|
||||
LAST_ERROR = "aten._grouped_mm not present"
|
||||
return None
|
||||
except Exception as e:
|
||||
LAST_ERROR = f"call error: {e}"
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def moe_ffn_forward_grouped(
|
||||
|
||||
Reference in New Issue
Block a user