grouped_mm
This commit is contained in:
@@ -109,9 +109,6 @@ def main():
|
||||
)
|
||||
p.add_argument("--iters", type=int, default=50)
|
||||
p.add_argument("--warmup", type=int, default=10)
|
||||
p.add_argument(
|
||||
"--check", action="store_true", help="Check numerical equivalence (outputs)"
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -166,30 +163,56 @@ def main():
|
||||
t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup
|
||||
)
|
||||
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
|
||||
speedup = t_naive / t_ms
|
||||
print(
|
||||
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s"
|
||||
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
|
||||
)
|
||||
if args.check:
|
||||
with torch.no_grad():
|
||||
y_ref = forward_naive(x, gate, experts, args.top_k)
|
||||
y_fast = y
|
||||
# align dtypes for error metrics
|
||||
y_ref32 = y_ref.float()
|
||||
y_fast32 = y_fast.float()
|
||||
diff = (y_ref32 - y_fast32).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (
|
||||
(diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
|
||||
)
|
||||
print(
|
||||
f"check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||||
)
|
||||
with torch.no_grad():
|
||||
y_ref = forward_naive(x, gate, experts, args.top_k)
|
||||
y_fast = y
|
||||
y_ref32 = y_ref.float()
|
||||
y_fast32 = y_fast.float()
|
||||
diff = (y_ref32 - y_fast32).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
|
||||
print(
|
||||
f"check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||||
)
|
||||
else:
|
||||
print("hf_triton\tN/A (kernels hub not available)")
|
||||
|
||||
# torch_grouped placeholder — not yet implemented
|
||||
print("torch_grouped\tN/A (pending implementation)")
|
||||
# torch_grouped backend (PyTorch 2.8+)
|
||||
try:
|
||||
from axolotl.kernels.moe import torch_grouped as tg
|
||||
except Exception:
|
||||
tg = None
|
||||
if tg is not None and tg.available():
|
||||
|
||||
def forward_tg(a, g, ex, topk):
|
||||
y, _ = tg.moe_ffn_forward_grouped(a, g, ex, topk)
|
||||
return y
|
||||
|
||||
y_tg = forward_tg(x, gate, experts, args.top_k)
|
||||
if y_tg is not None:
|
||||
t_ms = bench(
|
||||
forward_tg,
|
||||
x,
|
||||
gate,
|
||||
experts,
|
||||
args.top_k,
|
||||
iters=args.iters,
|
||||
warmup=args.warmup,
|
||||
)
|
||||
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
|
||||
speedup = t_naive / t_ms
|
||||
print(
|
||||
f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
|
||||
)
|
||||
else:
|
||||
print("torch_grouped\tN/A (op not callable)")
|
||||
else:
|
||||
print("torch_grouped\tN/A (unavailable)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
47
scripts/probe_torch_grouped_ops.py
Normal file
47
scripts/probe_torch_grouped_ops.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Probe PyTorch for grouped GEMM operator names and namespaces.
|
||||
Run: python scripts/probe_torch_grouped_ops.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
import torch
|
||||
except Exception as e:
|
||||
print("Failed to import torch:", e)
|
||||
sys.exit(1)
|
||||
|
||||
print("torch version:", torch.__version__)
|
||||
namespaces = [n for n in dir(torch.ops) if not n.startswith("_")]
|
||||
print("ops namespaces:", namespaces)
|
||||
|
||||
found_any = False
|
||||
for ns in namespaces:
|
||||
obj = getattr(torch.ops, ns, None)
|
||||
ops = []
|
||||
if obj is not None:
|
||||
try:
|
||||
ops = dir(obj)
|
||||
except Exception as e:
|
||||
print(f"warning: failed to list ops for namespace {ns}: {e}")
|
||||
cands = [
|
||||
o
|
||||
for o in ops
|
||||
if ("group" in o.lower())
|
||||
or ("mm_grouped" in o.lower())
|
||||
or ("matmul_grouped" in o.lower())
|
||||
or ("grouped" in o.lower())
|
||||
]
|
||||
if cands:
|
||||
found_any = True
|
||||
print(f"namespace {ns} candidates:", cands)
|
||||
|
||||
if not found_any:
|
||||
print("No grouped GEMM candidates found. PyTorch >= 2.8 is recommended.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user