From d2f1e23bcddf8c058060df8aaef52b0a14919dcc Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 19 Sep 2025 12:45:18 -0400 Subject: [PATCH] fix --- src/axolotl/kernels/moe/torch_grouped.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 2f4d8c6cb..baceb31d4 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -27,7 +27,16 @@ def available() -> bool: return False -def _iter_expert_impls(experts_module) -> List[torch.nn.Module]: +def _iter_expert_impls( + experts_module, visited: Optional[set[int]] = None +) -> List[torch.nn.Module]: + if visited is None: + visited = set() + module_id = id(experts_module) + if module_id in visited: + return [] + visited.add(module_id) + impls: List[torch.nn.Module] = [] for exp in experts_module: candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp)) @@ -36,7 +45,7 @@ def _iter_expert_impls(experts_module) -> List[torch.nn.Module]: continue nested = getattr(candidate, "experts", None) if nested is not None: - impls.extend(_iter_expert_impls(nested)) + impls.extend(_iter_expert_impls(nested, visited)) continue raise RuntimeError( "torch_grouped: unable to resolve expert implementation for module"