This commit is contained in:
Dan Saunders
2025-09-22 16:21:50 -04:00
parent 9d69c6fb3e
commit 8d8fa834a2
4 changed files with 211 additions and 14 deletions

View File

@@ -1,3 +1,5 @@
#!/usr/bin/env python
# mypy: ignore-errors
"""Microbenchmark for DeepSeek V3 MoE block comparing baseline vs Triton CG kernels."""
from __future__ import annotations
@@ -7,10 +9,16 @@ import time
from types import MethodType
import torch
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
DeepseekV3Config,
)
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
try:
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
DeepseekV3Config,
)
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
except ImportError as exc: # pragma: no cover - utility script
raise SystemExit(
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
) from exc
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
@@ -102,7 +110,7 @@ def build_module(args: argparse.Namespace) -> DeepseekV3MoE:
@torch.no_grad()
def benchmark(
def timed_forward(
module: DeepseekV3MoE, inputs: torch.Tensor, iters: int, warmup: int
) -> float:
for _ in range(warmup):
@@ -118,8 +126,7 @@ def benchmark(
return (elapsed / iters) * 1000.0
def main() -> None: # pragma: no cover - CLI entrypoint
args = parse_args()
def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
torch.manual_seed(args.seed)
device = resolve_device(args.device)
@@ -178,6 +185,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint
topk_idx = assignments.view(flat_inputs.size(0), args.top_k)
else:
topk_idx, _ = patched_module.gate(flat_inputs)
tokens_per_expert = torch.bincount(
topk_idx.reshape(-1), minlength=args.n_experts
)
@@ -209,22 +217,40 @@ def main() -> None: # pragma: no cover - CLI entrypoint
patched_output = patched_module(inputs)
max_diff = (ref_output - patched_output).abs().max().item()
baseline_ms = benchmark(baseline_module, inputs, args.iters, args.warmup)
patched_ms = benchmark(patched_module, inputs, args.iters, args.warmup)
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
return {
"device": device,
"dtype": dtype,
"baseline_ms": baseline_ms,
"patched_ms": patched_ms,
"speedup": speedup,
"max_diff": max_diff,
"routed_tokens": routed_tokens,
"avg_tokens": avg_tokens_per_expert,
"min_tokens": min_tokens,
"max_tokens": max_tokens,
}
def main() -> None: # pragma: no cover - CLI entrypoint
args = parse_args()
result = benchmark_deepseek_v3(args)
print(
f"Device={device.type} dtype={dtype} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
f"Device={result['device'].type} dtype={result['dtype']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
)
print(
f"routed tokens={routed_tokens} avg tokens/expert={avg_tokens_per_expert:.1f} group_size={args.group_size}"
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
)
print(f"min/max tokens per expert: {min_tokens}/{max_tokens}")
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
print(
f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}"
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
)
print(f"Max |Δ| between outputs: {max_diff:.2e}")
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
if __name__ == "__main__":