This commit is contained in:
Dan Saunders
2025-09-18 11:29:33 -04:00
parent 7d867de9b2
commit 2a176e4923

View File

@@ -134,6 +134,7 @@ def moe_ffn_forward_grouped(
counts_active = assignments[active_idx]
offsets = torch.cumsum(counts_active.to(device=device, dtype=torch.int32), dim=0)
offsets = offsets.to(torch.int32)
if offsets[-1].item() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits