diff --git a/src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py b/src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py index 5c01474d4..e8cece256 100644 --- a/src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py +++ b/src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py @@ -896,9 +896,10 @@ class KimiSparseMoeBlock(nn.Module): 2. A loss that encourages the router logits to have a small magnitude (z-loss). """ if router_logits is None or not self.training: - return torch.tensor( - 0.0, device=router_logits.device, dtype=router_logits.dtype - ) + # Return a zero tensor without accessing router_logits attributes when it's None + return torch.zeros(1, requires_grad=False)[ + 0 + ] # Returns a scalar zero tensor num_tokens, num_experts = router_logits.shape @@ -1074,7 +1075,7 @@ class KimiSparseMoeBlock(nn.Module): ).sum(dim=0) # `split_indices` will be [count_e0, count_e0+count_e1, ...] - split_indices = tokens_per_expert.cumsum(dim=0) + # split_indices = tokens_per_expert.cumsum(dim=0) # Process tokens expert by expert expert_outputs = []