This commit is contained in:
Dan Saunders
2025-09-15 18:52:55 -04:00
parent 68da65cba2
commit 479b6144df

View File

@@ -86,6 +86,16 @@ def bench(fn, *args, iters=50, warmup=10, sync=True):
return sum(times) / len(times)
def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
"""Estimate GEMM FLOPs for a SwiGLU MoE MLP.
Two up projections (w1,w3) + one down (w2), each token processed by top_k experts.
FLOPs ≈ 6 * (tokens * top_k) * hidden * inter (2*m*k*n per GEMM).
"""
m_rep = tokens * top_k
return 6.0 * m_rep * hidden * inter
def main():
p = argparse.ArgumentParser(description="MoE microbenchmark")
p.add_argument("--bsz", type=int, default=8)
@@ -139,9 +149,13 @@ def main():
iters=args.iters,
warmup=args.warmup,
)
print(f"naive {t_naive:.2f} ms {tokens / (t_naive / 1000):.1f} tok/s")
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
)
# HF Triton (routing + stub compute for now)
# HF Triton (stub compute for now)
os.environ.setdefault("AXOLOTL_MOE_BACKEND", "hf_triton")
t_hf = forward_hf_triton
y = t_hf(x, gate, experts, args.top_k)
@@ -149,12 +163,15 @@ def main():
t_ms = bench(
t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup
)
print(f"hf_triton {t_ms:.2f} ms {tokens / (t_ms / 1000):.1f} tok/s")
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
print(
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s"
)
else:
print("hf_triton N/A (kernels hub not available)")
print("hf_triton\tN/A (kernels hub not available)")
# torch_grouped placeholder — not yet implemented
print("torch_grouped N/A (pending implementation)")
print("torch_grouped\tN/A (pending implementation)")
if __name__ == "__main__":