uniform routing:

This commit is contained in:
Dan Saunders
2025-09-22 16:03:38 -04:00
parent 4ab9e3f58b
commit e5d2aebe16

View File

@@ -65,6 +65,11 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--uniform-routing",
action="store_true",
help="Override router to distribute tokens evenly across experts",
)
parser.add_argument(
"--group-size",
type=int,
@@ -154,13 +159,45 @@ def main() -> None: # pragma: no cover - CLI entrypoint
with torch.no_grad():
flat_inputs = inputs.view(-1, args.hidden_size)
topk_idx, _ = patched_module.gate(flat_inputs)
if args.uniform_routing:
total_assignments = flat_inputs.size(0) * args.top_k
base = total_assignments // args.n_experts
remainder = total_assignments % args.n_experts
counts = torch.full(
(args.n_experts,),
base,
dtype=torch.int64,
device=device,
)
if remainder:
counts[:remainder] += 1
assignments = torch.repeat_interleave(
torch.arange(args.n_experts, device=device), counts
)
assignments = assignments[torch.randperm(assignments.size(0))]
topk_idx = assignments.view(flat_inputs.size(0), args.top_k)
else:
topk_idx, _ = patched_module.gate(flat_inputs)
tokens_per_expert = torch.bincount(
topk_idx.reshape(-1), minlength=args.n_experts
)
min_tokens = int(tokens_per_expert.min().item())
max_tokens = int(tokens_per_expert.max().item())
if args.uniform_routing:
weights = torch.full_like(topk_idx, 1.0 / args.top_k)
def _uniform_gate(self, hidden_states):
batch_tokens = hidden_states.shape[0]
return topk_idx[:batch_tokens], weights[:batch_tokens]
patched_module.gate.forward = _uniform_gate.__get__(
patched_module.gate, patched_module.gate.__class__
)
baseline_module.gate.forward = _uniform_gate.__get__(
baseline_module.gate, baseline_module.gate.__class__
)
with torch.no_grad():
ref_output = baseline_module(inputs)
patched_output = patched_module(inputs)