precompute fuse
This commit is contained in:
@@ -40,9 +40,9 @@ class _GroupedWeightStorage:
|
||||
gate: torch.Tensor
|
||||
up: torch.Tensor
|
||||
down: torch.Tensor
|
||||
fused_gate_up: torch.Tensor
|
||||
dtype: torch.dtype
|
||||
device: torch.device
|
||||
base_gate: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def _ensure_grouped_weights(
|
||||
@@ -102,12 +102,14 @@ def _ensure_grouped_weights(
|
||||
mod.w2.weight.detach_()
|
||||
mod.w2.weight.set_(down[idx])
|
||||
|
||||
fused = torch.cat((gate, up), dim=1).contiguous()
|
||||
return _store(
|
||||
_GroupedWeightStorage(
|
||||
pattern=pattern,
|
||||
gate=gate,
|
||||
up=up,
|
||||
down=down,
|
||||
fused_gate_up=fused,
|
||||
dtype=gate.dtype,
|
||||
device=gate.device,
|
||||
)
|
||||
@@ -154,9 +156,9 @@ def _ensure_grouped_weights(
|
||||
gate=gate,
|
||||
up=up,
|
||||
down=down,
|
||||
fused_gate_up=gate_full,
|
||||
dtype=gate.dtype,
|
||||
device=gate.device,
|
||||
base_gate=gate_full,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -205,12 +207,14 @@ def _ensure_grouped_weights(
|
||||
mod.down_proj.weight.detach_()
|
||||
mod.down_proj.weight.set_(down[idx])
|
||||
|
||||
fused = torch.cat((gate, up), dim=1).contiguous()
|
||||
return _store(
|
||||
_GroupedWeightStorage(
|
||||
pattern=pattern,
|
||||
gate=gate,
|
||||
up=up,
|
||||
down=down,
|
||||
fused_gate_up=fused,
|
||||
dtype=gate.dtype,
|
||||
device=gate.device,
|
||||
)
|
||||
@@ -248,8 +252,8 @@ def moe_ffn_forward_grouped(
|
||||
sample_mod = expert_impls[0]
|
||||
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
|
||||
w_gate = storage.gate
|
||||
w_up = storage.up
|
||||
w2 = storage.down
|
||||
w_gate_up = storage.fused_gate_up
|
||||
|
||||
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
||||
router_logits = gate_linear(x_flat.to(routing_dtype))
|
||||
@@ -289,15 +293,12 @@ def moe_ffn_forward_grouped(
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
w_gate_t = w_gate[active_idx].transpose(-2, -1)
|
||||
w_up_t = w_up[active_idx].transpose(-2, -1)
|
||||
w_gate_up_t = w_gate_up[active_idx].transpose(-2, -1)
|
||||
w2_t = w2[active_idx].transpose(-2, -1)
|
||||
|
||||
routed_in = routed_input.to(expert_dtype)
|
||||
gate_up_out = torch._grouped_mm(
|
||||
routed_in, torch.cat((w_gate_t, w_up_t), dim=-1), offs=offsets
|
||||
)
|
||||
inter_dim = w_gate_t.shape[-1]
|
||||
gate_up_out = torch._grouped_mm(routed_in, w_gate_up_t, offs=offsets)
|
||||
inter_dim = w_gate.shape[1]
|
||||
gate_out = torch.ops.aten.silu_(gate_up_out[..., :inter_dim])
|
||||
gate_out.mul_(gate_up_out[..., inter_dim:])
|
||||
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets)
|
||||
|
||||
Reference in New Issue
Block a user