grid sweep

This commit is contained in:
Dan Saunders
2025-09-22 16:34:55 -04:00
parent b670c45276
commit d2b25c7327

View File

@@ -6,6 +6,7 @@ from __future__ import annotations
import argparse import argparse
import csv import csv
import itertools
import sys import sys
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
@@ -20,46 +21,6 @@ from scripts.benchmarks.deepseek_v3_moe import (
benchmark_deepseek_v3, benchmark_deepseek_v3,
) )
DEFAULT_SWEEP = [
{
"batch": 4,
"seq_len": 1024,
"hidden_size": 2048,
"moe_intermediate_size": 4096,
"n_experts": 64,
"top_k": 4,
"groups": 4,
},
{
"batch": 8,
"seq_len": 2048,
"hidden_size": 2048,
"moe_intermediate_size": 4096,
"n_experts": 64,
"top_k": 4,
"groups": 4,
},
{
"batch": 8,
"seq_len": 2048,
"hidden_size": 4096,
"moe_intermediate_size": 8192,
"n_experts": 128,
"top_k": 8,
"groups": 8,
},
{
"batch": 8,
"seq_len": 2048,
"hidden_size": 4096,
"moe_intermediate_size": 8192,
"n_experts": 256,
"top_k": 8,
"groups": 8,
},
]
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
@@ -83,6 +44,41 @@ def parse_args() -> argparse.Namespace:
default=128, default=128,
help="GROUP_SIZE_M used by the Triton kernel", help="GROUP_SIZE_M used by the Triton kernel",
) )
parser.add_argument(
"--batches",
default="4,8",
help="Comma separated list of batch sizes",
)
parser.add_argument(
"--seq-lens",
default="1024,2048",
help="Comma separated list of sequence lengths",
)
parser.add_argument(
"--hidden-sizes",
default="2048,4096",
help="Comma separated list of hidden sizes",
)
parser.add_argument(
"--moe-intermediates",
default="4096,8192",
help="Comma separated list of MoE intermediate sizes",
)
parser.add_argument(
"--n-experts-list",
default="64,128,256",
help="Comma separated list of expert counts",
)
parser.add_argument(
"--top-ks",
default="4,8",
help="Comma separated list of top-k values",
)
parser.add_argument(
"--groups-list",
default="4,8",
help="Comma separated list of router group counts",
)
parser.add_argument( parser.add_argument(
"--uniform-routing", "--uniform-routing",
action="store_true", action="store_true",
@@ -115,6 +111,44 @@ def make_namespace(base: dict, args: argparse.Namespace) -> SimpleNamespace:
def main() -> None: # pragma: no cover - utility script def main() -> None: # pragma: no cover - utility script
args = parse_args() args = parse_args()
def _parse_list(text: str) -> list[int]:
return [int(item.strip()) for item in text.split(",") if item.strip()]
batch_values = _parse_list(args.batches)
seq_values = _parse_list(args.seq_lens)
hidden_values = _parse_list(args.hidden_sizes)
moe_values = _parse_list(args.moe_intermediates)
expert_values = _parse_list(args.n_experts_list)
topk_values = _parse_list(args.top_ks)
group_values = _parse_list(args.groups_list)
grid = []
for batch, seq_len, hidden, moe, n_experts, top_k, groups in itertools.product(
batch_values,
seq_values,
hidden_values,
moe_values,
expert_values,
topk_values,
group_values,
):
if n_experts % groups != 0 or top_k > n_experts:
continue
grid.append(
{
"batch": batch,
"seq_len": seq_len,
"hidden_size": hidden,
"moe_intermediate_size": moe,
"n_experts": n_experts,
"top_k": top_k,
"groups": groups,
}
)
if not grid:
raise SystemExit("No valid parameter combinations produced")
header = ( header = (
"batch", "batch",
"seq_len", "seq_len",
@@ -122,6 +156,7 @@ def main() -> None: # pragma: no cover - utility script
"moe_intermediate", "moe_intermediate",
"n_experts", "n_experts",
"top_k", "top_k",
"groups",
"baseline_ms", "baseline_ms",
"patched_ms", "patched_ms",
"speedup", "speedup",
@@ -135,10 +170,10 @@ def main() -> None: # pragma: no cover - utility script
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}" f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}"
) )
print( print(
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'baseline':>12} {'patched':>12} {'speedup':>8}" f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'baseline':>12} {'patched':>12} {'speedup':>8}"
) )
for cfg in DEFAULT_SWEEP: for cfg in grid:
ns = make_namespace(cfg, args) ns = make_namespace(cfg, args)
result = benchmark_deepseek_v3(ns) result = benchmark_deepseek_v3(ns)
rows.append( rows.append(
@@ -149,6 +184,7 @@ def main() -> None: # pragma: no cover - utility script
cfg["moe_intermediate_size"], cfg["moe_intermediate_size"],
cfg["n_experts"], cfg["n_experts"],
cfg["top_k"], cfg["top_k"],
cfg["groups"],
result["baseline_ms"], result["baseline_ms"],
result["patched_ms"], result["patched_ms"],
result["speedup"], result["speedup"],
@@ -158,7 +194,7 @@ def main() -> None: # pragma: no cover - utility script
) )
) )
print( print(
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4}" f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x" f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
) )