fix
This commit is contained in:
@@ -43,14 +43,11 @@ def _call_grouped_mm(
|
||||
|
||||
if hasattr(torch.ops.aten, "_grouped_mm"):
|
||||
try:
|
||||
return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined]
|
||||
# Some builds expect tuples rather than lists
|
||||
return torch.ops.aten._grouped_mm(tuple(As), tuple(Bs)) # type: ignore[attr-defined]
|
||||
except Exception as e:
|
||||
LAST_ERROR = f"_grouped_mm failed: {e}"
|
||||
if hasattr(torch.ops.aten, "_scaled_grouped_mm"):
|
||||
try:
|
||||
return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined]
|
||||
except Exception as e:
|
||||
LAST_ERROR = f"_scaled_grouped_mm failed: {e}"
|
||||
# Avoid _scaled_grouped_mm for now; its signature requires packed inputs.
|
||||
except Exception as e:
|
||||
LAST_ERROR = f"call error: {e}"
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user