recurse fix
This commit is contained in:
@@ -286,7 +286,14 @@ def moe_ffn_forward_grouped(
|
|||||||
)
|
)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
parent_block = getattr(experts_module, "_ax_parent_block", None)
|
parent_block = None
|
||||||
|
parent_ref = getattr(experts_module, "_ax_parent_block_ref", None)
|
||||||
|
if parent_ref is not None:
|
||||||
|
try:
|
||||||
|
parent_block = parent_ref()
|
||||||
|
except TypeError:
|
||||||
|
parent_block = None
|
||||||
|
|
||||||
expert_container = getattr(experts_module, "experts", experts_module)
|
expert_container = getattr(experts_module, "experts", experts_module)
|
||||||
|
|
||||||
expert_impls = _iter_expert_impls(expert_container)
|
expert_impls = _iter_expert_impls(expert_container)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import weakref
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -78,7 +79,7 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
|||||||
bsz, seqlen, hdim = hidden_states.shape
|
bsz, seqlen, hdim = hidden_states.shape
|
||||||
# expose parent block so grouped backend can access shared expert context
|
# expose parent block so grouped backend can access shared expert context
|
||||||
try:
|
try:
|
||||||
self.experts._ax_parent_block = self
|
self.experts._ax_parent_block_ref = weakref.ref(self)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
y, router_logits = _tg.moe_ffn_forward_grouped(
|
y, router_logits = _tg.moe_ffn_forward_grouped(
|
||||||
|
|||||||
Reference in New Issue
Block a user