This commit is contained in:
Dan Saunders
2025-09-22 15:58:28 -04:00
parent 5788832812
commit 4ab9e3f58b
2 changed files with 22 additions and 0 deletions

View File

@@ -152,6 +152,15 @@ def main() -> None: # pragma: no cover - CLI entrypoint
dtype=dtype,
)
with torch.no_grad():
flat_inputs = inputs.view(-1, args.hidden_size)
topk_idx, _ = patched_module.gate(flat_inputs)
tokens_per_expert = torch.bincount(
topk_idx.reshape(-1), minlength=args.n_experts
)
min_tokens = int(tokens_per_expert.min().item())
max_tokens = int(tokens_per_expert.max().item())
with torch.no_grad():
ref_output = baseline_module(inputs)
patched_output = patched_module(inputs)
@@ -168,6 +177,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint
print(
f"routed tokens={routed_tokens} avg tokens/expert={avg_tokens_per_expert:.1f} group_size={args.group_size}"
)
print(f"min/max tokens per expert: {min_tokens}/{max_tokens}")
print(
f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}"
)

View File

@@ -119,6 +119,18 @@ def _moe_triton_forward(
if total_actual == 0:
return hidden_states.new_zeros_like(hidden_states)
if not getattr(module, "_axolotl_triton_logged", False):
min_tokens = int(counts.min().item())
max_tokens = int(counts.max().item())
LOG.info(
"DeepseekV3MoE Triton: tokens per expert (min=%s, max=%s, avg=%.1f) with group_size=%s",
min_tokens,
max_tokens,
total_actual / max(1, num_experts),
group_size_m,
)
module._axolotl_triton_logged = True
counts_int = counts.to(torch.int32)
aligned_counts = (
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m