grouped_mm

This commit is contained in:
Dan Saunders
2025-09-15 19:31:21 -04:00
parent 3c6648678f
commit d7de6b0e96
3 changed files with 222 additions and 27 deletions

View File

@@ -109,9 +109,6 @@ def main():
)
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
p.add_argument(
"--check", action="store_true", help="Check numerical equivalence (outputs)"
)
args = p.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -166,30 +163,56 @@ def main():
t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup
)
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
speedup = t_naive / t_ms
print(
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s"
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
)
if args.check:
with torch.no_grad():
y_ref = forward_naive(x, gate, experts, args.top_k)
y_fast = y
# align dtypes for error metrics
y_ref32 = y_ref.float()
y_fast32 = y_fast.float()
diff = (y_ref32 - y_fast32).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (
(diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
)
print(
f"check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
with torch.no_grad():
y_ref = forward_naive(x, gate, experts, args.top_k)
y_fast = y
y_ref32 = y_ref.float()
y_fast32 = y_fast.float()
diff = (y_ref32 - y_fast32).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
print(
f"check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
else:
print("hf_triton\tN/A (kernels hub not available)")
# torch_grouped placeholder — not yet implemented
print("torch_grouped\tN/A (pending implementation)")
# torch_grouped backend (PyTorch 2.8+)
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception:
tg = None
if tg is not None and tg.available():
def forward_tg(a, g, ex, topk):
y, _ = tg.moe_ffn_forward_grouped(a, g, ex, topk)
return y
y_tg = forward_tg(x, gate, experts, args.top_k)
if y_tg is not None:
t_ms = bench(
forward_tg,
x,
gate,
experts,
args.top_k,
iters=args.iters,
warmup=args.warmup,
)
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
speedup = t_naive / t_ms
print(
f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
)
else:
print("torch_grouped\tN/A (op not callable)")
else:
print("torch_grouped\tN/A (unavailable)")
if __name__ == "__main__":

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python
"""
Probe PyTorch for grouped GEMM operator names and namespaces.
Run: python scripts/probe_torch_grouped_ops.py
"""
import sys
def main():
try:
import torch
except Exception as e:
print("Failed to import torch:", e)
sys.exit(1)
print("torch version:", torch.__version__)
namespaces = [n for n in dir(torch.ops) if not n.startswith("_")]
print("ops namespaces:", namespaces)
found_any = False
for ns in namespaces:
obj = getattr(torch.ops, ns, None)
ops = []
if obj is not None:
try:
ops = dir(obj)
except Exception as e:
print(f"warning: failed to list ops for namespace {ns}: {e}")
cands = [
o
for o in ops
if ("group" in o.lower())
or ("mm_grouped" in o.lower())
or ("matmul_grouped" in o.lower())
or ("grouped" in o.lower())
]
if cands:
found_any = True
print(f"namespace {ns} candidates:", cands)
if not found_any:
print("No grouped GEMM candidates found. PyTorch >= 2.8 is recommended.")
if __name__ == "__main__":
main()