This commit is contained in:
Dan Saunders
2025-09-15 19:34:08 -04:00
parent d7de6b0e96
commit 5c2229721d
2 changed files with 24 additions and 5 deletions

View File

@@ -210,7 +210,12 @@ def main():
f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×" f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
) )
else: else:
print("torch_grouped\tN/A (op not callable)") try:
from axolotl.kernels.moe.torch_grouped import LAST_ERROR as _TG_ERR
except Exception:
_TG_ERR = None
reason = f" (reason: {_TG_ERR})" if _TG_ERR else ""
print(f"torch_grouped\tN/A (op not callable){reason}")
else: else:
print("torch_grouped\tN/A (unavailable)") print("torch_grouped\tN/A (unavailable)")

View File

@@ -25,6 +25,9 @@ def available() -> bool:
return False return False
LAST_ERROR: Optional[str] = None
def _call_grouped_mm( def _call_grouped_mm(
As: List[torch.Tensor], Bs: List[torch.Tensor] As: List[torch.Tensor], Bs: List[torch.Tensor]
) -> Optional[List[torch.Tensor]]: ) -> Optional[List[torch.Tensor]]:
@@ -32,13 +35,24 @@ def _call_grouped_mm(
Try calling the appropriate grouped mm op available in this torch build. Try calling the appropriate grouped mm op available in this torch build.
Returns list of outputs or None on failure. Returns list of outputs or None on failure.
""" """
global LAST_ERROR
try: try:
# Ensure 2D contiguous inputs
As = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As]
Bs = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
if hasattr(torch.ops.aten, "_grouped_mm"): if hasattr(torch.ops.aten, "_grouped_mm"):
return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined] try:
return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined]
except Exception as e:
LAST_ERROR = f"_grouped_mm failed: {e}"
if hasattr(torch.ops.aten, "_scaled_grouped_mm"): if hasattr(torch.ops.aten, "_scaled_grouped_mm"):
# signature likely (As, Bs, alpha, beta) try:
return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined] return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined]
except Exception: except Exception as e:
LAST_ERROR = f"_scaled_grouped_mm failed: {e}"
except Exception as e:
LAST_ERROR = f"call error: {e}"
return None return None
return None return None