This commit is contained in:
Dan Saunders
2025-09-25 16:08:35 -04:00
parent d90ade3b1b
commit 55d98db0d0

View File

@@ -6,6 +6,7 @@ from __future__ import annotations
import argparse import argparse
import csv import csv
import logging
import sys import sys
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
@@ -26,6 +27,8 @@ from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402
benchmark_deepseek_v3, benchmark_deepseek_v3,
) )
LOG = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
@@ -190,7 +193,7 @@ def main() -> None: # pragma: no cover - utility script
) )
print( print(
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}" f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}" f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'diff':>10} {'acc':>5}"
) )
for cfg in grid: for cfg in grid:
@@ -233,11 +236,15 @@ def main() -> None: # pragma: no cover - utility script
print( print(
f"{cfg['label']:>10} {cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {backend:>8}" f"{cfg['label']:>10} {cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {backend:>8}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x" f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}" f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {result['max_diff']:>10.3e} {status:>5}"
) )
if not result["accuracy_ok"]: if not result["accuracy_ok"]:
raise RuntimeError( LOG.warning(
f"Accuracy check failed for config {cfg}: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}" "Accuracy tolerance exceeded for %s backend=%s: diff=%.3e (> %.1e)",
cfg["label"],
backend,
result["max_diff"],
ACCURACY_TOLERANCE,
) )
if args.output: if args.output: