diff --git a/src/axolotl/kernels/moe/hf_triton.py b/src/axolotl/kernels/moe/hf_triton.py index 08b8a740a..b79c01cb3 100644 --- a/src/axolotl/kernels/moe/hf_triton.py +++ b/src/axolotl/kernels/moe/hf_triton.py @@ -79,40 +79,20 @@ def moe_ffn_forward_stub( bsz, seqlen, hdim = hidden_states.shape flat = hidden_states.view(-1, hdim) router_logits = gate_linear(flat) - # use hub routing if available; otherwise fallback to softmax+topk - routed = None - if available(): - try: - routed = route_topk(router_logits, top_k) - except Exception: - routed = None - if routed is None: - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) - topk_weight /= topk_weight.sum(dim=-1, keepdim=True) - topk_weight = topk_weight.to(flat.dtype) - x_rep = flat.repeat_interleave(top_k, dim=0) - y = torch.empty_like(x_rep) - flat_idx = topk_idx.view(-1) - for i in range(experts_module.num_experts): - expert = experts_module[i] - y[flat_idx == i] = expert(x_rep[flat_idx == i]) - y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) - return y.reshape(bsz, seqlen, hdim), router_logits - # If routed via hub, still fallback to per-expert compute until grouped GEMM path is wired. - ex_routing_data, gather_idx, scatter_idx = routed - # Convert to naive per-expert compute on packed tokens (future: call matmul_ogs + swiglu) - # For now, reconstruct the same result as naive path (no speedup but validates routing). - # We map the selected experts from gather_idx back to expert ids via router_logits argmax among top-k. - # Simpler: reuse naive computation for correctness; detailed integration will follow. - routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float) + # For now, do not call routing to avoid extra overhead until + # grouped GEMM integration is complete. Use the naive compute path + # for correctness and baseline performance. + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) - topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(flat.dtype) + topk_weight /= topk_weight.sum(dim=-1, keepdim=True) + topk_weight = topk_weight.to(flat.dtype) x_rep = flat.repeat_interleave(top_k, dim=0) y = torch.empty_like(x_rep) flat_idx = topk_idx.view(-1) for i in range(experts_module.num_experts): expert = experts_module[i] - y[flat_idx == i] = expert(x_rep[flat_idx == i]) + sel = flat_idx == i + if sel.any(): + y[sel] = expert(x_rep[sel]) y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) return y.reshape(bsz, seqlen, hdim), router_logits