fix: device for tensor

This commit is contained in:
NanoCode012
2025-12-22 18:40:40 +07:00
parent 9249b49a09
commit eec5342c76

View File

@@ -896,9 +896,10 @@ class KimiSparseMoeBlock(nn.Module):
2. A loss that encourages the router logits to have a small magnitude (z-loss). 2. A loss that encourages the router logits to have a small magnitude (z-loss).
""" """
if router_logits is None or not self.training: if router_logits is None or not self.training:
return torch.tensor( # Return a zero tensor without accessing router_logits attributes when it's None
0.0, device=router_logits.device, dtype=router_logits.dtype return torch.zeros(1, requires_grad=False)[
) 0
] # Returns a scalar zero tensor
num_tokens, num_experts = router_logits.shape num_tokens, num_experts = router_logits.shape
@@ -1074,7 +1075,7 @@ class KimiSparseMoeBlock(nn.Module):
).sum(dim=0) ).sum(dim=0)
# `split_indices` will be [count_e0, count_e0+count_e1, ...] # `split_indices` will be [count_e0, count_e0+count_e1, ...]
split_indices = tokens_per_expert.cumsum(dim=0) # split_indices = tokens_per_expert.cumsum(dim=0)
# Process tokens expert by expert # Process tokens expert by expert
expert_outputs = [] expert_outputs = []