From e003a05177e409e809b8330025d95a67367cf811 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 25 Sep 2025 14:54:03 -0400 Subject: [PATCH] narrow sweep; compare both backends --- scripts/benchmarks/deepseek_v3_moe_sweep.py | 254 +++++++++++--------- 1 file changed, 134 insertions(+), 120 deletions(-) diff --git a/scripts/benchmarks/deepseek_v3_moe_sweep.py b/scripts/benchmarks/deepseek_v3_moe_sweep.py index de7f9426e..162a3a53b 100644 --- a/scripts/benchmarks/deepseek_v3_moe_sweep.py +++ b/scripts/benchmarks/deepseek_v3_moe_sweep.py @@ -6,7 +6,6 @@ from __future__ import annotations import argparse import csv -import itertools import sys from pathlib import Path from types import SimpleNamespace @@ -42,61 +41,24 @@ def parse_args() -> argparse.Namespace: choices=["auto", "cpu", "cuda"], help="Execution device", ) - parser.add_argument( - "--backend", - choices=["cg", "mg"], - default="mg", - help="MoE kernel backend to benchmark", - ) parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations") parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations") parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument( "--group-size", type=int, - default=128, - help="GROUP_SIZE_M used by the Triton kernel", - ) - parser.add_argument( - "--batches", - default="4,8", - help="Comma separated list of batch sizes", - ) - parser.add_argument( - "--seq-lens", - default="1024,2048", - help="Comma separated list of sequence lengths", - ) - parser.add_argument( - "--hidden-sizes", - default="2048,4096", - help="Comma separated list of hidden sizes", - ) - parser.add_argument( - "--moe-intermediates", - default="4096,8192", - help="Comma separated list of MoE intermediate sizes", - ) - parser.add_argument( - "--n-experts-list", - default="64,128,256", - help="Comma separated list of expert counts", - ) - parser.add_argument( - "--top-ks", - default="4,8", - help="Comma separated list of top-k values", - ) - parser.add_argument( - "--groups-list", - default="4,8", - help="Comma separated list of router group counts", + help="Override GROUP_SIZE_M for every configuration", ) parser.add_argument( "--uniform-routing", action="store_true", help="Force uniform routing for every configuration", ) + parser.add_argument( + "--include-mixtral-long", + action="store_true", + help="Add an 8×8192 Mixtral-style run to the sweep", + ) parser.add_argument( "--output", type=Path, @@ -105,65 +67,115 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def make_namespace(base: dict, args: argparse.Namespace) -> SimpleNamespace: +def make_namespace( + base: dict, args: argparse.Namespace, backend: str +) -> SimpleNamespace: combined = dict(base) combined.update( { "dtype": args.dtype, "device": args.device, - "backend": args.backend, + "backend": backend, "warmup": args.warmup, "iters": args.iters, "seed": args.seed, - "group_size": args.group_size, "uniform_routing": args.uniform_routing, } ) + if args.group_size is not None: + combined["group_size"] = args.group_size return SimpleNamespace(**combined) +ARCHETYPES = ( + ( + "mixtral", + { + "hidden_size": 4096, + "moe_intermediate_size": 14336, + "n_experts": 8, + "top_k": 2, + "groups": 1, + "group_size": 128, + }, + [(4, 2048), (8, 4096)], + ), + ( + "dbrx", + { + "hidden_size": 6144, + "moe_intermediate_size": 24576, + "n_experts": 16, + "top_k": 2, + "groups": 4, + "group_size": 192, + }, + [(4, 4096), (8, 8192)], + ), + ( + "qwen", + { + "hidden_size": 6144, + "moe_intermediate_size": 24576, + "n_experts": 16, + "top_k": 4, + "groups": 8, + "group_size": 128, + }, + [(4, 4096), (8, 8192)], + ), + ( + "deepseek_v3", + { + "hidden_size": 12288, + "moe_intermediate_size": 49152, + "n_experts": 128, + "top_k": 8, + "groups": 16, + "group_size": 128, + }, + [(4, 4096), (8, 8192)], + ), +) + +MIXTRAL_LONG_SHAPES = [(8, 8192)] + +BACKENDS = ("cg", "mg") + + def main() -> None: # pragma: no cover - utility script args = parse_args() - def _parse_list(text: str) -> list[int]: - return [int(item.strip()) for item in text.split(",") if item.strip()] - - batch_values = _parse_list(args.batches) - seq_values = _parse_list(args.seq_lens) - hidden_values = _parse_list(args.hidden_sizes) - moe_values = _parse_list(args.moe_intermediates) - expert_values = _parse_list(args.n_experts_list) - topk_values = _parse_list(args.top_ks) - group_values = _parse_list(args.groups_list) - grid = [] - for batch, seq_len, hidden, moe, n_experts, top_k, groups in itertools.product( - batch_values, - seq_values, - hidden_values, - moe_values, - expert_values, - topk_values, - group_values, - ): - if n_experts % groups != 0 or top_k > n_experts: - continue - grid.append( - { + for label, base_cfg, shapes in ARCHETYPES: + for batch, seq_len in shapes: + cfg = { + "label": label, "batch": batch, "seq_len": seq_len, - "hidden_size": hidden, - "moe_intermediate_size": moe, - "n_experts": n_experts, - "top_k": top_k, - "groups": groups, + **base_cfg, } - ) + if cfg["n_experts"] % cfg["groups"] != 0 or cfg["top_k"] > cfg["n_experts"]: + continue + grid.append(cfg) + + if args.include_mixtral_long: + base_cfg = ARCHETYPES[0][1] + for batch, seq_len in MIXTRAL_LONG_SHAPES: + grid.append( + { + "label": "mixtral_long", + "batch": batch, + "seq_len": seq_len, + **base_cfg, + } + ) if not grid: raise SystemExit("No valid parameter combinations produced") header = ( + "model", "batch", "seq_len", "hidden_size", @@ -185,57 +197,59 @@ def main() -> None: # pragma: no cover - utility script rows = [] print( - f"Running sweep on device={args.device} dtype={args.dtype} backend={args.backend} uniform_routing={args.uniform_routing}" + f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={args.uniform_routing}" ) print( - f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}" + f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}" f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}" ) for cfg in grid: - ns = make_namespace(cfg, args) - result = benchmark_deepseek_v3(ns) - baseline_vram_mib = ( - result["baseline_vram"] / (1024**2) - if result["baseline_vram"] is not None - else float("nan") - ) - patched_vram_mib = ( - result["patched_vram"] / (1024**2) - if result["patched_vram"] is not None - else float("nan") - ) - rows.append( - ( - cfg["batch"], - cfg["seq_len"], - cfg["hidden_size"], - cfg["moe_intermediate_size"], - cfg["n_experts"], - cfg["top_k"], - cfg["groups"], - args.backend, - result["baseline_ms"], - result["patched_ms"], - result["speedup"], - baseline_vram_mib, - patched_vram_mib, - result["min_tokens"], - result["max_tokens"], - result["max_diff"], - result["accuracy_ok"], + for backend in BACKENDS: + ns = make_namespace(cfg, args, backend) + result = benchmark_deepseek_v3(ns) + baseline_vram_mib = ( + result["baseline_vram"] / (1024**2) + if result["baseline_vram"] is not None + else float("nan") ) - ) - status = "OK" if result["accuracy_ok"] else "FAIL" - print( - f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {args.backend:>8}" - f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x" - f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}" - ) - if not result["accuracy_ok"]: - raise RuntimeError( - f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}" + patched_vram_mib = ( + result["patched_vram"] / (1024**2) + if result["patched_vram"] is not None + else float("nan") ) + rows.append( + ( + cfg["label"], + cfg["batch"], + cfg["seq_len"], + cfg["hidden_size"], + cfg["moe_intermediate_size"], + cfg["n_experts"], + cfg["top_k"], + cfg["groups"], + backend, + result["baseline_ms"], + result["patched_ms"], + result["speedup"], + baseline_vram_mib, + patched_vram_mib, + result["min_tokens"], + result["max_tokens"], + result["max_diff"], + result["accuracy_ok"], + ) + ) + status = "OK" if result["accuracy_ok"] else "FAIL" + print( + f"{cfg['label']:>10} {cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {backend:>8}" + f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x" + f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}" + ) + if not result["accuracy_ok"]: + raise RuntimeError( + f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}" + ) if args.output: args.output.parent.mkdir(parents=True, exist_ok=True)