diff --git a/src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py b/src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py index e8cece256..577933ecf 100644 --- a/src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py +++ b/src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py @@ -647,24 +647,19 @@ class KimiMoEGate(nn.Module): if self.training: # Training path: standard, differentiable top-k routing - - # Use softmax for probabilities, as it's standard for MoE routing loss gating_scores = F.softmax(router_logits, dim=-1, dtype=torch.float32) - # Get top-k scores and their indices topk_weight, topk_idx = torch.topk( gating_scores, self.top_k, dim=-1, sorted=False ) - # Re-normalize top-k weights to sum to 1 if self.top_k > 1 and self.moe_renormalize: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator - # Apply scaling factor topk_weight = topk_weight * self.routed_scaling_factor - # During training, we return the raw logits for the aux loss calculation + # Return logits AND topk_idx for aux loss calculation return router_logits, topk_idx, topk_weight else: @@ -885,37 +880,34 @@ class KimiSparseMoeBlock(nn.Module): config=config, intermediate_size=intermediate_size ) - def calculate_aux_loss(self, router_logits: torch.Tensor) -> torch.Tensor: + def calculate_aux_loss( + self, router_logits: torch.Tensor, topk_idx: torch.Tensor + ) -> Optional[torch.Tensor]: """ Calculates the auxiliary load-balancing loss for the MoE layer. - This is a critical component for stable training of MoE models. - It encourages the router to send a balanced number of tokens to each expert. - - The loss is a combination of: - 1. A loss that encourages the router to distribute tokens evenly. - 2. A loss that encourages the router logits to have a small magnitude (z-loss). + Uses the standard Switch Transformer formulation. """ if router_logits is None or not self.training: - # Return a zero tensor without accessing router_logits attributes when it's None - return torch.zeros(1, requires_grad=False)[ - 0 - ] # Returns a scalar zero tensor + return None num_tokens, num_experts = router_logits.shape - # Calculate the probabilities and their mean across all tokens + # P_i: Mean router probability per expert (differentiable) router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) - mean_router_probs_per_expert = torch.mean(router_probs, dim=0) + mean_router_prob_per_expert = torch.mean(router_probs, dim=0) - # Calculate the fraction of tokens dispatched to each expert - # Create a one-hot representation of the router's choices - # For top_k > 1, a token is "dispatched" to all its chosen experts - # We can approximate this with router_probs during training for a differentiable loss - tokens_per_expert = torch.mean(router_probs, dim=0) + # f_i: Fraction of tokens routed to each expert (non-differentiable) + expert_counts = torch.zeros( + num_experts, device=router_logits.device, dtype=torch.float32 + ) + flat_topk_idx = topk_idx.view(-1) + ones = torch.ones_like(flat_topk_idx, dtype=torch.float32) + expert_counts.scatter_add_(0, flat_topk_idx, ones) + tokens_per_expert_fraction = expert_counts / flat_topk_idx.numel() - # The standard load balancing loss from the Switch Transformer paper + # Load balancing loss: L = N * Σ(f_i * P_i) load_balancing_loss = num_experts * torch.sum( - mean_router_probs_per_expert * tokens_per_expert + tokens_per_expert_fraction * mean_router_prob_per_expert ) return self.router_aux_loss_coef * load_balancing_loss @@ -973,24 +965,18 @@ class KimiSparseMoeBlock(nn.Module): # # Return the final states and the auxiliary loss # return final_hidden_states, aux_loss def forward(self, hidden_states: torch.Tensor): - """ - Optimized forward pass for MoE training that avoids materializing all expert outputs at once. - """ identity = hidden_states batch_size, seq_len, hidden_dim = hidden_states.shape num_tokens = batch_size * seq_len - # Reshape for routing hidden_states = hidden_states.view(num_tokens, hidden_dim) - # Get routing decisions from the gate - # router_logits is None during inference router_logits, topk_idx, topk_weight = self.gate( hidden_states.view(batch_size, seq_len, hidden_dim) ) - # Calculate auxiliary loss (will be 0.0 during inference) - aux_loss = self.calculate_aux_loss(router_logits) + # Pass topk_idx to aux loss calculation + aux_loss = self.calculate_aux_loss(router_logits, topk_idx) if self.training: # =================================================================================