fix: aux loss calculation
This commit is contained in:
@@ -647,24 +647,19 @@ class KimiMoEGate(nn.Module):
|
|||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
# Training path: standard, differentiable top-k routing
|
# Training path: standard, differentiable top-k routing
|
||||||
|
|
||||||
# Use softmax for probabilities, as it's standard for MoE routing loss
|
|
||||||
gating_scores = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
gating_scores = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
||||||
|
|
||||||
# Get top-k scores and their indices
|
|
||||||
topk_weight, topk_idx = torch.topk(
|
topk_weight, topk_idx = torch.topk(
|
||||||
gating_scores, self.top_k, dim=-1, sorted=False
|
gating_scores, self.top_k, dim=-1, sorted=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Re-normalize top-k weights to sum to 1
|
|
||||||
if self.top_k > 1 and self.moe_renormalize:
|
if self.top_k > 1 and self.moe_renormalize:
|
||||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||||
topk_weight = topk_weight / denominator
|
topk_weight = topk_weight / denominator
|
||||||
|
|
||||||
# Apply scaling factor
|
|
||||||
topk_weight = topk_weight * self.routed_scaling_factor
|
topk_weight = topk_weight * self.routed_scaling_factor
|
||||||
|
|
||||||
# During training, we return the raw logits for the aux loss calculation
|
# Return logits AND topk_idx for aux loss calculation
|
||||||
return router_logits, topk_idx, topk_weight
|
return router_logits, topk_idx, topk_weight
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -885,37 +880,34 @@ class KimiSparseMoeBlock(nn.Module):
|
|||||||
config=config, intermediate_size=intermediate_size
|
config=config, intermediate_size=intermediate_size
|
||||||
)
|
)
|
||||||
|
|
||||||
def calculate_aux_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
|
def calculate_aux_loss(
|
||||||
|
self, router_logits: torch.Tensor, topk_idx: torch.Tensor
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Calculates the auxiliary load-balancing loss for the MoE layer.
|
Calculates the auxiliary load-balancing loss for the MoE layer.
|
||||||
This is a critical component for stable training of MoE models.
|
Uses the standard Switch Transformer formulation.
|
||||||
It encourages the router to send a balanced number of tokens to each expert.
|
|
||||||
|
|
||||||
The loss is a combination of:
|
|
||||||
1. A loss that encourages the router to distribute tokens evenly.
|
|
||||||
2. A loss that encourages the router logits to have a small magnitude (z-loss).
|
|
||||||
"""
|
"""
|
||||||
if router_logits is None or not self.training:
|
if router_logits is None or not self.training:
|
||||||
# Return a zero tensor without accessing router_logits attributes when it's None
|
return None
|
||||||
return torch.zeros(1, requires_grad=False)[
|
|
||||||
0
|
|
||||||
] # Returns a scalar zero tensor
|
|
||||||
|
|
||||||
num_tokens, num_experts = router_logits.shape
|
num_tokens, num_experts = router_logits.shape
|
||||||
|
|
||||||
# Calculate the probabilities and their mean across all tokens
|
# P_i: Mean router probability per expert (differentiable)
|
||||||
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
|
||||||
mean_router_probs_per_expert = torch.mean(router_probs, dim=0)
|
mean_router_prob_per_expert = torch.mean(router_probs, dim=0)
|
||||||
|
|
||||||
# Calculate the fraction of tokens dispatched to each expert
|
# f_i: Fraction of tokens routed to each expert (non-differentiable)
|
||||||
# Create a one-hot representation of the router's choices
|
expert_counts = torch.zeros(
|
||||||
# For top_k > 1, a token is "dispatched" to all its chosen experts
|
num_experts, device=router_logits.device, dtype=torch.float32
|
||||||
# We can approximate this with router_probs during training for a differentiable loss
|
)
|
||||||
tokens_per_expert = torch.mean(router_probs, dim=0)
|
flat_topk_idx = topk_idx.view(-1)
|
||||||
|
ones = torch.ones_like(flat_topk_idx, dtype=torch.float32)
|
||||||
|
expert_counts.scatter_add_(0, flat_topk_idx, ones)
|
||||||
|
tokens_per_expert_fraction = expert_counts / flat_topk_idx.numel()
|
||||||
|
|
||||||
# The standard load balancing loss from the Switch Transformer paper
|
# Load balancing loss: L = N * Σ(f_i * P_i)
|
||||||
load_balancing_loss = num_experts * torch.sum(
|
load_balancing_loss = num_experts * torch.sum(
|
||||||
mean_router_probs_per_expert * tokens_per_expert
|
tokens_per_expert_fraction * mean_router_prob_per_expert
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.router_aux_loss_coef * load_balancing_loss
|
return self.router_aux_loss_coef * load_balancing_loss
|
||||||
@@ -973,24 +965,18 @@ class KimiSparseMoeBlock(nn.Module):
|
|||||||
# # Return the final states and the auxiliary loss
|
# # Return the final states and the auxiliary loss
|
||||||
# return final_hidden_states, aux_loss
|
# return final_hidden_states, aux_loss
|
||||||
def forward(self, hidden_states: torch.Tensor):
|
def forward(self, hidden_states: torch.Tensor):
|
||||||
"""
|
|
||||||
Optimized forward pass for MoE training that avoids materializing all expert outputs at once.
|
|
||||||
"""
|
|
||||||
identity = hidden_states
|
identity = hidden_states
|
||||||
batch_size, seq_len, hidden_dim = hidden_states.shape
|
batch_size, seq_len, hidden_dim = hidden_states.shape
|
||||||
num_tokens = batch_size * seq_len
|
num_tokens = batch_size * seq_len
|
||||||
|
|
||||||
# Reshape for routing
|
|
||||||
hidden_states = hidden_states.view(num_tokens, hidden_dim)
|
hidden_states = hidden_states.view(num_tokens, hidden_dim)
|
||||||
|
|
||||||
# Get routing decisions from the gate
|
|
||||||
# router_logits is None during inference
|
|
||||||
router_logits, topk_idx, topk_weight = self.gate(
|
router_logits, topk_idx, topk_weight = self.gate(
|
||||||
hidden_states.view(batch_size, seq_len, hidden_dim)
|
hidden_states.view(batch_size, seq_len, hidden_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate auxiliary loss (will be 0.0 during inference)
|
# Pass topk_idx to aux loss calculation
|
||||||
aux_loss = self.calculate_aux_loss(router_logits)
|
aux_loss = self.calculate_aux_loss(router_logits, topk_idx)
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
# =================================================================================
|
# =================================================================================
|
||||||
|
|||||||
Reference in New Issue
Block a user