uniform routing default

This commit is contained in:
Dan Saunders
2025-09-25 15:47:23 -04:00
parent e003a05177
commit 824a641cee
2 changed files with 15 additions and 15 deletions

View File

@@ -50,9 +50,9 @@ def parse_args() -> argparse.Namespace:
help="Override GROUP_SIZE_M for every configuration", help="Override GROUP_SIZE_M for every configuration",
) )
parser.add_argument( parser.add_argument(
"--uniform-routing", "--no-uniform-routing",
action="store_true", action="store_true",
help="Force uniform routing for every configuration", help="Disable uniform routing for every configuration",
) )
parser.add_argument( parser.add_argument(
"--include-mixtral-long", "--include-mixtral-long",
@@ -79,7 +79,7 @@ def make_namespace(
"warmup": args.warmup, "warmup": args.warmup,
"iters": args.iters, "iters": args.iters,
"seed": args.seed, "seed": args.seed,
"uniform_routing": args.uniform_routing, "uniform_routing": not args.no_uniform_routing,
} }
) )
if args.group_size is not None: if args.group_size is not None:
@@ -196,8 +196,9 @@ def main() -> None: # pragma: no cover - utility script
) )
rows = [] rows = []
uniform_flag = not args.no_uniform_routing
print( print(
f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={args.uniform_routing}" f"Running sweep on device={args.device} dtype={args.dtype} backends={BACKENDS} uniform_routing={uniform_flag}"
) )
print( print(
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}" f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"

View File

@@ -132,7 +132,7 @@ def _run_cg_grouped_gemm(
group_size_m: int, group_size_m: int,
hidden_dtype: torch.dtype, hidden_dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg") _ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
expert_index_tensor = torch.repeat_interleave( expert_index_tensor = torch.repeat_interleave(
@@ -167,17 +167,11 @@ def _run_cg_grouped_gemm(
expert_index_tensor, expert_index_tensor,
group_size_m, group_size_m,
) )
down_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
down_weights,
expert_index_tensor,
group_size_m,
)
return ( return (
gate_out.to(hidden_dtype), gate_out.to(hidden_dtype),
up_out.to(hidden_dtype), up_out.to(hidden_dtype),
down_out.to(hidden_dtype), down_weights,
expert_index_tensor,
) )
gate_out = mg_grouped_gemm( gate_out = mg_grouped_gemm(
@@ -296,7 +290,7 @@ def _moe_triton_forward(
m_sizes_tensor, m_sizes_tensor,
).to(hidden_dtype) ).to(hidden_dtype)
else: else:
gate_out, up_out, down_out_cg = _run_cg_grouped_gemm( gate_out, up_out, down_weights, expert_index_tensor = _run_cg_grouped_gemm(
module, module,
grouped_hidden, grouped_hidden,
m_sizes, m_sizes,
@@ -330,7 +324,12 @@ def _moe_triton_forward(
m_sizes_tensor, m_sizes_tensor,
).to(hidden_dtype) ).to(hidden_dtype)
else: else:
down_out = down_out_cg down_out = ContiguousGroupedGEMM.apply(
hidden_grouped,
down_weights,
expert_index_tensor,
group_size_m,
).to(hidden_dtype)
if valid_positions.numel() > 0: if valid_positions.numel() > 0:
down_valid = down_out.index_select(0, valid_positions) down_valid = down_out.index_select(0, valid_positions)