From 4ab9e3f58b019b0bad2a4a5f2e9cfd6a3a164111 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 22 Sep 2025 15:58:28 -0400 Subject: [PATCH] add logs --- scripts/benchmarks/deepseek_v3_moe.py | 10 ++++++++++ src/axolotl/monkeypatch/deepseek_v3/__init__.py | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/scripts/benchmarks/deepseek_v3_moe.py b/scripts/benchmarks/deepseek_v3_moe.py index 61984527c..0f9296552 100644 --- a/scripts/benchmarks/deepseek_v3_moe.py +++ b/scripts/benchmarks/deepseek_v3_moe.py @@ -152,6 +152,15 @@ def main() -> None: # pragma: no cover - CLI entrypoint dtype=dtype, ) + with torch.no_grad(): + flat_inputs = inputs.view(-1, args.hidden_size) + topk_idx, _ = patched_module.gate(flat_inputs) + tokens_per_expert = torch.bincount( + topk_idx.reshape(-1), minlength=args.n_experts + ) + min_tokens = int(tokens_per_expert.min().item()) + max_tokens = int(tokens_per_expert.max().item()) + with torch.no_grad(): ref_output = baseline_module(inputs) patched_output = patched_module(inputs) @@ -168,6 +177,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint print( f"routed tokens={routed_tokens} avg tokens/expert={avg_tokens_per_expert:.1f} group_size={args.group_size}" ) + print(f"min/max tokens per expert: {min_tokens}/{max_tokens}") print( f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}" ) diff --git a/src/axolotl/monkeypatch/deepseek_v3/__init__.py b/src/axolotl/monkeypatch/deepseek_v3/__init__.py index c46ec34e8..e8a7408d8 100644 --- a/src/axolotl/monkeypatch/deepseek_v3/__init__.py +++ b/src/axolotl/monkeypatch/deepseek_v3/__init__.py @@ -119,6 +119,18 @@ def _moe_triton_forward( if total_actual == 0: return hidden_states.new_zeros_like(hidden_states) + if not getattr(module, "_axolotl_triton_logged", False): + min_tokens = int(counts.min().item()) + max_tokens = int(counts.max().item()) + LOG.info( + "DeepseekV3MoE Triton: tokens per expert (min=%s, max=%s, avg=%.1f) with group_size=%s", + min_tokens, + max_tokens, + total_actual / max(1, num_experts), + group_size_m, + ) + module._axolotl_triton_logged = True + counts_int = counts.to(torch.int32) aligned_counts = ( (torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m