tflops
This commit is contained in:
@@ -86,6 +86,16 @@ def bench(fn, *args, iters=50, warmup=10, sync=True):
|
||||
return sum(times) / len(times)
|
||||
|
||||
|
||||
def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
|
||||
"""Estimate GEMM FLOPs for a SwiGLU MoE MLP.
|
||||
|
||||
Two up projections (w1,w3) + one down (w2), each token processed by top_k experts.
|
||||
FLOPs ≈ 6 * (tokens * top_k) * hidden * inter (2*m*k*n per GEMM).
|
||||
"""
|
||||
m_rep = tokens * top_k
|
||||
return 6.0 * m_rep * hidden * inter
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(description="MoE microbenchmark")
|
||||
p.add_argument("--bsz", type=int, default=8)
|
||||
@@ -139,9 +149,13 @@ def main():
|
||||
iters=args.iters,
|
||||
warmup=args.warmup,
|
||||
)
|
||||
print(f"naive {t_naive:.2f} ms {tokens / (t_naive / 1000):.1f} tok/s")
|
||||
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
|
||||
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
|
||||
print(
|
||||
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
|
||||
)
|
||||
|
||||
# HF Triton (routing + stub compute for now)
|
||||
# HF Triton (stub compute for now)
|
||||
os.environ.setdefault("AXOLOTL_MOE_BACKEND", "hf_triton")
|
||||
t_hf = forward_hf_triton
|
||||
y = t_hf(x, gate, experts, args.top_k)
|
||||
@@ -149,12 +163,15 @@ def main():
|
||||
t_ms = bench(
|
||||
t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup
|
||||
)
|
||||
print(f"hf_triton {t_ms:.2f} ms {tokens / (t_ms / 1000):.1f} tok/s")
|
||||
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
|
||||
print(
|
||||
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s"
|
||||
)
|
||||
else:
|
||||
print("hf_triton N/A (kernels hub not available)")
|
||||
print("hf_triton\tN/A (kernels hub not available)")
|
||||
|
||||
# torch_grouped placeholder — not yet implemented
|
||||
print("torch_grouped N/A (pending implementation)")
|
||||
print("torch_grouped\tN/A (pending implementation)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user