fix
This commit is contained in:
@@ -134,6 +134,7 @@ def moe_ffn_forward_grouped(
|
|||||||
|
|
||||||
counts_active = assignments[active_idx]
|
counts_active = assignments[active_idx]
|
||||||
offsets = torch.cumsum(counts_active.to(device=device, dtype=torch.int32), dim=0)
|
offsets = torch.cumsum(counts_active.to(device=device, dtype=torch.int32), dim=0)
|
||||||
|
offsets = offsets.to(torch.int32)
|
||||||
if offsets[-1].item() == 0:
|
if offsets[-1].item() == 0:
|
||||||
zero = torch.zeros_like(x_flat)
|
zero = torch.zeros_like(x_flat)
|
||||||
return zero.view(bsz, seqlen, hdim), router_logits
|
return zero.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|||||||
Reference in New Issue
Block a user