add logs
This commit is contained in:
@@ -152,6 +152,15 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
flat_inputs = inputs.view(-1, args.hidden_size)
|
||||
topk_idx, _ = patched_module.gate(flat_inputs)
|
||||
tokens_per_expert = torch.bincount(
|
||||
topk_idx.reshape(-1), minlength=args.n_experts
|
||||
)
|
||||
min_tokens = int(tokens_per_expert.min().item())
|
||||
max_tokens = int(tokens_per_expert.max().item())
|
||||
|
||||
with torch.no_grad():
|
||||
ref_output = baseline_module(inputs)
|
||||
patched_output = patched_module(inputs)
|
||||
@@ -168,6 +177,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
print(
|
||||
f"routed tokens={routed_tokens} avg tokens/expert={avg_tokens_per_expert:.1f} group_size={args.group_size}"
|
||||
)
|
||||
print(f"min/max tokens per expert: {min_tokens}/{max_tokens}")
|
||||
print(
|
||||
f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user