This commit is contained in:
Dan Saunders
2025-09-17 14:26:25 -04:00
parent 51e565f60a
commit db61e0d4ff
2 changed files with 51 additions and 1 deletions

View File

@@ -100,6 +100,36 @@ def test_grouped_uses_per_expert_nested_modules(monkeypatch):
assert not torch.equal(first_call[0], first_call[1])
def test_grouped_accepts_module_list_experts(monkeypatch):
hidden = 4
intermediate = 2
experts = nn.ModuleList(
[DummyQwenExpert(i, hidden, intermediate) for i in range(2)]
)
gate = nn.Linear(hidden, len(experts), bias=False)
nn.init.zeros_(gate.weight)
calls = {"count": 0}
def fake_grouped_mm(As, Bs):
calls["count"] += 1
return [
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
for a, b in zip(As, Bs, strict=False)
]
monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm)
hidden_states = torch.randn(1, 2, hidden)
y, router_logits = torch_grouped_module.moe_ffn_forward_grouped(
hidden_states, gate, experts, top_k=2
)
assert y is not None
assert router_logits is not None
assert calls["count"] > 0
class _DummyCfg:
moe_backend = "torch_grouped"