This commit is contained in:
Dan Saunders
2025-09-23 13:50:48 -04:00
parent fd312f6058
commit 146ca48cba
4 changed files with 80 additions and 6 deletions

View File

@@ -12,11 +12,17 @@ from pathlib import Path
from types import SimpleNamespace
CURRENT_DIR = Path(__file__).resolve().parent
REPO_ROOT = CURRENT_DIR.parent.parent
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from scripts.benchmarks.deepseek_v3_moe import (
ACCURACY_TOLERANCE,
DTYPE_MAP,
benchmark_deepseek_v3,
)
@@ -161,9 +167,12 @@ def main() -> None: # pragma: no cover - utility script
"baseline_ms",
"patched_ms",
"speedup",
"baseline_vram_mib",
"patched_vram_mib",
"min_tokens",
"max_tokens",
"max_diff",
"accuracy_ok",
)
rows = []
@@ -171,12 +180,23 @@ def main() -> None: # pragma: no cover - utility script
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}"
)
print(
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'baseline':>12} {'patched':>12} {'speedup':>8}"
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6}"
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
)
for cfg in grid:
ns = make_namespace(cfg, args)
result = benchmark_deepseek_v3(ns)
baseline_vram_mib = (
result["baseline_vram"] / (1024**2)
if result["baseline_vram"] is not None
else float("nan")
)
patched_vram_mib = (
result["patched_vram"] / (1024**2)
if result["patched_vram"] is not None
else float("nan")
)
rows.append(
(
cfg["batch"],
@@ -189,15 +209,24 @@ def main() -> None: # pragma: no cover - utility script
result["baseline_ms"],
result["patched_ms"],
result["speedup"],
baseline_vram_mib,
patched_vram_mib,
result["min_tokens"],
result["max_tokens"],
result["max_diff"],
result["accuracy_ok"],
)
)
status = "OK" if result["accuracy_ok"] else "FAIL"
print(
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
)
if not result["accuracy_ok"]:
raise RuntimeError(
f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
)
if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True)