This commit is contained in:
Dan Saunders
2025-09-17 16:42:35 -04:00
parent 129db67705
commit fd87eed501
5 changed files with 91 additions and 297 deletions

View File

@@ -75,7 +75,7 @@ def test_grouped_uses_per_expert_nested_modules(monkeypatch):
captured = []
def fake_grouped_mm(As, Bs):
def fake_grouped_mm(As, Bs, dtype):
captured.append([b.detach().clone() for b in Bs])
return [
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
@@ -111,7 +111,7 @@ def test_grouped_accepts_module_list_experts(monkeypatch):
calls = {"count": 0}
def fake_grouped_mm(As, Bs):
def fake_grouped_mm(As, Bs, dtype):
calls["count"] += 1
return [
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)