diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index db4044748..14653aba1 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -155,7 +155,11 @@ def main(): f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s" ) - # HF Triton (stub compute for now) + # Prepare reference output once for checks + with torch.no_grad(): + y_ref = forward_naive(x, gate, experts, args.top_k) + + # HF Triton t_hf = forward_hf_triton y = t_hf(x, gate, experts, args.top_k) if y is not None: @@ -167,8 +171,8 @@ def main(): print( f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×" ) + # parity for hf_triton vs naive with torch.no_grad(): - y_ref = forward_naive(x, gate, experts, args.top_k) y_fast = y y_ref32 = y_ref.float() y_fast32 = y_fast.float() @@ -177,7 +181,7 @@ def main(): mean_abs = diff.mean().item() rel_l2 = (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item() print( - f"check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}" + f"hf_triton_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}" ) else: print("hf_triton\tN/A (kernels hub not available)") @@ -209,6 +213,19 @@ def main(): print( f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×" ) + with torch.no_grad(): + y_fast = y_tg + y_ref32 = y_ref.float() + y_fast32 = y_fast.float() + diff = (y_ref32 - y_fast32).abs() + max_abs = diff.max().item() + mean_abs = diff.mean().item() + rel_l2 = ( + (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item() + ) + print( + f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}" + ) else: try: from axolotl.kernels.moe.torch_grouped import LAST_ERROR as _TG_ERR