vram
This commit is contained in:
@@ -12,11 +12,17 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
CURRENT_DIR = Path(__file__).resolve().parent
|
||||
REPO_ROOT = CURRENT_DIR.parent.parent
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
||||
repo_root = candidate / "axolotl"
|
||||
if repo_root.exists():
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
break
|
||||
else: # pragma: no cover - defensive guard
|
||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||
|
||||
from scripts.benchmarks.deepseek_v3_moe import (
|
||||
ACCURACY_TOLERANCE,
|
||||
DTYPE_MAP,
|
||||
benchmark_deepseek_v3,
|
||||
)
|
||||
@@ -161,9 +167,12 @@ def main() -> None: # pragma: no cover - utility script
|
||||
"baseline_ms",
|
||||
"patched_ms",
|
||||
"speedup",
|
||||
"baseline_vram_mib",
|
||||
"patched_vram_mib",
|
||||
"min_tokens",
|
||||
"max_tokens",
|
||||
"max_diff",
|
||||
"accuracy_ok",
|
||||
)
|
||||
rows = []
|
||||
|
||||
@@ -171,12 +180,23 @@ def main() -> None: # pragma: no cover - utility script
|
||||
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}"
|
||||
)
|
||||
print(
|
||||
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'baseline':>12} {'patched':>12} {'speedup':>8}"
|
||||
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6}"
|
||||
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
|
||||
)
|
||||
|
||||
for cfg in grid:
|
||||
ns = make_namespace(cfg, args)
|
||||
result = benchmark_deepseek_v3(ns)
|
||||
baseline_vram_mib = (
|
||||
result["baseline_vram"] / (1024**2)
|
||||
if result["baseline_vram"] is not None
|
||||
else float("nan")
|
||||
)
|
||||
patched_vram_mib = (
|
||||
result["patched_vram"] / (1024**2)
|
||||
if result["patched_vram"] is not None
|
||||
else float("nan")
|
||||
)
|
||||
rows.append(
|
||||
(
|
||||
cfg["batch"],
|
||||
@@ -189,15 +209,24 @@ def main() -> None: # pragma: no cover - utility script
|
||||
result["baseline_ms"],
|
||||
result["patched_ms"],
|
||||
result["speedup"],
|
||||
baseline_vram_mib,
|
||||
patched_vram_mib,
|
||||
result["min_tokens"],
|
||||
result["max_tokens"],
|
||||
result["max_diff"],
|
||||
result["accuracy_ok"],
|
||||
)
|
||||
)
|
||||
status = "OK" if result["accuracy_ok"] else "FAIL"
|
||||
print(
|
||||
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}"
|
||||
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
||||
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
|
||||
)
|
||||
if not result["accuracy_ok"]:
|
||||
raise RuntimeError(
|
||||
f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
|
||||
)
|
||||
|
||||
if args.output:
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user