fix: device for tensor
This commit is contained in:
@@ -896,9 +896,10 @@ class KimiSparseMoeBlock(nn.Module):
|
||||
2. A loss that encourages the router logits to have a small magnitude (z-loss).
|
||||
"""
|
||||
if router_logits is None or not self.training:
|
||||
return torch.tensor(
|
||||
0.0, device=router_logits.device, dtype=router_logits.dtype
|
||||
)
|
||||
# Return a zero tensor without accessing router_logits attributes when it's None
|
||||
return torch.zeros(1, requires_grad=False)[
|
||||
0
|
||||
] # Returns a scalar zero tensor
|
||||
|
||||
num_tokens, num_experts = router_logits.shape
|
||||
|
||||
@@ -1074,7 +1075,7 @@ class KimiSparseMoeBlock(nn.Module):
|
||||
).sum(dim=0)
|
||||
|
||||
# `split_indices` will be [count_e0, count_e0+count_e1, ...]
|
||||
split_indices = tokens_per_expert.cumsum(dim=0)
|
||||
# split_indices = tokens_per_expert.cumsum(dim=0)
|
||||
|
||||
# Process tokens expert by expert
|
||||
expert_outputs = []
|
||||
|
||||
Reference in New Issue
Block a user