Files
axolotl/tests/e2e/integrations/test_scattermoe_lora_olmoe.py
Wing Lian 0ee98a0309 fix token state json and mistral tokenizer issue (#3522) [skip ci]
* fix token state json and mistral tokenizer issue

* centralize constants

* forgot to commit constants file

* Fix weakref in pickling relora state dict

* make curl a bit quieter so it doesn't log 2K lines

* fix path traversal for olmoe test

* more test fixes that weren't flagged previously

* chore: lint

* skip tests that fail b/c of OutOfResources

* scattermoe as slow tests

* update fbgemm-genai for torch 2.10
2026-03-21 22:46:10 -04:00

1256 lines
42 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
Integration tests: OLMoE + peft LoRA + ScatterMoE fused kernels.
Validates that scattermoe_lora fused kernels produce correct results when used
with HuggingFace OLMoE models and peft LoRA adapters applied via
``target_parameters``.
Key things tested
-----------------
- LoRA weight layout conversion between peft (rank-major) and scattermoe (expert-major)
- Base forward equivalence: per-expert reference vs ScatterMoE kernels (no LoRA)
- LoRA forward equivalence: peft merged-weight approach vs scattermoe fused kernels
- Backward gradient correctness through the fused LoRA path
- ``kernelize()`` integration via ``LocalLayerRepository``
"""
from pathlib import Path
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model
from transformers import OlmoeConfig
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
_SMOE = "axolotl.integrations.kernels.libs.scattermoe_lora"
# Try to import from axolotl's scattermoe_lora.layers; may fail on CPU without triton.
try:
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_unwrap_experts_lora,
_unwrap_gate_lora,
peft_lora_B_to_scattermoe,
peft_lora_to_scattermoe,
)
HAS_SCATTERMOE = True
except (ImportError, ModuleNotFoundError):
HAS_SCATTERMOE = False
# Provide pure-torch fallbacks for CPU-only layout conversion tests.
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
N = peft_B.shape[0]
return (
peft_B.reshape(N, rank, num_experts)
.permute(0, 2, 1)
.contiguous()
.reshape(N, num_experts * rank)
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
K_inter, N_hidden = peft_B.shape[0], peft_A.shape[1]
smoe_A = torch.zeros(
rank * num_experts,
K_inter,
device=peft_A.device,
dtype=peft_A.dtype,
)
smoe_B = torch.zeros(
N_hidden,
rank * num_experts,
device=peft_A.device,
dtype=peft_A.dtype,
)
for e in range(num_experts):
s = e * rank
smoe_A[s : s + rank, :] = peft_B_em[:, s : s + rank].T
smoe_B[:, s : s + rank] = peft_A[s : s + rank, :].T
return smoe_A, smoe_B
def _unwrap_experts_lora(experts_module):
return experts_module, None, None
def _unwrap_gate_lora(gate_module):
if hasattr(gate_module, "base_layer") and hasattr(gate_module, "lora_A"):
base_gate = gate_module.base_layer
active = getattr(gate_module, "active_adapters", ["default"])
name = active[0] if active else "default"
lora_A_dict = getattr(gate_module, "lora_A", {})
lora_B_dict = getattr(gate_module, "lora_B", {})
scaling_dict = getattr(gate_module, "scaling", {})
if name in lora_A_dict:
lora_A = lora_A_dict[name].weight
lora_B = lora_B_dict[name].weight
s = scaling_dict[name]
delta = s * (lora_B @ lora_A)
return base_gate, base_gate.weight, delta
return base_gate, base_gate.weight, None
return gate_module, gate_module.weight, None
# =============================================================================
# Configuration
# =============================================================================
FULL_OLMOE_CONFIG = dict(
hidden_size=2048,
intermediate_size=1024,
num_experts=64,
num_experts_per_tok=8,
hidden_act="silu",
norm_topk_prob=False,
)
SMALL_OLMOE_CONFIG = dict(
hidden_size=128,
intermediate_size=48, # non-square: 2*inter=96 != hidden=128
num_experts=8,
num_experts_per_tok=2,
hidden_act="silu",
norm_topk_prob=False,
)
requires_cuda = pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
)
def make_olmoe_config(use_full=False):
cfg = dict(FULL_OLMOE_CONFIG if use_full else SMALL_OLMOE_CONFIG)
cfg["experts_implementation"] = "grouped_mm"
return OlmoeConfig(**cfg)
# =============================================================================
# Layout conversion utilities (test-local helpers)
# =============================================================================
def scattermoe_lora_B_to_peft(smoe_B, num_experts, rank):
"""Inverse of ``peft_lora_B_to_scattermoe``."""
N = smoe_B.shape[0]
return (
smoe_B.reshape(N, num_experts, rank)
.permute(0, 2, 1)
.contiguous()
.reshape(N, num_experts * rank)
)
def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA for gate_up_proj to scattermoe layout.
Both gate_up_proj and down_proj need the A<->B swap because
scattermoe transposes the parameter (W = param.T).
"""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
# =============================================================================
# Helpers
# =============================================================================
def _init_expert_weights(moe_block):
"""Initialize OlmoeExperts parameters which use torch.empty (uninitialized).
Without this, gate_up_proj and down_proj contain garbage/NaN values.
"""
with torch.no_grad():
nn.init.kaiming_uniform_(moe_block.experts.gate_up_proj)
nn.init.kaiming_uniform_(moe_block.experts.down_proj)
return moe_block
class MinimalOLMoEModel(nn.Module):
"""Thin wrapper so peft's get_peft_model can attach adapters."""
def __init__(self, config):
super().__init__()
self.moe = OlmoeSparseMoeBlock(config)
_init_expert_weights(self.moe)
def forward(self, x):
return self.moe(x)
def _get_routing(moe_block, hidden_states):
"""Run the router and return (routing_weights, selected_experts)."""
with torch.no_grad():
_, routing_weights, selected_experts = moe_block.gate(
hidden_states.view(-1, hidden_states.size(-1))
)
return routing_weights, selected_experts
def _reference_moe_forward(
x_flat,
gate_up_proj,
down_proj,
act_fn,
top_k_index,
top_k_weights,
num_experts,
):
"""Pure-PyTorch per-expert reference MoE forward (no LoRA).
Uses F.linear per expert for an apples-to-apples comparison with
the ScatterMoE kernel path.
"""
final = torch.zeros_like(x_flat)
expert_mask = F.one_hot(top_k_index, num_classes=num_experts).permute(2, 1, 0)
for e in range(num_experts):
top_k_pos, token_idx = torch.where(expert_mask[e])
if token_idx.numel() == 0:
continue
cur = x_flat[token_idx]
gate_up = F.linear(cur, gate_up_proj[e])
g, u = gate_up.chunk(2, dim=-1)
h = act_fn(g) * u
out = F.linear(h, down_proj[e])
out = out * top_k_weights[token_idx, top_k_pos, None]
final.index_add_(0, token_idx, out.to(final.dtype))
return final
def _reference_moe_forward_with_lora(
x_flat,
gate_up_proj,
down_proj,
act_fn,
top_k_index,
top_k_weights,
num_experts,
gup_delta,
down_delta,
):
"""Pure-PyTorch reference MoE forward with pre-computed weight deltas."""
merged_gup = gate_up_proj + gup_delta
merged_down = down_proj + down_delta
return _reference_moe_forward(
x_flat,
merged_gup,
merged_down,
act_fn,
top_k_index,
top_k_weights,
num_experts,
)
def _compute_delta_from_scattermoe_lora(lora_A, lora_B, scaling, E, r, param_shape):
"""Compute additive weight delta from scattermoe-layout LoRA weights.
delta[e] = scaling * B_e @ A_e where A_e [r,K], B_e [N,r] -> [N,K].
"""
delta = torch.zeros(param_shape, device=lora_A.device, dtype=lora_A.dtype)
for e in range(E):
A_e = lora_A[e * r : (e + 1) * r, :]
B_e = lora_B[:, e * r : (e + 1) * r]
delta[e] = scaling * (B_e @ A_e)
return delta
# =============================================================================
# Tests: Layout conversion
# =============================================================================
class TestLoRABLayoutConversion:
"""Test the peft <-> scattermoe lora_B layout conversion."""
def test_roundtrip(self):
E, r, N = 8, 4, 64
original = torch.randn(N, E * r)
converted = peft_lora_B_to_scattermoe(original, E, r)
back = scattermoe_lora_B_to_peft(converted, E, r)
torch.testing.assert_close(back, original)
def test_per_expert_slices(self):
"""After conversion, scattermoe slicing gives the same per-expert
matrices as peft's reshape slicing."""
E, r, N = 4, 2, 16
peft_B = torch.randn(N, E * r)
smoe_B = peft_lora_B_to_scattermoe(peft_B, E, r)
peft_reshaped = peft_B.reshape(N, r, E)
for e in range(E):
torch.testing.assert_close(
smoe_B[:, e * r : (e + 1) * r],
peft_reshaped[:, :, e],
)
def test_lora_A_already_compatible(self):
"""lora_A layout is identical between peft and scattermoe."""
E, r, K = 4, 2, 16
lora_A = torch.randn(E * r, K)
peft_reshaped = lora_A.reshape(E, r, K)
for e in range(E):
torch.testing.assert_close(
lora_A[e * r : (e + 1) * r, :],
peft_reshaped[e],
)
def test_delta_weight_equivalence(self):
"""peft's einsum delta matches per-expert B @ A with converted layouts."""
E, r, K, N = 8, 4, 32, 64
peft_A = torch.randn(E * r, K)
peft_B = torch.randn(N, E * r)
scaling = 2.0
A_r = peft_A.reshape(E, r, K)
B_r = peft_B.reshape(N, r, E)
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
smoe_B = peft_lora_B_to_scattermoe(peft_B, E, r)
for e in range(E):
A_e = peft_A[e * r : (e + 1) * r, :]
B_e = smoe_B[:, e * r : (e + 1) * r]
delta_e = scaling * (B_e @ A_e)
torch.testing.assert_close(delta_e, delta_peft[e].T, atol=1e-5, rtol=1e-5)
def test_down_proj_conversion(self):
"""Verify peft_lora_to_scattermoe produces correct delta."""
E, r = 4, 2
hidden, inter = 32, 16
scaling = 2.0
peft_A = torch.randn(E * r, hidden)
peft_B = torch.randn(inter, E * r)
A_r = peft_A.reshape(E, r, hidden)
B_r = peft_B.reshape(inter, r, E)
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
for e in range(E):
A_e = smoe_A[e * r : (e + 1) * r, :]
B_e = smoe_B[:, e * r : (e + 1) * r]
delta_smoe_e = scaling * (B_e @ A_e)
torch.testing.assert_close(
delta_smoe_e, delta_peft[e], atol=1e-5, rtol=1e-5
)
def test_gate_up_proj_conversion(self):
"""Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like).
gate_up_proj param: [E, 2*inter, hidden].
peft: in_features=2*inter, out_features=hidden.
peft lora_A: [r*E, 2*inter], lora_B: [hidden, r*E].
scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter.
scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].
Uses non-square dims (hidden=32 != 2*inter=24) to catch A<->B swap bugs.
"""
E, r = 4, 2
hidden, inter = 32, 12 # 2*inter=24 != hidden=32
scaling = 2.0
# peft assigns: in_features=2*inter, out_features=hidden
peft_A = torch.randn(E * r, 2 * inter) # [r*E, in_features=2*inter]
peft_B = torch.randn(hidden, E * r) # [out_features=hidden, r*E]
# peft delta via einsum: "o r e, e r i -> e i o"
A_r = peft_A.reshape(E, r, 2 * inter)
B_r = peft_B.reshape(hidden, r, E)
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
# delta_peft[e] has shape [in_features, out_features] = [2*inter, hidden]
# = param[e] shape [2*inter, hidden]
smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)
# smoe_A should be [r*E, K=hidden], smoe_B should be [N=2*inter, r*E]
assert smoe_A.shape == (E * r, hidden), (
f"Expected {(E * r, hidden)}, got {smoe_A.shape}"
)
assert smoe_B.shape == (2 * inter, E * r), (
f"Expected {(2 * inter, E * r)}, got {smoe_B.shape}"
)
for e in range(E):
A_e = smoe_A[e * r : (e + 1) * r, :] # [r, K=hidden]
B_e = smoe_B[:, e * r : (e + 1) * r] # [N=2*inter, r]
delta_smoe_e = scaling * (B_e @ A_e) # [2*inter, hidden]
# Should match peft delta which is [2*inter, hidden] = param[e]
torch.testing.assert_close(
delta_smoe_e, delta_peft[e], atol=1e-5, rtol=1e-5
)
# =============================================================================
# Tests: peft weight extraction
# =============================================================================
class TestPeftLoRAWeightExtraction:
"""Test extracting peft LoRA weights for OLMoE."""
def test_peft_creates_correct_shapes(self):
config = make_olmoe_config(use_full=False)
E, r = config.num_experts, 4
model = MinimalOLMoEModel(config)
lora_config = LoraConfig(
r=r,
lora_alpha=16,
target_modules=[],
target_parameters=[
"gate.weight",
"experts.gate_up_proj",
"experts.down_proj",
],
bias="none",
)
peft_model = get_peft_model(model, lora_config)
trainable = {n: p for n, p in peft_model.named_parameters() if p.requires_grad}
# Gate router
assert trainable["base_model.model.moe.gate.lora_A.default.weight"].shape == (
r,
config.hidden_size,
)
assert trainable["base_model.model.moe.gate.lora_B.default.weight"].shape == (
E,
r,
)
# gate_up_proj [E, 2*inter, hidden]
# peft: in_features=2*inter (dim 1), out_features=hidden (dim 2)
assert trainable[
"base_model.model.moe.experts.base_layer.lora_A.default.weight"
].shape == (E * r, 2 * config.intermediate_size)
assert trainable[
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
].shape == (config.hidden_size, E * r)
# down_proj [E, hidden, inter]
# peft: in_features=hidden (dim 1), out_features=inter (dim 2)
assert trainable[
"base_model.model.moe.experts.lora_A.default.weight"
].shape == (E * r, config.hidden_size)
assert trainable[
"base_model.model.moe.experts.lora_B.default.weight"
].shape == (config.intermediate_size, E * r)
@requires_cuda
def test_peft_forward_runs(self):
"""Smoke test: peft model forward pass completes (needs CUDA for grouped_mm)."""
config = make_olmoe_config(use_full=False)
model = MinimalOLMoEModel(config)
lora_config = LoraConfig(
r=4,
lora_alpha=16,
target_modules=[],
target_parameters=[
"gate.weight",
"experts.gate_up_proj",
"experts.down_proj",
],
bias="none",
)
peft_model = get_peft_model(model, lora_config)
x = torch.randn(1, 4, config.hidden_size)
out = peft_model(x)
assert out.shape == x.shape
@pytest.mark.skipif(
not HAS_SCATTERMOE, reason="scattermoe_lora not importable (no triton)"
)
def test_unwrap_experts_lora(self):
"""Test that _unwrap_experts_lora correctly detects LoRA wrappers."""
config = make_olmoe_config(use_full=False)
model = MinimalOLMoEModel(config)
lora_config = LoraConfig(
r=4,
lora_alpha=16,
target_modules=[],
target_parameters=["experts.gate_up_proj", "experts.down_proj"],
bias="none",
)
peft_model = get_peft_model(model, lora_config)
base_moe = peft_model.base_model.model.moe
# Experts should be wrapped by ParamWrapper
experts, gup_lora, down_lora = _unwrap_experts_lora(base_moe.experts)
# Base experts should have the raw parameters
assert hasattr(experts, "gate_up_proj")
assert hasattr(experts, "down_proj")
# LoRA should be detected
assert gup_lora is not None, "gate_up_proj LoRA not detected"
assert down_lora is not None, "down_proj LoRA not detected"
# Check shapes (after peft->scattermoe conversion with A<->B swap)
# gate_up_proj W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter
E, r = config.num_experts, 4
gup_A, gup_B, gup_s = gup_lora
assert gup_A.shape == (E * r, config.hidden_size), (
f"gate_up_proj smoe_A: expected [r*E, K=hidden]={(E * r, config.hidden_size)}, "
f"got {gup_A.shape}"
)
assert gup_B.shape == (2 * config.intermediate_size, E * r), (
f"gate_up_proj smoe_B: expected [N=2*inter, r*E]="
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
)
# down_proj W = param.T = [E, inter, hidden], K=inter, N=hidden
down_A, down_B, down_s = down_lora
assert down_A.shape == (E * r, config.intermediate_size), (
f"down_proj smoe_A: expected [r*E, K=inter]={(E * r, config.intermediate_size)}, "
f"got {down_A.shape}"
)
assert down_B.shape == (config.hidden_size, E * r), (
f"down_proj smoe_B: expected [N=hidden, r*E]={(config.hidden_size, E * r)}, "
f"got {down_B.shape}"
)
def test_unwrap_no_lora(self):
"""Without peft, _unwrap_experts_lora returns no LoRA."""
config = make_olmoe_config(use_full=False)
moe = OlmoeSparseMoeBlock(config)
experts, gup_lora, down_lora = _unwrap_experts_lora(moe.experts)
assert gup_lora is None
assert down_lora is None
assert hasattr(experts, "gate_up_proj")
def test_unwrap_gate_lora(self):
"""Test that _unwrap_gate_lora detects LoRA on the router gate."""
config = make_olmoe_config(use_full=False)
model = MinimalOLMoEModel(config)
r = 4
lora_config = LoraConfig(
r=r,
lora_alpha=16,
target_modules=[],
target_parameters=["gate.weight"],
bias="none",
)
peft_model = get_peft_model(model, lora_config)
base_moe = peft_model.base_model.model.moe
# Set non-zero LoRA weights (peft initializes lora_B to zeros)
with torch.no_grad():
base_moe.gate.lora_B["default"].weight.normal_(0, 0.01)
base_gate, gate_weight, gate_delta = _unwrap_gate_lora(base_moe.gate)
# Base gate should be the original router
assert hasattr(base_gate, "top_k")
assert hasattr(base_gate, "num_experts")
assert base_gate.top_k == config.num_experts_per_tok
assert base_gate.num_experts == config.num_experts
# Gate weight should be the base weight (delta returned separately)
assert gate_weight.shape == (config.num_experts, config.hidden_size)
torch.testing.assert_close(gate_weight, base_gate.weight)
# Delta should be non-zero (LoRA was applied)
assert gate_delta is not None
assert gate_delta.shape == (config.num_experts, config.hidden_size)
assert gate_delta.abs().max() > 0, "Gate LoRA delta should be non-zero"
def test_unwrap_gate_no_lora(self):
"""Without peft, _unwrap_gate_lora returns the original gate."""
config = make_olmoe_config(use_full=False)
moe = OlmoeSparseMoeBlock(config)
base_gate, gate_weight, gate_delta = _unwrap_gate_lora(moe.gate)
assert base_gate is moe.gate
torch.testing.assert_close(gate_weight, moe.gate.weight)
assert gate_delta is None
def test_gate_lora_delta_matches_peft(self):
"""Verify _unwrap_gate_lora computes the same delta as peft."""
config = make_olmoe_config(use_full=False)
model = MinimalOLMoEModel(config)
r = 4
lora_alpha = 16
scaling = lora_alpha / r
lora_config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
target_modules=[],
target_parameters=["gate.weight"],
bias="none",
)
peft_model = get_peft_model(model, lora_config)
base_moe = peft_model.base_model.model.moe
# Our unwrapped weight + delta
_, gate_weight, gate_delta = _unwrap_gate_lora(base_moe.gate)
# Manually compute expected delta
lora_A = base_moe.gate.lora_A["default"].weight # [r, hidden]
lora_B = base_moe.gate.lora_B["default"].weight # [E, r]
base_weight = base_moe.gate.base_layer.weight # [E, hidden]
expected_delta = scaling * (lora_B @ lora_A)
torch.testing.assert_close(gate_weight, base_weight)
torch.testing.assert_close(gate_delta, expected_delta)
# Combined should match the old behavior
torch.testing.assert_close(
gate_weight + gate_delta, base_weight + expected_delta
)
# =============================================================================
# Tests: Base forward equivalence (no LoRA)
# =============================================================================
@requires_cuda
class TestOLMoEReferenceVsScatterMoE:
"""Base forward equivalence: per-expert reference vs ScatterMoE kernels."""
def test_small(self):
self._run(use_full=False, M=16)
@pytest.mark.slow
def test_full(self):
self._run(use_full=True, M=32)
def _run(self, use_full, M):
from axolotl.integrations.kernels.libs.scattermoe_lora import (
flatten_sort_count,
parallel_linear,
)
config = make_olmoe_config(use_full=use_full)
torch.manual_seed(42)
moe = _init_expert_weights(OlmoeSparseMoeBlock(config)).cuda().float()
E, k = config.num_experts, config.num_experts_per_tok
x = torch.randn(1, M, config.hidden_size, device="cuda")
x_flat = x.view(-1, config.hidden_size)
with torch.no_grad():
# Shared routing for both paths
_, rw, sel = moe.gate(x_flat)
sei, ssi, eo = flatten_sort_count(sel, num_experts=E)
# Per-expert reference
ref_out = _reference_moe_forward(
x_flat,
moe.experts.gate_up_proj,
moe.experts.down_proj,
moe.experts.act_fn,
sel,
rw,
E,
).view(1, M, config.hidden_size)
# ScatterMoE kernel path
gup = parallel_linear(
x_flat,
moe.experts.gate_up_proj.transpose(2, 1),
k,
sei,
ssi,
eo,
grouped_in=False,
grouped_out=True,
)
g, u = gup.chunk(2, dim=-1)
h = moe.experts.act_fn(g) * u
smoe_out = parallel_linear(
h,
moe.experts.down_proj.transpose(2, 1),
1,
sei,
ssi,
eo,
grouped_in=True,
grouped_out=False,
gates=rw,
).view(1, M, config.hidden_size)
torch.testing.assert_close(smoe_out, ref_out, atol=1e-3, rtol=1e-3)
# =============================================================================
# Tests: LoRA forward equivalence (peft vs scattermoe fused)
# =============================================================================
@requires_cuda
class TestOLMoEPeftLoRAForward:
"""Fused LoRA forward: peft merged-weight vs scattermoe_lora kernel."""
def test_small(self):
self._run(use_full=False, M=16, r=4)
@pytest.mark.slow
def test_full(self):
self._run(use_full=True, M=32, r=8)
def _run(self, use_full, M, r):
from axolotl.integrations.kernels.libs.scattermoe_lora import (
flatten_sort_count,
parallel_linear_lora,
)
config = make_olmoe_config(use_full=use_full)
E, k = config.num_experts, config.num_experts_per_tok
lora_alpha = 16
scaling = lora_alpha / r
# Create peft model
model = MinimalOLMoEModel(config).cuda().float()
lora_config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
target_modules=[],
target_parameters=["experts.gate_up_proj", "experts.down_proj"],
bias="none",
)
peft_model = get_peft_model(model, lora_config)
torch.manual_seed(42)
x = torch.randn(1, M, config.hidden_size, device="cuda")
# peft forward
with torch.no_grad():
peft_out = peft_model(x)
# Extract base weights and LoRA weights
base_moe = peft_model.base_model.model.moe
base_experts = base_moe.experts.base_layer.base_layer
gate_up_proj = base_experts.gate_up_proj
down_proj = base_experts.down_proj
act_fn = base_experts.act_fn
# gate_up_proj LoRA
gup_w = base_moe.experts.base_layer
peft_gup_A = gup_w.lora_A["default"].weight.detach()
peft_gup_B = gup_w.lora_B["default"].weight.detach()
smoe_gup_A, smoe_gup_B = peft_gate_up_lora_to_scattermoe(
peft_gup_A, peft_gup_B, E, r
)
# down_proj LoRA
down_w = base_moe.experts
peft_down_A = down_w.lora_A["default"].weight.detach()
peft_down_B = down_w.lora_B["default"].weight.detach()
smoe_down_A, smoe_down_B = peft_lora_to_scattermoe(
peft_down_A, peft_down_B, E, r
)
# ScatterMoE fused forward -- gate is NOT peft-wrapped, access directly
x_flat = x.view(-1, config.hidden_size)
with torch.no_grad():
_, rw, sel = base_moe.gate(x_flat)
sei, ssi, eo = flatten_sort_count(sel, num_experts=E)
gup = parallel_linear_lora(
x_flat,
gate_up_proj.transpose(2, 1),
k,
sei,
ssi,
eo,
lora_A=smoe_gup_A,
lora_B=smoe_gup_B,
scaling=scaling,
grouped_in=False,
grouped_out=True,
)
g, u = gup.chunk(2, dim=-1)
h = act_fn(g) * u
smoe_out = parallel_linear_lora(
h,
down_proj.transpose(2, 1),
1,
sei,
ssi,
eo,
lora_A=smoe_down_A,
lora_B=smoe_down_B,
scaling=scaling,
grouped_in=True,
grouped_out=False,
gates=rw,
).view(1, M, config.hidden_size)
torch.testing.assert_close(smoe_out, peft_out, atol=5e-3, rtol=5e-3)
# =============================================================================
# Tests: Backward gradient correctness
# =============================================================================
@requires_cuda
class TestOLMoEPeftLoRABackward:
"""Backward gradients through scattermoe_lora vs pure-PyTorch reference."""
def test_small(self):
self._run(use_full=False, M=16, r=4)
def _run(self, use_full, M, r):
from axolotl.integrations.kernels.libs.scattermoe_lora import (
flatten_sort_count,
parallel_linear_lora,
)
config = make_olmoe_config(use_full=use_full)
E, k = config.num_experts, config.num_experts_per_tok
lora_alpha = 16
scaling = lora_alpha / r
torch.manual_seed(42)
moe = _init_expert_weights(OlmoeSparseMoeBlock(config)).cuda().float()
x = torch.randn(1, M, config.hidden_size, device="cuda")
x_flat = x.view(-1, config.hidden_size)
gate_up_proj = moe.experts.gate_up_proj
down_proj = moe.experts.down_proj
# Create LoRA weights in scattermoe layout directly
gup_A = torch.randn(r * E, config.hidden_size, device="cuda") * 0.01
gup_B = torch.randn(2 * config.intermediate_size, r * E, device="cuda") * 0.01
down_A = torch.randn(r * E, config.intermediate_size, device="cuda") * 0.01
down_B = torch.randn(config.hidden_size, r * E, device="cuda") * 0.01
rw, sel = _get_routing(moe, x)
sei, ssi, eo = flatten_sort_count(sel, num_experts=E)
# --- Reference ---
gup_delta = _compute_delta_from_scattermoe_lora(
gup_A, gup_B, scaling, E, r, gate_up_proj.shape
)
down_delta = _compute_delta_from_scattermoe_lora(
down_A, down_B, scaling, E, r, down_proj.shape
)
x_ref = x_flat.clone().detach().requires_grad_(True)
ref_out = _reference_moe_forward_with_lora(
x_ref,
gate_up_proj,
down_proj,
moe.experts.act_fn,
sel,
rw,
E,
gup_delta,
down_delta,
)
ref_out.sum().backward()
# --- ScatterMoE fused path ---
x_smoe = x_flat.clone().detach().requires_grad_(True)
gup_A_s = gup_A.clone().requires_grad_(True)
gup_B_s = gup_B.clone().requires_grad_(True)
down_A_s = down_A.clone().requires_grad_(True)
down_B_s = down_B.clone().requires_grad_(True)
gup_out = parallel_linear_lora(
x_smoe,
gate_up_proj.transpose(2, 1),
k,
sei,
ssi,
eo,
lora_A=gup_A_s,
lora_B=gup_B_s,
scaling=scaling,
grouped_in=False,
grouped_out=True,
)
g, u = gup_out.chunk(2, dim=-1)
h = moe.experts.act_fn(g) * u
smoe_out = parallel_linear_lora(
h,
down_proj.transpose(2, 1),
1,
sei,
ssi,
eo,
lora_A=down_A_s,
lora_B=down_B_s,
scaling=scaling,
grouped_in=True,
grouped_out=False,
gates=rw,
)
smoe_out.sum().backward()
torch.testing.assert_close(
smoe_out.detach(),
ref_out.detach(),
atol=5e-3,
rtol=5e-3,
)
torch.testing.assert_close(
x_smoe.grad,
x_ref.grad,
atol=5e-2,
rtol=5e-2,
)
# =============================================================================
# Tests: kernelize() integration via LocalLayerRepository
# =============================================================================
@requires_cuda
class TestKernelizeIntegration:
"""Test the HF kernels library integration with LocalLayerRepository."""
@staticmethod
def _get_kernelize_imports():
"""Import kernels library components, skip if not available."""
try:
from kernels import (
LocalLayerRepository,
Mode,
kernelize,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
return (
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
kernelize,
)
except ImportError:
pytest.skip("kernels library not installed")
@staticmethod
def _get_repo_path():
"""Get the path to scattermoe_lora within axolotl's plugin."""
return (
Path(__file__).parent.parent.parent.parent
/ "src"
/ "axolotl"
/ "integrations"
/ "kernels"
/ "libs"
/ "scattermoe_lora"
)
def _setup_kernels(
self,
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
):
"""Register kernel mapping for tests."""
repo_path = self._get_repo_path()
local_repo = LocalLayerRepository(
repo_path=repo_path,
package_name="scattermoe_lora",
layer_name="HFScatterMoEGatedMLP",
)
replace_kernel_forward_from_hub(
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
)
register_kernel_mapping(
{
"HFScatterMoEParallelExperts": {
"cuda": {
Mode.TRAINING: local_repo,
Mode.INFERENCE: local_repo,
},
}
}
)
def test_base_forward_via_kernelize(self):
"""Kernelized OlmoeSparseMoeBlock (no LoRA) matches per-expert reference."""
(
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
kernelize,
) = self._get_kernelize_imports()
config = make_olmoe_config(use_full=False)
E = config.num_experts
# Create model
torch.manual_seed(42)
moe = _init_expert_weights(OlmoeSparseMoeBlock(config)).cuda().float()
x = torch.randn(1, 8, config.hidden_size, device="cuda")
x_flat = x.view(-1, config.hidden_size)
# Compute reference BEFORE kernelizing
with torch.no_grad():
_, rw, sel = moe.gate(x_flat)
ref_out = _reference_moe_forward(
x_flat,
moe.experts.gate_up_proj,
moe.experts.down_proj,
moe.experts.act_fn,
sel,
rw,
E,
).view(1, 8, config.hidden_size)
# Set up kernel mapping
self._setup_kernels(
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
# Kernelize the model
kernelize(moe, mode=Mode.TRAINING, device="cuda")
# Forward through kernelized model
with torch.no_grad():
kern_out = moe(x)
torch.testing.assert_close(kern_out, ref_out, atol=1e-3, rtol=1e-3)
def test_lora_forward_via_kernelize(self):
"""Kernelized OlmoeSparseMoeBlock with peft LoRA matches reference."""
(
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
kernelize,
) = self._get_kernelize_imports()
config = make_olmoe_config(use_full=False)
r = 4
# Create peft model
torch.manual_seed(42)
model = MinimalOLMoEModel(config).cuda().float()
lora_config = LoraConfig(
r=r,
lora_alpha=16,
target_modules=[],
target_parameters=["experts.gate_up_proj", "experts.down_proj"],
bias="none",
)
peft_model = get_peft_model(model, lora_config)
x = torch.randn(1, 8, config.hidden_size, device="cuda")
# Reference: peft's own forward (uses _activate_lora context manager)
with torch.no_grad():
ref_out = peft_model(x)
# Set up kernel mapping
self._setup_kernels(
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
# Kernelize the MoE block inside the peft model
base_moe = peft_model.base_model.model.moe
kernelize(base_moe, mode=Mode.TRAINING, device="cuda")
# Forward through kernelized peft model
with torch.no_grad():
kern_out = peft_model(x)
torch.testing.assert_close(kern_out, ref_out, atol=5e-3, rtol=5e-3)
def test_gate_lora_forward_via_kernelize(self):
"""Kernelized forward with gate LoRA matches peft reference."""
(
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
kernelize,
) = self._get_kernelize_imports()
config = make_olmoe_config(use_full=False)
r = 4
# Create peft model with gate + experts LoRA
torch.manual_seed(42)
model = MinimalOLMoEModel(config).cuda().float()
lora_config = LoraConfig(
r=r,
lora_alpha=16,
target_modules=[],
target_parameters=[
"gate.weight",
"experts.gate_up_proj",
"experts.down_proj",
],
bias="none",
)
peft_model = get_peft_model(model, lora_config)
x = torch.randn(1, 8, config.hidden_size, device="cuda")
# Reference: peft's own forward
with torch.no_grad():
ref_out = peft_model(x)
# Set up kernel mapping
self._setup_kernels(
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
# Kernelize the MoE block inside the peft model
base_moe = peft_model.base_model.model.moe
kernelize(base_moe, mode=Mode.TRAINING, device="cuda")
# Forward through kernelized peft model
with torch.no_grad():
kern_out = peft_model(x)
torch.testing.assert_close(kern_out, ref_out, atol=5e-3, rtol=5e-3)
# =============================================================================
# Tests: Shared expert handling
# =============================================================================
class TestSharedExpertHandling:
"""Test that HFScatterMoEGatedMLP.forward handles shared experts."""
@staticmethod
def _make_shared_expert_block(config):
"""Create an OlmoeSparseMoeBlock with a mock shared expert attached."""
moe = OlmoeSparseMoeBlock(config)
_init_expert_weights(moe)
hidden = config.hidden_size
inter = config.intermediate_size
# Attach a simple shared expert MLP (mimics Qwen2MoE structure)
class SharedExpertMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
moe.shared_expert = SharedExpertMLP(hidden, inter)
moe.shared_expert_gate = nn.Linear(hidden, 1, bias=False)
return moe
def test_shared_expert_is_used(self):
"""Verify shared expert output affects final result."""
config = make_olmoe_config(use_full=False)
moe = self._make_shared_expert_block(config)
# Compute reference without shared expert
torch.manual_seed(42)
x = torch.randn(1, 4, config.hidden_size)
x_flat = x.view(-1, config.hidden_size)
with torch.no_grad():
# Shared expert contribution
shared_out = moe.shared_expert(x_flat)
gate_val = F.sigmoid(moe.shared_expert_gate(x_flat))
shared_contribution = shared_out * gate_val
# Verify shared expert produces non-zero output
assert shared_contribution.abs().max() > 0
@requires_cuda
def test_shared_expert_forward_via_kernelize(self):
"""Kernelized forward with shared expert matches manual reference."""
try:
from kernels import (
LocalLayerRepository,
Mode,
kernelize,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
except ImportError:
pytest.skip("kernels library not installed")
config = make_olmoe_config(use_full=False)
E = config.num_experts
torch.manual_seed(42)
moe = self._make_shared_expert_block(config).cuda().float()
x = torch.randn(1, 8, config.hidden_size, device="cuda")
x_flat = x.view(-1, config.hidden_size)
# Compute reference: per-expert + shared expert
with torch.no_grad():
_, rw, sel = moe.gate(x_flat)
expert_out = _reference_moe_forward(
x_flat,
moe.experts.gate_up_proj,
moe.experts.down_proj,
moe.experts.act_fn,
sel,
rw,
E,
)
shared_out = moe.shared_expert(x_flat)
gate_val = F.sigmoid(moe.shared_expert_gate(x_flat))
ref_out = (expert_out + shared_out * gate_val).view(
1, 8, config.hidden_size
)
# Kernelize
repo_path = (
Path(__file__).parent.parent.parent.parent
/ "src"
/ "axolotl"
/ "integrations"
/ "kernels"
/ "libs"
/ "scattermoe_lora"
)
local_repo = LocalLayerRepository(
repo_path=repo_path,
package_name="scattermoe_lora",
layer_name="HFScatterMoEGatedMLP",
)
replace_kernel_forward_from_hub(
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
)
register_kernel_mapping(
{
"HFScatterMoEParallelExperts": {
"cuda": {
Mode.TRAINING: local_repo,
Mode.INFERENCE: local_repo,
},
}
}
)
kernelize(moe, mode=Mode.TRAINING, device="cuda")
with torch.no_grad():
kern_out = moe(x)
torch.testing.assert_close(kern_out, ref_out, atol=1e-3, rtol=1e-3)