default mg

This commit is contained in:
Dan Saunders
2025-09-25 16:30:23 -04:00
parent 55d98db0d0
commit dd85358543

View File

@@ -52,6 +52,11 @@ def parse_args() -> argparse.Namespace:
type=int,
help="Override GROUP_SIZE_M for every configuration",
)
parser.add_argument(
"--backends",
default="mg",
help="Comma separated list of backends to benchmark (subset of cg,mg)",
)
parser.add_argument(
"--no-uniform-routing",
action="store_true",
@@ -131,8 +136,6 @@ ARCHETYPES = (
MIXTRAL_LONG_SHAPES = [(8, 8192)]
BACKENDS = ("cg", "mg")
def main() -> None: # pragma: no cover - utility script
args = parse_args()
@@ -187,9 +190,21 @@ def main() -> None: # pragma: no cover - utility script
)
rows = []
raw_backends = [
token.strip() for token in args.backends.split(",") if token.strip()
]
if not raw_backends:
raw_backends = ["mg"]
valid_backends = []
for backend in raw_backends:
if backend not in {"cg", "mg"}:
raise SystemExit(f"Unsupported backend '{backend}' requested")
if backend not in valid_backends:
valid_backends.append(backend)
uniform_flag = not args.no_uniform_routing
print(
f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={uniform_flag}"
f"Running sweep on device={args.device} dtype={args.dtype} backends={tuple(valid_backends)} uniform_routing={uniform_flag}"
)
print(
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
@@ -197,7 +212,7 @@ def main() -> None: # pragma: no cover - utility script
)
for cfg in grid:
for backend in BACKENDS:
for backend in valid_backends:
ns = make_namespace(cfg, args, backend)
result = benchmark_deepseek_v3(ns)
baseline_vram_mib = (