fixes for scattermoe from latest peft upgrade
This commit is contained in:
@@ -54,25 +54,7 @@ except (ImportError, ModuleNotFoundError):
|
||||
)
|
||||
|
||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
K_inter, N_hidden = peft_B.shape[0], peft_A.shape[1]
|
||||
smoe_A = torch.zeros(
|
||||
rank * num_experts,
|
||||
K_inter,
|
||||
device=peft_A.device,
|
||||
dtype=peft_A.dtype,
|
||||
)
|
||||
smoe_B = torch.zeros(
|
||||
N_hidden,
|
||||
rank * num_experts,
|
||||
device=peft_A.device,
|
||||
dtype=peft_A.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
s = e * rank
|
||||
smoe_A[s : s + rank, :] = peft_B_em[:, s : s + rank].T
|
||||
smoe_B[:, s : s + rank] = peft_A[s : s + rank, :].T
|
||||
return smoe_A, smoe_B
|
||||
return peft_A, peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
|
||||
def _unwrap_experts_lora(experts_module):
|
||||
return experts_module, None, None
|
||||
@@ -145,11 +127,7 @@ def scattermoe_lora_B_to_peft(smoe_B, num_experts, rank):
|
||||
|
||||
|
||||
def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
"""Convert peft LoRA for gate_up_proj to scattermoe layout.
|
||||
|
||||
Both gate_up_proj and down_proj need the A<->B swap because
|
||||
scattermoe transposes the parameter (W = param.T).
|
||||
"""
|
||||
"""Convert peft LoRA for gate_up_proj to scattermoe layout."""
|
||||
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
||||
|
||||
|
||||
@@ -322,14 +300,16 @@ class TestLoRABLayoutConversion:
|
||||
hidden, inter = 32, 16
|
||||
scaling = 2.0
|
||||
|
||||
peft_A = torch.randn(E * r, hidden)
|
||||
peft_B = torch.randn(inter, E * r)
|
||||
peft_A = torch.randn(E * r, inter)
|
||||
peft_B = torch.randn(hidden, E * r)
|
||||
|
||||
A_r = peft_A.reshape(E, r, hidden)
|
||||
B_r = peft_B.reshape(inter, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
|
||||
A_r = peft_A.reshape(E, r, inter)
|
||||
B_r = peft_B.reshape(hidden, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||
|
||||
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||
assert smoe_A.shape == (E * r, inter)
|
||||
assert smoe_B.shape == (hidden, E * r)
|
||||
for e in range(E):
|
||||
A_e = smoe_A[e * r : (e + 1) * r, :]
|
||||
B_e = smoe_B[:, e * r : (e + 1) * r]
|
||||
@@ -342,27 +322,26 @@ class TestLoRABLayoutConversion:
|
||||
"""Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like).
|
||||
|
||||
gate_up_proj param: [E, 2*inter, hidden].
|
||||
peft: in_features=2*inter, out_features=hidden.
|
||||
peft lora_A: [r*E, 2*inter], lora_B: [hidden, r*E].
|
||||
peft: in_features=hidden, out_features=2*inter.
|
||||
peft lora_A: [r*E, hidden], lora_B: [2*inter, r*E].
|
||||
|
||||
scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter.
|
||||
scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].
|
||||
|
||||
Uses non-square dims (hidden=32 != 2*inter=24) to catch A<->B swap bugs.
|
||||
Uses non-square dims (hidden=32 != 2*inter=24) to catch layout bugs.
|
||||
"""
|
||||
E, r = 4, 2
|
||||
hidden, inter = 32, 12 # 2*inter=24 != hidden=32
|
||||
scaling = 2.0
|
||||
|
||||
# peft assigns: in_features=2*inter, out_features=hidden
|
||||
peft_A = torch.randn(E * r, 2 * inter) # [r*E, in_features=2*inter]
|
||||
peft_B = torch.randn(hidden, E * r) # [out_features=hidden, r*E]
|
||||
# peft assigns: in_features=hidden, out_features=2*inter
|
||||
peft_A = torch.randn(E * r, hidden) # [r*E, in_features=hidden]
|
||||
peft_B = torch.randn(2 * inter, E * r) # [out_features=2*inter, r*E]
|
||||
|
||||
# peft delta via einsum: "o r e, e r i -> e i o"
|
||||
A_r = peft_A.reshape(E, r, 2 * inter)
|
||||
B_r = peft_B.reshape(hidden, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
|
||||
# delta_peft[e] has shape [in_features, out_features] = [2*inter, hidden]
|
||||
A_r = peft_A.reshape(E, r, hidden)
|
||||
B_r = peft_B.reshape(2 * inter, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||
# delta_peft[e] has shape [out_features, in_features] = [2*inter, hidden]
|
||||
# = param[e] shape [2*inter, hidden]
|
||||
|
||||
smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||
@@ -488,29 +467,29 @@ class TestPeftLoRAWeightExtraction:
|
||||
assert gup_lora is not None, "gate_up_proj LoRA not detected"
|
||||
assert down_lora is not None, "down_proj LoRA not detected"
|
||||
|
||||
# Check shapes (after peft->scattermoe conversion with A<->B swap)
|
||||
# Check shapes after peft->scattermoe conversion.
|
||||
# gate_up_proj: peft A [E*r, hidden] / B [2*inter, E*r]
|
||||
# After swap: smoe_A [E*r, 2*inter], smoe_B [hidden, E*r]
|
||||
# scattermoe: smoe_A [E*r, hidden], smoe_B [2*inter, E*r]
|
||||
E, r = config.num_experts, 4
|
||||
gup_A, gup_B, gup_s = gup_lora
|
||||
assert gup_A.shape == (E * r, 2 * config.intermediate_size), (
|
||||
f"gate_up_proj smoe_A: expected [r*E, 2*inter]={(E * r, 2 * config.intermediate_size)}, "
|
||||
assert gup_A.shape == (E * r, config.hidden_size), (
|
||||
f"gate_up_proj smoe_A: expected [r*E, hidden]={(E * r, config.hidden_size)}, "
|
||||
f"got {gup_A.shape}"
|
||||
)
|
||||
assert gup_B.shape == (config.hidden_size, E * r), (
|
||||
f"gate_up_proj smoe_B: expected [hidden, r*E]="
|
||||
f"{(config.hidden_size, E * r)}, got {gup_B.shape}"
|
||||
assert gup_B.shape == (2 * config.intermediate_size, E * r), (
|
||||
f"gate_up_proj smoe_B: expected [2*inter, r*E]="
|
||||
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
|
||||
)
|
||||
|
||||
# down_proj: peft A [E*r, inter] / B [hidden, E*r]
|
||||
# After swap: smoe_A [E*r, hidden], smoe_B [inter, E*r]
|
||||
# scattermoe: smoe_A [E*r, inter], smoe_B [hidden, E*r]
|
||||
down_A, down_B, down_s = down_lora
|
||||
assert down_A.shape == (E * r, config.hidden_size), (
|
||||
f"down_proj smoe_A: expected [r*E, hidden]={(E * r, config.hidden_size)}, "
|
||||
assert down_A.shape == (E * r, config.intermediate_size), (
|
||||
f"down_proj smoe_A: expected [r*E, inter]={(E * r, config.intermediate_size)}, "
|
||||
f"got {down_A.shape}"
|
||||
)
|
||||
assert down_B.shape == (config.intermediate_size, E * r), (
|
||||
f"down_proj smoe_B: expected [inter, r*E]={(config.intermediate_size, E * r)}, "
|
||||
assert down_B.shape == (config.hidden_size, E * r), (
|
||||
f"down_proj smoe_B: expected [hidden, r*E]={(config.hidden_size, E * r)}, "
|
||||
f"got {down_B.shape}"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user