fix
This commit is contained in:
@@ -27,7 +27,16 @@ def available() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
|
||||
def _iter_expert_impls(
|
||||
experts_module, visited: Optional[set[int]] = None
|
||||
) -> List[torch.nn.Module]:
|
||||
if visited is None:
|
||||
visited = set()
|
||||
module_id = id(experts_module)
|
||||
if module_id in visited:
|
||||
return []
|
||||
visited.add(module_id)
|
||||
|
||||
impls: List[torch.nn.Module] = []
|
||||
for exp in experts_module:
|
||||
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
|
||||
@@ -36,7 +45,7 @@ def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
|
||||
continue
|
||||
nested = getattr(candidate, "experts", None)
|
||||
if nested is not None:
|
||||
impls.extend(_iter_expert_impls(nested))
|
||||
impls.extend(_iter_expert_impls(nested, visited))
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"torch_grouped: unable to resolve expert implementation for module"
|
||||
|
||||
Reference in New Issue
Block a user