contig
This commit is contained in:
@@ -344,10 +344,15 @@ def moe_ffn_forward_grouped(
|
|||||||
w_up_t = w_up.transpose(-2, -1).to(mm_dtype)
|
w_up_t = w_up.transpose(-2, -1).to(mm_dtype)
|
||||||
w2_t = w2.transpose(-2, -1).to(mm_dtype)
|
w2_t = w2.transpose(-2, -1).to(mm_dtype)
|
||||||
|
|
||||||
|
routed_in = routed_in.contiguous()
|
||||||
|
w_gate_t = w_gate_t.contiguous()
|
||||||
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
||||||
torch.ops.aten.silu_(gate_out)
|
torch.ops.aten.silu_(gate_out)
|
||||||
|
w_up_t = w_up_t.contiguous()
|
||||||
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
|
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
|
||||||
gate_out.mul_(up_out)
|
gate_out.mul_(up_out)
|
||||||
|
gate_out = gate_out.contiguous()
|
||||||
|
w2_t = w2_t.contiguous()
|
||||||
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype)
|
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype)
|
||||||
|
|
||||||
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
|
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user