diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 2c5049244..736cd7e9b 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -5,6 +5,7 @@ This is a cautious first pass that probes available ops and only runs when suppo from __future__ import annotations +import logging from typing import List, Optional, Tuple import torch @@ -28,6 +29,7 @@ def available() -> bool: LAST_ERROR: Optional[str] = None +_LOGGER = logging.getLogger("axolotl.moe.grouped") def _call_grouped_mm( @@ -94,37 +96,72 @@ def moe_ffn_forward_grouped( flat_idx = topk_idx.view(-1) x_rep = x.repeat_interleave(top_k, dim=0) - # Cache stacked weights on experts + # Cache stacked weights on experts (support Mixtral and Qwen2-MoE layouts) E = experts_module.num_experts dev, dt = x.device, x.dtype - if ( - not hasattr(experts_module, "_stacked_w1") - or experts_module._stacked_w1.device != dev - or experts_module._stacked_w1.dtype != dt - ): - w1 = [experts_module[i].w1.weight.t() for i in range(E)] - w3 = [experts_module[i].w3.weight.t() for i in range(E)] - w2 = [experts_module[i].w2.weight.t() for i in range(E)] - experts_module._stacked_w1 = ( - torch.stack(w1, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w3 = ( - torch.stack(w3, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w2 = ( - torch.stack(w2, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w13 = torch.cat( - [experts_module._stacked_w1, experts_module._stacked_w3], dim=-1 - ).contiguous() - W13 = experts_module._stacked_w13 - W2 = experts_module._stacked_w2 + first = experts_module[0] + is_mixtral = ( + hasattr(first, "w1") and hasattr(first, "w3") and hasattr(first, "w2") + ) + is_qwen2 = hasattr(first, "gate_up_proj") and hasattr(first, "down_proj") + if not (is_mixtral or is_qwen2): + if not getattr(experts_module, "_ax_grouped_logged_fail", False): + _LOGGER.warning( + "torch_grouped: unsupported expert layout; falling back to naive" + ) + experts_module._ax_grouped_logged_fail = True + return None, None + + if is_mixtral: + if ( + not hasattr(experts_module, "_stacked_w1") + or experts_module._stacked_w1.device != dev + or experts_module._stacked_w1.dtype != dt + ): + w1 = [experts_module[i].w1.weight.t() for i in range(E)] + w3 = [experts_module[i].w3.weight.t() for i in range(E)] + w2 = [experts_module[i].w2.weight.t() for i in range(E)] + experts_module._stacked_w1 = ( + torch.stack(w1, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w3 = ( + torch.stack(w3, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w2 = ( + torch.stack(w2, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w13 = torch.cat( + [experts_module._stacked_w1, experts_module._stacked_w3], dim=-1 + ).contiguous() + W13 = experts_module._stacked_w13 + W2 = experts_module._stacked_w2 + else: + # Qwen2-MoE style: gate_up_proj (2I x H), down_proj (H x I) + if ( + not hasattr(experts_module, "_stacked_w13") + or experts_module._stacked_w13.device != dev + or experts_module._stacked_w13.dtype != dt + ): + w13 = [experts_module[i].gate_up_proj.weight.t() for i in range(E)] + w2 = [experts_module[i].down_proj.weight.t() for i in range(E)] + experts_module._stacked_w13 = ( + torch.stack(w13, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w2 = ( + torch.stack(w2, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + W13 = experts_module._stacked_w13 + W2 = experts_module._stacked_w2 # Grouped GEMM for up+gate As: List[torch.Tensor] = [] @@ -145,6 +182,11 @@ def moe_ffn_forward_grouped( Y_list = _call_grouped_mm(As, Bs) if Y_list is None: + if not getattr(experts_module, "_ax_grouped_logged_fail", False): + _LOGGER.warning( + f"torch_grouped: grouped_mm up+gate failed; falling back to naive. Reason: {LAST_ERROR}" + ) + experts_module._ax_grouped_logged_fail = True return None, None # SwiGLU on each expert block and prepare for down projection @@ -160,12 +202,22 @@ def moe_ffn_forward_grouped( Y2_list = _call_grouped_mm(As2, Bs2) if Y2_list is None: + if not getattr(experts_module, "_ax_grouped_logged_fail", False): + _LOGGER.warning( + f"torch_grouped: grouped_mm down failed; falling back to naive. Reason: {LAST_ERROR}" + ) + experts_module._ax_grouped_logged_fail = True return None, None # Write back, apply per-token weighting, and reduce over top_k for (i, sel), Out_i in zip(expert_slices, Y2_list): y_buf[sel] = Out_i y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + if not getattr(experts_module, "_ax_grouped_logged_ok", False): + _LOGGER.info( + f"torch_grouped: engaged grouped GEMM path (experts={E}, top_k={top_k})" + ) + experts_module._ax_grouped_logged_ok = True return y.view(bsz, seqlen, hdim), router_logits except Exception: return None, None