diff --git a/scripts/benchmarks/__init__.py b/scripts/benchmarks/__init__.py index 7bb8fc1a2..6413b9194 100644 --- a/scripts/benchmarks/__init__.py +++ b/scripts/benchmarks/__init__.py @@ -1,5 +1,5 @@ """Benchmark helpers.""" -from .deepseek_v3_moe import DTYPE_MAP, benchmark_deepseek_v3 +from .deepseek_v3_moe import ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3 -__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP"] +__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP", "ACCURACY_TOLERANCE"] diff --git a/scripts/benchmarks/deepseek_v3_group_gemm_table.py b/scripts/benchmarks/deepseek_v3_group_gemm_table.py index e0aae9cda..a3ef4d78d 100644 --- a/scripts/benchmarks/deepseek_v3_group_gemm_table.py +++ b/scripts/benchmarks/deepseek_v3_group_gemm_table.py @@ -4,12 +4,24 @@ from __future__ import annotations import argparse +import sys import time from dataclasses import dataclass +from pathlib import Path from typing import Iterable import torch +CURRENT_DIR = Path(__file__).resolve().parent +for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]: + repo_root = candidate / "axolotl" + if repo_root.exists(): + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + break +else: + raise SystemExit("Unable to locate axolotl repository root for imports") + from axolotl.kernels.moe import ( cg_grouped_gemm_forward, cg_grouped_gemm_forward_dynamic, diff --git a/scripts/benchmarks/deepseek_v3_moe.py b/scripts/benchmarks/deepseek_v3_moe.py index 68492e39e..7a48e47ac 100644 --- a/scripts/benchmarks/deepseek_v3_moe.py +++ b/scripts/benchmarks/deepseek_v3_moe.py @@ -5,7 +5,9 @@ from __future__ import annotations import argparse +import sys import time +from pathlib import Path from types import MethodType import torch @@ -20,8 +22,20 @@ except ImportError as exc: # pragma: no cover - utility script "Transformers with DeepSeek-V3 support must be available in PYTHONPATH" ) from exc +CURRENT_DIR = Path(__file__).resolve().parent +for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]: + repo_root = candidate / "axolotl" + if repo_root.exists(): + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + break +else: # pragma: no cover - defensive guard + raise SystemExit("Unable to locate axolotl repository root for imports") + from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe +ACCURACY_TOLERANCE = 5e-3 + DTYPE_MAP = { "bf16": torch.bfloat16, "fp16": torch.float16, @@ -221,8 +235,16 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict: patched_output = patched_module(inputs) max_diff = (ref_output - patched_output).abs().max().item() + baseline_vram = patched_vram = None + if device.type == "cuda": + torch.cuda.reset_peak_memory_stats(device) baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup) + if device.type == "cuda": + baseline_vram = torch.cuda.max_memory_allocated(device) + torch.cuda.reset_peak_memory_stats(device) patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup) + if device.type == "cuda": + patched_vram = torch.cuda.max_memory_allocated(device) speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan") @@ -237,6 +259,9 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict: "avg_tokens": avg_tokens_per_expert, "min_tokens": min_tokens, "max_tokens": max_tokens, + "baseline_vram": baseline_vram, + "patched_vram": patched_vram, + "accuracy_ok": max_diff <= ACCURACY_TOLERANCE, } @@ -251,10 +276,18 @@ def main() -> None: # pragma: no cover - CLI entrypoint f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}" ) print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}") + if result["baseline_vram"] is not None: + print( + f"VRAM baseline={result['baseline_vram'] / (1024**2):.1f} MiB | patched={result['patched_vram'] / (1024**2):.1f} MiB" + ) print( f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}" ) print(f"Max |Δ| between outputs: {result['max_diff']:.2e}") + if not result["accuracy_ok"]: + raise RuntimeError( + f"Accuracy check failed: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}" + ) if __name__ == "__main__": diff --git a/scripts/benchmarks/deepseek_v3_moe_sweep.py b/scripts/benchmarks/deepseek_v3_moe_sweep.py index 6611e27f0..7b7c68763 100644 --- a/scripts/benchmarks/deepseek_v3_moe_sweep.py +++ b/scripts/benchmarks/deepseek_v3_moe_sweep.py @@ -12,11 +12,17 @@ from pathlib import Path from types import SimpleNamespace CURRENT_DIR = Path(__file__).resolve().parent -REPO_ROOT = CURRENT_DIR.parent.parent -if str(REPO_ROOT) not in sys.path: - sys.path.insert(0, str(REPO_ROOT)) +for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]: + repo_root = candidate / "axolotl" + if repo_root.exists(): + if str(repo_root) not in sys.path: + sys.path.insert(0, str(repo_root)) + break +else: # pragma: no cover - defensive guard + raise SystemExit("Unable to locate axolotl repository root for imports") from scripts.benchmarks.deepseek_v3_moe import ( + ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3, ) @@ -161,9 +167,12 @@ def main() -> None: # pragma: no cover - utility script "baseline_ms", "patched_ms", "speedup", + "baseline_vram_mib", + "patched_vram_mib", "min_tokens", "max_tokens", "max_diff", + "accuracy_ok", ) rows = [] @@ -171,12 +180,23 @@ def main() -> None: # pragma: no cover - utility script f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}" ) print( - f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'baseline':>12} {'patched':>12} {'speedup':>8}" + f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6}" + f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}" ) for cfg in grid: ns = make_namespace(cfg, args) result = benchmark_deepseek_v3(ns) + baseline_vram_mib = ( + result["baseline_vram"] / (1024**2) + if result["baseline_vram"] is not None + else float("nan") + ) + patched_vram_mib = ( + result["patched_vram"] / (1024**2) + if result["patched_vram"] is not None + else float("nan") + ) rows.append( ( cfg["batch"], @@ -189,15 +209,24 @@ def main() -> None: # pragma: no cover - utility script result["baseline_ms"], result["patched_ms"], result["speedup"], + baseline_vram_mib, + patched_vram_mib, result["min_tokens"], result["max_tokens"], result["max_diff"], + result["accuracy_ok"], ) ) + status = "OK" if result["accuracy_ok"] else "FAIL" print( f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}" f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x" + f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}" ) + if not result["accuracy_ok"]: + raise RuntimeError( + f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}" + ) if args.output: args.output.parent.mkdir(parents=True, exist_ok=True)