precompute fuse

This commit is contained in:
Dan Saunders
2025-09-18 12:10:46 -04:00
parent b39ef54833
commit 0295df5bca

View File

@@ -40,9 +40,9 @@ class _GroupedWeightStorage:
gate: torch.Tensor
up: torch.Tensor
down: torch.Tensor
fused_gate_up: torch.Tensor
dtype: torch.dtype
device: torch.device
base_gate: Optional[torch.Tensor] = None
def _ensure_grouped_weights(
@@ -102,12 +102,14 @@ def _ensure_grouped_weights(
mod.w2.weight.detach_()
mod.w2.weight.set_(down[idx])
fused = torch.cat((gate, up), dim=1).contiguous()
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=fused,
dtype=gate.dtype,
device=gate.device,
)
@@ -154,9 +156,9 @@ def _ensure_grouped_weights(
gate=gate,
up=up,
down=down,
fused_gate_up=gate_full,
dtype=gate.dtype,
device=gate.device,
base_gate=gate_full,
)
)
@@ -205,12 +207,14 @@ def _ensure_grouped_weights(
mod.down_proj.weight.detach_()
mod.down_proj.weight.set_(down[idx])
fused = torch.cat((gate, up), dim=1).contiguous()
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=fused,
dtype=gate.dtype,
device=gate.device,
)
@@ -248,8 +252,8 @@ def moe_ffn_forward_grouped(
sample_mod = expert_impls[0]
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
w_gate = storage.gate
w_up = storage.up
w2 = storage.down
w_gate_up = storage.fused_gate_up
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
router_logits = gate_linear(x_flat.to(routing_dtype))
@@ -289,15 +293,12 @@ def moe_ffn_forward_grouped(
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
w_gate_t = w_gate[active_idx].transpose(-2, -1)
w_up_t = w_up[active_idx].transpose(-2, -1)
w_gate_up_t = w_gate_up[active_idx].transpose(-2, -1)
w2_t = w2[active_idx].transpose(-2, -1)
routed_in = routed_input.to(expert_dtype)
gate_up_out = torch._grouped_mm(
routed_in, torch.cat((w_gate_t, w_up_t), dim=-1), offs=offsets
)
inter_dim = w_gate_t.shape[-1]
gate_up_out = torch._grouped_mm(routed_in, w_gate_up_t, offs=offsets)
inter_dim = w_gate.shape[1]
gate_out = torch.ops.aten.silu_(gate_up_out[..., :inter_dim])
gate_out.mul_(gate_up_out[..., inter_dim:])
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets)