uniform routing default
This commit is contained in:
@@ -50,9 +50,9 @@ def parse_args() -> argparse.Namespace:
|
||||
help="Override GROUP_SIZE_M for every configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--uniform-routing",
|
||||
"--no-uniform-routing",
|
||||
action="store_true",
|
||||
help="Force uniform routing for every configuration",
|
||||
help="Disable uniform routing for every configuration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-mixtral-long",
|
||||
@@ -79,7 +79,7 @@ def make_namespace(
|
||||
"warmup": args.warmup,
|
||||
"iters": args.iters,
|
||||
"seed": args.seed,
|
||||
"uniform_routing": args.uniform_routing,
|
||||
"uniform_routing": not args.no_uniform_routing,
|
||||
}
|
||||
)
|
||||
if args.group_size is not None:
|
||||
@@ -196,8 +196,9 @@ def main() -> None: # pragma: no cover - utility script
|
||||
)
|
||||
rows = []
|
||||
|
||||
uniform_flag = not args.no_uniform_routing
|
||||
print(
|
||||
f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={args.uniform_routing}"
|
||||
f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={uniform_flag}"
|
||||
)
|
||||
print(
|
||||
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
|
||||
|
||||
Reference in New Issue
Block a user