minify
This commit is contained in:
@@ -75,7 +75,7 @@ def test_grouped_uses_per_expert_nested_modules(monkeypatch):
|
||||
|
||||
captured = []
|
||||
|
||||
def fake_grouped_mm(As, Bs):
|
||||
def fake_grouped_mm(As, Bs, dtype):
|
||||
captured.append([b.detach().clone() for b in Bs])
|
||||
return [
|
||||
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
|
||||
@@ -111,7 +111,7 @@ def test_grouped_accepts_module_list_experts(monkeypatch):
|
||||
|
||||
calls = {"count": 0}
|
||||
|
||||
def fake_grouped_mm(As, Bs):
|
||||
def fake_grouped_mm(As, Bs, dtype):
|
||||
calls["count"] += 1
|
||||
return [
|
||||
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
|
||||
|
||||
Reference in New Issue
Block a user