fix: aux loss calculation

This commit is contained in:
NanoCode012
2025-12-22 18:56:39 +07:00
parent eec5342c76
commit 8a3cb223e6

View File

@@ -647,24 +647,19 @@ class KimiMoEGate(nn.Module):
if self.training: if self.training:
# Training path: standard, differentiable top-k routing # Training path: standard, differentiable top-k routing
# Use softmax for probabilities, as it's standard for MoE routing loss
gating_scores = F.softmax(router_logits, dim=-1, dtype=torch.float32) gating_scores = F.softmax(router_logits, dim=-1, dtype=torch.float32)
# Get top-k scores and their indices
topk_weight, topk_idx = torch.topk( topk_weight, topk_idx = torch.topk(
gating_scores, self.top_k, dim=-1, sorted=False gating_scores, self.top_k, dim=-1, sorted=False
) )
# Re-normalize top-k weights to sum to 1
if self.top_k > 1 and self.moe_renormalize: if self.top_k > 1 and self.moe_renormalize:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator topk_weight = topk_weight / denominator
# Apply scaling factor
topk_weight = topk_weight * self.routed_scaling_factor topk_weight = topk_weight * self.routed_scaling_factor
# During training, we return the raw logits for the aux loss calculation # Return logits AND topk_idx for aux loss calculation
return router_logits, topk_idx, topk_weight return router_logits, topk_idx, topk_weight
else: else:
@@ -885,37 +880,34 @@ class KimiSparseMoeBlock(nn.Module):
config=config, intermediate_size=intermediate_size config=config, intermediate_size=intermediate_size
) )
def calculate_aux_loss(self, router_logits: torch.Tensor) -> torch.Tensor: def calculate_aux_loss(
self, router_logits: torch.Tensor, topk_idx: torch.Tensor
) -> Optional[torch.Tensor]:
""" """
Calculates the auxiliary load-balancing loss for the MoE layer. Calculates the auxiliary load-balancing loss for the MoE layer.
This is a critical component for stable training of MoE models. Uses the standard Switch Transformer formulation.
It encourages the router to send a balanced number of tokens to each expert.
The loss is a combination of:
1. A loss that encourages the router to distribute tokens evenly.
2. A loss that encourages the router logits to have a small magnitude (z-loss).
""" """
if router_logits is None or not self.training: if router_logits is None or not self.training:
# Return a zero tensor without accessing router_logits attributes when it's None return None
return torch.zeros(1, requires_grad=False)[
0
] # Returns a scalar zero tensor
num_tokens, num_experts = router_logits.shape num_tokens, num_experts = router_logits.shape
# Calculate the probabilities and their mean across all tokens # P_i: Mean router probability per expert (differentiable)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
mean_router_probs_per_expert = torch.mean(router_probs, dim=0) mean_router_prob_per_expert = torch.mean(router_probs, dim=0)
# Calculate the fraction of tokens dispatched to each expert # f_i: Fraction of tokens routed to each expert (non-differentiable)
# Create a one-hot representation of the router's choices expert_counts = torch.zeros(
# For top_k > 1, a token is "dispatched" to all its chosen experts num_experts, device=router_logits.device, dtype=torch.float32
# We can approximate this with router_probs during training for a differentiable loss )
tokens_per_expert = torch.mean(router_probs, dim=0) flat_topk_idx = topk_idx.view(-1)
ones = torch.ones_like(flat_topk_idx, dtype=torch.float32)
expert_counts.scatter_add_(0, flat_topk_idx, ones)
tokens_per_expert_fraction = expert_counts / flat_topk_idx.numel()
# The standard load balancing loss from the Switch Transformer paper # Load balancing loss: L = N * Σ(f_i * P_i)
load_balancing_loss = num_experts * torch.sum( load_balancing_loss = num_experts * torch.sum(
mean_router_probs_per_expert * tokens_per_expert tokens_per_expert_fraction * mean_router_prob_per_expert
) )
return self.router_aux_loss_coef * load_balancing_loss return self.router_aux_loss_coef * load_balancing_loss
@@ -973,24 +965,18 @@ class KimiSparseMoeBlock(nn.Module):
# # Return the final states and the auxiliary loss # # Return the final states and the auxiliary loss
# return final_hidden_states, aux_loss # return final_hidden_states, aux_loss
def forward(self, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor):
"""
Optimized forward pass for MoE training that avoids materializing all expert outputs at once.
"""
identity = hidden_states identity = hidden_states
batch_size, seq_len, hidden_dim = hidden_states.shape batch_size, seq_len, hidden_dim = hidden_states.shape
num_tokens = batch_size * seq_len num_tokens = batch_size * seq_len
# Reshape for routing
hidden_states = hidden_states.view(num_tokens, hidden_dim) hidden_states = hidden_states.view(num_tokens, hidden_dim)
# Get routing decisions from the gate
# router_logits is None during inference
router_logits, topk_idx, topk_weight = self.gate( router_logits, topk_idx, topk_weight = self.gate(
hidden_states.view(batch_size, seq_len, hidden_dim) hidden_states.view(batch_size, seq_len, hidden_dim)
) )
# Calculate auxiliary loss (will be 0.0 during inference) # Pass topk_idx to aux loss calculation
aux_loss = self.calculate_aux_loss(router_logits) aux_loss = self.calculate_aux_loss(router_logits, topk_idx)
if self.training: if self.training:
# ================================================================================= # =================================================================================