uniform routing:
This commit is contained in:
@@ -65,6 +65,11 @@ def parse_args() -> argparse.Namespace:
|
|||||||
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
||||||
parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations")
|
parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations")
|
||||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--uniform-routing",
|
||||||
|
action="store_true",
|
||||||
|
help="Override router to distribute tokens evenly across experts",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--group-size",
|
"--group-size",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -154,13 +159,45 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
flat_inputs = inputs.view(-1, args.hidden_size)
|
flat_inputs = inputs.view(-1, args.hidden_size)
|
||||||
topk_idx, _ = patched_module.gate(flat_inputs)
|
if args.uniform_routing:
|
||||||
|
total_assignments = flat_inputs.size(0) * args.top_k
|
||||||
|
base = total_assignments // args.n_experts
|
||||||
|
remainder = total_assignments % args.n_experts
|
||||||
|
counts = torch.full(
|
||||||
|
(args.n_experts,),
|
||||||
|
base,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if remainder:
|
||||||
|
counts[:remainder] += 1
|
||||||
|
assignments = torch.repeat_interleave(
|
||||||
|
torch.arange(args.n_experts, device=device), counts
|
||||||
|
)
|
||||||
|
assignments = assignments[torch.randperm(assignments.size(0))]
|
||||||
|
topk_idx = assignments.view(flat_inputs.size(0), args.top_k)
|
||||||
|
else:
|
||||||
|
topk_idx, _ = patched_module.gate(flat_inputs)
|
||||||
tokens_per_expert = torch.bincount(
|
tokens_per_expert = torch.bincount(
|
||||||
topk_idx.reshape(-1), minlength=args.n_experts
|
topk_idx.reshape(-1), minlength=args.n_experts
|
||||||
)
|
)
|
||||||
min_tokens = int(tokens_per_expert.min().item())
|
min_tokens = int(tokens_per_expert.min().item())
|
||||||
max_tokens = int(tokens_per_expert.max().item())
|
max_tokens = int(tokens_per_expert.max().item())
|
||||||
|
|
||||||
|
if args.uniform_routing:
|
||||||
|
weights = torch.full_like(topk_idx, 1.0 / args.top_k)
|
||||||
|
|
||||||
|
def _uniform_gate(self, hidden_states):
|
||||||
|
batch_tokens = hidden_states.shape[0]
|
||||||
|
return topk_idx[:batch_tokens], weights[:batch_tokens]
|
||||||
|
|
||||||
|
patched_module.gate.forward = _uniform_gate.__get__(
|
||||||
|
patched_module.gate, patched_module.gate.__class__
|
||||||
|
)
|
||||||
|
baseline_module.gate.forward = _uniform_gate.__get__(
|
||||||
|
baseline_module.gate, baseline_module.gate.__class__
|
||||||
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ref_output = baseline_module(inputs)
|
ref_output = baseline_module(inputs)
|
||||||
patched_output = patched_module(inputs)
|
patched_output = patched_module(inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user