vram
This commit is contained in:
@@ -5,7 +5,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
|
||||
import torch
|
||||
@@ -20,8 +22,20 @@ except ImportError as exc: # pragma: no cover - utility script
|
||||
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
|
||||
) from exc
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else: # pragma: no cover - defensive guard
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||
|
||||
ACCURACY_TOLERANCE = 5e-3
|
||||
|
||||
DTYPE_MAP = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
@@ -221,8 +235,16 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
||||
patched_output = patched_module(inputs)
|
||||
max_diff = (ref_output - patched_output).abs().max().item()
|
||||
|
||||
baseline_vram = patched_vram = None
|
||||
if device.type == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
|
||||
if device.type == "cuda":
|
||||
baseline_vram = torch.cuda.max_memory_allocated(device)
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
|
||||
if device.type == "cuda":
|
||||
patched_vram = torch.cuda.max_memory_allocated(device)
|
||||
|
||||
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
|
||||
|
||||
@@ -237,6 +259,9 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
||||
"avg_tokens": avg_tokens_per_expert,
|
||||
"min_tokens": min_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"baseline_vram": baseline_vram,
|
||||
"patched_vram": patched_vram,
|
||||
"accuracy_ok": max_diff <= ACCURACY_TOLERANCE,
|
||||
}
|
||||
|
||||
|
||||
@@ -251,10 +276,18 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
|
||||
)
|
||||
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
|
||||
if result["baseline_vram"] is not None:
|
||||
print(
|
||||
f"VRAM baseline={result['baseline_vram'] / (1024**2):.1f} MiB | patched={result['patched_vram'] / (1024**2):.1f} MiB"
|
||||
)
|
||||
print(
|
||||
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
|
||||
)
|
||||
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
|
||||
if not result["accuracy_ok"]:
|
||||
raise RuntimeError(
|
||||
f"Accuracy check failed: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user