combine mult
This commit is contained in:
@@ -294,9 +294,12 @@ def moe_ffn_forward_grouped(
|
|||||||
w2_t = w2[active_idx].transpose(-2, -1)
|
w2_t = w2[active_idx].transpose(-2, -1)
|
||||||
|
|
||||||
routed_in = routed_input.to(expert_dtype)
|
routed_in = routed_input.to(expert_dtype)
|
||||||
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
gate_up_out = torch._grouped_mm(
|
||||||
gate_out = torch.ops.aten.silu_(gate_out)
|
routed_in, torch.cat((w_gate_t, w_up_t), dim=-1), offs=offsets
|
||||||
gate_out.mul_(torch._grouped_mm(routed_in, w_up_t, offs=offsets))
|
)
|
||||||
|
inter_dim = w_gate_t.shape[-1]
|
||||||
|
gate_out = torch.ops.aten.silu_(gate_up_out[..., :inter_dim])
|
||||||
|
gate_out.mul_(gate_up_out[..., inter_dim:])
|
||||||
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets)
|
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets)
|
||||||
|
|
||||||
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
|
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user