fix
This commit is contained in:
@@ -42,6 +42,17 @@ def _is_qwen_layout(mod: torch.nn.Module) -> bool:
|
||||
return (has_fused or has_split) and hasattr(mod, "down_proj")
|
||||
|
||||
|
||||
def _num_experts(module: torch.nn.Module) -> int:
|
||||
"""Return expert count, supporting ModuleList-style inputs."""
|
||||
count = getattr(module, "num_experts", None)
|
||||
if count is not None:
|
||||
return int(count() if callable(count) else count)
|
||||
try:
|
||||
return len(module) # type: ignore[arg-type]
|
||||
except TypeError as exc: # pragma: no cover - defensive
|
||||
raise AttributeError("experts module missing num_experts/len support") from exc
|
||||
|
||||
|
||||
def _call_grouped_mm(
|
||||
As: List[torch.Tensor], Bs: List[torch.Tensor]
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
@@ -103,7 +114,16 @@ def moe_ffn_forward_grouped(
|
||||
flat_idx = topk_idx.view(-1)
|
||||
x_rep = x.repeat_interleave(top_k, dim=0)
|
||||
|
||||
E = experts_module.num_experts
|
||||
try:
|
||||
E = _num_experts(experts_module)
|
||||
except AttributeError as err:
|
||||
LAST_ERROR = str(err)
|
||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||
_LOGGER.warning(
|
||||
"torch_grouped: could not determine expert count; falling back to naive"
|
||||
)
|
||||
experts_module._ax_grouped_logged_fail = True
|
||||
return None, None
|
||||
dev, dt = x.device, x.dtype
|
||||
first = experts_module[0]
|
||||
|
||||
|
||||
@@ -100,6 +100,36 @@ def test_grouped_uses_per_expert_nested_modules(monkeypatch):
|
||||
assert not torch.equal(first_call[0], first_call[1])
|
||||
|
||||
|
||||
def test_grouped_accepts_module_list_experts(monkeypatch):
|
||||
hidden = 4
|
||||
intermediate = 2
|
||||
experts = nn.ModuleList(
|
||||
[DummyQwenExpert(i, hidden, intermediate) for i in range(2)]
|
||||
)
|
||||
gate = nn.Linear(hidden, len(experts), bias=False)
|
||||
nn.init.zeros_(gate.weight)
|
||||
|
||||
calls = {"count": 0}
|
||||
|
||||
def fake_grouped_mm(As, Bs):
|
||||
calls["count"] += 1
|
||||
return [
|
||||
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
|
||||
for a, b in zip(As, Bs, strict=False)
|
||||
]
|
||||
|
||||
monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm)
|
||||
|
||||
hidden_states = torch.randn(1, 2, hidden)
|
||||
y, router_logits = torch_grouped_module.moe_ffn_forward_grouped(
|
||||
hidden_states, gate, experts, top_k=2
|
||||
)
|
||||
|
||||
assert y is not None
|
||||
assert router_logits is not None
|
||||
assert calls["count"] > 0
|
||||
|
||||
|
||||
class _DummyCfg:
|
||||
moe_backend = "torch_grouped"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user