fix compile
This commit is contained in:
@@ -342,9 +342,6 @@ def moe_ffn_forward_grouped(
|
|||||||
|
|
||||||
counts_i32 = assignments.to(device=device, dtype=torch.int32)
|
counts_i32 = assignments.to(device=device, dtype=torch.int32)
|
||||||
offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32)
|
offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32)
|
||||||
if not torch.is_nonzero(offsets[-1]):
|
|
||||||
zero = torch.zeros_like(x_flat)
|
|
||||||
return zero.view(bsz, seqlen, hdim), router_logits
|
|
||||||
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
|
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
|
||||||
routed_in = routed_input.to(mm_dtype)
|
routed_in = routed_input.to(mm_dtype)
|
||||||
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
|
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user