fixes for scattermoe from latest peft upgrade
This commit is contained in:
@@ -21,6 +21,51 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
class TestPeftScatterMoELoRALayout:
|
||||
"""CPU-only guards for PEFT target_parameters layout conversion."""
|
||||
|
||||
def test_peft_layout_keeps_a_and_reorders_b(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
|
||||
peft_lora_to_scattermoe,
|
||||
)
|
||||
|
||||
E, r, K, N = 3, 2, 5, 7
|
||||
scaling = 2.0
|
||||
peft_A = torch.randn(E * r, K)
|
||||
peft_B = torch.randn(N, E * r)
|
||||
|
||||
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||
|
||||
assert smoe_A is peft_A
|
||||
assert smoe_A.shape == (E * r, K)
|
||||
assert smoe_B.shape == (N, E * r)
|
||||
|
||||
A_r = peft_A.reshape(E, r, K)
|
||||
B_r = peft_B.reshape(N, r, E)
|
||||
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||
|
||||
for e in range(E):
|
||||
A_e = smoe_A[e * r : (e + 1) * r, :]
|
||||
B_e = smoe_B[:, e * r : (e + 1) * r]
|
||||
torch.testing.assert_close(scaling * (B_e @ A_e), delta_peft[e])
|
||||
|
||||
def test_swapped_layout_fails_before_kernel_dispatch(self):
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
|
||||
validate_scattermoe_lora_shapes,
|
||||
)
|
||||
|
||||
E, r, K, N = 3, 2, 5, 7
|
||||
expert_weights = torch.empty(E, K, N)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid ScatterMoE LoRA layout"):
|
||||
validate_scattermoe_lora_shapes(
|
||||
expert_weights=expert_weights,
|
||||
lora_A=torch.empty(E * r, N),
|
||||
lora_B=torch.empty(K, E * r),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 1. KernelsArgs: disable_mlp_kernel validator
|
||||
# ============================================================================
|
||||
|
||||
Reference in New Issue
Block a user