diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 9201d17a3..d5d311b6b 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -134,6 +134,7 @@ def moe_ffn_forward_grouped( counts_active = assignments[active_idx] offsets = torch.cumsum(counts_active.to(device=device, dtype=torch.int32), dim=0) + offsets = offsets.to(torch.int32) if offsets[-1].item() == 0: zero = torch.zeros_like(x_flat) return zero.view(bsz, seqlen, hdim), router_logits