fixes for scattermoe from latest peft upgrade

This commit is contained in:
Wing Lian
2026-04-21 08:00:16 -04:00
parent 4195605ab2
commit 02e4f2350d
5 changed files with 119 additions and 142 deletions

View File

@@ -21,6 +21,51 @@ from unittest.mock import patch
import pytest
import torch
class TestPeftScatterMoELoRALayout:
"""CPU-only guards for PEFT target_parameters layout conversion."""
def test_peft_layout_keeps_a_and_reorders_b(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
peft_lora_to_scattermoe,
)
E, r, K, N = 3, 2, 5, 7
scaling = 2.0
peft_A = torch.randn(E * r, K)
peft_B = torch.randn(N, E * r)
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
assert smoe_A is peft_A
assert smoe_A.shape == (E * r, K)
assert smoe_B.shape == (N, E * r)
A_r = peft_A.reshape(E, r, K)
B_r = peft_B.reshape(N, r, E)
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
for e in range(E):
A_e = smoe_A[e * r : (e + 1) * r, :]
B_e = smoe_B[:, e * r : (e + 1) * r]
torch.testing.assert_close(scaling * (B_e @ A_e), delta_peft[e])
def test_swapped_layout_fails_before_kernel_dispatch(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
validate_scattermoe_lora_shapes,
)
E, r, K, N = 3, 2, 5, 7
expert_weights = torch.empty(E, K, N)
with pytest.raises(ValueError, match="Invalid ScatterMoE LoRA layout"):
validate_scattermoe_lora_shapes(
expert_weights=expert_weights,
lora_A=torch.empty(E * r, N),
lora_B=torch.empty(K, E * r),
)
# ============================================================================
# 1. KernelsArgs: disable_mlp_kernel validator
# ============================================================================