diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 2f4d8c6cb..baceb31d4 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -27,7 +27,16 @@ def available() -> bool: return False -def _iter_expert_impls(experts_module) -> List[torch.nn.Module]: +def _iter_expert_impls( + experts_module, visited: Optional[set[int]] = None +) -> List[torch.nn.Module]: + if visited is None: + visited = set() + module_id = id(experts_module) + if module_id in visited: + return [] + visited.add(module_id) + impls: List[torch.nn.Module] = [] for exp in experts_module: candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp)) @@ -36,7 +45,7 @@ def _iter_expert_impls(experts_module) -> List[torch.nn.Module]: continue nested = getattr(candidate, "experts", None) if nested is not None: - impls.extend(_iter_expert_impls(nested)) + impls.extend(_iter_expert_impls(nested, visited)) continue raise RuntimeError( "torch_grouped: unable to resolve expert implementation for module"