fix: device for tensor

This commit is contained in:
NanoCode012
2025-12-22 18:40:40 +07:00
parent 9249b49a09
commit eec5342c76

View File

@@ -896,9 +896,10 @@ class KimiSparseMoeBlock(nn.Module):
2. A loss that encourages the router logits to have a small magnitude (z-loss).
"""
if router_logits is None or not self.training:
return torch.tensor(
0.0, device=router_logits.device, dtype=router_logits.dtype
)
# Return a zero tensor without accessing router_logits attributes when it's None
return torch.zeros(1, requires_grad=False)[
0
] # Returns a scalar zero tensor
num_tokens, num_experts = router_logits.shape
@@ -1074,7 +1075,7 @@ class KimiSparseMoeBlock(nn.Module):
).sum(dim=0)
# `split_indices` will be [count_e0, count_e0+count_e1, ...]
split_indices = tokens_per_expert.cumsum(dim=0)
# split_indices = tokens_per_expert.cumsum(dim=0)
# Process tokens expert by expert
expert_outputs = []