precompute fuse
This commit is contained in:
@@ -40,9 +40,9 @@ class _GroupedWeightStorage:
|
|||||||
gate: torch.Tensor
|
gate: torch.Tensor
|
||||||
up: torch.Tensor
|
up: torch.Tensor
|
||||||
down: torch.Tensor
|
down: torch.Tensor
|
||||||
|
fused_gate_up: torch.Tensor
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
device: torch.device
|
device: torch.device
|
||||||
base_gate: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_grouped_weights(
|
def _ensure_grouped_weights(
|
||||||
@@ -102,12 +102,14 @@ def _ensure_grouped_weights(
|
|||||||
mod.w2.weight.detach_()
|
mod.w2.weight.detach_()
|
||||||
mod.w2.weight.set_(down[idx])
|
mod.w2.weight.set_(down[idx])
|
||||||
|
|
||||||
|
fused = torch.cat((gate, up), dim=1).contiguous()
|
||||||
return _store(
|
return _store(
|
||||||
_GroupedWeightStorage(
|
_GroupedWeightStorage(
|
||||||
pattern=pattern,
|
pattern=pattern,
|
||||||
gate=gate,
|
gate=gate,
|
||||||
up=up,
|
up=up,
|
||||||
down=down,
|
down=down,
|
||||||
|
fused_gate_up=fused,
|
||||||
dtype=gate.dtype,
|
dtype=gate.dtype,
|
||||||
device=gate.device,
|
device=gate.device,
|
||||||
)
|
)
|
||||||
@@ -154,9 +156,9 @@ def _ensure_grouped_weights(
|
|||||||
gate=gate,
|
gate=gate,
|
||||||
up=up,
|
up=up,
|
||||||
down=down,
|
down=down,
|
||||||
|
fused_gate_up=gate_full,
|
||||||
dtype=gate.dtype,
|
dtype=gate.dtype,
|
||||||
device=gate.device,
|
device=gate.device,
|
||||||
base_gate=gate_full,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -205,12 +207,14 @@ def _ensure_grouped_weights(
|
|||||||
mod.down_proj.weight.detach_()
|
mod.down_proj.weight.detach_()
|
||||||
mod.down_proj.weight.set_(down[idx])
|
mod.down_proj.weight.set_(down[idx])
|
||||||
|
|
||||||
|
fused = torch.cat((gate, up), dim=1).contiguous()
|
||||||
return _store(
|
return _store(
|
||||||
_GroupedWeightStorage(
|
_GroupedWeightStorage(
|
||||||
pattern=pattern,
|
pattern=pattern,
|
||||||
gate=gate,
|
gate=gate,
|
||||||
up=up,
|
up=up,
|
||||||
down=down,
|
down=down,
|
||||||
|
fused_gate_up=fused,
|
||||||
dtype=gate.dtype,
|
dtype=gate.dtype,
|
||||||
device=gate.device,
|
device=gate.device,
|
||||||
)
|
)
|
||||||
@@ -248,8 +252,8 @@ def moe_ffn_forward_grouped(
|
|||||||
sample_mod = expert_impls[0]
|
sample_mod = expert_impls[0]
|
||||||
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
|
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
|
||||||
w_gate = storage.gate
|
w_gate = storage.gate
|
||||||
w_up = storage.up
|
|
||||||
w2 = storage.down
|
w2 = storage.down
|
||||||
|
w_gate_up = storage.fused_gate_up
|
||||||
|
|
||||||
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
||||||
router_logits = gate_linear(x_flat.to(routing_dtype))
|
router_logits = gate_linear(x_flat.to(routing_dtype))
|
||||||
@@ -289,15 +293,12 @@ def moe_ffn_forward_grouped(
|
|||||||
zero = torch.zeros_like(x_flat)
|
zero = torch.zeros_like(x_flat)
|
||||||
return zero.view(bsz, seqlen, hdim), router_logits
|
return zero.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|
||||||
w_gate_t = w_gate[active_idx].transpose(-2, -1)
|
w_gate_up_t = w_gate_up[active_idx].transpose(-2, -1)
|
||||||
w_up_t = w_up[active_idx].transpose(-2, -1)
|
|
||||||
w2_t = w2[active_idx].transpose(-2, -1)
|
w2_t = w2[active_idx].transpose(-2, -1)
|
||||||
|
|
||||||
routed_in = routed_input.to(expert_dtype)
|
routed_in = routed_input.to(expert_dtype)
|
||||||
gate_up_out = torch._grouped_mm(
|
gate_up_out = torch._grouped_mm(routed_in, w_gate_up_t, offs=offsets)
|
||||||
routed_in, torch.cat((w_gate_t, w_up_t), dim=-1), offs=offsets
|
inter_dim = w_gate.shape[1]
|
||||||
)
|
|
||||||
inter_dim = w_gate_t.shape[-1]
|
|
||||||
gate_out = torch.ops.aten.silu_(gate_up_out[..., :inter_dim])
|
gate_out = torch.ops.aten.silu_(gate_up_out[..., :inter_dim])
|
||||||
gate_out.mul_(gate_up_out[..., inter_dim:])
|
gate_out.mul_(gate_up_out[..., inter_dim:])
|
||||||
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets)
|
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets)
|
||||||
|
|||||||
Reference in New Issue
Block a user