diff --git a/tests/utils/lora/test_merge_lora.py b/tests/utils/lora/test_merge_lora.py index b66ee8bf4..ea20c391f 100644 --- a/tests/utils/lora/test_merge_lora.py +++ b/tests/utils/lora/test_merge_lora.py @@ -491,7 +491,8 @@ class TestEfficientMerge: out_features = 4 alpha = 4 - base = torch.randn(num_experts, in_features, out_features) + # PEFT ParamWrapper treats non-transposed 3D weights as (experts, out, in) + base = torch.randn(num_experts, out_features, in_features) lora_a = torch.randn(r * num_experts, in_features) lora_b = torch.randn(out_features, r * num_experts) @@ -507,7 +508,7 @@ class TestEfficientMerge: scale = alpha / r wa = lora_a.reshape(num_experts, r, in_features) wb = lora_b.reshape(out_features, r, num_experts) - manual_delta = torch.einsum("o r e, e r i -> e i o", wb, wa) * scale + manual_delta = torch.einsum("o r e, e r i -> e o i", wb, wa) * scale for e in range(num_experts): assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), ( f"Expert {e} mismatch"