narrow sweep; compare both backends
This commit is contained in:
@@ -6,7 +6,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
import itertools
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
@@ -42,61 +41,24 @@ def parse_args() -> argparse.Namespace:
|
|||||||
choices=["auto", "cpu", "cuda"],
|
choices=["auto", "cpu", "cuda"],
|
||||||
help="Execution device",
|
help="Execution device",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--backend",
|
|
||||||
choices=["cg", "mg"],
|
|
||||||
default="mg",
|
|
||||||
help="MoE kernel backend to benchmark",
|
|
||||||
)
|
|
||||||
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
|
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
|
||||||
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
|
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
|
||||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--group-size",
|
"--group-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=128,
|
help="Override GROUP_SIZE_M for every configuration",
|
||||||
help="GROUP_SIZE_M used by the Triton kernel",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batches",
|
|
||||||
default="4,8",
|
|
||||||
help="Comma separated list of batch sizes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seq-lens",
|
|
||||||
default="1024,2048",
|
|
||||||
help="Comma separated list of sequence lengths",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--hidden-sizes",
|
|
||||||
default="2048,4096",
|
|
||||||
help="Comma separated list of hidden sizes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--moe-intermediates",
|
|
||||||
default="4096,8192",
|
|
||||||
help="Comma separated list of MoE intermediate sizes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n-experts-list",
|
|
||||||
default="64,128,256",
|
|
||||||
help="Comma separated list of expert counts",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--top-ks",
|
|
||||||
default="4,8",
|
|
||||||
help="Comma separated list of top-k values",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--groups-list",
|
|
||||||
default="4,8",
|
|
||||||
help="Comma separated list of router group counts",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--uniform-routing",
|
"--uniform-routing",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Force uniform routing for every configuration",
|
help="Force uniform routing for every configuration",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--include-mixtral-long",
|
||||||
|
action="store_true",
|
||||||
|
help="Add an 8×8192 Mixtral-style run to the sweep",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"--output",
|
||||||
type=Path,
|
type=Path,
|
||||||
@@ -105,65 +67,115 @@ def parse_args() -> argparse.Namespace:
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def make_namespace(base: dict, args: argparse.Namespace) -> SimpleNamespace:
|
def make_namespace(
|
||||||
|
base: dict, args: argparse.Namespace, backend: str
|
||||||
|
) -> SimpleNamespace:
|
||||||
combined = dict(base)
|
combined = dict(base)
|
||||||
combined.update(
|
combined.update(
|
||||||
{
|
{
|
||||||
"dtype": args.dtype,
|
"dtype": args.dtype,
|
||||||
"device": args.device,
|
"device": args.device,
|
||||||
"backend": args.backend,
|
"backend": backend,
|
||||||
"warmup": args.warmup,
|
"warmup": args.warmup,
|
||||||
"iters": args.iters,
|
"iters": args.iters,
|
||||||
"seed": args.seed,
|
"seed": args.seed,
|
||||||
"group_size": args.group_size,
|
|
||||||
"uniform_routing": args.uniform_routing,
|
"uniform_routing": args.uniform_routing,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if args.group_size is not None:
|
||||||
|
combined["group_size"] = args.group_size
|
||||||
return SimpleNamespace(**combined)
|
return SimpleNamespace(**combined)
|
||||||
|
|
||||||
|
|
||||||
|
ARCHETYPES = (
|
||||||
|
(
|
||||||
|
"mixtral",
|
||||||
|
{
|
||||||
|
"hidden_size": 4096,
|
||||||
|
"moe_intermediate_size": 14336,
|
||||||
|
"n_experts": 8,
|
||||||
|
"top_k": 2,
|
||||||
|
"groups": 1,
|
||||||
|
"group_size": 128,
|
||||||
|
},
|
||||||
|
[(4, 2048), (8, 4096)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"dbrx",
|
||||||
|
{
|
||||||
|
"hidden_size": 6144,
|
||||||
|
"moe_intermediate_size": 24576,
|
||||||
|
"n_experts": 16,
|
||||||
|
"top_k": 2,
|
||||||
|
"groups": 4,
|
||||||
|
"group_size": 192,
|
||||||
|
},
|
||||||
|
[(4, 4096), (8, 8192)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"qwen",
|
||||||
|
{
|
||||||
|
"hidden_size": 6144,
|
||||||
|
"moe_intermediate_size": 24576,
|
||||||
|
"n_experts": 16,
|
||||||
|
"top_k": 4,
|
||||||
|
"groups": 8,
|
||||||
|
"group_size": 128,
|
||||||
|
},
|
||||||
|
[(4, 4096), (8, 8192)],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"deepseek_v3",
|
||||||
|
{
|
||||||
|
"hidden_size": 12288,
|
||||||
|
"moe_intermediate_size": 49152,
|
||||||
|
"n_experts": 128,
|
||||||
|
"top_k": 8,
|
||||||
|
"groups": 16,
|
||||||
|
"group_size": 128,
|
||||||
|
},
|
||||||
|
[(4, 4096), (8, 8192)],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
MIXTRAL_LONG_SHAPES = [(8, 8192)]
|
||||||
|
|
||||||
|
BACKENDS = ("cg", "mg")
|
||||||
|
|
||||||
|
|
||||||
def main() -> None: # pragma: no cover - utility script
|
def main() -> None: # pragma: no cover - utility script
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
def _parse_list(text: str) -> list[int]:
|
|
||||||
return [int(item.strip()) for item in text.split(",") if item.strip()]
|
|
||||||
|
|
||||||
batch_values = _parse_list(args.batches)
|
|
||||||
seq_values = _parse_list(args.seq_lens)
|
|
||||||
hidden_values = _parse_list(args.hidden_sizes)
|
|
||||||
moe_values = _parse_list(args.moe_intermediates)
|
|
||||||
expert_values = _parse_list(args.n_experts_list)
|
|
||||||
topk_values = _parse_list(args.top_ks)
|
|
||||||
group_values = _parse_list(args.groups_list)
|
|
||||||
|
|
||||||
grid = []
|
grid = []
|
||||||
for batch, seq_len, hidden, moe, n_experts, top_k, groups in itertools.product(
|
for label, base_cfg, shapes in ARCHETYPES:
|
||||||
batch_values,
|
for batch, seq_len in shapes:
|
||||||
seq_values,
|
cfg = {
|
||||||
hidden_values,
|
"label": label,
|
||||||
moe_values,
|
|
||||||
expert_values,
|
|
||||||
topk_values,
|
|
||||||
group_values,
|
|
||||||
):
|
|
||||||
if n_experts % groups != 0 or top_k > n_experts:
|
|
||||||
continue
|
|
||||||
grid.append(
|
|
||||||
{
|
|
||||||
"batch": batch,
|
"batch": batch,
|
||||||
"seq_len": seq_len,
|
"seq_len": seq_len,
|
||||||
"hidden_size": hidden,
|
**base_cfg,
|
||||||
"moe_intermediate_size": moe,
|
|
||||||
"n_experts": n_experts,
|
|
||||||
"top_k": top_k,
|
|
||||||
"groups": groups,
|
|
||||||
}
|
}
|
||||||
)
|
if cfg["n_experts"] % cfg["groups"] != 0 or cfg["top_k"] > cfg["n_experts"]:
|
||||||
|
continue
|
||||||
|
grid.append(cfg)
|
||||||
|
|
||||||
|
if args.include_mixtral_long:
|
||||||
|
base_cfg = ARCHETYPES[0][1]
|
||||||
|
for batch, seq_len in MIXTRAL_LONG_SHAPES:
|
||||||
|
grid.append(
|
||||||
|
{
|
||||||
|
"label": "mixtral_long",
|
||||||
|
"batch": batch,
|
||||||
|
"seq_len": seq_len,
|
||||||
|
**base_cfg,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if not grid:
|
if not grid:
|
||||||
raise SystemExit("No valid parameter combinations produced")
|
raise SystemExit("No valid parameter combinations produced")
|
||||||
|
|
||||||
header = (
|
header = (
|
||||||
|
"model",
|
||||||
"batch",
|
"batch",
|
||||||
"seq_len",
|
"seq_len",
|
||||||
"hidden_size",
|
"hidden_size",
|
||||||
@@ -185,57 +197,59 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
rows = []
|
rows = []
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Running sweep on device={args.device} dtype={args.dtype} backend={args.backend} uniform_routing={args.uniform_routing}"
|
f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={args.uniform_routing}"
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
|
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
|
||||||
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
|
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for cfg in grid:
|
for cfg in grid:
|
||||||
ns = make_namespace(cfg, args)
|
for backend in BACKENDS:
|
||||||
result = benchmark_deepseek_v3(ns)
|
ns = make_namespace(cfg, args, backend)
|
||||||
baseline_vram_mib = (
|
result = benchmark_deepseek_v3(ns)
|
||||||
result["baseline_vram"] / (1024**2)
|
baseline_vram_mib = (
|
||||||
if result["baseline_vram"] is not None
|
result["baseline_vram"] / (1024**2)
|
||||||
else float("nan")
|
if result["baseline_vram"] is not None
|
||||||
)
|
else float("nan")
|
||||||
patched_vram_mib = (
|
|
||||||
result["patched_vram"] / (1024**2)
|
|
||||||
if result["patched_vram"] is not None
|
|
||||||
else float("nan")
|
|
||||||
)
|
|
||||||
rows.append(
|
|
||||||
(
|
|
||||||
cfg["batch"],
|
|
||||||
cfg["seq_len"],
|
|
||||||
cfg["hidden_size"],
|
|
||||||
cfg["moe_intermediate_size"],
|
|
||||||
cfg["n_experts"],
|
|
||||||
cfg["top_k"],
|
|
||||||
cfg["groups"],
|
|
||||||
args.backend,
|
|
||||||
result["baseline_ms"],
|
|
||||||
result["patched_ms"],
|
|
||||||
result["speedup"],
|
|
||||||
baseline_vram_mib,
|
|
||||||
patched_vram_mib,
|
|
||||||
result["min_tokens"],
|
|
||||||
result["max_tokens"],
|
|
||||||
result["max_diff"],
|
|
||||||
result["accuracy_ok"],
|
|
||||||
)
|
)
|
||||||
)
|
patched_vram_mib = (
|
||||||
status = "OK" if result["accuracy_ok"] else "FAIL"
|
result["patched_vram"] / (1024**2)
|
||||||
print(
|
if result["patched_vram"] is not None
|
||||||
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {args.backend:>8}"
|
else float("nan")
|
||||||
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
|
||||||
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
|
|
||||||
)
|
|
||||||
if not result["accuracy_ok"]:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
|
|
||||||
)
|
)
|
||||||
|
rows.append(
|
||||||
|
(
|
||||||
|
cfg["label"],
|
||||||
|
cfg["batch"],
|
||||||
|
cfg["seq_len"],
|
||||||
|
cfg["hidden_size"],
|
||||||
|
cfg["moe_intermediate_size"],
|
||||||
|
cfg["n_experts"],
|
||||||
|
cfg["top_k"],
|
||||||
|
cfg["groups"],
|
||||||
|
backend,
|
||||||
|
result["baseline_ms"],
|
||||||
|
result["patched_ms"],
|
||||||
|
result["speedup"],
|
||||||
|
baseline_vram_mib,
|
||||||
|
patched_vram_mib,
|
||||||
|
result["min_tokens"],
|
||||||
|
result["max_tokens"],
|
||||||
|
result["max_diff"],
|
||||||
|
result["accuracy_ok"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
status = "OK" if result["accuracy_ok"] else "FAIL"
|
||||||
|
print(
|
||||||
|
f"{cfg['label']:>10} {cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {backend:>8}"
|
||||||
|
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
||||||
|
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
|
||||||
|
)
|
||||||
|
if not result["accuracy_ok"]:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
|
||||||
|
)
|
||||||
|
|
||||||
if args.output:
|
if args.output:
|
||||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user