diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 68ff631bc..dd434d7a8 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -342,7 +342,7 @@ def moe_ffn_forward_grouped( counts_i32 = assignments.to(device=device, dtype=torch.int32) offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32) - if offsets[-1].item() == 0: + if not torch.is_nonzero(offsets[-1]): zero = torch.zeros_like(x_flat) return zero.view(bsz, seqlen, hdim), router_logits mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype