48 lines
1.2 KiB
Python
48 lines
1.2 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Probe PyTorch for grouped GEMM operator names and namespaces.
|
|
Run: python scripts/probe_torch_grouped_ops.py
|
|
"""
|
|
|
|
import sys
|
|
|
|
|
|
def main():
|
|
try:
|
|
import torch
|
|
except Exception as e:
|
|
print("Failed to import torch:", e)
|
|
sys.exit(1)
|
|
|
|
print("torch version:", torch.__version__)
|
|
namespaces = [n for n in dir(torch.ops) if not n.startswith("_")]
|
|
print("ops namespaces:", namespaces)
|
|
|
|
found_any = False
|
|
for ns in namespaces:
|
|
obj = getattr(torch.ops, ns, None)
|
|
ops = []
|
|
if obj is not None:
|
|
try:
|
|
ops = dir(obj)
|
|
except Exception as e:
|
|
print(f"warning: failed to list ops for namespace {ns}: {e}")
|
|
cands = [
|
|
o
|
|
for o in ops
|
|
if ("group" in o.lower())
|
|
or ("mm_grouped" in o.lower())
|
|
or ("matmul_grouped" in o.lower())
|
|
or ("grouped" in o.lower())
|
|
]
|
|
if cands:
|
|
found_any = True
|
|
print(f"namespace {ns} candidates:", cands)
|
|
|
|
if not found_any:
|
|
print("No grouped GEMM candidates found. PyTorch >= 2.8 is recommended.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|