narrow sweep; compare both backends

This commit is contained in:
Dan Saunders
2025-09-25 14:54:03 -04:00
parent 91393c4dc8
commit e003a05177

View File

@@ -6,7 +6,6 @@ from __future__ import annotations
import argparse import argparse
import csv import csv
import itertools
import sys import sys
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
@@ -42,61 +41,24 @@ def parse_args() -> argparse.Namespace:
choices=["auto", "cpu", "cuda"], choices=["auto", "cpu", "cuda"],
help="Execution device", help="Execution device",
) )
parser.add_argument(
"--backend",
choices=["cg", "mg"],
default="mg",
help="MoE kernel backend to benchmark",
)
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations") parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations") parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument( parser.add_argument(
"--group-size", "--group-size",
type=int, type=int,
default=128, help="Override GROUP_SIZE_M for every configuration",
help="GROUP_SIZE_M used by the Triton kernel",
)
parser.add_argument(
"--batches",
default="4,8",
help="Comma separated list of batch sizes",
)
parser.add_argument(
"--seq-lens",
default="1024,2048",
help="Comma separated list of sequence lengths",
)
parser.add_argument(
"--hidden-sizes",
default="2048,4096",
help="Comma separated list of hidden sizes",
)
parser.add_argument(
"--moe-intermediates",
default="4096,8192",
help="Comma separated list of MoE intermediate sizes",
)
parser.add_argument(
"--n-experts-list",
default="64,128,256",
help="Comma separated list of expert counts",
)
parser.add_argument(
"--top-ks",
default="4,8",
help="Comma separated list of top-k values",
)
parser.add_argument(
"--groups-list",
default="4,8",
help="Comma separated list of router group counts",
) )
parser.add_argument( parser.add_argument(
"--uniform-routing", "--uniform-routing",
action="store_true", action="store_true",
help="Force uniform routing for every configuration", help="Force uniform routing for every configuration",
) )
parser.add_argument(
"--include-mixtral-long",
action="store_true",
help="Add an 8×8192 Mixtral-style run to the sweep",
)
parser.add_argument( parser.add_argument(
"--output", "--output",
type=Path, type=Path,
@@ -105,65 +67,115 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args() return parser.parse_args()
def make_namespace(base: dict, args: argparse.Namespace) -> SimpleNamespace: def make_namespace(
base: dict, args: argparse.Namespace, backend: str
) -> SimpleNamespace:
combined = dict(base) combined = dict(base)
combined.update( combined.update(
{ {
"dtype": args.dtype, "dtype": args.dtype,
"device": args.device, "device": args.device,
"backend": args.backend, "backend": backend,
"warmup": args.warmup, "warmup": args.warmup,
"iters": args.iters, "iters": args.iters,
"seed": args.seed, "seed": args.seed,
"group_size": args.group_size,
"uniform_routing": args.uniform_routing, "uniform_routing": args.uniform_routing,
} }
) )
if args.group_size is not None:
combined["group_size"] = args.group_size
return SimpleNamespace(**combined) return SimpleNamespace(**combined)
ARCHETYPES = (
(
"mixtral",
{
"hidden_size": 4096,
"moe_intermediate_size": 14336,
"n_experts": 8,
"top_k": 2,
"groups": 1,
"group_size": 128,
},
[(4, 2048), (8, 4096)],
),
(
"dbrx",
{
"hidden_size": 6144,
"moe_intermediate_size": 24576,
"n_experts": 16,
"top_k": 2,
"groups": 4,
"group_size": 192,
},
[(4, 4096), (8, 8192)],
),
(
"qwen",
{
"hidden_size": 6144,
"moe_intermediate_size": 24576,
"n_experts": 16,
"top_k": 4,
"groups": 8,
"group_size": 128,
},
[(4, 4096), (8, 8192)],
),
(
"deepseek_v3",
{
"hidden_size": 12288,
"moe_intermediate_size": 49152,
"n_experts": 128,
"top_k": 8,
"groups": 16,
"group_size": 128,
},
[(4, 4096), (8, 8192)],
),
)
MIXTRAL_LONG_SHAPES = [(8, 8192)]
BACKENDS = ("cg", "mg")
def main() -> None: # pragma: no cover - utility script def main() -> None: # pragma: no cover - utility script
args = parse_args() args = parse_args()
def _parse_list(text: str) -> list[int]:
return [int(item.strip()) for item in text.split(",") if item.strip()]
batch_values = _parse_list(args.batches)
seq_values = _parse_list(args.seq_lens)
hidden_values = _parse_list(args.hidden_sizes)
moe_values = _parse_list(args.moe_intermediates)
expert_values = _parse_list(args.n_experts_list)
topk_values = _parse_list(args.top_ks)
group_values = _parse_list(args.groups_list)
grid = [] grid = []
for batch, seq_len, hidden, moe, n_experts, top_k, groups in itertools.product( for label, base_cfg, shapes in ARCHETYPES:
batch_values, for batch, seq_len in shapes:
seq_values, cfg = {
hidden_values, "label": label,
moe_values,
expert_values,
topk_values,
group_values,
):
if n_experts % groups != 0 or top_k > n_experts:
continue
grid.append(
{
"batch": batch, "batch": batch,
"seq_len": seq_len, "seq_len": seq_len,
"hidden_size": hidden, **base_cfg,
"moe_intermediate_size": moe,
"n_experts": n_experts,
"top_k": top_k,
"groups": groups,
} }
) if cfg["n_experts"] % cfg["groups"] != 0 or cfg["top_k"] > cfg["n_experts"]:
continue
grid.append(cfg)
if args.include_mixtral_long:
base_cfg = ARCHETYPES[0][1]
for batch, seq_len in MIXTRAL_LONG_SHAPES:
grid.append(
{
"label": "mixtral_long",
"batch": batch,
"seq_len": seq_len,
**base_cfg,
}
)
if not grid: if not grid:
raise SystemExit("No valid parameter combinations produced") raise SystemExit("No valid parameter combinations produced")
header = ( header = (
"model",
"batch", "batch",
"seq_len", "seq_len",
"hidden_size", "hidden_size",
@@ -185,57 +197,59 @@ def main() -> None: # pragma: no cover - utility script
rows = [] rows = []
print( print(
f"Running sweep on device={args.device} dtype={args.dtype} backend={args.backend} uniform_routing={args.uniform_routing}" f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={args.uniform_routing}"
) )
print( print(
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}" f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}" f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
) )
for cfg in grid: for cfg in grid:
ns = make_namespace(cfg, args) for backend in BACKENDS:
result = benchmark_deepseek_v3(ns) ns = make_namespace(cfg, args, backend)
baseline_vram_mib = ( result = benchmark_deepseek_v3(ns)
result["baseline_vram"] / (1024**2) baseline_vram_mib = (
if result["baseline_vram"] is not None result["baseline_vram"] / (1024**2)
else float("nan") if result["baseline_vram"] is not None
) else float("nan")
patched_vram_mib = (
result["patched_vram"] / (1024**2)
if result["patched_vram"] is not None
else float("nan")
)
rows.append(
(
cfg["batch"],
cfg["seq_len"],
cfg["hidden_size"],
cfg["moe_intermediate_size"],
cfg["n_experts"],
cfg["top_k"],
cfg["groups"],
args.backend,
result["baseline_ms"],
result["patched_ms"],
result["speedup"],
baseline_vram_mib,
patched_vram_mib,
result["min_tokens"],
result["max_tokens"],
result["max_diff"],
result["accuracy_ok"],
) )
) patched_vram_mib = (
status = "OK" if result["accuracy_ok"] else "FAIL" result["patched_vram"] / (1024**2)
print( if result["patched_vram"] is not None
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {args.backend:>8}" else float("nan")
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
)
if not result["accuracy_ok"]:
raise RuntimeError(
f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
) )
rows.append(
(
cfg["label"],
cfg["batch"],
cfg["seq_len"],
cfg["hidden_size"],
cfg["moe_intermediate_size"],
cfg["n_experts"],
cfg["top_k"],
cfg["groups"],
backend,
result["baseline_ms"],
result["patched_ms"],
result["speedup"],
baseline_vram_mib,
patched_vram_mib,
result["min_tokens"],
result["max_tokens"],
result["max_diff"],
result["accuracy_ok"],
)
)
status = "OK" if result["accuracy_ok"] else "FAIL"
print(
f"{cfg['label']:>10} {cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {backend:>8}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
)
if not result["accuracy_ok"]:
raise RuntimeError(
f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
)
if args.output: if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True) args.output.parent.mkdir(parents=True, exist_ok=True)