fix: aux loss calculation

This commit is contained in:
NanoCode012
2025-12-22 18:56:39 +07:00
parent eec5342c76
commit 8a3cb223e6

View File

@@ -647,24 +647,19 @@ class KimiMoEGate(nn.Module):
if self.training:
# Training path: standard, differentiable top-k routing
# Use softmax for probabilities, as it's standard for MoE routing loss
gating_scores = F.softmax(router_logits, dim=-1, dtype=torch.float32)
# Get top-k scores and their indices
topk_weight, topk_idx = torch.topk(
gating_scores, self.top_k, dim=-1, sorted=False
)
# Re-normalize top-k weights to sum to 1
if self.top_k > 1 and self.moe_renormalize:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
# Apply scaling factor
topk_weight = topk_weight * self.routed_scaling_factor
# During training, we return the raw logits for the aux loss calculation
# Return logits AND topk_idx for aux loss calculation
return router_logits, topk_idx, topk_weight
else:
@@ -885,37 +880,34 @@ class KimiSparseMoeBlock(nn.Module):
config=config, intermediate_size=intermediate_size
)
def calculate_aux_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
def calculate_aux_loss(
self, router_logits: torch.Tensor, topk_idx: torch.Tensor
) -> Optional[torch.Tensor]:
"""
Calculates the auxiliary load-balancing loss for the MoE layer.
This is a critical component for stable training of MoE models.
It encourages the router to send a balanced number of tokens to each expert.
The loss is a combination of:
1. A loss that encourages the router to distribute tokens evenly.
2. A loss that encourages the router logits to have a small magnitude (z-loss).
Uses the standard Switch Transformer formulation.
"""
if router_logits is None or not self.training:
# Return a zero tensor without accessing router_logits attributes when it's None
return torch.zeros(1, requires_grad=False)[
0
] # Returns a scalar zero tensor
return None
num_tokens, num_experts = router_logits.shape
# Calculate the probabilities and their mean across all tokens
# P_i: Mean router probability per expert (differentiable)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
mean_router_probs_per_expert = torch.mean(router_probs, dim=0)
mean_router_prob_per_expert = torch.mean(router_probs, dim=0)
# Calculate the fraction of tokens dispatched to each expert
# Create a one-hot representation of the router's choices
# For top_k > 1, a token is "dispatched" to all its chosen experts
# We can approximate this with router_probs during training for a differentiable loss
tokens_per_expert = torch.mean(router_probs, dim=0)
# f_i: Fraction of tokens routed to each expert (non-differentiable)
expert_counts = torch.zeros(
num_experts, device=router_logits.device, dtype=torch.float32
)
flat_topk_idx = topk_idx.view(-1)
ones = torch.ones_like(flat_topk_idx, dtype=torch.float32)
expert_counts.scatter_add_(0, flat_topk_idx, ones)
tokens_per_expert_fraction = expert_counts / flat_topk_idx.numel()
# The standard load balancing loss from the Switch Transformer paper
# Load balancing loss: L = N * Σ(f_i * P_i)
load_balancing_loss = num_experts * torch.sum(
mean_router_probs_per_expert * tokens_per_expert
tokens_per_expert_fraction * mean_router_prob_per_expert
)
return self.router_aux_loss_coef * load_balancing_loss
@@ -973,24 +965,18 @@ class KimiSparseMoeBlock(nn.Module):
# # Return the final states and the auxiliary loss
# return final_hidden_states, aux_loss
def forward(self, hidden_states: torch.Tensor):
"""
Optimized forward pass for MoE training that avoids materializing all expert outputs at once.
"""
identity = hidden_states
batch_size, seq_len, hidden_dim = hidden_states.shape
num_tokens = batch_size * seq_len
# Reshape for routing
hidden_states = hidden_states.view(num_tokens, hidden_dim)
# Get routing decisions from the gate
# router_logits is None during inference
router_logits, topk_idx, topk_weight = self.gate(
hidden_states.view(batch_size, seq_len, hidden_dim)
)
# Calculate auxiliary loss (will be 0.0 during inference)
aux_loss = self.calculate_aux_loss(router_logits)
# Pass topk_idx to aux loss calculation
aux_loss = self.calculate_aux_loss(router_logits, topk_idx)
if self.training:
# =================================================================================