default mg
This commit is contained in:
@@ -52,6 +52,11 @@ def parse_args() -> argparse.Namespace:
|
||||
type=int,
|
||||
help="Override GROUP_SIZE_M for every configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backends",
|
||||
default="mg",
|
||||
help="Comma separated list of backends to benchmark (subset of cg,mg)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-uniform-routing",
|
||||
action="store_true",
|
||||
@@ -131,8 +136,6 @@ ARCHETYPES = (
|
||||
|
||||
MIXTRAL_LONG_SHAPES = [(8, 8192)]
|
||||
|
||||
BACKENDS = ("cg", "mg")
|
||||
|
||||
|
||||
def main() -> None: # pragma: no cover - utility script
|
||||
args = parse_args()
|
||||
@@ -187,9 +190,21 @@ def main() -> None: # pragma: no cover - utility script
|
||||
)
|
||||
rows = []
|
||||
|
||||
raw_backends = [
|
||||
token.strip() for token in args.backends.split(",") if token.strip()
|
||||
]
|
||||
if not raw_backends:
|
||||
raw_backends = ["mg"]
|
||||
valid_backends = []
|
||||
for backend in raw_backends:
|
||||
if backend not in {"cg", "mg"}:
|
||||
raise SystemExit(f"Unsupported backend '{backend}' requested")
|
||||
if backend not in valid_backends:
|
||||
valid_backends.append(backend)
|
||||
|
||||
uniform_flag = not args.no_uniform_routing
|
||||
print(
|
||||
f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={uniform_flag}"
|
||||
f"Running sweep on device={args.device} dtype={args.dtype} backends={tuple(valid_backends)} uniform_routing={uniform_flag}"
|
||||
)
|
||||
print(
|
||||
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
|
||||
@@ -197,7 +212,7 @@ def main() -> None: # pragma: no cover - utility script
|
||||
)
|
||||
|
||||
for cfg in grid:
|
||||
for backend in BACKENDS:
|
||||
for backend in valid_backends:
|
||||
ns = make_namespace(cfg, args, backend)
|
||||
result = benchmark_deepseek_v3(ns)
|
||||
baseline_vram_mib = (
|
||||
|
||||
Reference in New Issue
Block a user