vram
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
"""Benchmark helpers."""
|
"""Benchmark helpers."""
|
||||||
|
|
||||||
from .deepseek_v3_moe import DTYPE_MAP, benchmark_deepseek_v3
|
from .deepseek_v3_moe import ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3
|
||||||
|
|
||||||
__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP"]
|
__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP", "ACCURACY_TOLERANCE"]
|
||||||
|
|||||||
@@ -4,12 +4,24 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
CURRENT_DIR = Path(__file__).resolve().parent
|
||||||
|
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||||
|
repo_root = candidate / "axolotl"
|
||||||
|
if repo_root.exists():
|
||||||
|
if str(repo_root) not in sys.path:
|
||||||
|
sys.path.insert(0, str(repo_root))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||||
|
|
||||||
from axolotl.kernels.moe import (
|
from axolotl.kernels.moe import (
|
||||||
cg_grouped_gemm_forward,
|
cg_grouped_gemm_forward,
|
||||||
cg_grouped_gemm_forward_dynamic,
|
cg_grouped_gemm_forward_dynamic,
|
||||||
|
|||||||
@@ -5,7 +5,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -20,8 +22,20 @@ except ImportError as exc: # pragma: no cover - utility script
|
|||||||
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
|
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
|
CURRENT_DIR = Path(__file__).resolve().parent
|
||||||
|
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||||
|
repo_root = candidate / "axolotl"
|
||||||
|
if repo_root.exists():
|
||||||
|
if str(repo_root) not in sys.path:
|
||||||
|
sys.path.insert(0, str(repo_root))
|
||||||
|
break
|
||||||
|
else: # pragma: no cover - defensive guard
|
||||||
|
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||||
|
|
||||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||||
|
|
||||||
|
ACCURACY_TOLERANCE = 5e-3
|
||||||
|
|
||||||
DTYPE_MAP = {
|
DTYPE_MAP = {
|
||||||
"bf16": torch.bfloat16,
|
"bf16": torch.bfloat16,
|
||||||
"fp16": torch.float16,
|
"fp16": torch.float16,
|
||||||
@@ -221,8 +235,16 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
|||||||
patched_output = patched_module(inputs)
|
patched_output = patched_module(inputs)
|
||||||
max_diff = (ref_output - patched_output).abs().max().item()
|
max_diff = (ref_output - patched_output).abs().max().item()
|
||||||
|
|
||||||
|
baseline_vram = patched_vram = None
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.reset_peak_memory_stats(device)
|
||||||
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
|
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
|
||||||
|
if device.type == "cuda":
|
||||||
|
baseline_vram = torch.cuda.max_memory_allocated(device)
|
||||||
|
torch.cuda.reset_peak_memory_stats(device)
|
||||||
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
|
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
|
||||||
|
if device.type == "cuda":
|
||||||
|
patched_vram = torch.cuda.max_memory_allocated(device)
|
||||||
|
|
||||||
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
|
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
|
||||||
|
|
||||||
@@ -237,6 +259,9 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
|||||||
"avg_tokens": avg_tokens_per_expert,
|
"avg_tokens": avg_tokens_per_expert,
|
||||||
"min_tokens": min_tokens,
|
"min_tokens": min_tokens,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
|
"baseline_vram": baseline_vram,
|
||||||
|
"patched_vram": patched_vram,
|
||||||
|
"accuracy_ok": max_diff <= ACCURACY_TOLERANCE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -251,10 +276,18 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
|||||||
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
|
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
|
||||||
)
|
)
|
||||||
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
|
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
|
||||||
|
if result["baseline_vram"] is not None:
|
||||||
|
print(
|
||||||
|
f"VRAM baseline={result['baseline_vram'] / (1024**2):.1f} MiB | patched={result['patched_vram'] / (1024**2):.1f} MiB"
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
|
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
|
||||||
)
|
)
|
||||||
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
|
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
|
||||||
|
if not result["accuracy_ok"]:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Accuracy check failed: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -12,11 +12,17 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
CURRENT_DIR = Path(__file__).resolve().parent
|
CURRENT_DIR = Path(__file__).resolve().parent
|
||||||
REPO_ROOT = CURRENT_DIR.parent.parent
|
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||||
if str(REPO_ROOT) not in sys.path:
|
repo_root = candidate / "axolotl"
|
||||||
sys.path.insert(0, str(REPO_ROOT))
|
if repo_root.exists():
|
||||||
|
if str(repo_root) not in sys.path:
|
||||||
|
sys.path.insert(0, str(repo_root))
|
||||||
|
break
|
||||||
|
else: # pragma: no cover - defensive guard
|
||||||
|
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||||
|
|
||||||
from scripts.benchmarks.deepseek_v3_moe import (
|
from scripts.benchmarks.deepseek_v3_moe import (
|
||||||
|
ACCURACY_TOLERANCE,
|
||||||
DTYPE_MAP,
|
DTYPE_MAP,
|
||||||
benchmark_deepseek_v3,
|
benchmark_deepseek_v3,
|
||||||
)
|
)
|
||||||
@@ -161,9 +167,12 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
"baseline_ms",
|
"baseline_ms",
|
||||||
"patched_ms",
|
"patched_ms",
|
||||||
"speedup",
|
"speedup",
|
||||||
|
"baseline_vram_mib",
|
||||||
|
"patched_vram_mib",
|
||||||
"min_tokens",
|
"min_tokens",
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
"max_diff",
|
"max_diff",
|
||||||
|
"accuracy_ok",
|
||||||
)
|
)
|
||||||
rows = []
|
rows = []
|
||||||
|
|
||||||
@@ -171,12 +180,23 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}"
|
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}"
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'baseline':>12} {'patched':>12} {'speedup':>8}"
|
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6}"
|
||||||
|
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for cfg in grid:
|
for cfg in grid:
|
||||||
ns = make_namespace(cfg, args)
|
ns = make_namespace(cfg, args)
|
||||||
result = benchmark_deepseek_v3(ns)
|
result = benchmark_deepseek_v3(ns)
|
||||||
|
baseline_vram_mib = (
|
||||||
|
result["baseline_vram"] / (1024**2)
|
||||||
|
if result["baseline_vram"] is not None
|
||||||
|
else float("nan")
|
||||||
|
)
|
||||||
|
patched_vram_mib = (
|
||||||
|
result["patched_vram"] / (1024**2)
|
||||||
|
if result["patched_vram"] is not None
|
||||||
|
else float("nan")
|
||||||
|
)
|
||||||
rows.append(
|
rows.append(
|
||||||
(
|
(
|
||||||
cfg["batch"],
|
cfg["batch"],
|
||||||
@@ -189,15 +209,24 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
result["baseline_ms"],
|
result["baseline_ms"],
|
||||||
result["patched_ms"],
|
result["patched_ms"],
|
||||||
result["speedup"],
|
result["speedup"],
|
||||||
|
baseline_vram_mib,
|
||||||
|
patched_vram_mib,
|
||||||
result["min_tokens"],
|
result["min_tokens"],
|
||||||
result["max_tokens"],
|
result["max_tokens"],
|
||||||
result["max_diff"],
|
result["max_diff"],
|
||||||
|
result["accuracy_ok"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
status = "OK" if result["accuracy_ok"] else "FAIL"
|
||||||
print(
|
print(
|
||||||
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}"
|
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}"
|
||||||
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
||||||
|
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
|
||||||
)
|
)
|
||||||
|
if not result["accuracy_ok"]:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
|
||||||
|
)
|
||||||
|
|
||||||
if args.output:
|
if args.output:
|
||||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user