Files
axolotl/scripts/benchmarks/deepseek_v3_moe_sweep.py
2025-09-25 14:27:34 -04:00

251 lines
7.5 KiB
Python

#!/usr/bin/env python
# mypy: ignore-errors
"""Sweep a set of DeepSeek V3 MoE benchmark configurations."""
from __future__ import annotations
import argparse
import csv
import itertools
import sys
from pathlib import Path
from types import SimpleNamespace
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402
ACCURACY_TOLERANCE,
DTYPE_MAP,
benchmark_deepseek_v3,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--dtype",
choices=DTYPE_MAP.keys(),
default="bf16",
help="Computation dtype for all benchmarks",
)
parser.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda"],
help="Execution device",
)
parser.add_argument(
"--backend",
choices=["cg", "mg"],
default="mg",
help="MoE kernel backend to benchmark",
)
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--group-size",
type=int,
default=128,
help="GROUP_SIZE_M used by the Triton kernel",
)
parser.add_argument(
"--batches",
default="4,8",
help="Comma separated list of batch sizes",
)
parser.add_argument(
"--seq-lens",
default="1024,2048",
help="Comma separated list of sequence lengths",
)
parser.add_argument(
"--hidden-sizes",
default="2048,4096",
help="Comma separated list of hidden sizes",
)
parser.add_argument(
"--moe-intermediates",
default="4096,8192",
help="Comma separated list of MoE intermediate sizes",
)
parser.add_argument(
"--n-experts-list",
default="64,128,256",
help="Comma separated list of expert counts",
)
parser.add_argument(
"--top-ks",
default="4,8",
help="Comma separated list of top-k values",
)
parser.add_argument(
"--groups-list",
default="4,8",
help="Comma separated list of router group counts",
)
parser.add_argument(
"--uniform-routing",
action="store_true",
help="Force uniform routing for every configuration",
)
parser.add_argument(
"--output",
type=Path,
help="Optional CSV file to store results",
)
return parser.parse_args()
def make_namespace(base: dict, args: argparse.Namespace) -> SimpleNamespace:
combined = dict(base)
combined.update(
{
"dtype": args.dtype,
"device": args.device,
"backend": args.backend,
"warmup": args.warmup,
"iters": args.iters,
"seed": args.seed,
"group_size": args.group_size,
"uniform_routing": args.uniform_routing,
}
)
return SimpleNamespace(**combined)
def main() -> None: # pragma: no cover - utility script
args = parse_args()
def _parse_list(text: str) -> list[int]:
return [int(item.strip()) for item in text.split(",") if item.strip()]
batch_values = _parse_list(args.batches)
seq_values = _parse_list(args.seq_lens)
hidden_values = _parse_list(args.hidden_sizes)
moe_values = _parse_list(args.moe_intermediates)
expert_values = _parse_list(args.n_experts_list)
topk_values = _parse_list(args.top_ks)
group_values = _parse_list(args.groups_list)
grid = []
for batch, seq_len, hidden, moe, n_experts, top_k, groups in itertools.product(
batch_values,
seq_values,
hidden_values,
moe_values,
expert_values,
topk_values,
group_values,
):
if n_experts % groups != 0 or top_k > n_experts:
continue
grid.append(
{
"batch": batch,
"seq_len": seq_len,
"hidden_size": hidden,
"moe_intermediate_size": moe,
"n_experts": n_experts,
"top_k": top_k,
"groups": groups,
}
)
if not grid:
raise SystemExit("No valid parameter combinations produced")
header = (
"batch",
"seq_len",
"hidden_size",
"moe_intermediate",
"n_experts",
"top_k",
"groups",
"backend",
"baseline_ms",
"patched_ms",
"speedup",
"baseline_vram_mib",
"patched_vram_mib",
"min_tokens",
"max_tokens",
"max_diff",
"accuracy_ok",
)
rows = []
print(
f"Running sweep on device={args.device} dtype={args.dtype} backend={args.backend} uniform_routing={args.uniform_routing}"
)
print(
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
)
for cfg in grid:
ns = make_namespace(cfg, args)
result = benchmark_deepseek_v3(ns)
baseline_vram_mib = (
result["baseline_vram"] / (1024**2)
if result["baseline_vram"] is not None
else float("nan")
)
patched_vram_mib = (
result["patched_vram"] / (1024**2)
if result["patched_vram"] is not None
else float("nan")
)
rows.append(
(
cfg["batch"],
cfg["seq_len"],
cfg["hidden_size"],
cfg["moe_intermediate_size"],
cfg["n_experts"],
cfg["top_k"],
cfg["groups"],
args.backend,
result["baseline_ms"],
result["patched_ms"],
result["speedup"],
baseline_vram_mib,
patched_vram_mib,
result["min_tokens"],
result["max_tokens"],
result["max_diff"],
result["accuracy_ok"],
)
)
status = "OK" if result["accuracy_ok"] else "FAIL"
print(
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {args.backend:>8}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
)
if not result["accuracy_ok"]:
raise RuntimeError(
f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
)
if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", newline="") as fp:
writer = csv.writer(fp)
writer.writerow(header)
writer.writerows(rows)
print(f"Results written to {args.output}")
if __name__ == "__main__":
main()