fix einsum dims

This commit is contained in:
Wing Lian
2026-04-20 23:09:47 +00:00
parent 4a5281e61a
commit 37acb28d02

View File

@@ -491,7 +491,8 @@ class TestEfficientMerge:
out_features = 4
alpha = 4
base = torch.randn(num_experts, in_features, out_features)
# PEFT ParamWrapper treats non-transposed 3D weights as (experts, out, in)
base = torch.randn(num_experts, out_features, in_features)
lora_a = torch.randn(r * num_experts, in_features)
lora_b = torch.randn(out_features, r * num_experts)
@@ -507,7 +508,7 @@ class TestEfficientMerge:
scale = alpha / r
wa = lora_a.reshape(num_experts, r, in_features)
wb = lora_b.reshape(out_features, r, num_experts)
manual_delta = torch.einsum("o r e, e r i -> e i o", wb, wa) * scale
manual_delta = torch.einsum("o r e, e r i -> e o i", wb, wa) * scale
for e in range(num_experts):
assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), (
f"Expert {e} mismatch"