From 37acb28d02add5000e7bcf088d74c8ecfa72a330 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 20 Apr 2026 23:09:47 +0000 Subject: [PATCH] fix einsum dims --- tests/utils/lora/test_merge_lora.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils/lora/test_merge_lora.py b/tests/utils/lora/test_merge_lora.py index b66ee8bf4..ea20c391f 100644 --- a/tests/utils/lora/test_merge_lora.py +++ b/tests/utils/lora/test_merge_lora.py @@ -491,7 +491,8 @@ class TestEfficientMerge: out_features = 4 alpha = 4 - base = torch.randn(num_experts, in_features, out_features) + # PEFT ParamWrapper treats non-transposed 3D weights as (experts, out, in) + base = torch.randn(num_experts, out_features, in_features) lora_a = torch.randn(r * num_experts, in_features) lora_b = torch.randn(out_features, r * num_experts) @@ -507,7 +508,7 @@ class TestEfficientMerge: scale = alpha / r wa = lora_a.reshape(num_experts, r, in_features) wb = lora_b.reshape(out_features, r, num_experts) - manual_delta = torch.einsum("o r e, e r i -> e i o", wb, wa) * scale + manual_delta = torch.einsum("o r e, e r i -> e o i", wb, wa) * scale for e in range(num_experts): assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), ( f"Expert {e} mismatch"