diag
This commit is contained in:
@@ -210,7 +210,12 @@ def main():
|
||||
f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
|
||||
)
|
||||
else:
|
||||
print("torch_grouped\tN/A (op not callable)")
|
||||
try:
|
||||
from axolotl.kernels.moe.torch_grouped import LAST_ERROR as _TG_ERR
|
||||
except Exception:
|
||||
_TG_ERR = None
|
||||
reason = f" (reason: {_TG_ERR})" if _TG_ERR else ""
|
||||
print(f"torch_grouped\tN/A (op not callable){reason}")
|
||||
else:
|
||||
print("torch_grouped\tN/A (unavailable)")
|
||||
|
||||
|
||||
@@ -25,6 +25,9 @@ def available() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
LAST_ERROR: Optional[str] = None
|
||||
|
||||
|
||||
def _call_grouped_mm(
|
||||
As: List[torch.Tensor], Bs: List[torch.Tensor]
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
@@ -32,13 +35,24 @@ def _call_grouped_mm(
|
||||
Try calling the appropriate grouped mm op available in this torch build.
|
||||
Returns list of outputs or None on failure.
|
||||
"""
|
||||
global LAST_ERROR
|
||||
try:
|
||||
# Ensure 2D contiguous inputs
|
||||
As = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
||||
Bs = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
||||
|
||||
if hasattr(torch.ops.aten, "_grouped_mm"):
|
||||
return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined]
|
||||
try:
|
||||
return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined]
|
||||
except Exception as e:
|
||||
LAST_ERROR = f"_grouped_mm failed: {e}"
|
||||
if hasattr(torch.ops.aten, "_scaled_grouped_mm"):
|
||||
# signature likely (As, Bs, alpha, beta)
|
||||
return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
try:
|
||||
return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined]
|
||||
except Exception as e:
|
||||
LAST_ERROR = f"_scaled_grouped_mm failed: {e}"
|
||||
except Exception as e:
|
||||
LAST_ERROR = f"call error: {e}"
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user