From a36aaa70cee298d524b310429194df1f8621ff2e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 7 Mar 2026 00:00:48 -0500 Subject: [PATCH] add gpu tests for scattermoe (#3474) [skip ci] --- .../test_scattermoe_lora_kernels.py | 1478 +++++++++++++++++ .../test_scattermoe_lora_olmoe.py | 1255 ++++++++++++++ 2 files changed, 2733 insertions(+) create mode 100644 tests/e2e/integrations/test_scattermoe_lora_kernels.py create mode 100644 tests/e2e/integrations/test_scattermoe_lora_olmoe.py diff --git a/tests/e2e/integrations/test_scattermoe_lora_kernels.py b/tests/e2e/integrations/test_scattermoe_lora_kernels.py new file mode 100644 index 000000000..d11272c8f --- /dev/null +++ b/tests/e2e/integrations/test_scattermoe_lora_kernels.py @@ -0,0 +1,1478 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +Tests for ScatterMoE + LoRA Fused Kernels +========================================== + +Tests verify correctness of: +1. Forward pass: fused kernel matches naive PyTorch reference +2. Backward pass: gradients for LoRA A, B, and input match reference +3. Frozen weights: expert weight gradients are correctly skipped +4. Various configurations: top-k, grouped_in/out, with/without bias +5. Numerical stability: bf16/fp16 outputs within tolerance of fp32 reference + +Test strategy: +- Reference implementation uses pure PyTorch ops (no Triton) +- ScatterMoE routing (flatten_sort_count) is shared between reference and kernel +- Tolerances account for tf32 accumulation in Triton kernels +""" + +import pytest +import torch + +# Skip all tests if CUDA is not available +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for Triton kernels", +) + +_SMOE = "axolotl.integrations.kernels.libs.scattermoe_lora" + + +# ============================================================================= +# Helpers +# ============================================================================= + + +def flatten_sort_count_ref(expert_idxs: torch.Tensor, num_experts: int): + """Reference implementation of routing.""" + with torch.no_grad(): + flat = expert_idxs.flatten() + sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flat) + counts = flat.bincount(minlength=num_experts) + offsets = counts.cumsum(-1) + return sorted_expert_idxs, sorted_scattered_idxs, offsets + + +def reference_parallel_linear_lora( + X, + W, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + lora_A, + lora_B, + scaling, + x_grouped=False, + y_grouped=False, + bias=None, +): + """ + Pure PyTorch reference for: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e] + + Args: + X: [M, K] input (token order) + W: [E, K, N] expert weights + sorted_expert_idxs: [M*k] expert assignments (sorted) + sorted_scattered_idxs: [M*k] original token indices (sorted) + lora_A: [r*E, K] LoRA A weights + lora_B: [N, r*E] LoRA B weights + scaling: LoRA scaling factor + """ + E, K, N = W.shape + R = lora_A.size(0) // E + L = sorted_expert_idxs.size(0) # M * k + + output = torch.zeros(L, N, device=X.device, dtype=X.dtype) + + for i in range(L): + e = sorted_expert_idxs[i].item() + if x_grouped: + x_i = X[i] + else: + token_idx = sorted_scattered_idxs[i].item() // k + x_i = X[token_idx] + + w_e = W[e] # [K, N] + a_e = lora_A[e * R : (e + 1) * R, :] # [r, K] + b_e = lora_B[:, e * R : (e + 1) * R] # [N, r] + + # Y = X @ W + scaling * (X @ A^T) @ B^T + base = x_i @ w_e # [N] + lora = scaling * ((x_i @ a_e.T) @ b_e.T) # [N] + out_i = base + lora + + if bias is not None: + out_i = out_i + bias[e] + + if y_grouped: + output[i] = out_i + else: + output[sorted_scattered_idxs[i]] = out_i + + return output + + +def reference_lora_backward( + grad_out, + X, + W, + lora_A, + lora_B, + scaling, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + k, + E, +): + """ + Pure PyTorch reference for LoRA backward pass on grouped data. + + Returns: + dX: [M*k, K] input gradient (in grouped order) + dA: [r*E, K] LoRA A gradient + dB: [N, r*E] LoRA B gradient + """ + R = lora_A.size(0) // E + + dA = torch.zeros_like(lora_A) + dB = torch.zeros_like(lora_B) + dX = torch.zeros_like(X) + + prev_offset = 0 + for e in range(E): + curr_offset = expert_offsets[e].item() + if curr_offset > prev_offset: + dy_e = grad_out[prev_offset:curr_offset] # [M_e, N] + x_e = X[prev_offset:curr_offset] # [M_e, K] + a_e = lora_A[e * R : (e + 1) * R, :] # [r, K] + b_e = lora_B[:, e * R : (e + 1) * R] # [N, r] + w_e = W[e] # [K, N] + + # Input gradient: dX = dY @ W^T + scaling * (dY @ B) @ A + dx_base = dy_e @ w_e.T # [M_e, K] + dy_b = dy_e @ b_e # [M_e, r] + dx_lora = scaling * (dy_b @ a_e) # [M_e, K] + dX[prev_offset:curr_offset] = dx_base + dx_lora + + # LoRA A gradient: dA = scaling * (dY @ B)^T @ X + xa = x_e @ a_e.T # [M_e, r] + dA[e * R : (e + 1) * R, :] = scaling * (dy_b.T @ x_e) + + # LoRA B gradient: dB = scaling * dY^T @ (X @ A^T) + dB[:, e * R : (e + 1) * R] = scaling * (dy_e.T @ xa) + + prev_offset = curr_offset + + return dX, dA, dB + + +def make_test_data( + M=32, + K=64, + N=128, + E=4, + R=8, + k=2, + dtype=torch.float32, + device="cuda", + seed=42, +): + """Create test data for ScatterMoE + LoRA tests.""" + torch.manual_seed(seed) + + X = torch.randn(M, K, device=device, dtype=dtype) + W = torch.randn(E, K, N, device=device, dtype=dtype) * 0.02 + lora_A = torch.randn(R * E, K, device=device, dtype=dtype) * 0.01 + lora_B = torch.randn(N, R * E, device=device, dtype=dtype) * 0.01 + scaling = 0.5 + + # Generate routing + selected_experts = torch.randint(0, E, (M, k), device=device) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count_ref( + selected_experts, E + ) + + return { + "X": X, + "W": W, + "lora_A": lora_A, + "lora_B": lora_B, + "scaling": scaling, + "k": k, + "E": E, + "R": R, + "sorted_expert_idxs": sorted_expert_idxs, + "sorted_scattered_idxs": sorted_scattered_idxs, + "expert_offsets": expert_offsets, + } + + +# ============================================================================= +# Test: Forward Pass Correctness +# ============================================================================= + + +class TestForwardPass: + """Test forward pass of fused scatter2scatter_lora kernel.""" + + def _run_forward_test( + self, M, K, N, E, R, k, dtype=torch.float32, atol=1e-2, rtol=1e-2 + ): + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k, dtype=dtype) + + # Reference + ref_output = reference_parallel_linear_lora( + data["X"], + data["W"], + data["k"], + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["lora_A"], + data["lora_B"], + data["scaling"], + ) + + # Kernel + kernel_output = lora_ops.scatter2scatter_lora( + X=data["X"], + W=data["W"], + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=data["k"], + lora_A=data["lora_A"], + lora_B=data["lora_B"], + scaling=data["scaling"], + ) + + torch.testing.assert_close(kernel_output, ref_output, atol=atol, rtol=rtol) + + def test_basic(self): + """Basic forward pass with small dimensions.""" + self._run_forward_test(M=16, K=64, N=64, E=4, R=8, k=1) + + def test_topk2(self): + """Forward pass with top-2 routing.""" + self._run_forward_test(M=32, K=64, N=128, E=4, R=8, k=2) + + def test_larger_rank(self): + """Forward pass with larger LoRA rank.""" + self._run_forward_test(M=16, K=128, N=128, E=8, R=32, k=2) + + def test_small_rank(self): + """Forward pass with very small LoRA rank.""" + self._run_forward_test(M=32, K=64, N=64, E=4, R=4, k=1) + + def test_many_experts(self): + """Forward with many experts, fewer tokens per expert.""" + self._run_forward_test(M=64, K=64, N=64, E=16, R=8, k=2) + + def test_non_power_of_2_dims(self): + """Test with dimensions that are not powers of 2.""" + self._run_forward_test(M=17, K=96, N=80, E=6, R=16, k=2, atol=2e-2, rtol=2e-2) + + def test_single_token(self): + """Test with a single token.""" + self._run_forward_test(M=1, K=64, N=64, E=4, R=8, k=1) + + def test_bf16(self): + """Test with bfloat16 precision.""" + self._run_forward_test( + M=32, K=64, N=128, E=4, R=8, k=2, dtype=torch.bfloat16, atol=5e-2, rtol=5e-2 + ) + + def test_fp16(self): + """Test with float16 precision.""" + self._run_forward_test( + M=32, K=64, N=128, E=4, R=8, k=2, dtype=torch.float16, atol=5e-2, rtol=5e-2 + ) + + +class TestForwardGrouped: + """Test forward pass with grouped_in/grouped_out configurations.""" + + def _make_grouped_data(self, M=32, K=64, N=128, E=4, R=8, k=2, dtype=torch.float32): + from importlib import import_module + + base_ops = import_module(f"{_SMOE}.kernels.ops") + + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k, dtype=dtype) + + # Create grouped X + grouped_X = base_ops.group(data["X"], data["sorted_scattered_idxs"], fan_out=k) + data["grouped_X"] = grouped_X + return data + + def test_x_grouped(self): + """Forward with pre-grouped input.""" + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + + data = self._make_grouped_data() + + ref_output = reference_parallel_linear_lora( + data["grouped_X"], + data["W"], + data["k"], + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["lora_A"], + data["lora_B"], + data["scaling"], + x_grouped=True, + ) + + kernel_output = lora_ops.scatter2scatter_lora( + X=data["grouped_X"], + W=data["W"], + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=1, # When x_grouped, fan_out=1 (already expanded) + lora_A=data["lora_A"], + lora_B=data["lora_B"], + scaling=data["scaling"], + x_grouped=True, + ) + + torch.testing.assert_close(kernel_output, ref_output, atol=1e-2, rtol=1e-2) + + def test_y_grouped(self): + """Forward with grouped output.""" + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + + data = make_test_data() + + ref_output = reference_parallel_linear_lora( + data["X"], + data["W"], + data["k"], + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["lora_A"], + data["lora_B"], + data["scaling"], + y_grouped=True, + ) + + kernel_output = lora_ops.scatter2scatter_lora( + X=data["X"], + W=data["W"], + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=data["k"], + lora_A=data["lora_A"], + lora_B=data["lora_B"], + scaling=data["scaling"], + y_grouped=True, + ) + + torch.testing.assert_close(kernel_output, ref_output, atol=1e-2, rtol=1e-2) + + +# ============================================================================= +# Test: Backward Pass Correctness (LoRA Gradients) +# ============================================================================= + + +class TestLoRAGradients: + """Test backward LoRA gradient computation (dA, dB).""" + + def _run_lora_grad_test(self, M, K, N, E, R, k, atol=1e-2, rtol=1e-2): + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + base_ops = import_module(f"{_SMOE}.kernels.ops") + + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k) + + # Group X for backward + grouped_X = base_ops.group(data["X"], data["sorted_scattered_idxs"], fan_out=k) + + # Create fake grad_out in grouped order + grad_out = torch.randn( + data["sorted_expert_idxs"].size(0), + N, + device="cuda", + dtype=torch.float32, + ) + + # Reference + _, ref_dA, ref_dB = reference_lora_backward( + grad_out, + grouped_X, + data["W"], + data["lora_A"], + data["lora_B"], + data["scaling"], + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + k, + E, + ) + + # Kernel + kernel_dA, kernel_dB = lora_ops.group_bwd_lora( + DY=grad_out, + X=grouped_X, + lora_A=data["lora_A"], + lora_B=data["lora_B"], + expert_offsets=data["expert_offsets"], + E=E, + scaling=data["scaling"], + ) + + torch.testing.assert_close(kernel_dA, ref_dA, atol=atol, rtol=rtol) + torch.testing.assert_close(kernel_dB, ref_dB, atol=atol, rtol=rtol) + + def test_basic_lora_grads(self): + self._run_lora_grad_test(M=32, K=64, N=128, E=4, R=8, k=2) + + def test_small_rank(self): + self._run_lora_grad_test(M=16, K=64, N=64, E=4, R=4, k=1) + + def test_larger_rank(self): + self._run_lora_grad_test( + M=16, K=128, N=128, E=8, R=32, k=2, atol=5e-2, rtol=5e-2 + ) + + def test_many_experts(self): + self._run_lora_grad_test(M=64, K=64, N=64, E=16, R=8, k=2) + + def test_single_token_per_expert(self): + """Edge case: roughly 1 token per expert.""" + self._run_lora_grad_test(M=8, K=64, N=64, E=8, R=4, k=1) + + +# ============================================================================= +# Test: Full Autograd (Forward + Backward) via torch.autograd +# ============================================================================= + + +class TestAutograd: + """Test full autograd integration through ScatterMoELoRA.""" + + def test_lora_receives_gradients(self): + """LoRA A and B receive non-zero gradients; frozen W does not.""" + from importlib import import_module + + pll = import_module(f"{_SMOE}.parallel_linear_lora") + + M, K, N, E, R, k = 16, 64, 64, 4, 8, 2 + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k) + + X = data["X"].clone().requires_grad_(True) + W = data["W"].clone().requires_grad_(False) # Frozen + lora_A = data["lora_A"].clone().requires_grad_(True) + lora_B = data["lora_B"].clone().requires_grad_(True) + + output = pll.ScatterMoELoRA.apply( + X, + W, + k, + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + lora_A, + lora_B, + data["scaling"], + None, + None, + False, + False, + ) + + loss = output.sum() + loss.backward() + + # LoRA params should have gradients + assert lora_A.grad is not None, "lora_A should have gradient" + assert lora_B.grad is not None, "lora_B should have gradient" + assert lora_A.grad.abs().sum() > 0, "lora_A gradient should be non-zero" + assert lora_B.grad.abs().sum() > 0, "lora_B gradient should be non-zero" + + # Input should have gradient (needed for upstream backprop) + assert X.grad is not None, "X should have gradient" + assert X.grad.abs().sum() > 0, "X gradient should be non-zero" + + def test_input_gradient_matches_reference(self): + """Input gradient from autograd matches pure PyTorch reference.""" + from importlib import import_module + + pll = import_module(f"{_SMOE}.parallel_linear_lora") + base_ops = import_module(f"{_SMOE}.kernels.ops") + + M, K, N, E, R, k = 16, 64, 64, 4, 8, 1 + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k) + + # Autograd path + X_kern = data["X"].clone().requires_grad_(True) + lora_A_kern = data["lora_A"].clone().requires_grad_(True) + lora_B_kern = data["lora_B"].clone().requires_grad_(True) + + out_kern = pll.ScatterMoELoRA.apply( + X_kern, + data["W"], + k, + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + lora_A_kern, + lora_B_kern, + data["scaling"], + None, + None, + False, + False, + ) + grad_out = torch.randn_like(out_kern) + out_kern.backward(grad_out) + + # Reference path + grouped_X = base_ops.group(data["X"], data["sorted_scattered_idxs"], fan_out=k) + grouped_grad = base_ops.group( + grad_out, data["sorted_scattered_idxs"], fan_out=1 + ) + + ref_dX, ref_dA, ref_dB = reference_lora_backward( + grouped_grad, + grouped_X, + data["W"], + data["lora_A"], + data["lora_B"], + data["scaling"], + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + k, + E, + ) + + # Compare input gradient (for k=1, no reduction needed) + # ref_dX is in grouped (expert-sorted) order; X_kern.grad is in original order. + # Ungroup ref_dX by scattering back to original positions. + ref_dX_ungrouped = torch.zeros_like(ref_dX) + ref_dX_ungrouped[data["sorted_scattered_idxs"]] = ref_dX + torch.testing.assert_close(X_kern.grad, ref_dX_ungrouped, atol=5e-2, rtol=5e-2) + + def test_lora_gradient_matches_reference(self): + """LoRA A/B gradients from autograd match reference.""" + from importlib import import_module + + pll = import_module(f"{_SMOE}.parallel_linear_lora") + base_ops = import_module(f"{_SMOE}.kernels.ops") + + M, K, N, E, R, k = 16, 64, 64, 4, 8, 1 + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k) + + # Autograd path + X_kern = data["X"].clone().requires_grad_(True) + lora_A_kern = data["lora_A"].clone().requires_grad_(True) + lora_B_kern = data["lora_B"].clone().requires_grad_(True) + + out_kern = pll.ScatterMoELoRA.apply( + X_kern, + data["W"], + k, + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + lora_A_kern, + lora_B_kern, + data["scaling"], + None, + None, + False, + False, + ) + grad_out = torch.randn_like(out_kern) + out_kern.backward(grad_out) + + # Reference path + grouped_X = base_ops.group(data["X"], data["sorted_scattered_idxs"], fan_out=k) + grouped_grad = base_ops.group( + grad_out, data["sorted_scattered_idxs"], fan_out=1 + ) + + _, ref_dA, ref_dB = reference_lora_backward( + grouped_grad, + grouped_X, + data["W"], + data["lora_A"], + data["lora_B"], + data["scaling"], + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + k, + E, + ) + + torch.testing.assert_close(lora_A_kern.grad, ref_dA, atol=5e-2, rtol=5e-2) + torch.testing.assert_close(lora_B_kern.grad, ref_dB, atol=5e-2, rtol=5e-2) + + +# ============================================================================= +# Test: Equivalence with Base ScatterMoE (scaling=0 should match base) +# ============================================================================= + + +class TestBaseEquivalence: + """When scaling=0, fused kernel should match base scatter2scatter.""" + + def test_zero_scaling_matches_base(self): + """With scaling=0, LoRA contribution vanishes; should match base.""" + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + base_ops = import_module(f"{_SMOE}.kernels.ops") + + data = make_test_data(M=32, K=64, N=128, E=4, R=8, k=2) + + base_output = base_ops.scatter2scatter( + X=data["X"], + W=data["W"], + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=data["k"], + ) + + lora_output = lora_ops.scatter2scatter_lora( + X=data["X"], + W=data["W"], + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=data["k"], + lora_A=data["lora_A"], + lora_B=data["lora_B"], + scaling=0.0, + ) + + torch.testing.assert_close(lora_output, base_output, atol=1e-3, rtol=1e-3) + + def test_zero_lora_weights_matches_base(self): + """With A=0, B=0, should match base scatter2scatter.""" + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + base_ops = import_module(f"{_SMOE}.kernels.ops") + + data = make_test_data(M=32, K=64, N=128, E=4, R=8, k=2) + + zero_A = torch.zeros_like(data["lora_A"]) + zero_B = torch.zeros_like(data["lora_B"]) + + base_output = base_ops.scatter2scatter( + X=data["X"], + W=data["W"], + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=data["k"], + ) + + lora_output = lora_ops.scatter2scatter_lora( + X=data["X"], + W=data["W"], + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=data["k"], + lora_A=zero_A, + lora_B=zero_B, + scaling=1.0, + ) + + torch.testing.assert_close(lora_output, base_output, atol=1e-3, rtol=1e-3) + + +# ============================================================================= +# Test: LoRA Additivity +# ============================================================================= + + +class TestLoRAAdditivity: + """Test that the LoRA component is correctly additive.""" + + def test_lora_additivity(self): + """ + Verify: fused(X, W, A, B, s) == base(X, W) + s * per_expert_lora(X, A, B) + """ + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + base_ops = import_module(f"{_SMOE}.kernels.ops") + + data = make_test_data(M=32, K=64, N=128, E=4, R=8, k=2) + + # Base output (no LoRA) + base_output = base_ops.scatter2scatter( + X=data["X"], + W=data["W"], + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=data["k"], + ) + + # Fused output + fused_output = lora_ops.scatter2scatter_lora( + X=data["X"], + W=data["W"], + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=data["k"], + lora_A=data["lora_A"], + lora_B=data["lora_B"], + scaling=data["scaling"], + ) + + # Compute LoRA contribution manually (reference) + lora_only = reference_parallel_linear_lora( + data["X"], + torch.zeros_like(data["W"]), + data["k"], + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["lora_A"], + data["lora_B"], + data["scaling"], + ) + + # fused = base + lora + expected = base_output + lora_only + torch.testing.assert_close(fused_output, expected, atol=2e-2, rtol=2e-2) + + +# ============================================================================= +# Test: ParallelExperts module integration +# ============================================================================= + + +class TestParallelExpertsModule: + """Test the ParallelExperts module with LoRA.""" + + def test_set_and_clear_lora(self): + """Test set_lora/clear_lora lifecycle.""" + from importlib import import_module + + lora_module = import_module(f"{_SMOE}.lora_ops") + + pe = lora_module.ParallelExperts(4, 64, 128).cuda() + + A = torch.randn(32, 64, device="cuda") # r=8, E=4 + B = torch.randn(128, 32, device="cuda") + pe.set_lora(A, B, 0.5) + + assert pe._lora_A is A + assert pe._lora_B is B + assert pe._lora_scaling == 0.5 + + pe.clear_lora() + assert pe._lora_A is None + assert pe._lora_B is None + + def test_forward_with_lora(self): + """ParallelExperts forward with LoRA matches reference.""" + from importlib import import_module + + lora_module = import_module(f"{_SMOE}.lora_ops") + + E, K, N, R = 4, 64, 128, 8 + M, k = 16, 2 + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k) + + pe = lora_module.ParallelExperts(E, K, N).cuda() + # Set weights to match test data + with torch.no_grad(): + pe.weight.copy_(data["W"].permute(0, 2, 1)) # [E, N, K] + + pe.set_lora(data["lora_A"], data["lora_B"], data["scaling"]) + + output = pe( + data["X"], + k, + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + ) + + ref = reference_parallel_linear_lora( + data["X"], + data["W"], + k, + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["lora_A"], + data["lora_B"], + data["scaling"], + ) + + torch.testing.assert_close(output, ref, atol=2e-2, rtol=2e-2) + + +# ============================================================================= +# Test: Edge Cases +# ============================================================================= + + +class TestEdgeCases: + """Edge cases and boundary conditions.""" + + def test_all_tokens_one_expert(self): + """All tokens routed to a single expert.""" + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + + M, K, N, E, R, k = 16, 64, 64, 4, 8, 1 + torch.manual_seed(42) + + X = torch.randn(M, K, device="cuda") + W = torch.randn(E, K, N, device="cuda") * 0.02 + lora_A = torch.randn(R * E, K, device="cuda") * 0.01 + lora_B = torch.randn(N, R * E, device="cuda") * 0.01 + + # All tokens go to expert 0 + selected_experts = torch.zeros(M, k, device="cuda", dtype=torch.long) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = ( + flatten_sort_count_ref(selected_experts, E) + ) + + ref = reference_parallel_linear_lora( + X, + W, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + lora_A, + lora_B, + 0.5, + ) + + kernel = lora_ops.scatter2scatter_lora( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=k, + lora_A=lora_A, + lora_B=lora_B, + scaling=0.5, + ) + + torch.testing.assert_close(kernel, ref, atol=1e-2, rtol=1e-2) + + def test_empty_experts(self): + """Some experts have no tokens assigned.""" + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + + M, K, N, E, R, k = 8, 64, 64, 8, 4, 1 + torch.manual_seed(42) + + X = torch.randn(M, K, device="cuda") + W = torch.randn(E, K, N, device="cuda") * 0.02 + lora_A = torch.randn(R * E, K, device="cuda") * 0.01 + lora_B = torch.randn(N, R * E, device="cuda") * 0.01 + + # Only use experts 0 and 1 + selected_experts = torch.randint(0, 2, (M, k), device="cuda") + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = ( + flatten_sort_count_ref(selected_experts, E) + ) + + ref = reference_parallel_linear_lora( + X, + W, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + lora_A, + lora_B, + 0.5, + ) + + kernel = lora_ops.scatter2scatter_lora( + X=X, + W=W, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=k, + lora_A=lora_A, + lora_B=lora_B, + scaling=0.5, + ) + + torch.testing.assert_close(kernel, ref, atol=1e-2, rtol=1e-2) + + +# ============================================================================= +# Test: Optimization 1 - Fused dX Kernel +# ============================================================================= + + +class TestFusedDX: + """Test fused backward dX kernel: dX = dY @ W^T + scaling * (dY @ B) @ A.""" + + def _run_fused_dX_test( + self, M, K, N, E, R, k, dtype=torch.float32, atol=5e-2, rtol=5e-2 + ): + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + base_ops = import_module(f"{_SMOE}.kernels.ops") + pll = import_module(f"{_SMOE}.parallel_linear_lora") + + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k, dtype=dtype) + + # Create dummy grad_out in grouped order + grad_out = torch.randn( + data["sorted_expert_idxs"].size(0), N, device="cuda", dtype=dtype + ) + grouped_grad = base_ops.group( + grad_out, + data["sorted_scattered_idxs"], + fan_out=1, + ) + + # Reference: separate scatter2scatter(DY, W^T) + _compute_lora_input_grad + ref_base = base_ops.scatter2scatter( + X=grouped_grad, + x_grouped=True, + W=data["W"].permute(0, 2, 1), + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=1, + y_grouped=False, + ) + + ref_lora = pll._compute_lora_input_grad( + grouped_grad, + data["lora_A"], + data["lora_B"], + data["expert_offsets"], + E, + data["scaling"], + ) + # Scatter lora from grouped to ungrouped order + ref_lora_ungrouped = torch.zeros_like(ref_base) + ref_lora_ungrouped[data["sorted_scattered_idxs"]] = ref_lora + ref_total = ref_base + ref_lora_ungrouped + + # Fused kernel + fused_result = lora_ops.scatter2scatter_lora_dX( + DY=grouped_grad, + W=data["W"], + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=1, + lora_A=data["lora_A"], + lora_B=data["lora_B"], + scaling=data["scaling"], + dy_grouped=True, + dx_grouped=False, + ) + + torch.testing.assert_close(fused_result, ref_total, atol=atol, rtol=rtol) + + def test_basic(self): + self._run_fused_dX_test(M=32, K=64, N=128, E=4, R=8, k=2) + + def test_large(self): + self._run_fused_dX_test(M=256, K=256, N=512, E=8, R=16, k=2) + + def test_single_expert(self): + self._run_fused_dX_test(M=64, K=128, N=256, E=1, R=8, k=1) + + def test_k1(self): + self._run_fused_dX_test(M=64, K=64, N=128, E=4, R=8, k=1) + + def test_bf16(self): + self._run_fused_dX_test( + M=64, + K=128, + N=256, + E=4, + R=16, + k=2, + dtype=torch.bfloat16, + atol=1e-1, + rtol=1e-1, + ) + + def test_grouped_output(self): + """Test fused dX with dx_grouped=True.""" + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + base_ops = import_module(f"{_SMOE}.kernels.ops") + pll = import_module(f"{_SMOE}.parallel_linear_lora") + + M, K, N, E, R, k = 32, 64, 128, 4, 8, 2 + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k) + + grad_out = torch.randn(data["sorted_expert_idxs"].size(0), N, device="cuda") + grouped_grad = base_ops.group( + grad_out, data["sorted_scattered_idxs"], fan_out=1 + ) + + # Reference: grouped output + ref_base = base_ops.scatter2scatter( + X=grouped_grad, + x_grouped=True, + W=data["W"].permute(0, 2, 1), + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=1, + y_grouped=True, # grouped output + ) + + ref_lora = pll._compute_lora_input_grad( + grouped_grad, + data["lora_A"], + data["lora_B"], + data["expert_offsets"], + E, + data["scaling"], + ) + ref_total = ref_base + ref_lora + + # Fused kernel with grouped output + fused_result = lora_ops.scatter2scatter_lora_dX( + DY=grouped_grad, + W=data["W"], + sorted_expert_idxs=data["sorted_expert_idxs"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + k=1, + lora_A=data["lora_A"], + lora_B=data["lora_B"], + scaling=data["scaling"], + dy_grouped=True, + dx_grouped=True, + ) + + torch.testing.assert_close(fused_result, ref_total, atol=5e-2, rtol=5e-2) + + def test_autograd_with_fused_dX(self): + """Full autograd round-trip with use_fused_dX=True.""" + from importlib import import_module + + pll = import_module(f"{_SMOE}.parallel_linear_lora") + + M, K, N, E, R, k = 32, 64, 128, 4, 8, 2 + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k) + + # Run without fused dX + X1 = data["X"].clone().requires_grad_(True) + A1 = data["lora_A"].clone().requires_grad_(True) + B1 = data["lora_B"].clone().requires_grad_(True) + out1 = pll.ScatterMoELoRA.apply( + X1, + data["W"], + k, + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + A1, + B1, + data["scaling"], + None, + None, + False, + False, + False, # use_fused_dX=False + ) + out1.sum().backward() + + # Run with fused dX + X2 = data["X"].clone().requires_grad_(True) + A2 = data["lora_A"].clone().requires_grad_(True) + B2 = data["lora_B"].clone().requires_grad_(True) + out2 = pll.ScatterMoELoRA.apply( + X2, + data["W"], + k, + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + A2, + B2, + data["scaling"], + None, + None, + False, + False, + True, # use_fused_dX=True + ) + out2.sum().backward() + + # Forward should be identical + torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) + + # Gradients should match + torch.testing.assert_close(X1.grad, X2.grad, atol=5e-2, rtol=5e-2) + torch.testing.assert_close(A1.grad, A2.grad, atol=5e-2, rtol=5e-2) + torch.testing.assert_close(B1.grad, B2.grad, atol=5e-2, rtol=5e-2) + + +# ============================================================================= +# Test: Optimization 2 - Fused Gather Backward +# ============================================================================= + + +class TestFusedGatherBackward: + """Test fused gather + backward dA/dB kernel.""" + + def _run_fused_gather_test( + self, M, K, N, E, R, k, dtype=torch.float32, atol=5e-2, rtol=5e-2 + ): + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + base_ops = import_module(f"{_SMOE}.kernels.ops") + + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k, dtype=dtype) + + # Create grad_out in ungrouped order (M*k, N) + M_total = data["sorted_expert_idxs"].size(0) + grad_out = torch.randn(M_total, N, device="cuda", dtype=dtype) + + # Reference: group() + group_bwd_lora() + grouped_grad = base_ops.group( + grad_out, data["sorted_scattered_idxs"], fan_out=1 + ) + grouped_x = base_ops.group(data["X"], data["sorted_scattered_idxs"], fan_out=k) + + ref_dA, ref_dB = lora_ops.group_bwd_lora( + DY=grouped_grad, + X=grouped_x, + lora_A=data["lora_A"], + lora_B=data["lora_B"], + expert_offsets=data["expert_offsets"], + E=E, + scaling=data["scaling"], + ) + + # Fused kernel: no group() calls + fused_dA, fused_dB = lora_ops.group_bwd_lora_fused( + DY=grad_out, + X=data["X"], + lora_A=data["lora_A"], + lora_B=data["lora_B"], + expert_offsets=data["expert_offsets"], + sorted_scattered_idxs=data["sorted_scattered_idxs"], + E=E, + k=k, + scaling=data["scaling"], + ) + + torch.testing.assert_close(fused_dA, ref_dA, atol=atol, rtol=rtol) + torch.testing.assert_close(fused_dB, ref_dB, atol=atol, rtol=rtol) + + def test_basic(self): + self._run_fused_gather_test(M=32, K=64, N=128, E=4, R=8, k=2) + + def test_large(self): + self._run_fused_gather_test(M=256, K=256, N=512, E=8, R=16, k=2) + + def test_single_expert(self): + self._run_fused_gather_test(M=64, K=128, N=256, E=1, R=8, k=1) + + def test_k1(self): + self._run_fused_gather_test(M=64, K=64, N=128, E=4, R=8, k=1) + + def test_many_experts(self): + self._run_fused_gather_test(M=128, K=64, N=128, E=16, R=8, k=4) + + def test_bf16(self): + self._run_fused_gather_test( + M=64, + K=128, + N=256, + E=4, + R=16, + k=2, + dtype=torch.bfloat16, + atol=1e-1, + rtol=1e-1, + ) + + def test_autograd_with_fused_gather(self): + """Full autograd round-trip with use_fused_gather=True.""" + from importlib import import_module + + pll = import_module(f"{_SMOE}.parallel_linear_lora") + + M, K, N, E, R, k = 32, 64, 128, 4, 8, 2 + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k) + + # Run without fused gather + X1 = data["X"].clone().requires_grad_(True) + A1 = data["lora_A"].clone().requires_grad_(True) + B1 = data["lora_B"].clone().requires_grad_(True) + out1 = pll.ScatterMoELoRA.apply( + X1, + data["W"], + k, + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + A1, + B1, + data["scaling"], + None, + None, + False, + False, + False, + False, # use_fused_dX=False, use_fused_gather=False + ) + out1.sum().backward() + + # Run with fused gather + X2 = data["X"].clone().requires_grad_(True) + A2 = data["lora_A"].clone().requires_grad_(True) + B2 = data["lora_B"].clone().requires_grad_(True) + out2 = pll.ScatterMoELoRA.apply( + X2, + data["W"], + k, + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + A2, + B2, + data["scaling"], + None, + None, + False, + False, + False, + True, # use_fused_dX=False, use_fused_gather=True + ) + out2.sum().backward() + + # Forward identical + torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) + + # dA/dB should match + torch.testing.assert_close(A1.grad, A2.grad, atol=5e-2, rtol=5e-2) + torch.testing.assert_close(B1.grad, B2.grad, atol=5e-2, rtol=5e-2) + # dX should also match (same path for dX) + torch.testing.assert_close(X1.grad, X2.grad, atol=5e-2, rtol=5e-2) + + +# ============================================================================= +# Test: Optimization 3 - Token Rounding +# ============================================================================= + + +class TestTokenRounding: + """Test token rounding utility and its integration with backward kernels.""" + + def test_round_expert_counts_basic(self): + """Verify round_expert_counts produces correct shapes and values.""" + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + + M, K, N, E, R, k = 32, 64, 128, 4, 8, 2 + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k) + + padded_ei, padded_si, padded_offsets, real_offsets = ( + lora_ops.round_expert_counts( + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + E=E, + block_m=lora_ops.BLOCK_M, + ) + ) + + # Real offsets should match original + torch.testing.assert_close(real_offsets, data["expert_offsets"]) + + # Padded offsets should be >= real offsets + assert (padded_offsets >= real_offsets).all(), ( + "Padded offsets should be >= real offsets" + ) + + # Each expert's padded count should be multiple of BLOCK_M (if non-zero) + prev = 0 + for e in range(E): + count = padded_offsets[e].item() - prev + real_count = real_offsets[e].item() - ( + real_offsets[e - 1].item() if e > 0 else 0 + ) + if real_count > 0: + assert count % lora_ops.BLOCK_M == 0, ( + f"Expert {e}: padded count {count} not multiple of {lora_ops.BLOCK_M}" + ) + assert count >= real_count, ( + f"Expert {e}: padded count {count} < real count {real_count}" + ) + prev = padded_offsets[e].item() + + def test_round_with_fused_gather(self): + """Token rounding + fused gather gives same result as plain fused gather.""" + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + base_ops = import_module(f"{_SMOE}.kernels.ops") + + M, K, N, E, R, k = 64, 64, 128, 4, 8, 2 + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k) + + M_total = data["sorted_expert_idxs"].size(0) + grad_out = torch.randn(M_total, N, device="cuda") + + # Reference: group() + group_bwd_lora() (the gold standard) + grouped_grad = base_ops.group( + grad_out, data["sorted_scattered_idxs"], fan_out=1 + ) + grouped_x = base_ops.group(data["X"], data["sorted_scattered_idxs"], fan_out=k) + ref_dA, ref_dB = lora_ops.group_bwd_lora( + DY=grouped_grad, + X=grouped_x, + lora_A=data["lora_A"], + lora_B=data["lora_B"], + expert_offsets=data["expert_offsets"], + E=E, + scaling=data["scaling"], + ) + + # Apply token rounding + padded_ei, padded_si, padded_offsets, real_offsets = ( + lora_ops.round_expert_counts( + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + E=E, + ) + ) + + # Fused gather with token rounding + rounded_dA, rounded_dB = lora_ops.group_bwd_lora_fused( + DY=grad_out, + X=data["X"], + lora_A=data["lora_A"], + lora_B=data["lora_B"], + expert_offsets=padded_offsets, + sorted_scattered_idxs=padded_si, + E=E, + k=k, + scaling=data["scaling"], + real_expert_offsets=real_offsets, + ) + + torch.testing.assert_close(rounded_dA, ref_dA, atol=5e-2, rtol=5e-2) + torch.testing.assert_close(rounded_dB, ref_dB, atol=5e-2, rtol=5e-2) + + def test_empty_experts_with_rounding(self): + """Token rounding handles experts with 0 tokens correctly.""" + from importlib import import_module + + lora_ops = import_module(f"{_SMOE}.kernels.lora_ops") + + E, k = 8, 1 + M = 8 + torch.manual_seed(42) + + # Only use experts 0 and 1 (rest have 0 tokens) + selected_experts = torch.randint(0, 2, (M, k), device="cuda") + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = ( + flatten_sort_count_ref(selected_experts, E) + ) + + padded_ei, padded_si, padded_offsets, real_offsets = ( + lora_ops.round_expert_counts( + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + E=E, + ) + ) + + # Verify empty experts have same count (0) + for e in range(E): + real_count = real_offsets[e].item() - ( + real_offsets[e - 1].item() if e > 0 else 0 + ) + padded_count = padded_offsets[e].item() - ( + padded_offsets[e - 1].item() if e > 0 else 0 + ) + if real_count == 0: + assert padded_count == 0, ( + f"Expert {e}: empty expert should have padded_count=0, got {padded_count}" + ) + + +# ============================================================================= +# Test: Combined Optimizations +# ============================================================================= + + +class TestCombinedOptimizations: + """Test all optimizations together.""" + + def test_fused_dX_and_fused_gather(self): + """Both fused dX and fused gather together.""" + from importlib import import_module + + pll = import_module(f"{_SMOE}.parallel_linear_lora") + + M, K, N, E, R, k = 64, 128, 256, 4, 8, 2 + data = make_test_data(M=M, K=K, N=N, E=E, R=R, k=k) + + # Baseline: no optimizations + X1 = data["X"].clone().requires_grad_(True) + A1 = data["lora_A"].clone().requires_grad_(True) + B1 = data["lora_B"].clone().requires_grad_(True) + out1 = pll.ScatterMoELoRA.apply( + X1, + data["W"], + k, + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + A1, + B1, + data["scaling"], + None, + None, + False, + False, + False, + False, # no optimizations + ) + out1.sum().backward() + + # Both optimizations + X2 = data["X"].clone().requires_grad_(True) + A2 = data["lora_A"].clone().requires_grad_(True) + B2 = data["lora_B"].clone().requires_grad_(True) + out2 = pll.ScatterMoELoRA.apply( + X2, + data["W"], + k, + data["sorted_expert_idxs"], + data["sorted_scattered_idxs"], + data["expert_offsets"], + A2, + B2, + data["scaling"], + None, + None, + False, + False, + True, + True, # use_fused_dX=True, use_fused_gather=True + ) + out2.sum().backward() + + # Forward identical + torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5) + + # All gradients match + torch.testing.assert_close(X1.grad, X2.grad, atol=5e-2, rtol=5e-2) + torch.testing.assert_close(A1.grad, A2.grad, atol=5e-2, rtol=5e-2) + torch.testing.assert_close(B1.grad, B2.grad, atol=5e-2, rtol=5e-2) diff --git a/tests/e2e/integrations/test_scattermoe_lora_olmoe.py b/tests/e2e/integrations/test_scattermoe_lora_olmoe.py new file mode 100644 index 000000000..048147632 --- /dev/null +++ b/tests/e2e/integrations/test_scattermoe_lora_olmoe.py @@ -0,0 +1,1255 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +Integration tests: OLMoE + peft LoRA + ScatterMoE fused kernels. + +Validates that scattermoe_lora fused kernels produce correct results when used +with HuggingFace OLMoE models and peft LoRA adapters applied via +``target_parameters``. + +Key things tested +----------------- +- LoRA weight layout conversion between peft (rank-major) and scattermoe (expert-major) +- Base forward equivalence: per-expert reference vs ScatterMoE kernels (no LoRA) +- LoRA forward equivalence: peft merged-weight approach vs scattermoe fused kernels +- Backward gradient correctness through the fused LoRA path +- ``kernelize()`` integration via ``LocalLayerRepository`` +""" + +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from peft import LoraConfig, get_peft_model +from transformers import OlmoeConfig +from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock + +_SMOE = "axolotl.integrations.kernels.libs.scattermoe_lora" + +# Try to import from axolotl's scattermoe_lora.layers; may fail on CPU without triton. +try: + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + _unwrap_experts_lora, + _unwrap_gate_lora, + peft_lora_B_to_scattermoe, + peft_lora_to_scattermoe, + ) + + HAS_SCATTERMOE = True +except (ImportError, ModuleNotFoundError): + HAS_SCATTERMOE = False + + # Provide pure-torch fallbacks for CPU-only layout conversion tests. + def peft_lora_B_to_scattermoe(peft_B, num_experts, rank): + N = peft_B.shape[0] + return ( + peft_B.reshape(N, rank, num_experts) + .permute(0, 2, 1) + .contiguous() + .reshape(N, num_experts * rank) + ) + + def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): + peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank) + K_inter, N_hidden = peft_B.shape[0], peft_A.shape[1] + smoe_A = torch.zeros( + rank * num_experts, + K_inter, + device=peft_A.device, + dtype=peft_A.dtype, + ) + smoe_B = torch.zeros( + N_hidden, + rank * num_experts, + device=peft_A.device, + dtype=peft_A.dtype, + ) + for e in range(num_experts): + s = e * rank + smoe_A[s : s + rank, :] = peft_B_em[:, s : s + rank].T + smoe_B[:, s : s + rank] = peft_A[s : s + rank, :].T + return smoe_A, smoe_B + + def _unwrap_experts_lora(experts_module): + return experts_module, None, None + + def _unwrap_gate_lora(gate_module): + if hasattr(gate_module, "base_layer") and hasattr(gate_module, "lora_A"): + base_gate = gate_module.base_layer + active = getattr(gate_module, "active_adapters", ["default"]) + name = active[0] if active else "default" + lora_A_dict = getattr(gate_module, "lora_A", {}) + lora_B_dict = getattr(gate_module, "lora_B", {}) + scaling_dict = getattr(gate_module, "scaling", {}) + if name in lora_A_dict: + lora_A = lora_A_dict[name].weight + lora_B = lora_B_dict[name].weight + s = scaling_dict[name] + delta = s * (lora_B @ lora_A) + return base_gate, base_gate.weight, delta + return base_gate, base_gate.weight, None + return gate_module, gate_module.weight, None + + +# ============================================================================= +# Configuration +# ============================================================================= + +FULL_OLMOE_CONFIG = dict( + hidden_size=2048, + intermediate_size=1024, + num_experts=64, + num_experts_per_tok=8, + hidden_act="silu", + norm_topk_prob=False, +) + +SMALL_OLMOE_CONFIG = dict( + hidden_size=128, + intermediate_size=48, # non-square: 2*inter=96 != hidden=128 + num_experts=8, + num_experts_per_tok=2, + hidden_act="silu", + norm_topk_prob=False, +) + +requires_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" +) + + +def make_olmoe_config(use_full=False): + cfg = dict(FULL_OLMOE_CONFIG if use_full else SMALL_OLMOE_CONFIG) + cfg["experts_implementation"] = "grouped_mm" + return OlmoeConfig(**cfg) + + +# ============================================================================= +# Layout conversion utilities (test-local helpers) +# ============================================================================= + + +def scattermoe_lora_B_to_peft(smoe_B, num_experts, rank): + """Inverse of ``peft_lora_B_to_scattermoe``.""" + N = smoe_B.shape[0] + return ( + smoe_B.reshape(N, num_experts, rank) + .permute(0, 2, 1) + .contiguous() + .reshape(N, num_experts * rank) + ) + + +def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): + """Convert peft LoRA for gate_up_proj to scattermoe layout. + + Both gate_up_proj and down_proj need the A<->B swap because + scattermoe transposes the parameter (W = param.T). + """ + return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank) + + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _init_expert_weights(moe_block): + """Initialize OlmoeExperts parameters which use torch.empty (uninitialized). + + Without this, gate_up_proj and down_proj contain garbage/NaN values. + """ + with torch.no_grad(): + nn.init.kaiming_uniform_(moe_block.experts.gate_up_proj) + nn.init.kaiming_uniform_(moe_block.experts.down_proj) + return moe_block + + +class MinimalOLMoEModel(nn.Module): + """Thin wrapper so peft's get_peft_model can attach adapters.""" + + def __init__(self, config): + super().__init__() + self.moe = OlmoeSparseMoeBlock(config) + _init_expert_weights(self.moe) + + def forward(self, x): + return self.moe(x) + + +def _get_routing(moe_block, hidden_states): + """Run the router and return (routing_weights, selected_experts).""" + with torch.no_grad(): + _, routing_weights, selected_experts = moe_block.gate( + hidden_states.view(-1, hidden_states.size(-1)) + ) + return routing_weights, selected_experts + + +def _reference_moe_forward( + x_flat, + gate_up_proj, + down_proj, + act_fn, + top_k_index, + top_k_weights, + num_experts, +): + """Pure-PyTorch per-expert reference MoE forward (no LoRA). + + Uses F.linear per expert for an apples-to-apples comparison with + the ScatterMoE kernel path. + """ + final = torch.zeros_like(x_flat) + expert_mask = F.one_hot(top_k_index, num_classes=num_experts).permute(2, 1, 0) + for e in range(num_experts): + top_k_pos, token_idx = torch.where(expert_mask[e]) + if token_idx.numel() == 0: + continue + cur = x_flat[token_idx] + gate_up = F.linear(cur, gate_up_proj[e]) + g, u = gate_up.chunk(2, dim=-1) + h = act_fn(g) * u + out = F.linear(h, down_proj[e]) + out = out * top_k_weights[token_idx, top_k_pos, None] + final.index_add_(0, token_idx, out.to(final.dtype)) + return final + + +def _reference_moe_forward_with_lora( + x_flat, + gate_up_proj, + down_proj, + act_fn, + top_k_index, + top_k_weights, + num_experts, + gup_delta, + down_delta, +): + """Pure-PyTorch reference MoE forward with pre-computed weight deltas.""" + merged_gup = gate_up_proj + gup_delta + merged_down = down_proj + down_delta + return _reference_moe_forward( + x_flat, + merged_gup, + merged_down, + act_fn, + top_k_index, + top_k_weights, + num_experts, + ) + + +def _compute_delta_from_scattermoe_lora(lora_A, lora_B, scaling, E, r, param_shape): + """Compute additive weight delta from scattermoe-layout LoRA weights. + + delta[e] = scaling * B_e @ A_e where A_e [r,K], B_e [N,r] -> [N,K]. + """ + delta = torch.zeros(param_shape, device=lora_A.device, dtype=lora_A.dtype) + for e in range(E): + A_e = lora_A[e * r : (e + 1) * r, :] + B_e = lora_B[:, e * r : (e + 1) * r] + delta[e] = scaling * (B_e @ A_e) + return delta + + +# ============================================================================= +# Tests: Layout conversion +# ============================================================================= + + +class TestLoRABLayoutConversion: + """Test the peft <-> scattermoe lora_B layout conversion.""" + + def test_roundtrip(self): + E, r, N = 8, 4, 64 + original = torch.randn(N, E * r) + converted = peft_lora_B_to_scattermoe(original, E, r) + back = scattermoe_lora_B_to_peft(converted, E, r) + torch.testing.assert_close(back, original) + + def test_per_expert_slices(self): + """After conversion, scattermoe slicing gives the same per-expert + matrices as peft's reshape slicing.""" + E, r, N = 4, 2, 16 + peft_B = torch.randn(N, E * r) + smoe_B = peft_lora_B_to_scattermoe(peft_B, E, r) + + peft_reshaped = peft_B.reshape(N, r, E) + for e in range(E): + torch.testing.assert_close( + smoe_B[:, e * r : (e + 1) * r], + peft_reshaped[:, :, e], + ) + + def test_lora_A_already_compatible(self): + """lora_A layout is identical between peft and scattermoe.""" + E, r, K = 4, 2, 16 + lora_A = torch.randn(E * r, K) + peft_reshaped = lora_A.reshape(E, r, K) + for e in range(E): + torch.testing.assert_close( + lora_A[e * r : (e + 1) * r, :], + peft_reshaped[e], + ) + + def test_delta_weight_equivalence(self): + """peft's einsum delta matches per-expert B @ A with converted layouts.""" + E, r, K, N = 8, 4, 32, 64 + peft_A = torch.randn(E * r, K) + peft_B = torch.randn(N, E * r) + scaling = 2.0 + + A_r = peft_A.reshape(E, r, K) + B_r = peft_B.reshape(N, r, E) + delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling + + smoe_B = peft_lora_B_to_scattermoe(peft_B, E, r) + for e in range(E): + A_e = peft_A[e * r : (e + 1) * r, :] + B_e = smoe_B[:, e * r : (e + 1) * r] + delta_e = scaling * (B_e @ A_e) + torch.testing.assert_close(delta_e, delta_peft[e].T, atol=1e-5, rtol=1e-5) + + def test_down_proj_conversion(self): + """Verify peft_lora_to_scattermoe produces correct delta.""" + E, r = 4, 2 + hidden, inter = 32, 16 + scaling = 2.0 + + peft_A = torch.randn(E * r, hidden) + peft_B = torch.randn(inter, E * r) + + A_r = peft_A.reshape(E, r, hidden) + B_r = peft_B.reshape(inter, r, E) + delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling + + smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r) + for e in range(E): + A_e = smoe_A[e * r : (e + 1) * r, :] + B_e = smoe_B[:, e * r : (e + 1) * r] + delta_smoe_e = scaling * (B_e @ A_e) + torch.testing.assert_close( + delta_smoe_e, delta_peft[e], atol=1e-5, rtol=1e-5 + ) + + def test_gate_up_proj_conversion(self): + """Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like). + + gate_up_proj param: [E, 2*inter, hidden]. + peft: in_features=2*inter, out_features=hidden. + peft lora_A: [r*E, 2*inter], lora_B: [hidden, r*E]. + + scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter. + scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E]. + + Uses non-square dims (hidden=32 != 2*inter=24) to catch A<->B swap bugs. + """ + E, r = 4, 2 + hidden, inter = 32, 12 # 2*inter=24 != hidden=32 + scaling = 2.0 + + # peft assigns: in_features=2*inter, out_features=hidden + peft_A = torch.randn(E * r, 2 * inter) # [r*E, in_features=2*inter] + peft_B = torch.randn(hidden, E * r) # [out_features=hidden, r*E] + + # peft delta via einsum: "o r e, e r i -> e i o" + A_r = peft_A.reshape(E, r, 2 * inter) + B_r = peft_B.reshape(hidden, r, E) + delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling + # delta_peft[e] has shape [in_features, out_features] = [2*inter, hidden] + # = param[e] shape [2*inter, hidden] + + smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r) + # smoe_A should be [r*E, K=hidden], smoe_B should be [N=2*inter, r*E] + assert smoe_A.shape == (E * r, hidden), ( + f"Expected {(E * r, hidden)}, got {smoe_A.shape}" + ) + assert smoe_B.shape == (2 * inter, E * r), ( + f"Expected {(2 * inter, E * r)}, got {smoe_B.shape}" + ) + + for e in range(E): + A_e = smoe_A[e * r : (e + 1) * r, :] # [r, K=hidden] + B_e = smoe_B[:, e * r : (e + 1) * r] # [N=2*inter, r] + delta_smoe_e = scaling * (B_e @ A_e) # [2*inter, hidden] + # Should match peft delta which is [2*inter, hidden] = param[e] + torch.testing.assert_close( + delta_smoe_e, delta_peft[e], atol=1e-5, rtol=1e-5 + ) + + +# ============================================================================= +# Tests: peft weight extraction +# ============================================================================= + + +class TestPeftLoRAWeightExtraction: + """Test extracting peft LoRA weights for OLMoE.""" + + def test_peft_creates_correct_shapes(self): + config = make_olmoe_config(use_full=False) + E, r = config.num_experts, 4 + + model = MinimalOLMoEModel(config) + lora_config = LoraConfig( + r=r, + lora_alpha=16, + target_modules=[], + target_parameters=[ + "gate.weight", + "experts.gate_up_proj", + "experts.down_proj", + ], + bias="none", + ) + peft_model = get_peft_model(model, lora_config) + trainable = {n: p for n, p in peft_model.named_parameters() if p.requires_grad} + + # Gate router + assert trainable["base_model.model.moe.gate.lora_A.default.weight"].shape == ( + r, + config.hidden_size, + ) + assert trainable["base_model.model.moe.gate.lora_B.default.weight"].shape == ( + E, + r, + ) + + # gate_up_proj [E, 2*inter, hidden] + # peft: in_features=2*inter (dim 1), out_features=hidden (dim 2) + assert trainable[ + "base_model.model.moe.experts.base_layer.lora_A.default.weight" + ].shape == (E * r, 2 * config.intermediate_size) + assert trainable[ + "base_model.model.moe.experts.base_layer.lora_B.default.weight" + ].shape == (config.hidden_size, E * r) + + # down_proj [E, hidden, inter] + # peft: in_features=hidden (dim 1), out_features=inter (dim 2) + assert trainable[ + "base_model.model.moe.experts.lora_A.default.weight" + ].shape == (E * r, config.hidden_size) + assert trainable[ + "base_model.model.moe.experts.lora_B.default.weight" + ].shape == (config.intermediate_size, E * r) + + @requires_cuda + def test_peft_forward_runs(self): + """Smoke test: peft model forward pass completes (needs CUDA for grouped_mm).""" + config = make_olmoe_config(use_full=False) + model = MinimalOLMoEModel(config) + lora_config = LoraConfig( + r=4, + lora_alpha=16, + target_modules=[], + target_parameters=[ + "gate.weight", + "experts.gate_up_proj", + "experts.down_proj", + ], + bias="none", + ) + peft_model = get_peft_model(model, lora_config) + x = torch.randn(1, 4, config.hidden_size) + out = peft_model(x) + assert out.shape == x.shape + + @pytest.mark.skipif( + not HAS_SCATTERMOE, reason="scattermoe_lora not importable (no triton)" + ) + def test_unwrap_experts_lora(self): + """Test that _unwrap_experts_lora correctly detects LoRA wrappers.""" + config = make_olmoe_config(use_full=False) + model = MinimalOLMoEModel(config) + lora_config = LoraConfig( + r=4, + lora_alpha=16, + target_modules=[], + target_parameters=["experts.gate_up_proj", "experts.down_proj"], + bias="none", + ) + peft_model = get_peft_model(model, lora_config) + base_moe = peft_model.base_model.model.moe + + # Experts should be wrapped by ParamWrapper + experts, gup_lora, down_lora = _unwrap_experts_lora(base_moe.experts) + + # Base experts should have the raw parameters + assert hasattr(experts, "gate_up_proj") + assert hasattr(experts, "down_proj") + + # LoRA should be detected + assert gup_lora is not None, "gate_up_proj LoRA not detected" + assert down_lora is not None, "down_proj LoRA not detected" + + # Check shapes (after peft->scattermoe conversion with A<->B swap) + # gate_up_proj W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter + E, r = config.num_experts, 4 + gup_A, gup_B, gup_s = gup_lora + assert gup_A.shape == (E * r, config.hidden_size), ( + f"gate_up_proj smoe_A: expected [r*E, K=hidden]={(E * r, config.hidden_size)}, " + f"got {gup_A.shape}" + ) + assert gup_B.shape == (2 * config.intermediate_size, E * r), ( + f"gate_up_proj smoe_B: expected [N=2*inter, r*E]=" + f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}" + ) + + # down_proj W = param.T = [E, inter, hidden], K=inter, N=hidden + down_A, down_B, down_s = down_lora + assert down_A.shape == (E * r, config.intermediate_size), ( + f"down_proj smoe_A: expected [r*E, K=inter]={(E * r, config.intermediate_size)}, " + f"got {down_A.shape}" + ) + assert down_B.shape == (config.hidden_size, E * r), ( + f"down_proj smoe_B: expected [N=hidden, r*E]={(config.hidden_size, E * r)}, " + f"got {down_B.shape}" + ) + + def test_unwrap_no_lora(self): + """Without peft, _unwrap_experts_lora returns no LoRA.""" + config = make_olmoe_config(use_full=False) + moe = OlmoeSparseMoeBlock(config) + experts, gup_lora, down_lora = _unwrap_experts_lora(moe.experts) + assert gup_lora is None + assert down_lora is None + assert hasattr(experts, "gate_up_proj") + + def test_unwrap_gate_lora(self): + """Test that _unwrap_gate_lora detects LoRA on the router gate.""" + config = make_olmoe_config(use_full=False) + model = MinimalOLMoEModel(config) + r = 4 + lora_config = LoraConfig( + r=r, + lora_alpha=16, + target_modules=[], + target_parameters=["gate.weight"], + bias="none", + ) + peft_model = get_peft_model(model, lora_config) + base_moe = peft_model.base_model.model.moe + + # Set non-zero LoRA weights (peft initializes lora_B to zeros) + with torch.no_grad(): + base_moe.gate.lora_B["default"].weight.normal_(0, 0.01) + + base_gate, gate_weight, gate_delta = _unwrap_gate_lora(base_moe.gate) + + # Base gate should be the original router + assert hasattr(base_gate, "top_k") + assert hasattr(base_gate, "num_experts") + assert base_gate.top_k == config.num_experts_per_tok + assert base_gate.num_experts == config.num_experts + + # Gate weight should be the base weight (delta returned separately) + assert gate_weight.shape == (config.num_experts, config.hidden_size) + torch.testing.assert_close(gate_weight, base_gate.weight) + + # Delta should be non-zero (LoRA was applied) + assert gate_delta is not None + assert gate_delta.shape == (config.num_experts, config.hidden_size) + assert gate_delta.abs().max() > 0, "Gate LoRA delta should be non-zero" + + def test_unwrap_gate_no_lora(self): + """Without peft, _unwrap_gate_lora returns the original gate.""" + config = make_olmoe_config(use_full=False) + moe = OlmoeSparseMoeBlock(config) + base_gate, gate_weight, gate_delta = _unwrap_gate_lora(moe.gate) + assert base_gate is moe.gate + torch.testing.assert_close(gate_weight, moe.gate.weight) + assert gate_delta is None + + def test_gate_lora_delta_matches_peft(self): + """Verify _unwrap_gate_lora computes the same delta as peft.""" + config = make_olmoe_config(use_full=False) + model = MinimalOLMoEModel(config) + r = 4 + lora_alpha = 16 + scaling = lora_alpha / r + lora_config = LoraConfig( + r=r, + lora_alpha=lora_alpha, + target_modules=[], + target_parameters=["gate.weight"], + bias="none", + ) + peft_model = get_peft_model(model, lora_config) + base_moe = peft_model.base_model.model.moe + + # Our unwrapped weight + delta + _, gate_weight, gate_delta = _unwrap_gate_lora(base_moe.gate) + + # Manually compute expected delta + lora_A = base_moe.gate.lora_A["default"].weight # [r, hidden] + lora_B = base_moe.gate.lora_B["default"].weight # [E, r] + base_weight = base_moe.gate.base_layer.weight # [E, hidden] + expected_delta = scaling * (lora_B @ lora_A) + + torch.testing.assert_close(gate_weight, base_weight) + torch.testing.assert_close(gate_delta, expected_delta) + # Combined should match the old behavior + torch.testing.assert_close( + gate_weight + gate_delta, base_weight + expected_delta + ) + + +# ============================================================================= +# Tests: Base forward equivalence (no LoRA) +# ============================================================================= + + +@requires_cuda +class TestOLMoEReferenceVsScatterMoE: + """Base forward equivalence: per-expert reference vs ScatterMoE kernels.""" + + def test_small(self): + self._run(use_full=False, M=16) + + @pytest.mark.slow + def test_full(self): + self._run(use_full=True, M=32) + + def _run(self, use_full, M): + from axolotl.integrations.kernels.libs.scattermoe_lora import ( + flatten_sort_count, + parallel_linear, + ) + + config = make_olmoe_config(use_full=use_full) + torch.manual_seed(42) + moe = _init_expert_weights(OlmoeSparseMoeBlock(config)).cuda().float() + E, k = config.num_experts, config.num_experts_per_tok + + x = torch.randn(1, M, config.hidden_size, device="cuda") + x_flat = x.view(-1, config.hidden_size) + + with torch.no_grad(): + # Shared routing for both paths + _, rw, sel = moe.gate(x_flat) + sei, ssi, eo = flatten_sort_count(sel, num_experts=E) + + # Per-expert reference + ref_out = _reference_moe_forward( + x_flat, + moe.experts.gate_up_proj, + moe.experts.down_proj, + moe.experts.act_fn, + sel, + rw, + E, + ).view(1, M, config.hidden_size) + + # ScatterMoE kernel path + gup = parallel_linear( + x_flat, + moe.experts.gate_up_proj.transpose(2, 1), + k, + sei, + ssi, + eo, + grouped_in=False, + grouped_out=True, + ) + g, u = gup.chunk(2, dim=-1) + h = moe.experts.act_fn(g) * u + + smoe_out = parallel_linear( + h, + moe.experts.down_proj.transpose(2, 1), + 1, + sei, + ssi, + eo, + grouped_in=True, + grouped_out=False, + gates=rw, + ).view(1, M, config.hidden_size) + + torch.testing.assert_close(smoe_out, ref_out, atol=1e-3, rtol=1e-3) + + +# ============================================================================= +# Tests: LoRA forward equivalence (peft vs scattermoe fused) +# ============================================================================= + + +@requires_cuda +class TestOLMoEPeftLoRAForward: + """Fused LoRA forward: peft merged-weight vs scattermoe_lora kernel.""" + + def test_small(self): + self._run(use_full=False, M=16, r=4) + + @pytest.mark.slow + def test_full(self): + self._run(use_full=True, M=32, r=8) + + def _run(self, use_full, M, r): + from axolotl.integrations.kernels.libs.scattermoe_lora import ( + flatten_sort_count, + parallel_linear_lora, + ) + + config = make_olmoe_config(use_full=use_full) + E, k = config.num_experts, config.num_experts_per_tok + lora_alpha = 16 + scaling = lora_alpha / r + + # Create peft model + model = MinimalOLMoEModel(config).cuda().float() + lora_config = LoraConfig( + r=r, + lora_alpha=lora_alpha, + target_modules=[], + target_parameters=["experts.gate_up_proj", "experts.down_proj"], + bias="none", + ) + peft_model = get_peft_model(model, lora_config) + + torch.manual_seed(42) + x = torch.randn(1, M, config.hidden_size, device="cuda") + + # peft forward + with torch.no_grad(): + peft_out = peft_model(x) + + # Extract base weights and LoRA weights + base_moe = peft_model.base_model.model.moe + base_experts = base_moe.experts.base_layer.base_layer + gate_up_proj = base_experts.gate_up_proj + down_proj = base_experts.down_proj + act_fn = base_experts.act_fn + + # gate_up_proj LoRA + gup_w = base_moe.experts.base_layer + peft_gup_A = gup_w.lora_A["default"].weight.detach() + peft_gup_B = gup_w.lora_B["default"].weight.detach() + smoe_gup_A, smoe_gup_B = peft_gate_up_lora_to_scattermoe( + peft_gup_A, peft_gup_B, E, r + ) + + # down_proj LoRA + down_w = base_moe.experts + peft_down_A = down_w.lora_A["default"].weight.detach() + peft_down_B = down_w.lora_B["default"].weight.detach() + smoe_down_A, smoe_down_B = peft_lora_to_scattermoe( + peft_down_A, peft_down_B, E, r + ) + + # ScatterMoE fused forward -- gate is NOT peft-wrapped, access directly + x_flat = x.view(-1, config.hidden_size) + + with torch.no_grad(): + _, rw, sel = base_moe.gate(x_flat) + sei, ssi, eo = flatten_sort_count(sel, num_experts=E) + + gup = parallel_linear_lora( + x_flat, + gate_up_proj.transpose(2, 1), + k, + sei, + ssi, + eo, + lora_A=smoe_gup_A, + lora_B=smoe_gup_B, + scaling=scaling, + grouped_in=False, + grouped_out=True, + ) + g, u = gup.chunk(2, dim=-1) + h = act_fn(g) * u + + smoe_out = parallel_linear_lora( + h, + down_proj.transpose(2, 1), + 1, + sei, + ssi, + eo, + lora_A=smoe_down_A, + lora_B=smoe_down_B, + scaling=scaling, + grouped_in=True, + grouped_out=False, + gates=rw, + ).view(1, M, config.hidden_size) + + torch.testing.assert_close(smoe_out, peft_out, atol=5e-3, rtol=5e-3) + + +# ============================================================================= +# Tests: Backward gradient correctness +# ============================================================================= + + +@requires_cuda +class TestOLMoEPeftLoRABackward: + """Backward gradients through scattermoe_lora vs pure-PyTorch reference.""" + + def test_small(self): + self._run(use_full=False, M=16, r=4) + + def _run(self, use_full, M, r): + from axolotl.integrations.kernels.libs.scattermoe_lora import ( + flatten_sort_count, + parallel_linear_lora, + ) + + config = make_olmoe_config(use_full=use_full) + E, k = config.num_experts, config.num_experts_per_tok + lora_alpha = 16 + scaling = lora_alpha / r + + torch.manual_seed(42) + moe = _init_expert_weights(OlmoeSparseMoeBlock(config)).cuda().float() + x = torch.randn(1, M, config.hidden_size, device="cuda") + x_flat = x.view(-1, config.hidden_size) + gate_up_proj = moe.experts.gate_up_proj + down_proj = moe.experts.down_proj + + # Create LoRA weights in scattermoe layout directly + gup_A = torch.randn(r * E, config.hidden_size, device="cuda") * 0.01 + gup_B = torch.randn(2 * config.intermediate_size, r * E, device="cuda") * 0.01 + down_A = torch.randn(r * E, config.intermediate_size, device="cuda") * 0.01 + down_B = torch.randn(config.hidden_size, r * E, device="cuda") * 0.01 + + rw, sel = _get_routing(moe, x) + sei, ssi, eo = flatten_sort_count(sel, num_experts=E) + + # --- Reference --- + gup_delta = _compute_delta_from_scattermoe_lora( + gup_A, gup_B, scaling, E, r, gate_up_proj.shape + ) + down_delta = _compute_delta_from_scattermoe_lora( + down_A, down_B, scaling, E, r, down_proj.shape + ) + + x_ref = x_flat.clone().detach().requires_grad_(True) + ref_out = _reference_moe_forward_with_lora( + x_ref, + gate_up_proj, + down_proj, + moe.experts.act_fn, + sel, + rw, + E, + gup_delta, + down_delta, + ) + ref_out.sum().backward() + + # --- ScatterMoE fused path --- + x_smoe = x_flat.clone().detach().requires_grad_(True) + gup_A_s = gup_A.clone().requires_grad_(True) + gup_B_s = gup_B.clone().requires_grad_(True) + down_A_s = down_A.clone().requires_grad_(True) + down_B_s = down_B.clone().requires_grad_(True) + + gup_out = parallel_linear_lora( + x_smoe, + gate_up_proj.transpose(2, 1), + k, + sei, + ssi, + eo, + lora_A=gup_A_s, + lora_B=gup_B_s, + scaling=scaling, + grouped_in=False, + grouped_out=True, + ) + g, u = gup_out.chunk(2, dim=-1) + h = moe.experts.act_fn(g) * u + + smoe_out = parallel_linear_lora( + h, + down_proj.transpose(2, 1), + 1, + sei, + ssi, + eo, + lora_A=down_A_s, + lora_B=down_B_s, + scaling=scaling, + grouped_in=True, + grouped_out=False, + gates=rw, + ) + smoe_out.sum().backward() + + torch.testing.assert_close( + smoe_out.detach(), + ref_out.detach(), + atol=5e-3, + rtol=5e-3, + ) + torch.testing.assert_close( + x_smoe.grad, + x_ref.grad, + atol=5e-2, + rtol=5e-2, + ) + + +# ============================================================================= +# Tests: kernelize() integration via LocalLayerRepository +# ============================================================================= + + +@requires_cuda +class TestKernelizeIntegration: + """Test the HF kernels library integration with LocalLayerRepository.""" + + @staticmethod + def _get_kernelize_imports(): + """Import kernels library components, skip if not available.""" + try: + from kernels import ( + LocalLayerRepository, + Mode, + kernelize, + register_kernel_mapping, + replace_kernel_forward_from_hub, + ) + + return ( + LocalLayerRepository, + Mode, + register_kernel_mapping, + replace_kernel_forward_from_hub, + kernelize, + ) + except ImportError: + pytest.skip("kernels library not installed") + + @staticmethod + def _get_repo_path(): + """Get the path to scattermoe_lora within axolotl's plugin.""" + return ( + Path(__file__).parent.parent.parent + / "src" + / "axolotl" + / "integrations" + / "kernels" + / "libs" + / "scattermoe_lora" + ) + + def _setup_kernels( + self, + LocalLayerRepository, + Mode, + register_kernel_mapping, + replace_kernel_forward_from_hub, + ): + """Register kernel mapping for tests.""" + repo_path = self._get_repo_path() + local_repo = LocalLayerRepository( + repo_path=repo_path, + package_name="scattermoe_lora", + layer_name="HFScatterMoEGatedMLP", + ) + + replace_kernel_forward_from_hub( + OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts" + ) + register_kernel_mapping( + { + "HFScatterMoEParallelExperts": { + "cuda": { + Mode.TRAINING: local_repo, + Mode.INFERENCE: local_repo, + }, + } + } + ) + + def test_base_forward_via_kernelize(self): + """Kernelized OlmoeSparseMoeBlock (no LoRA) matches per-expert reference.""" + ( + LocalLayerRepository, + Mode, + register_kernel_mapping, + replace_kernel_forward_from_hub, + kernelize, + ) = self._get_kernelize_imports() + + config = make_olmoe_config(use_full=False) + E = config.num_experts + + # Create model + torch.manual_seed(42) + moe = _init_expert_weights(OlmoeSparseMoeBlock(config)).cuda().float() + x = torch.randn(1, 8, config.hidden_size, device="cuda") + x_flat = x.view(-1, config.hidden_size) + + # Compute reference BEFORE kernelizing + with torch.no_grad(): + _, rw, sel = moe.gate(x_flat) + ref_out = _reference_moe_forward( + x_flat, + moe.experts.gate_up_proj, + moe.experts.down_proj, + moe.experts.act_fn, + sel, + rw, + E, + ).view(1, 8, config.hidden_size) + + # Set up kernel mapping + self._setup_kernels( + LocalLayerRepository, + Mode, + register_kernel_mapping, + replace_kernel_forward_from_hub, + ) + + # Kernelize the model + kernelize(moe, mode=Mode.TRAINING, device="cuda") + + # Forward through kernelized model + with torch.no_grad(): + kern_out = moe(x) + + torch.testing.assert_close(kern_out, ref_out, atol=1e-3, rtol=1e-3) + + def test_lora_forward_via_kernelize(self): + """Kernelized OlmoeSparseMoeBlock with peft LoRA matches reference.""" + ( + LocalLayerRepository, + Mode, + register_kernel_mapping, + replace_kernel_forward_from_hub, + kernelize, + ) = self._get_kernelize_imports() + + config = make_olmoe_config(use_full=False) + r = 4 + + # Create peft model + torch.manual_seed(42) + model = MinimalOLMoEModel(config).cuda().float() + lora_config = LoraConfig( + r=r, + lora_alpha=16, + target_modules=[], + target_parameters=["experts.gate_up_proj", "experts.down_proj"], + bias="none", + ) + peft_model = get_peft_model(model, lora_config) + + x = torch.randn(1, 8, config.hidden_size, device="cuda") + + # Reference: peft's own forward (uses _activate_lora context manager) + with torch.no_grad(): + ref_out = peft_model(x) + + # Set up kernel mapping + self._setup_kernels( + LocalLayerRepository, + Mode, + register_kernel_mapping, + replace_kernel_forward_from_hub, + ) + + # Kernelize the MoE block inside the peft model + base_moe = peft_model.base_model.model.moe + kernelize(base_moe, mode=Mode.TRAINING, device="cuda") + + # Forward through kernelized peft model + with torch.no_grad(): + kern_out = peft_model(x) + + torch.testing.assert_close(kern_out, ref_out, atol=5e-3, rtol=5e-3) + + def test_gate_lora_forward_via_kernelize(self): + """Kernelized forward with gate LoRA matches peft reference.""" + ( + LocalLayerRepository, + Mode, + register_kernel_mapping, + replace_kernel_forward_from_hub, + kernelize, + ) = self._get_kernelize_imports() + + config = make_olmoe_config(use_full=False) + r = 4 + + # Create peft model with gate + experts LoRA + torch.manual_seed(42) + model = MinimalOLMoEModel(config).cuda().float() + lora_config = LoraConfig( + r=r, + lora_alpha=16, + target_modules=[], + target_parameters=[ + "gate.weight", + "experts.gate_up_proj", + "experts.down_proj", + ], + bias="none", + ) + peft_model = get_peft_model(model, lora_config) + + x = torch.randn(1, 8, config.hidden_size, device="cuda") + + # Reference: peft's own forward + with torch.no_grad(): + ref_out = peft_model(x) + + # Set up kernel mapping + self._setup_kernels( + LocalLayerRepository, + Mode, + register_kernel_mapping, + replace_kernel_forward_from_hub, + ) + + # Kernelize the MoE block inside the peft model + base_moe = peft_model.base_model.model.moe + kernelize(base_moe, mode=Mode.TRAINING, device="cuda") + + # Forward through kernelized peft model + with torch.no_grad(): + kern_out = peft_model(x) + + torch.testing.assert_close(kern_out, ref_out, atol=5e-3, rtol=5e-3) + + +# ============================================================================= +# Tests: Shared expert handling +# ============================================================================= + + +class TestSharedExpertHandling: + """Test that HFScatterMoEGatedMLP.forward handles shared experts.""" + + @staticmethod + def _make_shared_expert_block(config): + """Create an OlmoeSparseMoeBlock with a mock shared expert attached.""" + moe = OlmoeSparseMoeBlock(config) + _init_expert_weights(moe) + + hidden = config.hidden_size + inter = config.intermediate_size + + # Attach a simple shared expert MLP (mimics Qwen2MoE structure) + class SharedExpertMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + moe.shared_expert = SharedExpertMLP(hidden, inter) + moe.shared_expert_gate = nn.Linear(hidden, 1, bias=False) + + return moe + + def test_shared_expert_is_used(self): + """Verify shared expert output affects final result.""" + config = make_olmoe_config(use_full=False) + moe = self._make_shared_expert_block(config) + + # Compute reference without shared expert + torch.manual_seed(42) + x = torch.randn(1, 4, config.hidden_size) + x_flat = x.view(-1, config.hidden_size) + + with torch.no_grad(): + # Shared expert contribution + shared_out = moe.shared_expert(x_flat) + gate_val = F.sigmoid(moe.shared_expert_gate(x_flat)) + shared_contribution = shared_out * gate_val + + # Verify shared expert produces non-zero output + assert shared_contribution.abs().max() > 0 + + @requires_cuda + def test_shared_expert_forward_via_kernelize(self): + """Kernelized forward with shared expert matches manual reference.""" + try: + from kernels import ( + LocalLayerRepository, + Mode, + kernelize, + register_kernel_mapping, + replace_kernel_forward_from_hub, + ) + except ImportError: + pytest.skip("kernels library not installed") + + config = make_olmoe_config(use_full=False) + E = config.num_experts + + torch.manual_seed(42) + moe = self._make_shared_expert_block(config).cuda().float() + x = torch.randn(1, 8, config.hidden_size, device="cuda") + x_flat = x.view(-1, config.hidden_size) + + # Compute reference: per-expert + shared expert + with torch.no_grad(): + _, rw, sel = moe.gate(x_flat) + + expert_out = _reference_moe_forward( + x_flat, + moe.experts.gate_up_proj, + moe.experts.down_proj, + moe.experts.act_fn, + sel, + rw, + E, + ) + shared_out = moe.shared_expert(x_flat) + gate_val = F.sigmoid(moe.shared_expert_gate(x_flat)) + ref_out = (expert_out + shared_out * gate_val).view( + 1, 8, config.hidden_size + ) + + # Kernelize + repo_path = ( + Path(__file__).parent.parent.parent + / "src" + / "axolotl" + / "integrations" + / "kernels" + / "libs" + / "scattermoe_lora" + ) + local_repo = LocalLayerRepository( + repo_path=repo_path, + package_name="scattermoe_lora", + layer_name="HFScatterMoEGatedMLP", + ) + + replace_kernel_forward_from_hub( + OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts" + ) + register_kernel_mapping( + { + "HFScatterMoEParallelExperts": { + "cuda": { + Mode.TRAINING: local_repo, + Mode.INFERENCE: local_repo, + }, + } + } + ) + + kernelize(moe, mode=Mode.TRAINING, device="cuda") + + with torch.no_grad(): + kern_out = moe(x) + + torch.testing.assert_close(kern_out, ref_out, atol=1e-3, rtol=1e-3)