fix einsum dims
This commit is contained in:
@@ -491,7 +491,8 @@ class TestEfficientMerge:
|
||||
out_features = 4
|
||||
alpha = 4
|
||||
|
||||
base = torch.randn(num_experts, in_features, out_features)
|
||||
# PEFT ParamWrapper treats non-transposed 3D weights as (experts, out, in)
|
||||
base = torch.randn(num_experts, out_features, in_features)
|
||||
lora_a = torch.randn(r * num_experts, in_features)
|
||||
lora_b = torch.randn(out_features, r * num_experts)
|
||||
|
||||
@@ -507,7 +508,7 @@ class TestEfficientMerge:
|
||||
scale = alpha / r
|
||||
wa = lora_a.reshape(num_experts, r, in_features)
|
||||
wb = lora_b.reshape(out_features, r, num_experts)
|
||||
manual_delta = torch.einsum("o r e, e r i -> e i o", wb, wa) * scale
|
||||
manual_delta = torch.einsum("o r e, e r i -> e o i", wb, wa) * scale
|
||||
for e in range(num_experts):
|
||||
assert torch.allclose(merged[e], base[e] + manual_delta[e], atol=1e-5), (
|
||||
f"Expert {e} mismatch"
|
||||
|
||||
Reference in New Issue
Block a user