vram
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""Benchmark helpers."""
|
||||
|
||||
from .deepseek_v3_moe import DTYPE_MAP, benchmark_deepseek_v3
|
||||
from .deepseek_v3_moe import ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3
|
||||
|
||||
__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP"]
|
||||
__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP", "ACCURACY_TOLERANCE"]
|
||||
|
||||
@@ -4,12 +4,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else:
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from axolotl.kernels.moe import (
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
|
||||
import torch
|
||||
@@ -20,8 +22,20 @@ except ImportError as exc: # pragma: no cover - utility script
|
||||
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
|
||||
) from exc
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else: # pragma: no cover - defensive guard
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||
|
||||
ACCURACY_TOLERANCE = 5e-3
|
||||
|
||||
DTYPE_MAP = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
@@ -221,8 +235,16 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
||||
patched_output = patched_module(inputs)
|
||||
max_diff = (ref_output - patched_output).abs().max().item()
|
||||
|
||||
baseline_vram = patched_vram = None
|
||||
if device.type == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
|
||||
if device.type == "cuda":
|
||||
baseline_vram = torch.cuda.max_memory_allocated(device)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
|
||||
if device.type == "cuda":
|
||||
patched_vram = torch.cuda.max_memory_allocated(device)
|
||||
|
||||
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
|
||||
|
||||
@@ -237,6 +259,9 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
||||
"avg_tokens": avg_tokens_per_expert,
|
||||
"min_tokens": min_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"baseline_vram": baseline_vram,
|
||||
"patched_vram": patched_vram,
|
||||
"accuracy_ok": max_diff <= ACCURACY_TOLERANCE,
|
||||
}
|
||||
|
||||
|
||||
@@ -251,10 +276,18 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
|
||||
)
|
||||
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
|
||||
if result["baseline_vram"] is not None:
|
||||
print(
|
||||
f"VRAM baseline={result['baseline_vram'] / (1024**2):.1f} MiB | patched={result['patched_vram'] / (1024**2):.1f} MiB"
|
||||
)
|
||||
print(
|
||||
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
|
||||
)
|
||||
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
|
||||
if not result["accuracy_ok"]:
|
||||
raise RuntimeError(
|
||||
f"Accuracy check failed: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -12,11 +12,17 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
REPO_ROOT = CURRENT_DIR.parent.parent
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else: # pragma: no cover - defensive guard
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from scripts.benchmarks.deepseek_v3_moe import (
|
||||
ACCURACY_TOLERANCE,
|
||||
DTYPE_MAP,
|
||||
benchmark_deepseek_v3,
|
||||
)
|
||||
@@ -161,9 +167,12 @@ def main() -> None: # pragma: no cover - utility script
|
||||
"baseline_ms",
|
||||
"patched_ms",
|
||||
"speedup",
|
||||
"baseline_vram_mib",
|
||||
"patched_vram_mib",
|
||||
"min_tokens",
|
||||
"max_tokens",
|
||||
"max_diff",
|
||||
"accuracy_ok",
|
||||
)
|
||||
rows = []
|
||||
|
||||
@@ -171,12 +180,23 @@ def main() -> None: # pragma: no cover - utility script
|
||||
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}"
|
||||
)
|
||||
print(
|
||||
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'baseline':>12} {'patched':>12} {'speedup':>8}"
|
||||
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6}"
|
||||
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
|
||||
)
|
||||
|
||||
for cfg in grid:
|
||||
ns = make_namespace(cfg, args)
|
||||
result = benchmark_deepseek_v3(ns)
|
||||
baseline_vram_mib = (
|
||||
result["baseline_vram"] / (1024**2)
|
||||
if result["baseline_vram"] is not None
|
||||
else float("nan")
|
||||
)
|
||||
patched_vram_mib = (
|
||||
result["patched_vram"] / (1024**2)
|
||||
if result["patched_vram"] is not None
|
||||
else float("nan")
|
||||
)
|
||||
rows.append(
|
||||
(
|
||||
cfg["batch"],
|
||||
@@ -189,15 +209,24 @@ def main() -> None: # pragma: no cover - utility script
|
||||
result["baseline_ms"],
|
||||
result["patched_ms"],
|
||||
result["speedup"],
|
||||
baseline_vram_mib,
|
||||
patched_vram_mib,
|
||||
result["min_tokens"],
|
||||
result["max_tokens"],
|
||||
result["max_diff"],
|
||||
result["accuracy_ok"],
|
||||
)
|
||||
)
|
||||
status = "OK" if result["accuracy_ok"] else "FAIL"
|
||||
print(
|
||||
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}"
|
||||
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
||||
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
|
||||
)
|
||||
if not result["accuracy_ok"]:
|
||||
raise RuntimeError(
|
||||
f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
|
||||
)
|
||||
|
||||
if args.output:
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user