From d2b25c732774bd5a1af98cd44d5f4f1a13b9ef75 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 22 Sep 2025 16:34:55 -0400 Subject: [PATCH] grid sweep --- scripts/benchmarks/deepseek_v3_moe_sweep.py | 122 +++++++++++++------- 1 file changed, 79 insertions(+), 43 deletions(-) diff --git a/scripts/benchmarks/deepseek_v3_moe_sweep.py b/scripts/benchmarks/deepseek_v3_moe_sweep.py index 87124056f..c4a176a9f 100644 --- a/scripts/benchmarks/deepseek_v3_moe_sweep.py +++ b/scripts/benchmarks/deepseek_v3_moe_sweep.py @@ -6,6 +6,7 @@ from __future__ import annotations import argparse import csv +import itertools import sys from pathlib import Path from types import SimpleNamespace @@ -20,46 +21,6 @@ from scripts.benchmarks.deepseek_v3_moe import ( benchmark_deepseek_v3, ) -DEFAULT_SWEEP = [ - { - "batch": 4, - "seq_len": 1024, - "hidden_size": 2048, - "moe_intermediate_size": 4096, - "n_experts": 64, - "top_k": 4, - "groups": 4, - }, - { - "batch": 8, - "seq_len": 2048, - "hidden_size": 2048, - "moe_intermediate_size": 4096, - "n_experts": 64, - "top_k": 4, - "groups": 4, - }, - { - "batch": 8, - "seq_len": 2048, - "hidden_size": 4096, - "moe_intermediate_size": 8192, - "n_experts": 128, - "top_k": 8, - "groups": 8, - }, - { - "batch": 8, - "seq_len": 2048, - "hidden_size": 4096, - "moe_intermediate_size": 8192, - "n_experts": 256, - "top_k": 8, - "groups": 8, - }, -] - - def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -83,6 +44,41 @@ def parse_args() -> argparse.Namespace: default=128, help="GROUP_SIZE_M used by the Triton kernel", ) + parser.add_argument( + "--batches", + default="4,8", + help="Comma separated list of batch sizes", + ) + parser.add_argument( + "--seq-lens", + default="1024,2048", + help="Comma separated list of sequence lengths", + ) + parser.add_argument( + "--hidden-sizes", + default="2048,4096", + help="Comma separated list of hidden sizes", + ) + parser.add_argument( + "--moe-intermediates", + default="4096,8192", + help="Comma separated list of MoE intermediate sizes", + ) + parser.add_argument( + "--n-experts-list", + default="64,128,256", + help="Comma separated list of expert counts", + ) + parser.add_argument( + "--top-ks", + default="4,8", + help="Comma separated list of top-k values", + ) + parser.add_argument( + "--groups-list", + default="4,8", + help="Comma separated list of router group counts", + ) parser.add_argument( "--uniform-routing", action="store_true", @@ -115,6 +111,44 @@ def make_namespace(base: dict, args: argparse.Namespace) -> SimpleNamespace: def main() -> None: # pragma: no cover - utility script args = parse_args() + def _parse_list(text: str) -> list[int]: + return [int(item.strip()) for item in text.split(",") if item.strip()] + + batch_values = _parse_list(args.batches) + seq_values = _parse_list(args.seq_lens) + hidden_values = _parse_list(args.hidden_sizes) + moe_values = _parse_list(args.moe_intermediates) + expert_values = _parse_list(args.n_experts_list) + topk_values = _parse_list(args.top_ks) + group_values = _parse_list(args.groups_list) + + grid = [] + for batch, seq_len, hidden, moe, n_experts, top_k, groups in itertools.product( + batch_values, + seq_values, + hidden_values, + moe_values, + expert_values, + topk_values, + group_values, + ): + if n_experts % groups != 0 or top_k > n_experts: + continue + grid.append( + { + "batch": batch, + "seq_len": seq_len, + "hidden_size": hidden, + "moe_intermediate_size": moe, + "n_experts": n_experts, + "top_k": top_k, + "groups": groups, + } + ) + + if not grid: + raise SystemExit("No valid parameter combinations produced") + header = ( "batch", "seq_len", @@ -122,6 +156,7 @@ def main() -> None: # pragma: no cover - utility script "moe_intermediate", "n_experts", "top_k", + "groups", "baseline_ms", "patched_ms", "speedup", @@ -135,10 +170,10 @@ def main() -> None: # pragma: no cover - utility script f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}" ) print( - f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'baseline':>12} {'patched':>12} {'speedup':>8}" + f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'baseline':>12} {'patched':>12} {'speedup':>8}" ) - for cfg in DEFAULT_SWEEP: + for cfg in grid: ns = make_namespace(cfg, args) result = benchmark_deepseek_v3(ns) rows.append( @@ -149,6 +184,7 @@ def main() -> None: # pragma: no cover - utility script cfg["moe_intermediate_size"], cfg["n_experts"], cfg["top_k"], + cfg["groups"], result["baseline_ms"], result["patched_ms"], result["speedup"], @@ -158,7 +194,7 @@ def main() -> None: # pragma: no cover - utility script ) ) print( - f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4}" + f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}" f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x" )