diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index dd434d7a8..40dc77b74 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -342,9 +342,6 @@ def moe_ffn_forward_grouped( counts_i32 = assignments.to(device=device, dtype=torch.int32) offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32) - if not torch.is_nonzero(offsets[-1]): - zero = torch.zeros_like(x_flat) - return zero.view(bsz, seqlen, hdim), router_logits mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype routed_in = routed_input.to(mm_dtype) w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)