From 3c6648678f30b0fb25b0dfdd63ce6af3f501e7df Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 15 Sep 2025 19:08:30 -0400 Subject: [PATCH] numerics --- scripts/bench_moe.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 5c5f07d78..0f2ae7d12 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -109,6 +109,9 @@ def main(): ) p.add_argument("--iters", type=int, default=50) p.add_argument("--warmup", type=int, default=10) + p.add_argument( + "--check", action="store_true", help="Check numerical equivalence (outputs)" + ) args = p.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" @@ -166,6 +169,22 @@ def main(): print( f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s" ) + if args.check: + with torch.no_grad(): + y_ref = forward_naive(x, gate, experts, args.top_k) + y_fast = y + # align dtypes for error metrics + y_ref32 = y_ref.float() + y_fast32 = y_fast.float() + diff = (y_ref32 - y_fast32).abs() + max_abs = diff.max().item() + mean_abs = diff.mean().item() + rel_l2 = ( + (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item() + ) + print( + f"check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}" + ) else: print("hf_triton\tN/A (kernels hub not available)")