fix: device for tensor
This commit is contained in:
@@ -896,9 +896,10 @@ class KimiSparseMoeBlock(nn.Module):
|
|||||||
2. A loss that encourages the router logits to have a small magnitude (z-loss).
|
2. A loss that encourages the router logits to have a small magnitude (z-loss).
|
||||||
"""
|
"""
|
||||||
if router_logits is None or not self.training:
|
if router_logits is None or not self.training:
|
||||||
return torch.tensor(
|
# Return a zero tensor without accessing router_logits attributes when it's None
|
||||||
0.0, device=router_logits.device, dtype=router_logits.dtype
|
return torch.zeros(1, requires_grad=False)[
|
||||||
)
|
0
|
||||||
|
] # Returns a scalar zero tensor
|
||||||
|
|
||||||
num_tokens, num_experts = router_logits.shape
|
num_tokens, num_experts = router_logits.shape
|
||||||
|
|
||||||
@@ -1074,7 +1075,7 @@ class KimiSparseMoeBlock(nn.Module):
|
|||||||
).sum(dim=0)
|
).sum(dim=0)
|
||||||
|
|
||||||
# `split_indices` will be [count_e0, count_e0+count_e1, ...]
|
# `split_indices` will be [count_e0, count_e0+count_e1, ...]
|
||||||
split_indices = tokens_per_expert.cumsum(dim=0)
|
# split_indices = tokens_per_expert.cumsum(dim=0)
|
||||||
|
|
||||||
# Process tokens expert by expert
|
# Process tokens expert by expert
|
||||||
expert_outputs = []
|
expert_outputs = []
|
||||||
|
|||||||
Reference in New Issue
Block a user