diff --git a/scripts/benchmarks/deepseek_v3_moe.py b/scripts/benchmarks/deepseek_v3_moe.py index 0f9296552..91c777048 100644 --- a/scripts/benchmarks/deepseek_v3_moe.py +++ b/scripts/benchmarks/deepseek_v3_moe.py @@ -65,6 +65,11 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations") parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations") parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--uniform-routing", + action="store_true", + help="Override router to distribute tokens evenly across experts", + ) parser.add_argument( "--group-size", type=int, @@ -154,13 +159,45 @@ def main() -> None: # pragma: no cover - CLI entrypoint with torch.no_grad(): flat_inputs = inputs.view(-1, args.hidden_size) - topk_idx, _ = patched_module.gate(flat_inputs) + if args.uniform_routing: + total_assignments = flat_inputs.size(0) * args.top_k + base = total_assignments // args.n_experts + remainder = total_assignments % args.n_experts + counts = torch.full( + (args.n_experts,), + base, + dtype=torch.int64, + device=device, + ) + if remainder: + counts[:remainder] += 1 + assignments = torch.repeat_interleave( + torch.arange(args.n_experts, device=device), counts + ) + assignments = assignments[torch.randperm(assignments.size(0))] + topk_idx = assignments.view(flat_inputs.size(0), args.top_k) + else: + topk_idx, _ = patched_module.gate(flat_inputs) tokens_per_expert = torch.bincount( topk_idx.reshape(-1), minlength=args.n_experts ) min_tokens = int(tokens_per_expert.min().item()) max_tokens = int(tokens_per_expert.max().item()) + if args.uniform_routing: + weights = torch.full_like(topk_idx, 1.0 / args.top_k) + + def _uniform_gate(self, hidden_states): + batch_tokens = hidden_states.shape[0] + return topk_idx[:batch_tokens], weights[:batch_tokens] + + patched_module.gate.forward = _uniform_gate.__get__( + patched_module.gate, patched_module.gate.__class__ + ) + baseline_module.gate.forward = _uniform_gate.__get__( + baseline_module.gate, baseline_module.gate.__class__ + ) + with torch.no_grad(): ref_output = baseline_module(inputs) patched_output = patched_module(inputs)