This commit is contained in:
Dan Saunders
2025-09-17 14:26:25 -04:00
parent 51e565f60a
commit db61e0d4ff
2 changed files with 51 additions and 1 deletions

View File

@@ -42,6 +42,17 @@ def _is_qwen_layout(mod: torch.nn.Module) -> bool:
return (has_fused or has_split) and hasattr(mod, "down_proj")
def _num_experts(module: torch.nn.Module) -> int:
"""Return expert count, supporting ModuleList-style inputs."""
count = getattr(module, "num_experts", None)
if count is not None:
return int(count() if callable(count) else count)
try:
return len(module) # type: ignore[arg-type]
except TypeError as exc: # pragma: no cover - defensive
raise AttributeError("experts module missing num_experts/len support") from exc
def _call_grouped_mm(
As: List[torch.Tensor], Bs: List[torch.Tensor]
) -> Optional[List[torch.Tensor]]:
@@ -103,7 +114,16 @@ def moe_ffn_forward_grouped(
flat_idx = topk_idx.view(-1)
x_rep = x.repeat_interleave(top_k, dim=0)
E = experts_module.num_experts
try:
E = _num_experts(experts_module)
except AttributeError as err:
LAST_ERROR = str(err)
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
_LOGGER.warning(
"torch_grouped: could not determine expert count; falling back to naive"
)
experts_module._ax_grouped_logged_fail = True
return None, None
dev, dt = x.device, x.dtype
first = experts_module[0]