From 7289e0cb55731d26cb3c15516b5c3dcbd731839d Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 16 Sep 2025 00:26:39 -0400 Subject: [PATCH] more logs --- src/axolotl/kernels/moe/torch_grouped.py | 52 ++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 736cd7e9b..7b356f834 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -103,7 +103,26 @@ def moe_ffn_forward_grouped( is_mixtral = ( hasattr(first, "w1") and hasattr(first, "w3") and hasattr(first, "w2") ) - is_qwen2 = hasattr(first, "gate_up_proj") and hasattr(first, "down_proj") + is_qwen2 = ( + hasattr(first, "gate_up_proj") + or hasattr(first, "gate_proj") + or hasattr(first, "up_proj") + ) and hasattr(first, "down_proj") + # try nested mlp/ffn module + nested = None + if not (is_mixtral or is_qwen2): + nested = getattr(first, "mlp", None) or getattr(first, "ffn", None) + if nested is not None: + is_mixtral = ( + hasattr(nested, "w1") + and hasattr(nested, "w3") + and hasattr(nested, "w2") + ) + is_qwen2 = ( + hasattr(nested, "gate_up_proj") + or hasattr(nested, "gate_proj") + or hasattr(nested, "up_proj") + ) and hasattr(nested, "down_proj") if not (is_mixtral or is_qwen2): if not getattr(experts_module, "_ax_grouped_logged_fail", False): _LOGGER.warning( @@ -142,14 +161,39 @@ def moe_ffn_forward_grouped( W13 = experts_module._stacked_w13 W2 = experts_module._stacked_w2 else: - # Qwen2-MoE style: gate_up_proj (2I x H), down_proj (H x I) + # Qwen2/3 MoE style: either gate_up_proj (2I x H) or (up_proj + gate_proj), down_proj (H x I) if ( not hasattr(experts_module, "_stacked_w13") or experts_module._stacked_w13.device != dev or experts_module._stacked_w13.dtype != dt ): - w13 = [experts_module[i].gate_up_proj.weight.t() for i in range(E)] - w2 = [experts_module[i].down_proj.weight.t() for i in range(E)] + w13 = [] + w2 = [] + for i in range(E): + exp = experts_module[i] + mod = nested if nested is not None else exp + # prefer fused gate_up_proj if present + if hasattr(mod, "gate_up_proj"): + w13.append(mod.gate_up_proj.weight.t()) + elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"): + # concatenate [up | gate] along N + w13.append( + torch.cat( + [mod.up_proj.weight.t(), mod.gate_proj.weight.t()], + dim=-1, + ) + ) + else: + LAST_ERROR = "unrecognized Qwen MoE expert weight layout" + if not getattr( + experts_module, "_ax_grouped_logged_fail", False + ): + _LOGGER.warning( + "torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive" + ) + experts_module._ax_grouped_logged_fail = True + return None, None + w2.append((mod.down_proj.weight.t())) experts_module._stacked_w13 = ( torch.stack(w13, dim=0) .to(device=dev, dtype=dt, non_blocking=True)