From dd853585436b48409e51255a8d368e32618e4a33 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 25 Sep 2025 16:30:23 -0400 Subject: [PATCH] default mg --- scripts/benchmarks/deepseek_v3_moe_sweep.py | 23 +++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/scripts/benchmarks/deepseek_v3_moe_sweep.py b/scripts/benchmarks/deepseek_v3_moe_sweep.py index 944d48086..735f81210 100644 --- a/scripts/benchmarks/deepseek_v3_moe_sweep.py +++ b/scripts/benchmarks/deepseek_v3_moe_sweep.py @@ -52,6 +52,11 @@ def parse_args() -> argparse.Namespace: type=int, help="Override GROUP_SIZE_M for every configuration", ) + parser.add_argument( + "--backends", + default="mg", + help="Comma separated list of backends to benchmark (subset of cg,mg)", + ) parser.add_argument( "--no-uniform-routing", action="store_true", @@ -131,8 +136,6 @@ ARCHETYPES = ( MIXTRAL_LONG_SHAPES = [(8, 8192)] -BACKENDS = ("cg", "mg") - def main() -> None: # pragma: no cover - utility script args = parse_args() @@ -187,9 +190,21 @@ def main() -> None: # pragma: no cover - utility script ) rows = [] + raw_backends = [ + token.strip() for token in args.backends.split(",") if token.strip() + ] + if not raw_backends: + raw_backends = ["mg"] + valid_backends = [] + for backend in raw_backends: + if backend not in {"cg", "mg"}: + raise SystemExit(f"Unsupported backend '{backend}' requested") + if backend not in valid_backends: + valid_backends.append(backend) + uniform_flag = not args.no_uniform_routing print( - f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={uniform_flag}" + f"Running sweep on device={args.device} dtype={args.dtype} backends={tuple(valid_backends)} uniform_routing={uniform_flag}" ) print( f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}" @@ -197,7 +212,7 @@ def main() -> None: # pragma: no cover - utility script ) for cfg in grid: - for backend in BACKENDS: + for backend in valid_backends: ns = make_namespace(cfg, args, backend) result = benchmark_deepseek_v3(ns) baseline_vram_mib = (