From 0295df5bca97a093cdb46d143004d7dd787d99fd Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 18 Sep 2025 12:10:46 -0400 Subject: [PATCH] precompute fuse --- src/axolotl/kernels/moe/torch_grouped.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index c2a33e455..797c499d0 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -40,9 +40,9 @@ class _GroupedWeightStorage: gate: torch.Tensor up: torch.Tensor down: torch.Tensor + fused_gate_up: torch.Tensor dtype: torch.dtype device: torch.device - base_gate: Optional[torch.Tensor] = None def _ensure_grouped_weights( @@ -102,12 +102,14 @@ def _ensure_grouped_weights( mod.w2.weight.detach_() mod.w2.weight.set_(down[idx]) + fused = torch.cat((gate, up), dim=1).contiguous() return _store( _GroupedWeightStorage( pattern=pattern, gate=gate, up=up, down=down, + fused_gate_up=fused, dtype=gate.dtype, device=gate.device, ) @@ -154,9 +156,9 @@ def _ensure_grouped_weights( gate=gate, up=up, down=down, + fused_gate_up=gate_full, dtype=gate.dtype, device=gate.device, - base_gate=gate_full, ) ) @@ -205,12 +207,14 @@ def _ensure_grouped_weights( mod.down_proj.weight.detach_() mod.down_proj.weight.set_(down[idx]) + fused = torch.cat((gate, up), dim=1).contiguous() return _store( _GroupedWeightStorage( pattern=pattern, gate=gate, up=up, down=down, + fused_gate_up=fused, dtype=gate.dtype, device=gate.device, ) @@ -248,8 +252,8 @@ def moe_ffn_forward_grouped( sample_mod = expert_impls[0] storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod) w_gate = storage.gate - w_up = storage.up w2 = storage.down + w_gate_up = storage.fused_gate_up x_flat = hidden_states.view(tokens, hdim).to(expert_dtype) router_logits = gate_linear(x_flat.to(routing_dtype)) @@ -289,15 +293,12 @@ def moe_ffn_forward_grouped( zero = torch.zeros_like(x_flat) return zero.view(bsz, seqlen, hdim), router_logits - w_gate_t = w_gate[active_idx].transpose(-2, -1) - w_up_t = w_up[active_idx].transpose(-2, -1) + w_gate_up_t = w_gate_up[active_idx].transpose(-2, -1) w2_t = w2[active_idx].transpose(-2, -1) routed_in = routed_input.to(expert_dtype) - gate_up_out = torch._grouped_mm( - routed_in, torch.cat((w_gate_t, w_up_t), dim=-1), offs=offsets - ) - inter_dim = w_gate_t.shape[-1] + gate_up_out = torch._grouped_mm(routed_in, w_gate_up_t, offs=offsets) + inter_dim = w_gate.shape[1] gate_out = torch.ops.aten.silu_(gate_up_out[..., :inter_dim]) gate_out.mul_(gate_up_out[..., inter_dim:]) down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets)