remove contig
This commit is contained in:
@@ -289,9 +289,9 @@ def moe_ffn_forward_grouped(
|
|||||||
zero = torch.zeros_like(x_flat)
|
zero = torch.zeros_like(x_flat)
|
||||||
return zero.view(bsz, seqlen, hdim), router_logits
|
return zero.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|
||||||
w_gate_t = w_gate[active_idx].transpose(-2, -1).contiguous()
|
w_gate_t = w_gate[active_idx].transpose(-2, -1)
|
||||||
w_up_t = w_up[active_idx].transpose(-2, -1).contiguous()
|
w_up_t = w_up[active_idx].transpose(-2, -1)
|
||||||
w2_t = w2[active_idx].transpose(-2, -1).contiguous()
|
w2_t = w2[active_idx].transpose(-2, -1)
|
||||||
|
|
||||||
routed_in = routed_input.to(expert_dtype)
|
routed_in = routed_input.to(expert_dtype)
|
||||||
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
||||||
|
|||||||
Reference in New Issue
Block a user