fix compile
This commit is contained in:
@@ -10,6 +10,7 @@ import weakref
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
|
||||
try:
|
||||
from axolotl.kernels.moe import torch_grouped as tg
|
||||
@@ -75,11 +76,11 @@ def load_hf_block(
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark")
|
||||
p.add_argument("--bsz", type=int, default=8)
|
||||
p.add_argument("--seq", type=int, default=512)
|
||||
p.add_argument("--hidden", type=int, default=1024)
|
||||
p.add_argument("--inter", type=int, default=2048)
|
||||
p.add_argument("--experts", type=int, default=8)
|
||||
p.add_argument("--top_k", type=int, default=2)
|
||||
p.add_argument("--seq", type=int, default=1024)
|
||||
p.add_argument("--hidden", type=int, default=4096)
|
||||
p.add_argument("--inter", type=int, default=14336)
|
||||
p.add_argument("--experts", type=int, default=32)
|
||||
p.add_argument("--top_k", type=int, default=4)
|
||||
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
|
||||
p.add_argument("--iters", type=int, default=50)
|
||||
p.add_argument("--warmup", type=int, default=10)
|
||||
@@ -123,6 +124,8 @@ def main() -> None:
|
||||
# Optional torch.compile
|
||||
run_grouped_impl = None
|
||||
if args.compile:
|
||||
dynamo.config.capture_scalar_outputs = True
|
||||
dynamo.config.allow_unspec_int_on_nn_module = True
|
||||
try:
|
||||
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
|
||||
except Exception as exc: # pragma: no cover
|
||||
|
||||
@@ -13,6 +13,7 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
|
||||
try:
|
||||
from axolotl.kernels.moe import torch_grouped as tg
|
||||
@@ -163,6 +164,8 @@ def main() -> None:
|
||||
|
||||
compiled_impl = None
|
||||
if args.compile:
|
||||
dynamo.config.capture_scalar_outputs = True
|
||||
dynamo.config.allow_unspec_int_on_nn_module = True
|
||||
try:
|
||||
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
|
||||
except Exception as exc:
|
||||
|
||||
Reference in New Issue
Block a user