From efcd032fce91941049ddf00b7c1b1b563c13e516 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 18 Sep 2025 13:03:28 -0400 Subject: [PATCH] yet another refactor --- src/axolotl/kernels/moe/torch_grouped.py | 35 +++++++++++------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 526e9d468..4f161b571 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -271,8 +271,8 @@ def moe_ffn_forward_grouped( sample_mod = expert_impls[0] storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod) w_gate = storage.gate + w_up = storage.up w2 = storage.down - w_gate_up = storage.fused_gate_up x_flat = hidden_states.view(tokens, hdim).to(expert_dtype) router_logits = gate_linear(x_flat.to(routing_dtype)) @@ -296,33 +296,30 @@ def moe_ffn_forward_grouped( token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous() scores_sorted = topk_weight.reshape(-1).index_select(0, perm) - routed_input = x_flat.index_select(0, token_indices_sorted).contiguous() + gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim) + routed_input = torch.gather(x_flat, 0, gather_index) - active_idx = torch.nonzero(assignments, as_tuple=False).squeeze(-1).contiguous() - if active_idx.numel() == 0: - zero = torch.zeros_like(x_flat) - return zero.view(bsz, seqlen, hdim), router_logits - - counts_active = assignments[active_idx] - counts_active_i32 = counts_active.to(device=device, dtype=torch.int32) - offsets = torch.cumsum(counts_active_i32, dim=0) + counts_i32 = assignments.to(device=device, dtype=torch.int32) + offsets = torch.cumsum(counts_i32, dim=0) if offsets[-1].item() == 0: zero = torch.zeros_like(x_flat) return zero.view(bsz, seqlen, hdim), router_logits - w_gate_up_t = w_gate_up.index_select(0, active_idx).transpose(-2, -1) - w2_t = w2.index_select(0, active_idx).transpose(-2, -1) + mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype + routed_in = routed_input.to(mm_dtype) + w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype) + w_up_t = w_up.transpose(-2, -1).to(mm_dtype) + w2_t = w2.transpose(-2, -1).to(mm_dtype) - routed_in = routed_input.to(expert_dtype) - gate_up_out = torch._grouped_mm(routed_in, w_gate_up_t, offs=offsets) - inter_dim = w_gate.shape[1] - gate_out = torch.ops.aten.silu_(gate_up_out[..., :inter_dim]) - gate_out.mul_(gate_up_out[..., inter_dim:]) - down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets) + gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets) + torch.ops.aten.silu_(gate_out) + up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets) + gate_out.mul_(up_out) + down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype) weights = scores_sorted.unsqueeze(-1).to(expert_dtype) down_out.mul_(weights) combined = torch.zeros_like(x_flat) - combined.index_add_(0, token_indices_sorted, down_out) + combined.scatter_add_(0, gather_index, down_out) return combined.view(bsz, seqlen, hdim), router_logits