recurse fix
This commit is contained in:
@@ -286,7 +286,14 @@ def moe_ffn_forward_grouped(
|
||||
)
|
||||
return None, None
|
||||
|
||||
parent_block = getattr(experts_module, "_ax_parent_block", None)
|
||||
parent_block = None
|
||||
parent_ref = getattr(experts_module, "_ax_parent_block_ref", None)
|
||||
if parent_ref is not None:
|
||||
try:
|
||||
parent_block = parent_ref()
|
||||
except TypeError:
|
||||
parent_block = None
|
||||
|
||||
expert_container = getattr(experts_module, "experts", experts_module)
|
||||
|
||||
expert_impls = _iter_expert_impls(expert_container)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import weakref
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
@@ -78,7 +79,7 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
# expose parent block so grouped backend can access shared expert context
|
||||
try:
|
||||
self.experts._ax_parent_block = self
|
||||
self.experts._ax_parent_block_ref = weakref.ref(self)
|
||||
except Exception:
|
||||
pass
|
||||
y, router_logits = _tg.moe_ffn_forward_grouped(
|
||||
|
||||
Reference in New Issue
Block a user