This commit is contained in:
Dan Saunders
2025-09-15 19:39:19 -04:00
parent 556d6448fe
commit f6ed8ddc01

View File

@@ -32,26 +32,43 @@ def _call_grouped_mm(
As: List[torch.Tensor], Bs: List[torch.Tensor]
) -> Optional[List[torch.Tensor]]:
"""
Try calling the appropriate grouped mm op available in this torch build.
Returns list of outputs or None on failure.
Call grouped mm using aten._grouped_mm with packed representation.
- A_cat: concat As along rows -> [sum_i Mi, K]
- B_stk: stack Bs per group -> [G, K, N]
- offs: lengths per group Mi -> [G] int32
Returns list of per-group outputs split from concatenated result.
"""
global LAST_ERROR
try:
# Ensure 2D contiguous inputs
As = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As]
Bs = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
As2 = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As]
Bs2 = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
if not As2:
return []
device = As2[0].device
A_cat = torch.cat(As2, dim=0)
B_stk = torch.stack(Bs2, dim=0)
offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
if hasattr(torch.ops.aten, "_grouped_mm"):
try:
# Some builds expect tuples rather than lists
return torch.ops.aten._grouped_mm(tuple(As), tuple(Bs)) # type: ignore[attr-defined]
Y_cat = torch.ops.aten._grouped_mm(A_cat, B_stk, offs) # type: ignore[attr-defined]
outs: List[torch.Tensor] = []
start = 0
for m in offs.tolist():
outs.append(Y_cat[start : start + m, :])
start += m
return outs
except Exception as e:
LAST_ERROR = f"_grouped_mm failed: {e}"
# Avoid _scaled_grouped_mm for now; its signature requires packed inputs.
return None
LAST_ERROR = "aten._grouped_mm not present"
return None
except Exception as e:
LAST_ERROR = f"call error: {e}"
return None
return None
def moe_ffn_forward_grouped(