This commit is contained in:
Dan Saunders
2025-09-22 16:21:50 -04:00
parent 9d69c6fb3e
commit 8d8fa834a2
4 changed files with 211 additions and 14 deletions

View File

@@ -0,0 +1 @@
"""Benchmark helpers."""

View File

@@ -1,3 +1,5 @@
#!/usr/bin/env python
# mypy: ignore-errors
"""Microbenchmark for DeepSeek V3 MoE block comparing baseline vs Triton CG kernels."""
from __future__ import annotations
@@ -7,10 +9,16 @@ import time
from types import MethodType
import torch
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
DeepseekV3Config,
)
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
try:
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
DeepseekV3Config,
)
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
except ImportError as exc: # pragma: no cover - utility script
raise SystemExit(
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
) from exc
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
@@ -102,7 +110,7 @@ def build_module(args: argparse.Namespace) -> DeepseekV3MoE:
@torch.no_grad()
def benchmark(
def timed_forward(
module: DeepseekV3MoE, inputs: torch.Tensor, iters: int, warmup: int
) -> float:
for _ in range(warmup):
@@ -118,8 +126,7 @@ def benchmark(
return (elapsed / iters) * 1000.0
def main() -> None: # pragma: no cover - CLI entrypoint
args = parse_args()
def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
torch.manual_seed(args.seed)
device = resolve_device(args.device)
@@ -178,6 +185,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint
topk_idx = assignments.view(flat_inputs.size(0), args.top_k)
else:
topk_idx, _ = patched_module.gate(flat_inputs)
tokens_per_expert = torch.bincount(
topk_idx.reshape(-1), minlength=args.n_experts
)
@@ -209,22 +217,40 @@ def main() -> None: # pragma: no cover - CLI entrypoint
patched_output = patched_module(inputs)
max_diff = (ref_output - patched_output).abs().max().item()
baseline_ms = benchmark(baseline_module, inputs, args.iters, args.warmup)
patched_ms = benchmark(patched_module, inputs, args.iters, args.warmup)
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
return {
"device": device,
"dtype": dtype,
"baseline_ms": baseline_ms,
"patched_ms": patched_ms,
"speedup": speedup,
"max_diff": max_diff,
"routed_tokens": routed_tokens,
"avg_tokens": avg_tokens_per_expert,
"min_tokens": min_tokens,
"max_tokens": max_tokens,
}
def main() -> None: # pragma: no cover - CLI entrypoint
args = parse_args()
result = benchmark_deepseek_v3(args)
print(
f"Device={device.type} dtype={dtype} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
f"Device={result['device'].type} dtype={result['dtype']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
)
print(
f"routed tokens={routed_tokens} avg tokens/expert={avg_tokens_per_expert:.1f} group_size={args.group_size}"
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
)
print(f"min/max tokens per expert: {min_tokens}/{max_tokens}")
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
print(
f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}"
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
)
print(f"Max |Δ| between outputs: {max_diff:.2e}")
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
if __name__ == "__main__":

View File

@@ -0,0 +1,169 @@
#!/usr/bin/env python
# mypy: ignore-errors
"""Sweep a set of DeepSeek V3 MoE benchmark configurations."""
from __future__ import annotations
import argparse
import csv
from pathlib import Path
from types import SimpleNamespace
from scripts.benchmarks.deepseek_v3_moe import (
DTYPE_MAP,
benchmark_deepseek_v3,
)
DEFAULT_SWEEP = [
{
"batch": 4,
"seq_len": 1024,
"hidden_size": 2048,
"moe_intermediate_size": 4096,
"n_experts": 64,
"top_k": 4,
"groups": 4,
},
{
"batch": 8,
"seq_len": 2048,
"hidden_size": 2048,
"moe_intermediate_size": 4096,
"n_experts": 64,
"top_k": 4,
"groups": 4,
},
{
"batch": 8,
"seq_len": 2048,
"hidden_size": 4096,
"moe_intermediate_size": 8192,
"n_experts": 128,
"top_k": 8,
"groups": 8,
},
{
"batch": 8,
"seq_len": 2048,
"hidden_size": 4096,
"moe_intermediate_size": 8192,
"n_experts": 256,
"top_k": 8,
"groups": 8,
},
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--dtype",
choices=DTYPE_MAP.keys(),
default="bf16",
help="Computation dtype for all benchmarks",
)
parser.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda"],
help="Execution device",
)
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--group-size",
type=int,
default=128,
help="GROUP_SIZE_M used by the Triton kernel",
)
parser.add_argument(
"--uniform-routing",
action="store_true",
help="Force uniform routing for every configuration",
)
parser.add_argument(
"--output",
type=Path,
help="Optional CSV file to store results",
)
return parser.parse_args()
def make_namespace(base: dict, args: argparse.Namespace) -> SimpleNamespace:
combined = dict(base)
combined.update(
{
"dtype": args.dtype,
"device": args.device,
"warmup": args.warmup,
"iters": args.iters,
"seed": args.seed,
"group_size": args.group_size,
"uniform_routing": args.uniform_routing,
}
)
return SimpleNamespace(**combined)
def main() -> None: # pragma: no cover - utility script
args = parse_args()
header = (
"batch",
"seq_len",
"hidden_size",
"moe_intermediate",
"n_experts",
"top_k",
"baseline_ms",
"patched_ms",
"speedup",
"min_tokens",
"max_tokens",
"max_diff",
)
rows = []
print(
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}"
)
print(
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'baseline':>12} {'patched':>12} {'speedup':>8}"
)
for cfg in DEFAULT_SWEEP:
ns = make_namespace(cfg, args)
result = benchmark_deepseek_v3(ns)
rows.append(
(
cfg["batch"],
cfg["seq_len"],
cfg["hidden_size"],
cfg["moe_intermediate_size"],
cfg["n_experts"],
cfg["top_k"],
result["baseline_ms"],
result["patched_ms"],
result["speedup"],
result["min_tokens"],
result["max_tokens"],
result["max_diff"],
)
)
print(
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
)
if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", newline="") as fp:
writer = csv.writer(fp)
writer.writerow(header)
writer.writerows(rows)
print(f"Results written to {args.output}")
if __name__ == "__main__":
main()