This commit is contained in:
Dan Saunders
2025-09-19 13:41:12 -04:00
parent fb11f696e9
commit 7327144344
2 changed files with 73 additions and 9 deletions

View File

@@ -6,6 +6,7 @@ from __future__ import annotations
import argparse import argparse
import sys import sys
import time import time
import weakref
from pathlib import Path from pathlib import Path
import torch import torch
@@ -83,6 +84,11 @@ def main() -> None:
p.add_argument("--iters", type=int, default=50) p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10) p.add_argument("--warmup", type=int, default=10)
p.add_argument("--profile", action="store_true") p.add_argument("--profile", action="store_true")
p.add_argument(
"--compile",
action="store_true",
help="Torch.compile both paths before benchmarking",
)
args = p.parse_args() args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -114,17 +120,39 @@ def main() -> None:
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype) x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
def run_naive(): # Optional torch.compile
y, _ = block_naive(x) run_grouped_impl = None
if args.compile:
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile naive failed ({exc}); using eager")
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp, block.gate, block.experts, block.top_k
)
return y
try:
run_grouped_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile grouped failed ({exc}); using eager")
run_grouped_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y return y
def run_grouped(): def run_grouped(block=block_grouped, data=x, impl=run_grouped_impl):
if impl is not None:
return impl(data)
if tg is None or not tg.available(): if tg is None or not tg.available():
return torch.empty(0) return torch.empty(0)
block_grouped.experts._ax_parent_block = block_grouped block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped( y, _ = tg.moe_ffn_forward_grouped(data, block.gate, block.experts, block.top_k)
x, block_grouped.gate, block_grouped.experts, block_grouped.top_k
)
return y if y is not None else torch.empty(0) return y if y is not None else torch.empty(0)
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup) t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)

View File

@@ -7,6 +7,7 @@ import argparse
import csv import csv
import sys import sys
import time import time
import weakref
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List from typing import List
@@ -108,6 +109,7 @@ def main() -> None:
p.add_argument("--iters", type=int, default=25) p.add_argument("--iters", type=int, default=25)
p.add_argument("--warmup", type=int, default=5) p.add_argument("--warmup", type=int, default=5)
p.add_argument("--csv", type=Path, default=None) p.add_argument("--csv", type=Path, default=None)
p.add_argument("--compile", action="store_true")
args = p.parse_args() args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -159,12 +161,46 @@ def main() -> None:
bsz, seq, hidden, device=device, dtype=dtype bsz, seq, hidden, device=device, dtype=dtype
) )
compiled_impl = None
if args.compile:
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile naive failed ({exc}); using eager"
)
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = (
weakref.ref(block)
) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp,
block.gate,
block.experts,
block.top_k,
)
return y
try:
compiled_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile grouped failed ({exc}); using eager"
)
compiled_impl = None
def run_naive(block=block_naive, data=x): def run_naive(block=block_naive, data=x):
y, _ = block(data) y, _ = block(data)
return y return y
def run_grouped(block=block_grouped, data=x): def run_grouped(
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore block=block_grouped, data=x, impl=compiled_impl
):
if impl is not None:
return impl(data)
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped( y, _ = tg.moe_ffn_forward_grouped(
data, data,
block.gate, block.gate,