From db61e0d4ff22de211a4cdd4dfd638acedcaea6da Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 17 Sep 2025 14:26:25 -0400 Subject: [PATCH] fix --- src/axolotl/kernels/moe/torch_grouped.py | 22 ++++++++++++++++- tests/monkeypatch/test_moe_grouped.py | 30 ++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index db285b1cc..9453ad437 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -42,6 +42,17 @@ def _is_qwen_layout(mod: torch.nn.Module) -> bool: return (has_fused or has_split) and hasattr(mod, "down_proj") +def _num_experts(module: torch.nn.Module) -> int: + """Return expert count, supporting ModuleList-style inputs.""" + count = getattr(module, "num_experts", None) + if count is not None: + return int(count() if callable(count) else count) + try: + return len(module) # type: ignore[arg-type] + except TypeError as exc: # pragma: no cover - defensive + raise AttributeError("experts module missing num_experts/len support") from exc + + def _call_grouped_mm( As: List[torch.Tensor], Bs: List[torch.Tensor] ) -> Optional[List[torch.Tensor]]: @@ -103,7 +114,16 @@ def moe_ffn_forward_grouped( flat_idx = topk_idx.view(-1) x_rep = x.repeat_interleave(top_k, dim=0) - E = experts_module.num_experts + try: + E = _num_experts(experts_module) + except AttributeError as err: + LAST_ERROR = str(err) + if not getattr(experts_module, "_ax_grouped_logged_fail", False): + _LOGGER.warning( + "torch_grouped: could not determine expert count; falling back to naive" + ) + experts_module._ax_grouped_logged_fail = True + return None, None dev, dt = x.device, x.dtype first = experts_module[0] diff --git a/tests/monkeypatch/test_moe_grouped.py b/tests/monkeypatch/test_moe_grouped.py index 9b409e344..f668f2886 100644 --- a/tests/monkeypatch/test_moe_grouped.py +++ b/tests/monkeypatch/test_moe_grouped.py @@ -100,6 +100,36 @@ def test_grouped_uses_per_expert_nested_modules(monkeypatch): assert not torch.equal(first_call[0], first_call[1]) +def test_grouped_accepts_module_list_experts(monkeypatch): + hidden = 4 + intermediate = 2 + experts = nn.ModuleList( + [DummyQwenExpert(i, hidden, intermediate) for i in range(2)] + ) + gate = nn.Linear(hidden, len(experts), bias=False) + nn.init.zeros_(gate.weight) + + calls = {"count": 0} + + def fake_grouped_mm(As, Bs): + calls["count"] += 1 + return [ + torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype) + for a, b in zip(As, Bs, strict=False) + ] + + monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm) + + hidden_states = torch.randn(1, 2, hidden) + y, router_logits = torch_grouped_module.moe_ffn_forward_grouped( + hidden_states, gate, experts, top_k=2 + ) + + assert y is not None + assert router_logits is not None + assert calls["count"] > 0 + + class _DummyCfg: moe_backend = "torch_grouped"