fix: aux loss calculation
This commit is contained in:
@@ -647,24 +647,19 @@ class KimiMoEGate(nn.Module):
|
||||
|
||||
if self.training:
|
||||
# Training path: standard, differentiable top-k routing
|
||||
|
||||
# Use softmax for probabilities, as it's standard for MoE routing loss
|
||||
gating_scores = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
# Get top-k scores and their indices
|
||||
topk_weight, topk_idx = torch.topk(
|
||||
gating_scores, self.top_k, dim=-1, sorted=False
|
||||
)
|
||||
|
||||
# Re-normalize top-k weights to sum to 1
|
||||
if self.top_k > 1 and self.moe_renormalize:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
# Apply scaling factor
|
||||
topk_weight = topk_weight * self.routed_scaling_factor
|
||||
|
||||
# During training, we return the raw logits for the aux loss calculation
|
||||
# Return logits AND topk_idx for aux loss calculation
|
||||
return router_logits, topk_idx, topk_weight
|
||||
|
||||
else:
|
||||
@@ -885,37 +880,34 @@ class KimiSparseMoeBlock(nn.Module):
|
||||
config=config, intermediate_size=intermediate_size
|
||||
)
|
||||
|
||||
def calculate_aux_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
|
||||
def calculate_aux_loss(
|
||||
self, router_logits: torch.Tensor, topk_idx: torch.Tensor
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Calculates the auxiliary load-balancing loss for the MoE layer.
|
||||
This is a critical component for stable training of MoE models.
|
||||
It encourages the router to send a balanced number of tokens to each expert.
|
||||
|
||||
The loss is a combination of:
|
||||
1. A loss that encourages the router to distribute tokens evenly.
|
||||
2. A loss that encourages the router logits to have a small magnitude (z-loss).
|
||||
Uses the standard Switch Transformer formulation.
|
||||
"""
|
||||
if router_logits is None or not self.training:
|
||||
# Return a zero tensor without accessing router_logits attributes when it's None
|
||||
return torch.zeros(1, requires_grad=False)[
|
||||
0
|
||||
] # Returns a scalar zero tensor
|
||||
return None
|
||||
|
||||
num_tokens, num_experts = router_logits.shape
|
||||
|
||||
# Calculate the probabilities and their mean across all tokens
|
||||
# P_i: Mean router probability per expert (differentiable)
|
||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
||||
mean_router_probs_per_expert = torch.mean(router_probs, dim=0)
|
||||
mean_router_prob_per_expert = torch.mean(router_probs, dim=0)
|
||||
|
||||
# Calculate the fraction of tokens dispatched to each expert
|
||||
# Create a one-hot representation of the router's choices
|
||||
# For top_k > 1, a token is "dispatched" to all its chosen experts
|
||||
# We can approximate this with router_probs during training for a differentiable loss
|
||||
tokens_per_expert = torch.mean(router_probs, dim=0)
|
||||
# f_i: Fraction of tokens routed to each expert (non-differentiable)
|
||||
expert_counts = torch.zeros(
|
||||
num_experts, device=router_logits.device, dtype=torch.float32
|
||||
)
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
ones = torch.ones_like(flat_topk_idx, dtype=torch.float32)
|
||||
expert_counts.scatter_add_(0, flat_topk_idx, ones)
|
||||
tokens_per_expert_fraction = expert_counts / flat_topk_idx.numel()
|
||||
|
||||
# The standard load balancing loss from the Switch Transformer paper
|
||||
# Load balancing loss: L = N * Σ(f_i * P_i)
|
||||
load_balancing_loss = num_experts * torch.sum(
|
||||
mean_router_probs_per_expert * tokens_per_expert
|
||||
tokens_per_expert_fraction * mean_router_prob_per_expert
|
||||
)
|
||||
|
||||
return self.router_aux_loss_coef * load_balancing_loss
|
||||
@@ -973,24 +965,18 @@ class KimiSparseMoeBlock(nn.Module):
|
||||
# # Return the final states and the auxiliary loss
|
||||
# return final_hidden_states, aux_loss
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
"""
|
||||
Optimized forward pass for MoE training that avoids materializing all expert outputs at once.
|
||||
"""
|
||||
identity = hidden_states
|
||||
batch_size, seq_len, hidden_dim = hidden_states.shape
|
||||
num_tokens = batch_size * seq_len
|
||||
|
||||
# Reshape for routing
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
# Get routing decisions from the gate
|
||||
# router_logits is None during inference
|
||||
router_logits, topk_idx, topk_weight = self.gate(
|
||||
hidden_states.view(batch_size, seq_len, hidden_dim)
|
||||
)
|
||||
|
||||
# Calculate auxiliary loss (will be 0.0 during inference)
|
||||
aux_loss = self.calculate_aux_loss(router_logits)
|
||||
# Pass topk_idx to aux loss calculation
|
||||
aux_loss = self.calculate_aux_loss(router_logits, topk_idx)
|
||||
|
||||
if self.training:
|
||||
# =================================================================================
|
||||
|
||||
Reference in New Issue
Block a user