fix
This commit is contained in:
@@ -43,14 +43,11 @@ def _call_grouped_mm(
|
|||||||
|
|
||||||
if hasattr(torch.ops.aten, "_grouped_mm"):
|
if hasattr(torch.ops.aten, "_grouped_mm"):
|
||||||
try:
|
try:
|
||||||
return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined]
|
# Some builds expect tuples rather than lists
|
||||||
|
return torch.ops.aten._grouped_mm(tuple(As), tuple(Bs)) # type: ignore[attr-defined]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LAST_ERROR = f"_grouped_mm failed: {e}"
|
LAST_ERROR = f"_grouped_mm failed: {e}"
|
||||||
if hasattr(torch.ops.aten, "_scaled_grouped_mm"):
|
# Avoid _scaled_grouped_mm for now; its signature requires packed inputs.
|
||||||
try:
|
|
||||||
return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined]
|
|
||||||
except Exception as e:
|
|
||||||
LAST_ERROR = f"_scaled_grouped_mm failed: {e}"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LAST_ERROR = f"call error: {e}"
|
LAST_ERROR = f"call error: {e}"
|
||||||
return None
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user