This commit is contained in:
Dan Saunders
2025-09-15 19:34:08 -04:00
parent d7de6b0e96
commit 5c2229721d
2 changed files with 24 additions and 5 deletions

View File

@@ -210,7 +210,12 @@ def main():
f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
)
else:
print("torch_grouped\tN/A (op not callable)")
try:
from axolotl.kernels.moe.torch_grouped import LAST_ERROR as _TG_ERR
except Exception:
_TG_ERR = None
reason = f" (reason: {_TG_ERR})" if _TG_ERR else ""
print(f"torch_grouped\tN/A (op not callable){reason}")
else:
print("torch_grouped\tN/A (unavailable)")

View File

@@ -25,6 +25,9 @@ def available() -> bool:
return False
LAST_ERROR: Optional[str] = None
def _call_grouped_mm(
As: List[torch.Tensor], Bs: List[torch.Tensor]
) -> Optional[List[torch.Tensor]]:
@@ -32,13 +35,24 @@ def _call_grouped_mm(
Try calling the appropriate grouped mm op available in this torch build.
Returns list of outputs or None on failure.
"""
global LAST_ERROR
try:
# Ensure 2D contiguous inputs
As = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As]
Bs = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
if hasattr(torch.ops.aten, "_grouped_mm"):
return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined]
try:
return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined]
except Exception as e:
LAST_ERROR = f"_grouped_mm failed: {e}"
if hasattr(torch.ops.aten, "_scaled_grouped_mm"):
# signature likely (As, Bs, alpha, beta)
return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined]
except Exception:
try:
return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined]
except Exception as e:
LAST_ERROR = f"_scaled_grouped_mm failed: {e}"
except Exception as e:
LAST_ERROR = f"call error: {e}"
return None
return None