diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 5bce3a33a..85edc50fa 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -10,6 +10,7 @@ import weakref from pathlib import Path import torch +import torch._dynamo as dynamo try: from axolotl.kernels.moe import torch_grouped as tg @@ -75,11 +76,11 @@ def load_hf_block( def main() -> None: p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark") p.add_argument("--bsz", type=int, default=8) - p.add_argument("--seq", type=int, default=512) - p.add_argument("--hidden", type=int, default=1024) - p.add_argument("--inter", type=int, default=2048) - p.add_argument("--experts", type=int, default=8) - p.add_argument("--top_k", type=int, default=2) + p.add_argument("--seq", type=int, default=1024) + p.add_argument("--hidden", type=int, default=4096) + p.add_argument("--inter", type=int, default=14336) + p.add_argument("--experts", type=int, default=32) + p.add_argument("--top_k", type=int, default=4) p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16") p.add_argument("--iters", type=int, default=50) p.add_argument("--warmup", type=int, default=10) @@ -123,6 +124,8 @@ def main() -> None: # Optional torch.compile run_grouped_impl = None if args.compile: + dynamo.config.capture_scalar_outputs = True + dynamo.config.allow_unspec_int_on_nn_module = True try: block_naive = torch.compile(block_naive) # type: ignore[arg-type] except Exception as exc: # pragma: no cover diff --git a/scripts/bench_moe_sweep.py b/scripts/bench_moe_sweep.py index d36495e33..848ae3436 100644 --- a/scripts/bench_moe_sweep.py +++ b/scripts/bench_moe_sweep.py @@ -13,6 +13,7 @@ from pathlib import Path from typing import List import torch +import torch._dynamo as dynamo try: from axolotl.kernels.moe import torch_grouped as tg @@ -163,6 +164,8 @@ def main() -> None: compiled_impl = None if args.compile: + dynamo.config.capture_scalar_outputs = True + dynamo.config.allow_unspec_int_on_nn_module = True try: block_naive = torch.compile(block_naive) # type: ignore[arg-type] except Exception as exc: