This commit is contained in:
Dan Saunders
2025-09-19 12:45:18 -04:00
parent 42aadc5069
commit d2f1e23bcd

View File

@@ -27,7 +27,16 @@ def available() -> bool:
return False
def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
def _iter_expert_impls(
experts_module, visited: Optional[set[int]] = None
) -> List[torch.nn.Module]:
if visited is None:
visited = set()
module_id = id(experts_module)
if module_id in visited:
return []
visited.add(module_id)
impls: List[torch.nn.Module] = []
for exp in experts_module:
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
@@ -36,7 +45,7 @@ def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
continue
nested = getattr(candidate, "experts", None)
if nested is not None:
impls.extend(_iter_expert_impls(nested))
impls.extend(_iter_expert_impls(nested, visited))
continue
raise RuntimeError(
"torch_grouped: unable to resolve expert implementation for module"