fix
This commit is contained in:
@@ -27,7 +27,16 @@ def available() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
|
def _iter_expert_impls(
|
||||||
|
experts_module, visited: Optional[set[int]] = None
|
||||||
|
) -> List[torch.nn.Module]:
|
||||||
|
if visited is None:
|
||||||
|
visited = set()
|
||||||
|
module_id = id(experts_module)
|
||||||
|
if module_id in visited:
|
||||||
|
return []
|
||||||
|
visited.add(module_id)
|
||||||
|
|
||||||
impls: List[torch.nn.Module] = []
|
impls: List[torch.nn.Module] = []
|
||||||
for exp in experts_module:
|
for exp in experts_module:
|
||||||
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
|
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
|
||||||
@@ -36,7 +45,7 @@ def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
|
|||||||
continue
|
continue
|
||||||
nested = getattr(candidate, "experts", None)
|
nested = getattr(candidate, "experts", None)
|
||||||
if nested is not None:
|
if nested is not None:
|
||||||
impls.extend(_iter_expert_impls(nested))
|
impls.extend(_iter_expert_impls(nested, visited))
|
||||||
continue
|
continue
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"torch_grouped: unable to resolve expert implementation for module"
|
"torch_grouped: unable to resolve expert implementation for module"
|
||||||
|
|||||||
Reference in New Issue
Block a user