fix compile
This commit is contained in:
@@ -10,6 +10,7 @@ import weakref
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch._dynamo as dynamo
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from axolotl.kernels.moe import torch_grouped as tg
|
from axolotl.kernels.moe import torch_grouped as tg
|
||||||
@@ -75,11 +76,11 @@ def load_hf_block(
|
|||||||
def main() -> None:
|
def main() -> None:
|
||||||
p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark")
|
p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark")
|
||||||
p.add_argument("--bsz", type=int, default=8)
|
p.add_argument("--bsz", type=int, default=8)
|
||||||
p.add_argument("--seq", type=int, default=512)
|
p.add_argument("--seq", type=int, default=1024)
|
||||||
p.add_argument("--hidden", type=int, default=1024)
|
p.add_argument("--hidden", type=int, default=4096)
|
||||||
p.add_argument("--inter", type=int, default=2048)
|
p.add_argument("--inter", type=int, default=14336)
|
||||||
p.add_argument("--experts", type=int, default=8)
|
p.add_argument("--experts", type=int, default=32)
|
||||||
p.add_argument("--top_k", type=int, default=2)
|
p.add_argument("--top_k", type=int, default=4)
|
||||||
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
|
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
|
||||||
p.add_argument("--iters", type=int, default=50)
|
p.add_argument("--iters", type=int, default=50)
|
||||||
p.add_argument("--warmup", type=int, default=10)
|
p.add_argument("--warmup", type=int, default=10)
|
||||||
@@ -123,6 +124,8 @@ def main() -> None:
|
|||||||
# Optional torch.compile
|
# Optional torch.compile
|
||||||
run_grouped_impl = None
|
run_grouped_impl = None
|
||||||
if args.compile:
|
if args.compile:
|
||||||
|
dynamo.config.capture_scalar_outputs = True
|
||||||
|
dynamo.config.allow_unspec_int_on_nn_module = True
|
||||||
try:
|
try:
|
||||||
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
|
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
|
||||||
except Exception as exc: # pragma: no cover
|
except Exception as exc: # pragma: no cover
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from pathlib import Path
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch._dynamo as dynamo
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from axolotl.kernels.moe import torch_grouped as tg
|
from axolotl.kernels.moe import torch_grouped as tg
|
||||||
@@ -163,6 +164,8 @@ def main() -> None:
|
|||||||
|
|
||||||
compiled_impl = None
|
compiled_impl = None
|
||||||
if args.compile:
|
if args.compile:
|
||||||
|
dynamo.config.capture_scalar_outputs = True
|
||||||
|
dynamo.config.allow_unspec_int_on_nn_module = True
|
||||||
try:
|
try:
|
||||||
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
|
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|||||||
Reference in New Issue
Block a user