grid sweep
This commit is contained in:
@@ -6,6 +6,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
|
import itertools
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
@@ -20,46 +21,6 @@ from scripts.benchmarks.deepseek_v3_moe import (
|
|||||||
benchmark_deepseek_v3,
|
benchmark_deepseek_v3,
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_SWEEP = [
|
|
||||||
{
|
|
||||||
"batch": 4,
|
|
||||||
"seq_len": 1024,
|
|
||||||
"hidden_size": 2048,
|
|
||||||
"moe_intermediate_size": 4096,
|
|
||||||
"n_experts": 64,
|
|
||||||
"top_k": 4,
|
|
||||||
"groups": 4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"batch": 8,
|
|
||||||
"seq_len": 2048,
|
|
||||||
"hidden_size": 2048,
|
|
||||||
"moe_intermediate_size": 4096,
|
|
||||||
"n_experts": 64,
|
|
||||||
"top_k": 4,
|
|
||||||
"groups": 4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"batch": 8,
|
|
||||||
"seq_len": 2048,
|
|
||||||
"hidden_size": 4096,
|
|
||||||
"moe_intermediate_size": 8192,
|
|
||||||
"n_experts": 128,
|
|
||||||
"top_k": 8,
|
|
||||||
"groups": 8,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"batch": 8,
|
|
||||||
"seq_len": 2048,
|
|
||||||
"hidden_size": 4096,
|
|
||||||
"moe_intermediate_size": 8192,
|
|
||||||
"n_experts": 256,
|
|
||||||
"top_k": 8,
|
|
||||||
"groups": 8,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -83,6 +44,41 @@ def parse_args() -> argparse.Namespace:
|
|||||||
default=128,
|
default=128,
|
||||||
help="GROUP_SIZE_M used by the Triton kernel",
|
help="GROUP_SIZE_M used by the Triton kernel",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batches",
|
||||||
|
default="4,8",
|
||||||
|
help="Comma separated list of batch sizes",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seq-lens",
|
||||||
|
default="1024,2048",
|
||||||
|
help="Comma separated list of sequence lengths",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hidden-sizes",
|
||||||
|
default="2048,4096",
|
||||||
|
help="Comma separated list of hidden sizes",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--moe-intermediates",
|
||||||
|
default="4096,8192",
|
||||||
|
help="Comma separated list of MoE intermediate sizes",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n-experts-list",
|
||||||
|
default="64,128,256",
|
||||||
|
help="Comma separated list of expert counts",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-ks",
|
||||||
|
default="4,8",
|
||||||
|
help="Comma separated list of top-k values",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--groups-list",
|
||||||
|
default="4,8",
|
||||||
|
help="Comma separated list of router group counts",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--uniform-routing",
|
"--uniform-routing",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -115,6 +111,44 @@ def make_namespace(base: dict, args: argparse.Namespace) -> SimpleNamespace:
|
|||||||
def main() -> None: # pragma: no cover - utility script
|
def main() -> None: # pragma: no cover - utility script
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
def _parse_list(text: str) -> list[int]:
|
||||||
|
return [int(item.strip()) for item in text.split(",") if item.strip()]
|
||||||
|
|
||||||
|
batch_values = _parse_list(args.batches)
|
||||||
|
seq_values = _parse_list(args.seq_lens)
|
||||||
|
hidden_values = _parse_list(args.hidden_sizes)
|
||||||
|
moe_values = _parse_list(args.moe_intermediates)
|
||||||
|
expert_values = _parse_list(args.n_experts_list)
|
||||||
|
topk_values = _parse_list(args.top_ks)
|
||||||
|
group_values = _parse_list(args.groups_list)
|
||||||
|
|
||||||
|
grid = []
|
||||||
|
for batch, seq_len, hidden, moe, n_experts, top_k, groups in itertools.product(
|
||||||
|
batch_values,
|
||||||
|
seq_values,
|
||||||
|
hidden_values,
|
||||||
|
moe_values,
|
||||||
|
expert_values,
|
||||||
|
topk_values,
|
||||||
|
group_values,
|
||||||
|
):
|
||||||
|
if n_experts % groups != 0 or top_k > n_experts:
|
||||||
|
continue
|
||||||
|
grid.append(
|
||||||
|
{
|
||||||
|
"batch": batch,
|
||||||
|
"seq_len": seq_len,
|
||||||
|
"hidden_size": hidden,
|
||||||
|
"moe_intermediate_size": moe,
|
||||||
|
"n_experts": n_experts,
|
||||||
|
"top_k": top_k,
|
||||||
|
"groups": groups,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not grid:
|
||||||
|
raise SystemExit("No valid parameter combinations produced")
|
||||||
|
|
||||||
header = (
|
header = (
|
||||||
"batch",
|
"batch",
|
||||||
"seq_len",
|
"seq_len",
|
||||||
@@ -122,6 +156,7 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
"moe_intermediate",
|
"moe_intermediate",
|
||||||
"n_experts",
|
"n_experts",
|
||||||
"top_k",
|
"top_k",
|
||||||
|
"groups",
|
||||||
"baseline_ms",
|
"baseline_ms",
|
||||||
"patched_ms",
|
"patched_ms",
|
||||||
"speedup",
|
"speedup",
|
||||||
@@ -135,10 +170,10 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}"
|
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}"
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'baseline':>12} {'patched':>12} {'speedup':>8}"
|
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'baseline':>12} {'patched':>12} {'speedup':>8}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for cfg in DEFAULT_SWEEP:
|
for cfg in grid:
|
||||||
ns = make_namespace(cfg, args)
|
ns = make_namespace(cfg, args)
|
||||||
result = benchmark_deepseek_v3(ns)
|
result = benchmark_deepseek_v3(ns)
|
||||||
rows.append(
|
rows.append(
|
||||||
@@ -149,6 +184,7 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
cfg["moe_intermediate_size"],
|
cfg["moe_intermediate_size"],
|
||||||
cfg["n_experts"],
|
cfg["n_experts"],
|
||||||
cfg["top_k"],
|
cfg["top_k"],
|
||||||
|
cfg["groups"],
|
||||||
result["baseline_ms"],
|
result["baseline_ms"],
|
||||||
result["patched_ms"],
|
result["patched_ms"],
|
||||||
result["speedup"],
|
result["speedup"],
|
||||||
@@ -158,7 +194,7 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4}"
|
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}"
|
||||||
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user