fixes for scattermoe from latest peft upgrade
This commit is contained in:
@@ -2,17 +2,35 @@
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
from . import layers
|
||||
from .lora_ops import ParallelExperts
|
||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
|
||||
from .lora_layout import (
|
||||
peft_down_proj_lora_to_scattermoe,
|
||||
peft_lora_B_to_scattermoe,
|
||||
peft_lora_to_scattermoe,
|
||||
validate_scattermoe_lora_shapes,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"layers",
|
||||
"ParallelExperts",
|
||||
"flatten_sort_count",
|
||||
"parallel_linear",
|
||||
"ScatterMoELoRA",
|
||||
"parallel_linear_lora",
|
||||
"lora_ops",
|
||||
"peft_down_proj_lora_to_scattermoe",
|
||||
"peft_lora_B_to_scattermoe",
|
||||
"peft_lora_to_scattermoe",
|
||||
"validate_scattermoe_lora_shapes",
|
||||
]
|
||||
|
||||
try:
|
||||
from . import layers
|
||||
from .lora_ops import ParallelExperts
|
||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
|
||||
except ModuleNotFoundError as exc:
|
||||
if exc.name != "triton":
|
||||
raise
|
||||
else:
|
||||
__all__ += [
|
||||
"layers",
|
||||
"ParallelExperts",
|
||||
"flatten_sort_count",
|
||||
"parallel_linear",
|
||||
"ScatterMoELoRA",
|
||||
"parallel_linear_lora",
|
||||
"lora_ops",
|
||||
]
|
||||
|
||||
@@ -35,81 +35,19 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .lora_layout import (
|
||||
peft_down_proj_lora_to_scattermoe,
|
||||
peft_lora_B_to_scattermoe,
|
||||
peft_lora_to_scattermoe,
|
||||
)
|
||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
|
||||
|
||||
# =============================================================================
|
||||
# LoRA layout conversion utilities (peft <-> scattermoe)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
|
||||
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
|
||||
expert-major ``[N, r*E]``.
|
||||
|
||||
peft reshapes B to ``[out, r, E]`` (rank-major).
|
||||
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
|
||||
"""
|
||||
N = peft_B.shape[0]
|
||||
return (
|
||||
peft_B.reshape(N, rank, num_experts)
|
||||
.permute(0, 2, 1)
|
||||
.contiguous()
|
||||
.reshape(N, num_experts * rank)
|
||||
)
|
||||
|
||||
|
||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
|
||||
|
||||
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
|
||||
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
|
||||
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
|
||||
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
|
||||
are swapped relative to scattermoe's convention.
|
||||
|
||||
peft gives:
|
||||
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
|
||||
|
||||
scattermoe needs:
|
||||
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
|
||||
|
||||
This function swaps A<->B and converts B from rank-major to expert-major.
|
||||
Uses vectorized tensor operations (no Python loop over experts).
|
||||
|
||||
Works for **both** gate_up_proj and down_proj since the transposition
|
||||
issue is the same for any parameter.
|
||||
"""
|
||||
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
|
||||
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
|
||||
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
|
||||
|
||||
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
|
||||
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
|
||||
smoe_A = (
|
||||
peft_B_em.reshape(dim2, num_experts, rank)
|
||||
.permute(1, 2, 0)
|
||||
.contiguous()
|
||||
.reshape(rank * num_experts, dim2)
|
||||
)
|
||||
|
||||
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
|
||||
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
|
||||
smoe_B = (
|
||||
peft_A.reshape(num_experts, rank, dim1)
|
||||
.permute(2, 0, 1)
|
||||
.contiguous()
|
||||
.reshape(dim1, num_experts * rank)
|
||||
)
|
||||
|
||||
return smoe_A, smoe_B
|
||||
|
||||
|
||||
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
|
||||
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
||||
|
||||
__all__ = [
|
||||
"peft_down_proj_lora_to_scattermoe",
|
||||
"peft_lora_B_to_scattermoe",
|
||||
"peft_lora_to_scattermoe",
|
||||
]
|
||||
|
||||
# =============================================================================
|
||||
# ParamWrapper unwrapping
|
||||
@@ -199,7 +137,7 @@ def _unwrap_experts_lora(experts_module):
|
||||
if gup is not None:
|
||||
num_experts = gup.shape[0]
|
||||
|
||||
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
|
||||
# Extract gate_up_proj LoRA
|
||||
gup_lora = None
|
||||
gup_wrapper = wrappers.get("gate_up_proj")
|
||||
if gup_wrapper is not None:
|
||||
@@ -208,7 +146,7 @@ def _unwrap_experts_lora(experts_module):
|
||||
rank = lora_A.shape[0] // num_experts
|
||||
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
|
||||
|
||||
# Extract down_proj LoRA (needs A<->B swap due to transposition)
|
||||
# Extract down_proj LoRA
|
||||
down_lora = None
|
||||
down_wrapper = wrappers.get("down_proj")
|
||||
if down_wrapper is not None:
|
||||
|
||||
@@ -34,6 +34,7 @@ from .kernels.lora_ops import (
|
||||
scatter2scatter_lora,
|
||||
scatter2scatter_lora_dX,
|
||||
)
|
||||
from .lora_layout import validate_scattermoe_lora_shapes
|
||||
|
||||
|
||||
class ScatterMoELoRA(torch.autograd.Function):
|
||||
@@ -422,11 +423,6 @@ def get_lora_params_from_wrapper(module) -> tuple:
|
||||
return lora_A, lora_B, scaling
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Drop-in replacement for parallel_linear
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def parallel_linear_lora(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
@@ -451,6 +447,7 @@ def parallel_linear_lora(
|
||||
Otherwise falls back to standard scatter2scatter.
|
||||
"""
|
||||
if lora_A is not None and lora_B is not None:
|
||||
validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B)
|
||||
return ScatterMoELoRA.apply(
|
||||
inputs,
|
||||
expert_weights,
|
||||
|
||||
Reference in New Issue
Block a user