From 02e4f2350dd3ebea8c4da5a682082874478aebb5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 21 Apr 2026 08:00:16 -0400 Subject: [PATCH] fixes for scattermoe from latest peft upgrade --- .../kernels/libs/scattermoe_lora/__init__.py | 40 ++++++--- .../kernels/libs/scattermoe_lora/layers.py | 86 +++---------------- .../scattermoe_lora/parallel_linear_lora.py | 7 +- .../test_scattermoe_lora_olmoe.py | 83 +++++++----------- tests/integrations/test_scattermoe_lora.py | 45 ++++++++++ 5 files changed, 119 insertions(+), 142 deletions(-) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py index f5148634e..398779e3b 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py @@ -2,17 +2,35 @@ # Copyright (c) Axolotl AI # Licensed under the Apache License, Version 2.0 -from . import layers -from .lora_ops import ParallelExperts -from .parallel_experts import flatten_sort_count, parallel_linear -from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora +from .lora_layout import ( + peft_down_proj_lora_to_scattermoe, + peft_lora_B_to_scattermoe, + peft_lora_to_scattermoe, + validate_scattermoe_lora_shapes, +) __all__ = [ - "layers", - "ParallelExperts", - "flatten_sort_count", - "parallel_linear", - "ScatterMoELoRA", - "parallel_linear_lora", - "lora_ops", + "peft_down_proj_lora_to_scattermoe", + "peft_lora_B_to_scattermoe", + "peft_lora_to_scattermoe", + "validate_scattermoe_lora_shapes", ] + +try: + from . import layers + from .lora_ops import ParallelExperts + from .parallel_experts import flatten_sort_count, parallel_linear + from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora +except ModuleNotFoundError as exc: + if exc.name != "triton": + raise +else: + __all__ += [ + "layers", + "ParallelExperts", + "flatten_sort_count", + "parallel_linear", + "ScatterMoELoRA", + "parallel_linear_lora", + "lora_ops", + ] diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index c6c01e255..b2bd4f640 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -35,81 +35,19 @@ import torch from torch import nn from torch.nn import functional as F +from .lora_layout import ( + peft_down_proj_lora_to_scattermoe, + peft_lora_B_to_scattermoe, + peft_lora_to_scattermoe, +) from .parallel_experts import flatten_sort_count, parallel_linear from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora -# ============================================================================= -# LoRA layout conversion utilities (peft <-> scattermoe) -# ============================================================================= - - -def peft_lora_B_to_scattermoe(peft_B, num_experts, rank): - """Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe - expert-major ``[N, r*E]``. - - peft reshapes B to ``[out, r, E]`` (rank-major). - scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major). - """ - N = peft_B.shape[0] - return ( - peft_B.reshape(N, rank, num_experts) - .permute(0, 2, 1) - .contiguous() - .reshape(N, num_experts * rank) - ) - - -def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): - """Convert peft LoRA weights to scattermoe layout (with A<->B swap). - - peft operates on the parameter in its native storage layout ``[E, dim1, dim2]`` - where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the - parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with - ``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles - are swapped relative to scattermoe's convention. - - peft gives: - lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]`` - - scattermoe needs: - lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]`` - - This function swaps A<->B and converts B from rank-major to expert-major. - Uses vectorized tensor operations (no Python loop over experts). - - Works for **both** gate_up_proj and down_proj since the transposition - issue is the same for any parameter. - """ - peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank) - - dim1 = peft_A.shape[1] # peft in_features -> scattermoe N - dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K - - # smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2] - # [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2] - smoe_A = ( - peft_B_em.reshape(dim2, num_experts, rank) - .permute(1, 2, 0) - .contiguous() - .reshape(rank * num_experts, dim2) - ) - - # smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r] - # [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r] - smoe_B = ( - peft_A.reshape(num_experts, rank, dim1) - .permute(2, 0, 1) - .contiguous() - .reshape(dim1, num_experts * rank) - ) - - return smoe_A, smoe_B - - -def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): - """Deprecated alias for :func:`peft_lora_to_scattermoe`.""" - return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank) - +__all__ = [ + "peft_down_proj_lora_to_scattermoe", + "peft_lora_B_to_scattermoe", + "peft_lora_to_scattermoe", +] # ============================================================================= # ParamWrapper unwrapping @@ -199,7 +137,7 @@ def _unwrap_experts_lora(experts_module): if gup is not None: num_experts = gup.shape[0] - # Extract gate_up_proj LoRA (needs A<->B swap due to transposition) + # Extract gate_up_proj LoRA gup_lora = None gup_wrapper = wrappers.get("gate_up_proj") if gup_wrapper is not None: @@ -208,7 +146,7 @@ def _unwrap_experts_lora(experts_module): rank = lora_A.shape[0] // num_experts gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling) - # Extract down_proj LoRA (needs A<->B swap due to transposition) + # Extract down_proj LoRA down_lora = None down_wrapper = wrappers.get("down_proj") if down_wrapper is not None: diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py index 17dfd420c..71a31f86e 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py @@ -34,6 +34,7 @@ from .kernels.lora_ops import ( scatter2scatter_lora, scatter2scatter_lora_dX, ) +from .lora_layout import validate_scattermoe_lora_shapes class ScatterMoELoRA(torch.autograd.Function): @@ -422,11 +423,6 @@ def get_lora_params_from_wrapper(module) -> tuple: return lora_A, lora_B, scaling -# ============================================================================= -# Drop-in replacement for parallel_linear -# ============================================================================= - - def parallel_linear_lora( inputs: torch.Tensor, expert_weights: torch.Tensor, @@ -451,6 +447,7 @@ def parallel_linear_lora( Otherwise falls back to standard scatter2scatter. """ if lora_A is not None and lora_B is not None: + validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B) return ScatterMoELoRA.apply( inputs, expert_weights, diff --git a/tests/e2e/integrations/test_scattermoe_lora_olmoe.py b/tests/e2e/integrations/test_scattermoe_lora_olmoe.py index f9376b35f..44260eb9e 100644 --- a/tests/e2e/integrations/test_scattermoe_lora_olmoe.py +++ b/tests/e2e/integrations/test_scattermoe_lora_olmoe.py @@ -54,25 +54,7 @@ except (ImportError, ModuleNotFoundError): ) def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): - peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank) - K_inter, N_hidden = peft_B.shape[0], peft_A.shape[1] - smoe_A = torch.zeros( - rank * num_experts, - K_inter, - device=peft_A.device, - dtype=peft_A.dtype, - ) - smoe_B = torch.zeros( - N_hidden, - rank * num_experts, - device=peft_A.device, - dtype=peft_A.dtype, - ) - for e in range(num_experts): - s = e * rank - smoe_A[s : s + rank, :] = peft_B_em[:, s : s + rank].T - smoe_B[:, s : s + rank] = peft_A[s : s + rank, :].T - return smoe_A, smoe_B + return peft_A, peft_lora_B_to_scattermoe(peft_B, num_experts, rank) def _unwrap_experts_lora(experts_module): return experts_module, None, None @@ -145,11 +127,7 @@ def scattermoe_lora_B_to_peft(smoe_B, num_experts, rank): def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): - """Convert peft LoRA for gate_up_proj to scattermoe layout. - - Both gate_up_proj and down_proj need the A<->B swap because - scattermoe transposes the parameter (W = param.T). - """ + """Convert peft LoRA for gate_up_proj to scattermoe layout.""" return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank) @@ -322,14 +300,16 @@ class TestLoRABLayoutConversion: hidden, inter = 32, 16 scaling = 2.0 - peft_A = torch.randn(E * r, hidden) - peft_B = torch.randn(inter, E * r) + peft_A = torch.randn(E * r, inter) + peft_B = torch.randn(hidden, E * r) - A_r = peft_A.reshape(E, r, hidden) - B_r = peft_B.reshape(inter, r, E) - delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling + A_r = peft_A.reshape(E, r, inter) + B_r = peft_B.reshape(hidden, r, E) + delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r) + assert smoe_A.shape == (E * r, inter) + assert smoe_B.shape == (hidden, E * r) for e in range(E): A_e = smoe_A[e * r : (e + 1) * r, :] B_e = smoe_B[:, e * r : (e + 1) * r] @@ -342,27 +322,26 @@ class TestLoRABLayoutConversion: """Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like). gate_up_proj param: [E, 2*inter, hidden]. - peft: in_features=2*inter, out_features=hidden. - peft lora_A: [r*E, 2*inter], lora_B: [hidden, r*E]. + peft: in_features=hidden, out_features=2*inter. + peft lora_A: [r*E, hidden], lora_B: [2*inter, r*E]. scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter. scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E]. - Uses non-square dims (hidden=32 != 2*inter=24) to catch A<->B swap bugs. + Uses non-square dims (hidden=32 != 2*inter=24) to catch layout bugs. """ E, r = 4, 2 hidden, inter = 32, 12 # 2*inter=24 != hidden=32 scaling = 2.0 - # peft assigns: in_features=2*inter, out_features=hidden - peft_A = torch.randn(E * r, 2 * inter) # [r*E, in_features=2*inter] - peft_B = torch.randn(hidden, E * r) # [out_features=hidden, r*E] + # peft assigns: in_features=hidden, out_features=2*inter + peft_A = torch.randn(E * r, hidden) # [r*E, in_features=hidden] + peft_B = torch.randn(2 * inter, E * r) # [out_features=2*inter, r*E] - # peft delta via einsum: "o r e, e r i -> e i o" - A_r = peft_A.reshape(E, r, 2 * inter) - B_r = peft_B.reshape(hidden, r, E) - delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling - # delta_peft[e] has shape [in_features, out_features] = [2*inter, hidden] + A_r = peft_A.reshape(E, r, hidden) + B_r = peft_B.reshape(2 * inter, r, E) + delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling + # delta_peft[e] has shape [out_features, in_features] = [2*inter, hidden] # = param[e] shape [2*inter, hidden] smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r) @@ -488,29 +467,29 @@ class TestPeftLoRAWeightExtraction: assert gup_lora is not None, "gate_up_proj LoRA not detected" assert down_lora is not None, "down_proj LoRA not detected" - # Check shapes (after peft->scattermoe conversion with A<->B swap) + # Check shapes after peft->scattermoe conversion. # gate_up_proj: peft A [E*r, hidden] / B [2*inter, E*r] - # After swap: smoe_A [E*r, 2*inter], smoe_B [hidden, E*r] + # scattermoe: smoe_A [E*r, hidden], smoe_B [2*inter, E*r] E, r = config.num_experts, 4 gup_A, gup_B, gup_s = gup_lora - assert gup_A.shape == (E * r, 2 * config.intermediate_size), ( - f"gate_up_proj smoe_A: expected [r*E, 2*inter]={(E * r, 2 * config.intermediate_size)}, " + assert gup_A.shape == (E * r, config.hidden_size), ( + f"gate_up_proj smoe_A: expected [r*E, hidden]={(E * r, config.hidden_size)}, " f"got {gup_A.shape}" ) - assert gup_B.shape == (config.hidden_size, E * r), ( - f"gate_up_proj smoe_B: expected [hidden, r*E]=" - f"{(config.hidden_size, E * r)}, got {gup_B.shape}" + assert gup_B.shape == (2 * config.intermediate_size, E * r), ( + f"gate_up_proj smoe_B: expected [2*inter, r*E]=" + f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}" ) # down_proj: peft A [E*r, inter] / B [hidden, E*r] - # After swap: smoe_A [E*r, hidden], smoe_B [inter, E*r] + # scattermoe: smoe_A [E*r, inter], smoe_B [hidden, E*r] down_A, down_B, down_s = down_lora - assert down_A.shape == (E * r, config.hidden_size), ( - f"down_proj smoe_A: expected [r*E, hidden]={(E * r, config.hidden_size)}, " + assert down_A.shape == (E * r, config.intermediate_size), ( + f"down_proj smoe_A: expected [r*E, inter]={(E * r, config.intermediate_size)}, " f"got {down_A.shape}" ) - assert down_B.shape == (config.intermediate_size, E * r), ( - f"down_proj smoe_B: expected [inter, r*E]={(config.intermediate_size, E * r)}, " + assert down_B.shape == (config.hidden_size, E * r), ( + f"down_proj smoe_B: expected [hidden, r*E]={(config.hidden_size, E * r)}, " f"got {down_B.shape}" ) diff --git a/tests/integrations/test_scattermoe_lora.py b/tests/integrations/test_scattermoe_lora.py index bd50d06fe..ca095182a 100644 --- a/tests/integrations/test_scattermoe_lora.py +++ b/tests/integrations/test_scattermoe_lora.py @@ -21,6 +21,51 @@ from unittest.mock import patch import pytest import torch + +class TestPeftScatterMoELoRALayout: + """CPU-only guards for PEFT target_parameters layout conversion.""" + + def test_peft_layout_keeps_a_and_reorders_b(self): + from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import ( + peft_lora_to_scattermoe, + ) + + E, r, K, N = 3, 2, 5, 7 + scaling = 2.0 + peft_A = torch.randn(E * r, K) + peft_B = torch.randn(N, E * r) + + smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r) + + assert smoe_A is peft_A + assert smoe_A.shape == (E * r, K) + assert smoe_B.shape == (N, E * r) + + A_r = peft_A.reshape(E, r, K) + B_r = peft_B.reshape(N, r, E) + delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling + + for e in range(E): + A_e = smoe_A[e * r : (e + 1) * r, :] + B_e = smoe_B[:, e * r : (e + 1) * r] + torch.testing.assert_close(scaling * (B_e @ A_e), delta_peft[e]) + + def test_swapped_layout_fails_before_kernel_dispatch(self): + from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import ( + validate_scattermoe_lora_shapes, + ) + + E, r, K, N = 3, 2, 5, 7 + expert_weights = torch.empty(E, K, N) + + with pytest.raises(ValueError, match="Invalid ScatterMoE LoRA layout"): + validate_scattermoe_lora_shapes( + expert_weights=expert_weights, + lora_A=torch.empty(E * r, N), + lora_B=torch.empty(K, E * r), + ) + + # ============================================================================ # 1. KernelsArgs: disable_mlp_kernel validator # ============================================================================