fix compile

This commit is contained in:
Dan Saunders
2025-09-19 13:52:42 -04:00
parent 7327144344
commit b5dc58373f

View File

@@ -342,7 +342,7 @@ def moe_ffn_forward_grouped(
counts_i32 = assignments.to(device=device, dtype=torch.int32)
offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32)
if offsets[-1].item() == 0:
if not torch.is_nonzero(offsets[-1]):
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype