This commit is contained in:
Dan Saunders
2025-09-22 15:58:28 -04:00
parent 5788832812
commit 4ab9e3f58b
2 changed files with 22 additions and 0 deletions

View File

@@ -152,6 +152,15 @@ def main() -> None: # pragma: no cover - CLI entrypoint
dtype=dtype,
)
with torch.no_grad():
flat_inputs = inputs.view(-1, args.hidden_size)
topk_idx, _ = patched_module.gate(flat_inputs)
tokens_per_expert = torch.bincount(
topk_idx.reshape(-1), minlength=args.n_experts
)
min_tokens = int(tokens_per_expert.min().item())
max_tokens = int(tokens_per_expert.max().item())
with torch.no_grad():
ref_output = baseline_module(inputs)
patched_output = patched_module(inputs)
@@ -168,6 +177,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint
print(
f"routed tokens={routed_tokens} avg tokens/expert={avg_tokens_per_expert:.1f} group_size={args.group_size}"
)
print(f"min/max tokens per expert: {min_tokens}/{max_tokens}")
print(
f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}"
)