diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 644d22320..635ef6a79 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -86,6 +86,16 @@ def bench(fn, *args, iters=50, warmup=10, sync=True): return sum(times) / len(times) +def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float: + """Estimate GEMM FLOPs for a SwiGLU MoE MLP. + + Two up projections (w1,w3) + one down (w2), each token processed by top_k experts. + FLOPs ≈ 6 * (tokens * top_k) * hidden * inter (2*m*k*n per GEMM). + """ + m_rep = tokens * top_k + return 6.0 * m_rep * hidden * inter + + def main(): p = argparse.ArgumentParser(description="MoE microbenchmark") p.add_argument("--bsz", type=int, default=8) @@ -139,9 +149,13 @@ def main(): iters=args.iters, warmup=args.warmup, ) - print(f"naive {t_naive:.2f} ms {tokens / (t_naive / 1000):.1f} tok/s") + flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k) + tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12) + print( + f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s" + ) - # HF Triton (routing + stub compute for now) + # HF Triton (stub compute for now) os.environ.setdefault("AXOLOTL_MOE_BACKEND", "hf_triton") t_hf = forward_hf_triton y = t_hf(x, gate, experts, args.top_k) @@ -149,12 +163,15 @@ def main(): t_ms = bench( t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup ) - print(f"hf_triton {t_ms:.2f} ms {tokens / (t_ms / 1000):.1f} tok/s") + tflops = flops_total / ((t_ms / 1000.0) * 1e12) + print( + f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s" + ) else: - print("hf_triton N/A (kernels hub not available)") + print("hf_triton\tN/A (kernels hub not available)") # torch_grouped placeholder — not yet implemented - print("torch_grouped N/A (pending implementation)") + print("torch_grouped\tN/A (pending implementation)") if __name__ == "__main__":