fix compile

This commit is contained in:
Dan Saunders
2025-09-19 13:55:54 -04:00
parent b5dc58373f
commit ce21da9177
2 changed files with 11 additions and 5 deletions

View File

@@ -10,6 +10,7 @@ import weakref
from pathlib import Path
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
@@ -75,11 +76,11 @@ def load_hf_block(
def main() -> None:
p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark")
p.add_argument("--bsz", type=int, default=8)
p.add_argument("--seq", type=int, default=512)
p.add_argument("--hidden", type=int, default=1024)
p.add_argument("--inter", type=int, default=2048)
p.add_argument("--experts", type=int, default=8)
p.add_argument("--top_k", type=int, default=2)
p.add_argument("--seq", type=int, default=1024)
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--experts", type=int, default=32)
p.add_argument("--top_k", type=int, default=4)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
@@ -123,6 +124,8 @@ def main() -> None:
# Optional torch.compile
run_grouped_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover

View File

@@ -13,6 +13,7 @@ from pathlib import Path
from typing import List
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
@@ -163,6 +164,8 @@ def main() -> None:
compiled_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc: