This commit is contained in:
Dan Saunders
2025-09-23 13:50:48 -04:00
parent fd312f6058
commit 146ca48cba
4 changed files with 80 additions and 6 deletions

View File

@@ -1,5 +1,5 @@
"""Benchmark helpers.""" """Benchmark helpers."""
from .deepseek_v3_moe import DTYPE_MAP, benchmark_deepseek_v3 from .deepseek_v3_moe import ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3
__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP"] __all__ = ["benchmark_deepseek_v3", "DTYPE_MAP", "ACCURACY_TOLERANCE"]

View File

@@ -4,12 +4,24 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import sys
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import Iterable from typing import Iterable
import torch import torch
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else:
raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.kernels.moe import ( from axolotl.kernels.moe import (
cg_grouped_gemm_forward, cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic, cg_grouped_gemm_forward_dynamic,

View File

@@ -5,7 +5,9 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import sys
import time import time
from pathlib import Path
from types import MethodType from types import MethodType
import torch import torch
@@ -20,8 +22,20 @@ except ImportError as exc: # pragma: no cover - utility script
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH" "Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
) from exc ) from exc
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
ACCURACY_TOLERANCE = 5e-3
DTYPE_MAP = { DTYPE_MAP = {
"bf16": torch.bfloat16, "bf16": torch.bfloat16,
"fp16": torch.float16, "fp16": torch.float16,
@@ -221,8 +235,16 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
patched_output = patched_module(inputs) patched_output = patched_module(inputs)
max_diff = (ref_output - patched_output).abs().max().item() max_diff = (ref_output - patched_output).abs().max().item()
baseline_vram = patched_vram = None
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup) baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
if device.type == "cuda":
baseline_vram = torch.cuda.max_memory_allocated(device)
torch.cuda.reset_peak_memory_stats(device)
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup) patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
if device.type == "cuda":
patched_vram = torch.cuda.max_memory_allocated(device)
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan") speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
@@ -237,6 +259,9 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
"avg_tokens": avg_tokens_per_expert, "avg_tokens": avg_tokens_per_expert,
"min_tokens": min_tokens, "min_tokens": min_tokens,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"baseline_vram": baseline_vram,
"patched_vram": patched_vram,
"accuracy_ok": max_diff <= ACCURACY_TOLERANCE,
} }
@@ -251,10 +276,18 @@ def main() -> None: # pragma: no cover - CLI entrypoint
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}" f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
) )
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}") print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
if result["baseline_vram"] is not None:
print(
f"VRAM baseline={result['baseline_vram'] / (1024**2):.1f} MiB | patched={result['patched_vram'] / (1024**2):.1f} MiB"
)
print( print(
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}" f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
) )
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}") print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
if not result["accuracy_ok"]:
raise RuntimeError(
f"Accuracy check failed: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -12,11 +12,17 @@ from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
CURRENT_DIR = Path(__file__).resolve().parent CURRENT_DIR = Path(__file__).resolve().parent
REPO_ROOT = CURRENT_DIR.parent.parent for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
if str(REPO_ROOT) not in sys.path: repo_root = candidate / "axolotl"
sys.path.insert(0, str(REPO_ROOT)) if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from scripts.benchmarks.deepseek_v3_moe import ( from scripts.benchmarks.deepseek_v3_moe import (
ACCURACY_TOLERANCE,
DTYPE_MAP, DTYPE_MAP,
benchmark_deepseek_v3, benchmark_deepseek_v3,
) )
@@ -161,9 +167,12 @@ def main() -> None: # pragma: no cover - utility script
"baseline_ms", "baseline_ms",
"patched_ms", "patched_ms",
"speedup", "speedup",
"baseline_vram_mib",
"patched_vram_mib",
"min_tokens", "min_tokens",
"max_tokens", "max_tokens",
"max_diff", "max_diff",
"accuracy_ok",
) )
rows = [] rows = []
@@ -171,12 +180,23 @@ def main() -> None: # pragma: no cover - utility script
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}" f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}"
) )
print( print(
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'baseline':>12} {'patched':>12} {'speedup':>8}" f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6}"
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
) )
for cfg in grid: for cfg in grid:
ns = make_namespace(cfg, args) ns = make_namespace(cfg, args)
result = benchmark_deepseek_v3(ns) result = benchmark_deepseek_v3(ns)
baseline_vram_mib = (
result["baseline_vram"] / (1024**2)
if result["baseline_vram"] is not None
else float("nan")
)
patched_vram_mib = (
result["patched_vram"] / (1024**2)
if result["patched_vram"] is not None
else float("nan")
)
rows.append( rows.append(
( (
cfg["batch"], cfg["batch"],
@@ -189,15 +209,24 @@ def main() -> None: # pragma: no cover - utility script
result["baseline_ms"], result["baseline_ms"],
result["patched_ms"], result["patched_ms"],
result["speedup"], result["speedup"],
baseline_vram_mib,
patched_vram_mib,
result["min_tokens"], result["min_tokens"],
result["max_tokens"], result["max_tokens"],
result["max_diff"], result["max_diff"],
result["accuracy_ok"],
) )
) )
status = "OK" if result["accuracy_ok"] else "FAIL"
print( print(
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}" f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x" f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
) )
if not result["accuracy_ok"]:
raise RuntimeError(
f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
)
if args.output: if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True) args.output.parent.mkdir(parents=True, exist_ok=True)