From 7327144344ff6f7926d3f0dd45a0fc02414940df Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 19 Sep 2025 13:41:12 -0400 Subject: [PATCH] compile --- scripts/bench_moe.py | 42 +++++++++++++++++++++++++++++++------- scripts/bench_moe_sweep.py | 40 ++++++++++++++++++++++++++++++++++-- 2 files changed, 73 insertions(+), 9 deletions(-) diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index e91bfebb3..5bce3a33a 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -6,6 +6,7 @@ from __future__ import annotations import argparse import sys import time +import weakref from pathlib import Path import torch @@ -83,6 +84,11 @@ def main() -> None: p.add_argument("--iters", type=int, default=50) p.add_argument("--warmup", type=int, default=10) p.add_argument("--profile", action="store_true") + p.add_argument( + "--compile", + action="store_true", + help="Torch.compile both paths before benchmarking", + ) args = p.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -114,17 +120,39 @@ def main() -> None: x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype) - def run_naive(): - y, _ = block_naive(x) + # Optional torch.compile + run_grouped_impl = None + if args.compile: + try: + block_naive = torch.compile(block_naive) # type: ignore[arg-type] + except Exception as exc: # pragma: no cover + print(f"torch.compile naive failed ({exc}); using eager") + else: + + def grouped_forward(inp, *, block=block_grouped): + block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined] + y, _ = tg.moe_ffn_forward_grouped( + inp, block.gate, block.experts, block.top_k + ) + return y + + try: + run_grouped_impl = torch.compile(grouped_forward) # type: ignore[arg-type] + except Exception as exc: # pragma: no cover + print(f"torch.compile grouped failed ({exc}); using eager") + run_grouped_impl = None + + def run_naive(block=block_naive, data=x): + y, _ = block(data) return y - def run_grouped(): + def run_grouped(block=block_grouped, data=x, impl=run_grouped_impl): + if impl is not None: + return impl(data) if tg is None or not tg.available(): return torch.empty(0) - block_grouped.experts._ax_parent_block = block_grouped - y, _ = tg.moe_ffn_forward_grouped( - x, block_grouped.gate, block_grouped.experts, block_grouped.top_k - ) + block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined] + y, _ = tg.moe_ffn_forward_grouped(data, block.gate, block.experts, block.top_k) return y if y is not None else torch.empty(0) t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup) diff --git a/scripts/bench_moe_sweep.py b/scripts/bench_moe_sweep.py index d90b37fb4..d36495e33 100644 --- a/scripts/bench_moe_sweep.py +++ b/scripts/bench_moe_sweep.py @@ -7,6 +7,7 @@ import argparse import csv import sys import time +import weakref from dataclasses import dataclass from pathlib import Path from typing import List @@ -108,6 +109,7 @@ def main() -> None: p.add_argument("--iters", type=int, default=25) p.add_argument("--warmup", type=int, default=5) p.add_argument("--csv", type=Path, default=None) + p.add_argument("--compile", action="store_true") args = p.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -159,12 +161,46 @@ def main() -> None: bsz, seq, hidden, device=device, dtype=dtype ) + compiled_impl = None + if args.compile: + try: + block_naive = torch.compile(block_naive) # type: ignore[arg-type] + except Exception as exc: + print( + f"torch.compile naive failed ({exc}); using eager" + ) + else: + + def grouped_forward(inp, *, block=block_grouped): + block.experts._ax_parent_block_ref = ( + weakref.ref(block) + ) # type: ignore[attr-defined] + y, _ = tg.moe_ffn_forward_grouped( + inp, + block.gate, + block.experts, + block.top_k, + ) + return y + + try: + compiled_impl = torch.compile(grouped_forward) # type: ignore[arg-type] + except Exception as exc: + print( + f"torch.compile grouped failed ({exc}); using eager" + ) + compiled_impl = None + def run_naive(block=block_naive, data=x): y, _ = block(data) return y - def run_grouped(block=block_grouped, data=x): - block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore + def run_grouped( + block=block_grouped, data=x, impl=compiled_impl + ): + if impl is not None: + return impl(data) + block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined] y, _ = tg.moe_ffn_forward_grouped( data, block.gate,