add logs
This commit is contained in:
@@ -152,6 +152,15 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
flat_inputs = inputs.view(-1, args.hidden_size)
|
||||
topk_idx, _ = patched_module.gate(flat_inputs)
|
||||
tokens_per_expert = torch.bincount(
|
||||
topk_idx.reshape(-1), minlength=args.n_experts
|
||||
)
|
||||
min_tokens = int(tokens_per_expert.min().item())
|
||||
max_tokens = int(tokens_per_expert.max().item())
|
||||
|
||||
with torch.no_grad():
|
||||
ref_output = baseline_module(inputs)
|
||||
patched_output = patched_module(inputs)
|
||||
@@ -168,6 +177,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
print(
|
||||
f"routed tokens={routed_tokens} avg tokens/expert={avg_tokens_per_expert:.1f} group_size={args.group_size}"
|
||||
)
|
||||
print(f"min/max tokens per expert: {min_tokens}/{max_tokens}")
|
||||
print(
|
||||
f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}"
|
||||
)
|
||||
|
||||
@@ -119,6 +119,18 @@ def _moe_triton_forward(
|
||||
if total_actual == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
if not getattr(module, "_axolotl_triton_logged", False):
|
||||
min_tokens = int(counts.min().item())
|
||||
max_tokens = int(counts.max().item())
|
||||
LOG.info(
|
||||
"DeepseekV3MoE Triton: tokens per expert (min=%s, max=%s, avg=%.1f) with group_size=%s",
|
||||
min_tokens,
|
||||
max_tokens,
|
||||
total_actual / max(1, num_experts),
|
||||
group_size_m,
|
||||
)
|
||||
module._axolotl_triton_logged = True
|
||||
|
||||
counts_int = counts.to(torch.int32)
|
||||
aligned_counts = (
|
||||
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m
|
||||
|
||||
Reference in New Issue
Block a user