recurse fix

This commit is contained in:
Dan Saunders
2025-09-19 12:58:58 -04:00
parent 0f8b921399
commit 64345e7707
2 changed files with 10 additions and 2 deletions

View File

@@ -286,7 +286,14 @@ def moe_ffn_forward_grouped(
)
return None, None
parent_block = getattr(experts_module, "_ax_parent_block", None)
parent_block = None
parent_ref = getattr(experts_module, "_ax_parent_block_ref", None)
if parent_ref is not None:
try:
parent_block = parent_ref()
except TypeError:
parent_block = None
expert_container = getattr(experts_module, "experts", experts_module)
expert_impls = _iter_expert_impls(expert_container)

View File

@@ -1,4 +1,5 @@
import logging
import weakref
from functools import wraps
import torch
@@ -78,7 +79,7 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
bsz, seqlen, hdim = hidden_states.shape
# expose parent block so grouped backend can access shared expert context
try:
self.experts._ax_parent_block = self
self.experts._ax_parent_block_ref = weakref.ref(self)
except Exception:
pass
y, router_logits = _tg.moe_ffn_forward_grouped(