From 51e565f60a7a82146701defe3d5d33c5a126ad72 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 17 Sep 2025 14:15:51 -0400 Subject: [PATCH] logs --- src/axolotl/kernels/moe/torch_grouped.py | 376 +++++++++++------------ src/axolotl/monkeypatch/moe_grouped.py | 15 +- 2 files changed, 192 insertions(+), 199 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 38c0bb2b7..db285b1cc 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -88,205 +88,189 @@ def _call_grouped_mm( def moe_ffn_forward_grouped( hidden_states, gate_linear, experts_module, top_k: int ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - """ - Attempt a grouped GEMM fast path using PyTorch 2.8+. - If unavailable or fails, returns (None, None) so caller can fallback. - """ - try: - bsz, seqlen, hdim = hidden_states.shape - x = hidden_states.view(-1, hdim) - router_logits = gate_linear(x) + """Attempt grouped GEMM fast path using PyTorch 2.8+.""" + global LAST_ERROR + LAST_ERROR = None + bsz, seqlen, hdim = hidden_states.shape + x = hidden_states.view(-1, hdim) + router_logits = gate_linear(x) - # topk routing in torch (keep simple to avoid dependency cycles) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) - topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype) + # top-k routing executed in torch to avoid extra dependencies + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) + topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype) - # Build per-expert input lists - flat_idx = topk_idx.view(-1) - x_rep = x.repeat_interleave(top_k, dim=0) + flat_idx = topk_idx.view(-1) + x_rep = x.repeat_interleave(top_k, dim=0) - # Cache stacked weights on experts (support Mixtral and Qwen-style layouts) - E = experts_module.num_experts - dev, dt = x.device, x.dtype - first = experts_module[0] + E = experts_module.num_experts + dev, dt = x.device, x.dtype + first = experts_module[0] - is_mixtral = _is_mixtral_layout(first) - is_qwen2 = _is_qwen_layout(first) - nested_attr: Optional[str] = None - if not (is_mixtral or is_qwen2): - for candidate in ("mlp", "ffn"): - nested = getattr(first, candidate, None) - if nested is None: - continue - if _is_mixtral_layout(nested): - is_mixtral = True - nested_attr = candidate - break - if _is_qwen_layout(nested): - is_qwen2 = True - nested_attr = candidate - break - if not (is_mixtral or is_qwen2): - if not getattr(experts_module, "_ax_grouped_logged_fail", False): - _LOGGER.warning( - "torch_grouped: unsupported expert layout; falling back to naive" - ) - experts_module._ax_grouped_logged_fail = True - return None, None - - def _resolve_expert(idx: int): - expert = experts_module[idx] - if nested_attr is None: - return expert - nested_mod = getattr(expert, nested_attr, None) - if nested_mod is None: - raise AttributeError( - f"expert {idx} missing nested module '{nested_attr}'" - ) - return nested_mod - - try: - if is_mixtral: - if ( - not hasattr(experts_module, "_stacked_w1") - or experts_module._stacked_w1.device != dev - or experts_module._stacked_w1.dtype != dt - ): - mods = [_resolve_expert(i) for i in range(E)] - w1 = [mod.w1.weight.t() for mod in mods] - w3 = [mod.w3.weight.t() for mod in mods] - w2 = [mod.w2.weight.t() for mod in mods] - experts_module._stacked_w1 = ( - torch.stack(w1, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w3 = ( - torch.stack(w3, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w2 = ( - torch.stack(w2, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w13 = torch.cat( - [experts_module._stacked_w1, experts_module._stacked_w3], dim=-1 - ).contiguous() - W13 = experts_module._stacked_w13 - W2 = experts_module._stacked_w2 - else: - # Qwen-style MoE: either gate_up_proj (2I x H) or (up_proj + gate_proj), down_proj (H x I) - if ( - not hasattr(experts_module, "_stacked_w13") - or experts_module._stacked_w13.device != dev - or experts_module._stacked_w13.dtype != dt - ): - w13 = [] - w2 = [] - for i in range(E): - mod = _resolve_expert(i) - # prefer fused gate_up_proj if present - if hasattr(mod, "gate_up_proj"): - w13.append(mod.gate_up_proj.weight.t()) - elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"): - # concatenate [up | gate] along N - w13.append( - torch.cat( - [mod.up_proj.weight.t(), mod.gate_proj.weight.t()], - dim=-1, - ) - ) - else: - LAST_ERROR = "unrecognized Qwen MoE expert weight layout" - if not getattr( - experts_module, "_ax_grouped_logged_fail", False - ): - _LOGGER.warning( - "torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive" - ) - experts_module._ax_grouped_logged_fail = True - return None, None - w2.append(mod.down_proj.weight.t()) - experts_module._stacked_w13 = ( - torch.stack(w13, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w2 = ( - torch.stack(w2, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - W13 = experts_module._stacked_w13 - W2 = experts_module._stacked_w2 - except AttributeError as err: - LAST_ERROR = str(err) - if not getattr(experts_module, "_ax_grouped_logged_fail", False): - _LOGGER.warning( - "torch_grouped: expert weights missing expected attributes; falling back to naive" - ) - experts_module._ax_grouped_logged_fail = True - return None, None - - # Grouped GEMM for up+gate - As: List[torch.Tensor] = [] - Bs: List[torch.Tensor] = [] - expert_slices = [] - for i in range(E): - sel = flat_idx == i - if sel.any(): - Xi = x_rep[sel] - As.append(Xi) - Bs.append(W13[i]) - expert_slices.append((i, sel)) - - if not As: - # no tokens routed — edge case - out = torch.zeros_like(x) - return out.view(bsz, seqlen, hdim), router_logits - - Y_list = _call_grouped_mm(As, Bs) - if Y_list is None: - if not getattr(experts_module, "_ax_grouped_logged_fail", False): - _LOGGER.warning( - f"torch_grouped: grouped_mm up+gate failed; falling back to naive. Reason: {LAST_ERROR}" - ) - experts_module._ax_grouped_logged_fail = True - return None, None - - # SwiGLU on each expert block and prepare for down projection - As2: List[torch.Tensor] = [] - Bs2: List[torch.Tensor] = [] - y_buf = torch.empty_like(x_rep) - - # split Y into (I, I) - for Yi in Y_list: - I2 = Yi.shape[-1] // 2 - Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:] - As2.append(Yi_hidden) - Bs2.append(W2[i]) - - Y2_list = _call_grouped_mm(As2, Bs2) - if Y2_list is None: - if not getattr(experts_module, "_ax_grouped_logged_fail", False): - _LOGGER.warning( - f"torch_grouped: grouped_mm down failed; falling back to naive. Reason: {LAST_ERROR}" - ) - experts_module._ax_grouped_logged_fail = True - return None, None - - # Write back, apply per-token weighting, and reduce over top_k - for (_, sel), Out_i in zip(expert_slices, Y2_list, strict=False): - y_buf[sel] = Out_i - y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) - if not getattr(experts_module, "_ax_grouped_logged_ok", False): - _LOGGER.info( - f"torch_grouped: engaged grouped GEMM path (experts={E}, top_k={top_k})" + is_mixtral = _is_mixtral_layout(first) + is_qwen2 = _is_qwen_layout(first) + nested_attr: Optional[str] = None + if not (is_mixtral or is_qwen2): + for candidate in ("mlp", "ffn"): + nested = getattr(first, candidate, None) + if nested is None: + continue + if _is_mixtral_layout(nested): + is_mixtral = True + nested_attr = candidate + break + if _is_qwen_layout(nested): + is_qwen2 = True + nested_attr = candidate + break + if not (is_mixtral or is_qwen2): + if not getattr(experts_module, "_ax_grouped_logged_fail", False): + _LOGGER.warning( + "torch_grouped: unsupported expert layout; falling back to naive" ) - experts_module._ax_grouped_logged_ok = True - return y.view(bsz, seqlen, hdim), router_logits - except Exception: + experts_module._ax_grouped_logged_fail = True + LAST_ERROR = "unsupported expert layout" return None, None + + def _resolve_expert(idx: int): + expert = experts_module[idx] + if nested_attr is None: + return expert + nested_mod = getattr(expert, nested_attr, None) + if nested_mod is None: + raise AttributeError(f"expert {idx} missing nested module '{nested_attr}'") + return nested_mod + + try: + if is_mixtral: + if ( + not hasattr(experts_module, "_stacked_w1") + or experts_module._stacked_w1.device != dev + or experts_module._stacked_w1.dtype != dt + ): + mods = [_resolve_expert(i) for i in range(E)] + w1 = [mod.w1.weight.t() for mod in mods] + w3 = [mod.w3.weight.t() for mod in mods] + w2 = [mod.w2.weight.t() for mod in mods] + experts_module._stacked_w1 = ( + torch.stack(w1, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w3 = ( + torch.stack(w3, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w2 = ( + torch.stack(w2, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w13 = torch.cat( + [experts_module._stacked_w1, experts_module._stacked_w3], dim=-1 + ).contiguous() + W13 = experts_module._stacked_w13 + W2 = experts_module._stacked_w2 + else: + if ( + not hasattr(experts_module, "_stacked_w13") + or experts_module._stacked_w13.device != dev + or experts_module._stacked_w13.dtype != dt + ): + w13 = [] + w2 = [] + for i in range(E): + mod = _resolve_expert(i) + if hasattr(mod, "gate_up_proj"): + w13.append(mod.gate_up_proj.weight.t()) + elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"): + w13.append( + torch.cat( + [mod.up_proj.weight.t(), mod.gate_proj.weight.t()], + dim=-1, + ) + ) + else: + LAST_ERROR = "unrecognized Qwen MoE expert weight layout" + if not getattr( + experts_module, "_ax_grouped_logged_fail", False + ): + _LOGGER.warning( + "torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive" + ) + experts_module._ax_grouped_logged_fail = True + return None, None + w2.append(mod.down_proj.weight.t()) + experts_module._stacked_w13 = ( + torch.stack(w13, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w2 = ( + torch.stack(w2, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + W13 = experts_module._stacked_w13 + W2 = experts_module._stacked_w2 + except AttributeError as err: + LAST_ERROR = str(err) + if not getattr(experts_module, "_ax_grouped_logged_fail", False): + _LOGGER.warning( + "torch_grouped: expert weights missing expected attributes; falling back to naive" + ) + experts_module._ax_grouped_logged_fail = True + return None, None + + As: List[torch.Tensor] = [] + Bs: List[torch.Tensor] = [] + expert_slices: List[Tuple[int, torch.Tensor]] = [] + for i in range(E): + sel = flat_idx == i + if sel.any(): + Xi = x_rep[sel] + As.append(Xi) + Bs.append(W13[i]) + expert_slices.append((i, sel)) + + if not As: + out = torch.zeros_like(x) + return out.view(bsz, seqlen, hdim), router_logits + + Y_list = _call_grouped_mm(As, Bs) + if Y_list is None: + if not getattr(experts_module, "_ax_grouped_logged_fail", False): + _LOGGER.warning( + f"torch_grouped: grouped_mm up+gate failed; falling back to naive. Reason: {LAST_ERROR}" + ) + experts_module._ax_grouped_logged_fail = True + return None, None + + As2: List[torch.Tensor] = [] + Bs2: List[torch.Tensor] = [] + y_buf = torch.empty_like(x_rep) + for (i, _sel), Yi in zip(expert_slices, Y_list, strict=False): + I2 = Yi.shape[-1] // 2 + Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:] + As2.append(Yi_hidden) + Bs2.append(W2[i]) + + Y2_list = _call_grouped_mm(As2, Bs2) + if Y2_list is None: + if not getattr(experts_module, "_ax_grouped_logged_fail", False): + _LOGGER.warning( + f"torch_grouped: grouped_mm down failed; falling back to naive. Reason: {LAST_ERROR}" + ) + experts_module._ax_grouped_logged_fail = True + return None, None + + for (_i, sel), Out_i in zip(expert_slices, Y2_list, strict=False): + y_buf[sel] = Out_i + y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + if not getattr(experts_module, "_ax_grouped_logged_ok", False): + _LOGGER.info( + f"torch_grouped: engaged grouped GEMM path (experts={E}, top_k={top_k})" + ) + experts_module._ax_grouped_logged_ok = True + return y.view(bsz, seqlen, hdim), router_logits diff --git a/src/axolotl/monkeypatch/moe_grouped.py b/src/axolotl/monkeypatch/moe_grouped.py index 01ac6e101..ab5bfba27 100644 --- a/src/axolotl/monkeypatch/moe_grouped.py +++ b/src/axolotl/monkeypatch/moe_grouped.py @@ -82,9 +82,18 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None: # One-time log per block instance indicating whether grouped engaged or fallback occurred if not getattr(self, "_ax_grouped_wrapper_logged", False): if y is None: - _LOG.warning( - f"Grouped wrapper active but fell back to naive for {self.__class__.__name__}" - ) + reason = getattr(_tg, "LAST_ERROR", None) + if reason: + _LOG.warning( + "Grouped wrapper fell back to naive for %s (reason=%s)", + self.__class__.__name__, + reason, + ) + else: + _LOG.warning( + "Grouped wrapper active but fell back to naive for %s", + self.__class__.__name__, + ) else: _LOG.info( f"Grouped wrapper engaged for {self.__class__.__name__} (top_k={self.top_k})"