compile
This commit is contained in:
@@ -6,6 +6,7 @@ from __future__ import annotations
|
|||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import weakref
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -83,6 +84,11 @@ def main() -> None:
|
|||||||
p.add_argument("--iters", type=int, default=50)
|
p.add_argument("--iters", type=int, default=50)
|
||||||
p.add_argument("--warmup", type=int, default=10)
|
p.add_argument("--warmup", type=int, default=10)
|
||||||
p.add_argument("--profile", action="store_true")
|
p.add_argument("--profile", action="store_true")
|
||||||
|
p.add_argument(
|
||||||
|
"--compile",
|
||||||
|
action="store_true",
|
||||||
|
help="Torch.compile both paths before benchmarking",
|
||||||
|
)
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
@@ -114,17 +120,39 @@ def main() -> None:
|
|||||||
|
|
||||||
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
|
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
|
||||||
|
|
||||||
def run_naive():
|
# Optional torch.compile
|
||||||
y, _ = block_naive(x)
|
run_grouped_impl = None
|
||||||
|
if args.compile:
|
||||||
|
try:
|
||||||
|
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
|
||||||
|
except Exception as exc: # pragma: no cover
|
||||||
|
print(f"torch.compile naive failed ({exc}); using eager")
|
||||||
|
else:
|
||||||
|
|
||||||
|
def grouped_forward(inp, *, block=block_grouped):
|
||||||
|
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
|
||||||
|
y, _ = tg.moe_ffn_forward_grouped(
|
||||||
|
inp, block.gate, block.experts, block.top_k
|
||||||
|
)
|
||||||
|
return y
|
||||||
|
|
||||||
|
try:
|
||||||
|
run_grouped_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
|
||||||
|
except Exception as exc: # pragma: no cover
|
||||||
|
print(f"torch.compile grouped failed ({exc}); using eager")
|
||||||
|
run_grouped_impl = None
|
||||||
|
|
||||||
|
def run_naive(block=block_naive, data=x):
|
||||||
|
y, _ = block(data)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def run_grouped():
|
def run_grouped(block=block_grouped, data=x, impl=run_grouped_impl):
|
||||||
|
if impl is not None:
|
||||||
|
return impl(data)
|
||||||
if tg is None or not tg.available():
|
if tg is None or not tg.available():
|
||||||
return torch.empty(0)
|
return torch.empty(0)
|
||||||
block_grouped.experts._ax_parent_block = block_grouped
|
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
|
||||||
y, _ = tg.moe_ffn_forward_grouped(
|
y, _ = tg.moe_ffn_forward_grouped(data, block.gate, block.experts, block.top_k)
|
||||||
x, block_grouped.gate, block_grouped.experts, block_grouped.top_k
|
|
||||||
)
|
|
||||||
return y if y is not None else torch.empty(0)
|
return y if y is not None else torch.empty(0)
|
||||||
|
|
||||||
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
|
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import argparse
|
|||||||
import csv
|
import csv
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -108,6 +109,7 @@ def main() -> None:
|
|||||||
p.add_argument("--iters", type=int, default=25)
|
p.add_argument("--iters", type=int, default=25)
|
||||||
p.add_argument("--warmup", type=int, default=5)
|
p.add_argument("--warmup", type=int, default=5)
|
||||||
p.add_argument("--csv", type=Path, default=None)
|
p.add_argument("--csv", type=Path, default=None)
|
||||||
|
p.add_argument("--compile", action="store_true")
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
@@ -159,12 +161,46 @@ def main() -> None:
|
|||||||
bsz, seq, hidden, device=device, dtype=dtype
|
bsz, seq, hidden, device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
|
compiled_impl = None
|
||||||
|
if args.compile:
|
||||||
|
try:
|
||||||
|
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
|
||||||
|
except Exception as exc:
|
||||||
|
print(
|
||||||
|
f"torch.compile naive failed ({exc}); using eager"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def grouped_forward(inp, *, block=block_grouped):
|
||||||
|
block.experts._ax_parent_block_ref = (
|
||||||
|
weakref.ref(block)
|
||||||
|
) # type: ignore[attr-defined]
|
||||||
|
y, _ = tg.moe_ffn_forward_grouped(
|
||||||
|
inp,
|
||||||
|
block.gate,
|
||||||
|
block.experts,
|
||||||
|
block.top_k,
|
||||||
|
)
|
||||||
|
return y
|
||||||
|
|
||||||
|
try:
|
||||||
|
compiled_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
|
||||||
|
except Exception as exc:
|
||||||
|
print(
|
||||||
|
f"torch.compile grouped failed ({exc}); using eager"
|
||||||
|
)
|
||||||
|
compiled_impl = None
|
||||||
|
|
||||||
def run_naive(block=block_naive, data=x):
|
def run_naive(block=block_naive, data=x):
|
||||||
y, _ = block(data)
|
y, _ = block(data)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def run_grouped(block=block_grouped, data=x):
|
def run_grouped(
|
||||||
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore
|
block=block_grouped, data=x, impl=compiled_impl
|
||||||
|
):
|
||||||
|
if impl is not None:
|
||||||
|
return impl(data)
|
||||||
|
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
|
||||||
y, _ = tg.moe_ffn_forward_grouped(
|
y, _ = tg.moe_ffn_forward_grouped(
|
||||||
data,
|
data,
|
||||||
block.gate,
|
block.gate,
|
||||||
|
|||||||
Reference in New Issue
Block a user