fixes for scattermoe from latest peft upgrade

This commit is contained in:
Wing Lian
2026-04-21 08:00:16 -04:00
parent 4195605ab2
commit 02e4f2350d
5 changed files with 119 additions and 142 deletions

View File

@@ -2,17 +2,35 @@
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import layers
from .lora_ops import ParallelExperts
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
from .lora_layout import (
peft_down_proj_lora_to_scattermoe,
peft_lora_B_to_scattermoe,
peft_lora_to_scattermoe,
validate_scattermoe_lora_shapes,
)
__all__ = [
"layers",
"ParallelExperts",
"flatten_sort_count",
"parallel_linear",
"ScatterMoELoRA",
"parallel_linear_lora",
"lora_ops",
"peft_down_proj_lora_to_scattermoe",
"peft_lora_B_to_scattermoe",
"peft_lora_to_scattermoe",
"validate_scattermoe_lora_shapes",
]
try:
from . import layers
from .lora_ops import ParallelExperts
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
except ModuleNotFoundError as exc:
if exc.name != "triton":
raise
else:
__all__ += [
"layers",
"ParallelExperts",
"flatten_sort_count",
"parallel_linear",
"ScatterMoELoRA",
"parallel_linear_lora",
"lora_ops",
]

View File

@@ -35,81 +35,19 @@ import torch
from torch import nn
from torch.nn import functional as F
from .lora_layout import (
peft_down_proj_lora_to_scattermoe,
peft_lora_B_to_scattermoe,
peft_lora_to_scattermoe,
)
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
# =============================================================================
# LoRA layout conversion utilities (peft <-> scattermoe)
# =============================================================================
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
expert-major ``[N, r*E]``.
peft reshapes B to ``[out, r, E]`` (rank-major).
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
"""
N = peft_B.shape[0]
return (
peft_B.reshape(N, rank, num_experts)
.permute(0, 2, 1)
.contiguous()
.reshape(N, num_experts * rank)
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
are swapped relative to scattermoe's convention.
peft gives:
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
scattermoe needs:
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
This function swaps A<->B and converts B from rank-major to expert-major.
Uses vectorized tensor operations (no Python loop over experts).
Works for **both** gate_up_proj and down_proj since the transposition
issue is the same for any parameter.
"""
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
smoe_A = (
peft_B_em.reshape(dim2, num_experts, rank)
.permute(1, 2, 0)
.contiguous()
.reshape(rank * num_experts, dim2)
)
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
smoe_B = (
peft_A.reshape(num_experts, rank, dim1)
.permute(2, 0, 1)
.contiguous()
.reshape(dim1, num_experts * rank)
)
return smoe_A, smoe_B
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
__all__ = [
"peft_down_proj_lora_to_scattermoe",
"peft_lora_B_to_scattermoe",
"peft_lora_to_scattermoe",
]
# =============================================================================
# ParamWrapper unwrapping
@@ -199,7 +137,7 @@ def _unwrap_experts_lora(experts_module):
if gup is not None:
num_experts = gup.shape[0]
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
# Extract gate_up_proj LoRA
gup_lora = None
gup_wrapper = wrappers.get("gate_up_proj")
if gup_wrapper is not None:
@@ -208,7 +146,7 @@ def _unwrap_experts_lora(experts_module):
rank = lora_A.shape[0] // num_experts
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
# Extract down_proj LoRA (needs A<->B swap due to transposition)
# Extract down_proj LoRA
down_lora = None
down_wrapper = wrappers.get("down_proj")
if down_wrapper is not None:

View File

@@ -34,6 +34,7 @@ from .kernels.lora_ops import (
scatter2scatter_lora,
scatter2scatter_lora_dX,
)
from .lora_layout import validate_scattermoe_lora_shapes
class ScatterMoELoRA(torch.autograd.Function):
@@ -422,11 +423,6 @@ def get_lora_params_from_wrapper(module) -> tuple:
return lora_A, lora_B, scaling
# =============================================================================
# Drop-in replacement for parallel_linear
# =============================================================================
def parallel_linear_lora(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
@@ -451,6 +447,7 @@ def parallel_linear_lora(
Otherwise falls back to standard scatter2scatter.
"""
if lora_A is not None and lora_B is not None:
validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B)
return ScatterMoELoRA.apply(
inputs,
expert_weights,

View File

@@ -54,25 +54,7 @@ except (ImportError, ModuleNotFoundError):
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
K_inter, N_hidden = peft_B.shape[0], peft_A.shape[1]
smoe_A = torch.zeros(
rank * num_experts,
K_inter,
device=peft_A.device,
dtype=peft_A.dtype,
)
smoe_B = torch.zeros(
N_hidden,
rank * num_experts,
device=peft_A.device,
dtype=peft_A.dtype,
)
for e in range(num_experts):
s = e * rank
smoe_A[s : s + rank, :] = peft_B_em[:, s : s + rank].T
smoe_B[:, s : s + rank] = peft_A[s : s + rank, :].T
return smoe_A, smoe_B
return peft_A, peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
def _unwrap_experts_lora(experts_module):
return experts_module, None, None
@@ -145,11 +127,7 @@ def scattermoe_lora_B_to_peft(smoe_B, num_experts, rank):
def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA for gate_up_proj to scattermoe layout.
Both gate_up_proj and down_proj need the A<->B swap because
scattermoe transposes the parameter (W = param.T).
"""
"""Convert peft LoRA for gate_up_proj to scattermoe layout."""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
@@ -322,14 +300,16 @@ class TestLoRABLayoutConversion:
hidden, inter = 32, 16
scaling = 2.0
peft_A = torch.randn(E * r, hidden)
peft_B = torch.randn(inter, E * r)
peft_A = torch.randn(E * r, inter)
peft_B = torch.randn(hidden, E * r)
A_r = peft_A.reshape(E, r, hidden)
B_r = peft_B.reshape(inter, r, E)
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
A_r = peft_A.reshape(E, r, inter)
B_r = peft_B.reshape(hidden, r, E)
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
assert smoe_A.shape == (E * r, inter)
assert smoe_B.shape == (hidden, E * r)
for e in range(E):
A_e = smoe_A[e * r : (e + 1) * r, :]
B_e = smoe_B[:, e * r : (e + 1) * r]
@@ -342,27 +322,26 @@ class TestLoRABLayoutConversion:
"""Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like).
gate_up_proj param: [E, 2*inter, hidden].
peft: in_features=2*inter, out_features=hidden.
peft lora_A: [r*E, 2*inter], lora_B: [hidden, r*E].
peft: in_features=hidden, out_features=2*inter.
peft lora_A: [r*E, hidden], lora_B: [2*inter, r*E].
scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter.
scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].
Uses non-square dims (hidden=32 != 2*inter=24) to catch A<->B swap bugs.
Uses non-square dims (hidden=32 != 2*inter=24) to catch layout bugs.
"""
E, r = 4, 2
hidden, inter = 32, 12 # 2*inter=24 != hidden=32
scaling = 2.0
# peft assigns: in_features=2*inter, out_features=hidden
peft_A = torch.randn(E * r, 2 * inter) # [r*E, in_features=2*inter]
peft_B = torch.randn(hidden, E * r) # [out_features=hidden, r*E]
# peft assigns: in_features=hidden, out_features=2*inter
peft_A = torch.randn(E * r, hidden) # [r*E, in_features=hidden]
peft_B = torch.randn(2 * inter, E * r) # [out_features=2*inter, r*E]
# peft delta via einsum: "o r e, e r i -> e i o"
A_r = peft_A.reshape(E, r, 2 * inter)
B_r = peft_B.reshape(hidden, r, E)
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
# delta_peft[e] has shape [in_features, out_features] = [2*inter, hidden]
A_r = peft_A.reshape(E, r, hidden)
B_r = peft_B.reshape(2 * inter, r, E)
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
# delta_peft[e] has shape [out_features, in_features] = [2*inter, hidden]
# = param[e] shape [2*inter, hidden]
smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)
@@ -488,29 +467,29 @@ class TestPeftLoRAWeightExtraction:
assert gup_lora is not None, "gate_up_proj LoRA not detected"
assert down_lora is not None, "down_proj LoRA not detected"
# Check shapes (after peft->scattermoe conversion with A<->B swap)
# Check shapes after peft->scattermoe conversion.
# gate_up_proj: peft A [E*r, hidden] / B [2*inter, E*r]
# After swap: smoe_A [E*r, 2*inter], smoe_B [hidden, E*r]
# scattermoe: smoe_A [E*r, hidden], smoe_B [2*inter, E*r]
E, r = config.num_experts, 4
gup_A, gup_B, gup_s = gup_lora
assert gup_A.shape == (E * r, 2 * config.intermediate_size), (
f"gate_up_proj smoe_A: expected [r*E, 2*inter]={(E * r, 2 * config.intermediate_size)}, "
assert gup_A.shape == (E * r, config.hidden_size), (
f"gate_up_proj smoe_A: expected [r*E, hidden]={(E * r, config.hidden_size)}, "
f"got {gup_A.shape}"
)
assert gup_B.shape == (config.hidden_size, E * r), (
f"gate_up_proj smoe_B: expected [hidden, r*E]="
f"{(config.hidden_size, E * r)}, got {gup_B.shape}"
assert gup_B.shape == (2 * config.intermediate_size, E * r), (
f"gate_up_proj smoe_B: expected [2*inter, r*E]="
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
)
# down_proj: peft A [E*r, inter] / B [hidden, E*r]
# After swap: smoe_A [E*r, hidden], smoe_B [inter, E*r]
# scattermoe: smoe_A [E*r, inter], smoe_B [hidden, E*r]
down_A, down_B, down_s = down_lora
assert down_A.shape == (E * r, config.hidden_size), (
f"down_proj smoe_A: expected [r*E, hidden]={(E * r, config.hidden_size)}, "
assert down_A.shape == (E * r, config.intermediate_size), (
f"down_proj smoe_A: expected [r*E, inter]={(E * r, config.intermediate_size)}, "
f"got {down_A.shape}"
)
assert down_B.shape == (config.intermediate_size, E * r), (
f"down_proj smoe_B: expected [inter, r*E]={(config.intermediate_size, E * r)}, "
assert down_B.shape == (config.hidden_size, E * r), (
f"down_proj smoe_B: expected [hidden, r*E]={(config.hidden_size, E * r)}, "
f"got {down_B.shape}"
)

View File

@@ -21,6 +21,51 @@ from unittest.mock import patch
import pytest
import torch
class TestPeftScatterMoELoRALayout:
"""CPU-only guards for PEFT target_parameters layout conversion."""
def test_peft_layout_keeps_a_and_reorders_b(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
peft_lora_to_scattermoe,
)
E, r, K, N = 3, 2, 5, 7
scaling = 2.0
peft_A = torch.randn(E * r, K)
peft_B = torch.randn(N, E * r)
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
assert smoe_A is peft_A
assert smoe_A.shape == (E * r, K)
assert smoe_B.shape == (N, E * r)
A_r = peft_A.reshape(E, r, K)
B_r = peft_B.reshape(N, r, E)
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
for e in range(E):
A_e = smoe_A[e * r : (e + 1) * r, :]
B_e = smoe_B[:, e * r : (e + 1) * r]
torch.testing.assert_close(scaling * (B_e @ A_e), delta_peft[e])
def test_swapped_layout_fails_before_kernel_dispatch(self):
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
validate_scattermoe_lora_shapes,
)
E, r, K, N = 3, 2, 5, 7
expert_weights = torch.empty(E, K, N)
with pytest.raises(ValueError, match="Invalid ScatterMoE LoRA layout"):
validate_scattermoe_lora_shapes(
expert_weights=expert_weights,
lora_A=torch.empty(E * r, N),
lora_B=torch.empty(K, E * r),
)
# ============================================================================
# 1. KernelsArgs: disable_mlp_kernel validator
# ============================================================================