diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 777a88024..db4044748 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -210,7 +210,12 @@ def main(): f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×" ) else: - print("torch_grouped\tN/A (op not callable)") + try: + from axolotl.kernels.moe.torch_grouped import LAST_ERROR as _TG_ERR + except Exception: + _TG_ERR = None + reason = f" (reason: {_TG_ERR})" if _TG_ERR else "" + print(f"torch_grouped\tN/A (op not callable){reason}") else: print("torch_grouped\tN/A (unavailable)") diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index d8689fba0..1ed2ce20b 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -25,6 +25,9 @@ def available() -> bool: return False +LAST_ERROR: Optional[str] = None + + def _call_grouped_mm( As: List[torch.Tensor], Bs: List[torch.Tensor] ) -> Optional[List[torch.Tensor]]: @@ -32,13 +35,24 @@ def _call_grouped_mm( Try calling the appropriate grouped mm op available in this torch build. Returns list of outputs or None on failure. """ + global LAST_ERROR try: + # Ensure 2D contiguous inputs + As = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As] + Bs = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs] + if hasattr(torch.ops.aten, "_grouped_mm"): - return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined] + try: + return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined] + except Exception as e: + LAST_ERROR = f"_grouped_mm failed: {e}" if hasattr(torch.ops.aten, "_scaled_grouped_mm"): - # signature likely (As, Bs, alpha, beta) - return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined] - except Exception: + try: + return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined] + except Exception as e: + LAST_ERROR = f"_scaled_grouped_mm failed: {e}" + except Exception as e: + LAST_ERROR = f"call error: {e}" return None return None