diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 1815096ec..95fe83793 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -74,11 +74,11 @@ def load_hf_block( def main() -> None: p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark") p.add_argument("--bsz", type=int, default=8) - p.add_argument("--seq", type=int, default=1024) - p.add_argument("--hidden", type=int, default=4096) - p.add_argument("--inter", type=int, default=14336) - p.add_argument("--experts", type=int, default=32) - p.add_argument("--top_k", type=int, default=4) + p.add_argument("--seq", type=int, default=512) + p.add_argument("--hidden", type=int, default=1024) + p.add_argument("--inter", type=int, default=2948) + p.add_argument("--experts", type=int, default=8) + p.add_argument("--top_k", type=int, default=2) p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16") p.add_argument("--iters", type=int, default=50) p.add_argument("--warmup", type=int, default=10)