This commit is contained in:
Dan Saunders
2025-09-15 19:36:00 -04:00
parent 5c2229721d
commit 556d6448fe

View File

@@ -43,14 +43,11 @@ def _call_grouped_mm(
if hasattr(torch.ops.aten, "_grouped_mm"):
try:
return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined]
# Some builds expect tuples rather than lists
return torch.ops.aten._grouped_mm(tuple(As), tuple(Bs)) # type: ignore[attr-defined]
except Exception as e:
LAST_ERROR = f"_grouped_mm failed: {e}"
if hasattr(torch.ops.aten, "_scaled_grouped_mm"):
try:
return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined]
except Exception as e:
LAST_ERROR = f"_scaled_grouped_mm failed: {e}"
# Avoid _scaled_grouped_mm for now; its signature requires packed inputs.
except Exception as e:
LAST_ERROR = f"call error: {e}"
return None