inplace
This commit is contained in:
@@ -295,13 +295,13 @@ def moe_ffn_forward_grouped(
|
|||||||
|
|
||||||
routed_in = routed_input.to(expert_dtype)
|
routed_in = routed_input.to(expert_dtype)
|
||||||
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
|
||||||
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
|
gate_out = torch.ops.aten.silu_(gate_out)
|
||||||
activated = F.silu(gate_out) * up_out
|
gate_out.mul_(torch._grouped_mm(routed_in, w_up_t, offs=offsets))
|
||||||
down_out = torch._grouped_mm(activated, w2_t, offs=offsets)
|
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets)
|
||||||
|
|
||||||
weights_fp32 = scores_sorted.unsqueeze(-1).to(torch.float32)
|
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
|
||||||
weighted = (down_out.to(torch.float32) * weights_fp32).to(expert_dtype)
|
down_out.mul_(weights)
|
||||||
|
|
||||||
combined = torch.zeros_like(x_flat)
|
combined = torch.zeros_like(x_flat)
|
||||||
combined.scatter_add_(0, gather_index, weighted)
|
combined.scatter_add_(0, gather_index, down_out)
|
||||||
return combined.view(bsz, seqlen, hdim), router_logits
|
return combined.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|||||||
Reference in New Issue
Block a user