* fix token state json and mistral tokenizer issue * centralize constants * forgot to commit constants file * Fix weakref in pickling relora state dict * make curl a bit quieter so it doesn't log 2K lines * fix path traversal for olmoe test * more test fixes that weren't flagged previously * chore: lint * skip tests that fail b/c of OutOfResources * scattermoe as slow tests * update fbgemm-genai for torch 2.10
1256 lines
42 KiB
Python
1256 lines
42 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# Copyright (c) Axolotl AI
|
|
# Licensed under the Apache License, Version 2.0
|
|
|
|
"""
|
|
Integration tests: OLMoE + peft LoRA + ScatterMoE fused kernels.
|
|
|
|
Validates that scattermoe_lora fused kernels produce correct results when used
|
|
with HuggingFace OLMoE models and peft LoRA adapters applied via
|
|
``target_parameters``.
|
|
|
|
Key things tested
|
|
-----------------
|
|
- LoRA weight layout conversion between peft (rank-major) and scattermoe (expert-major)
|
|
- Base forward equivalence: per-expert reference vs ScatterMoE kernels (no LoRA)
|
|
- LoRA forward equivalence: peft merged-weight approach vs scattermoe fused kernels
|
|
- Backward gradient correctness through the fused LoRA path
|
|
- ``kernelize()`` integration via ``LocalLayerRepository``
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from peft import LoraConfig, get_peft_model
|
|
from transformers import OlmoeConfig
|
|
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
|
|
|
|
_SMOE = "axolotl.integrations.kernels.libs.scattermoe_lora"
|
|
|
|
# Try to import from axolotl's scattermoe_lora.layers; may fail on CPU without triton.
|
|
try:
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_unwrap_experts_lora,
|
|
_unwrap_gate_lora,
|
|
peft_lora_B_to_scattermoe,
|
|
peft_lora_to_scattermoe,
|
|
)
|
|
|
|
HAS_SCATTERMOE = True
|
|
except (ImportError, ModuleNotFoundError):
|
|
HAS_SCATTERMOE = False
|
|
|
|
# Provide pure-torch fallbacks for CPU-only layout conversion tests.
|
|
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
|
|
N = peft_B.shape[0]
|
|
return (
|
|
peft_B.reshape(N, rank, num_experts)
|
|
.permute(0, 2, 1)
|
|
.contiguous()
|
|
.reshape(N, num_experts * rank)
|
|
)
|
|
|
|
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
|
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
|
K_inter, N_hidden = peft_B.shape[0], peft_A.shape[1]
|
|
smoe_A = torch.zeros(
|
|
rank * num_experts,
|
|
K_inter,
|
|
device=peft_A.device,
|
|
dtype=peft_A.dtype,
|
|
)
|
|
smoe_B = torch.zeros(
|
|
N_hidden,
|
|
rank * num_experts,
|
|
device=peft_A.device,
|
|
dtype=peft_A.dtype,
|
|
)
|
|
for e in range(num_experts):
|
|
s = e * rank
|
|
smoe_A[s : s + rank, :] = peft_B_em[:, s : s + rank].T
|
|
smoe_B[:, s : s + rank] = peft_A[s : s + rank, :].T
|
|
return smoe_A, smoe_B
|
|
|
|
def _unwrap_experts_lora(experts_module):
|
|
return experts_module, None, None
|
|
|
|
def _unwrap_gate_lora(gate_module):
|
|
if hasattr(gate_module, "base_layer") and hasattr(gate_module, "lora_A"):
|
|
base_gate = gate_module.base_layer
|
|
active = getattr(gate_module, "active_adapters", ["default"])
|
|
name = active[0] if active else "default"
|
|
lora_A_dict = getattr(gate_module, "lora_A", {})
|
|
lora_B_dict = getattr(gate_module, "lora_B", {})
|
|
scaling_dict = getattr(gate_module, "scaling", {})
|
|
if name in lora_A_dict:
|
|
lora_A = lora_A_dict[name].weight
|
|
lora_B = lora_B_dict[name].weight
|
|
s = scaling_dict[name]
|
|
delta = s * (lora_B @ lora_A)
|
|
return base_gate, base_gate.weight, delta
|
|
return base_gate, base_gate.weight, None
|
|
return gate_module, gate_module.weight, None
|
|
|
|
|
|
# =============================================================================
|
|
# Configuration
|
|
# =============================================================================
|
|
|
|
FULL_OLMOE_CONFIG = dict(
|
|
hidden_size=2048,
|
|
intermediate_size=1024,
|
|
num_experts=64,
|
|
num_experts_per_tok=8,
|
|
hidden_act="silu",
|
|
norm_topk_prob=False,
|
|
)
|
|
|
|
SMALL_OLMOE_CONFIG = dict(
|
|
hidden_size=128,
|
|
intermediate_size=48, # non-square: 2*inter=96 != hidden=128
|
|
num_experts=8,
|
|
num_experts_per_tok=2,
|
|
hidden_act="silu",
|
|
norm_topk_prob=False,
|
|
)
|
|
|
|
requires_cuda = pytest.mark.skipif(
|
|
not torch.cuda.is_available(), reason="CUDA not available"
|
|
)
|
|
|
|
|
|
def make_olmoe_config(use_full=False):
|
|
cfg = dict(FULL_OLMOE_CONFIG if use_full else SMALL_OLMOE_CONFIG)
|
|
cfg["experts_implementation"] = "grouped_mm"
|
|
return OlmoeConfig(**cfg)
|
|
|
|
|
|
# =============================================================================
|
|
# Layout conversion utilities (test-local helpers)
|
|
# =============================================================================
|
|
|
|
|
|
def scattermoe_lora_B_to_peft(smoe_B, num_experts, rank):
|
|
"""Inverse of ``peft_lora_B_to_scattermoe``."""
|
|
N = smoe_B.shape[0]
|
|
return (
|
|
smoe_B.reshape(N, num_experts, rank)
|
|
.permute(0, 2, 1)
|
|
.contiguous()
|
|
.reshape(N, num_experts * rank)
|
|
)
|
|
|
|
|
|
def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
|
"""Convert peft LoRA for gate_up_proj to scattermoe layout.
|
|
|
|
Both gate_up_proj and down_proj need the A<->B swap because
|
|
scattermoe transposes the parameter (W = param.T).
|
|
"""
|
|
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
|
|
|
|
|
# =============================================================================
|
|
# Helpers
|
|
# =============================================================================
|
|
|
|
|
|
def _init_expert_weights(moe_block):
|
|
"""Initialize OlmoeExperts parameters which use torch.empty (uninitialized).
|
|
|
|
Without this, gate_up_proj and down_proj contain garbage/NaN values.
|
|
"""
|
|
with torch.no_grad():
|
|
nn.init.kaiming_uniform_(moe_block.experts.gate_up_proj)
|
|
nn.init.kaiming_uniform_(moe_block.experts.down_proj)
|
|
return moe_block
|
|
|
|
|
|
class MinimalOLMoEModel(nn.Module):
|
|
"""Thin wrapper so peft's get_peft_model can attach adapters."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.moe = OlmoeSparseMoeBlock(config)
|
|
_init_expert_weights(self.moe)
|
|
|
|
def forward(self, x):
|
|
return self.moe(x)
|
|
|
|
|
|
def _get_routing(moe_block, hidden_states):
|
|
"""Run the router and return (routing_weights, selected_experts)."""
|
|
with torch.no_grad():
|
|
_, routing_weights, selected_experts = moe_block.gate(
|
|
hidden_states.view(-1, hidden_states.size(-1))
|
|
)
|
|
return routing_weights, selected_experts
|
|
|
|
|
|
def _reference_moe_forward(
|
|
x_flat,
|
|
gate_up_proj,
|
|
down_proj,
|
|
act_fn,
|
|
top_k_index,
|
|
top_k_weights,
|
|
num_experts,
|
|
):
|
|
"""Pure-PyTorch per-expert reference MoE forward (no LoRA).
|
|
|
|
Uses F.linear per expert for an apples-to-apples comparison with
|
|
the ScatterMoE kernel path.
|
|
"""
|
|
final = torch.zeros_like(x_flat)
|
|
expert_mask = F.one_hot(top_k_index, num_classes=num_experts).permute(2, 1, 0)
|
|
for e in range(num_experts):
|
|
top_k_pos, token_idx = torch.where(expert_mask[e])
|
|
if token_idx.numel() == 0:
|
|
continue
|
|
cur = x_flat[token_idx]
|
|
gate_up = F.linear(cur, gate_up_proj[e])
|
|
g, u = gate_up.chunk(2, dim=-1)
|
|
h = act_fn(g) * u
|
|
out = F.linear(h, down_proj[e])
|
|
out = out * top_k_weights[token_idx, top_k_pos, None]
|
|
final.index_add_(0, token_idx, out.to(final.dtype))
|
|
return final
|
|
|
|
|
|
def _reference_moe_forward_with_lora(
|
|
x_flat,
|
|
gate_up_proj,
|
|
down_proj,
|
|
act_fn,
|
|
top_k_index,
|
|
top_k_weights,
|
|
num_experts,
|
|
gup_delta,
|
|
down_delta,
|
|
):
|
|
"""Pure-PyTorch reference MoE forward with pre-computed weight deltas."""
|
|
merged_gup = gate_up_proj + gup_delta
|
|
merged_down = down_proj + down_delta
|
|
return _reference_moe_forward(
|
|
x_flat,
|
|
merged_gup,
|
|
merged_down,
|
|
act_fn,
|
|
top_k_index,
|
|
top_k_weights,
|
|
num_experts,
|
|
)
|
|
|
|
|
|
def _compute_delta_from_scattermoe_lora(lora_A, lora_B, scaling, E, r, param_shape):
|
|
"""Compute additive weight delta from scattermoe-layout LoRA weights.
|
|
|
|
delta[e] = scaling * B_e @ A_e where A_e [r,K], B_e [N,r] -> [N,K].
|
|
"""
|
|
delta = torch.zeros(param_shape, device=lora_A.device, dtype=lora_A.dtype)
|
|
for e in range(E):
|
|
A_e = lora_A[e * r : (e + 1) * r, :]
|
|
B_e = lora_B[:, e * r : (e + 1) * r]
|
|
delta[e] = scaling * (B_e @ A_e)
|
|
return delta
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: Layout conversion
|
|
# =============================================================================
|
|
|
|
|
|
class TestLoRABLayoutConversion:
|
|
"""Test the peft <-> scattermoe lora_B layout conversion."""
|
|
|
|
def test_roundtrip(self):
|
|
E, r, N = 8, 4, 64
|
|
original = torch.randn(N, E * r)
|
|
converted = peft_lora_B_to_scattermoe(original, E, r)
|
|
back = scattermoe_lora_B_to_peft(converted, E, r)
|
|
torch.testing.assert_close(back, original)
|
|
|
|
def test_per_expert_slices(self):
|
|
"""After conversion, scattermoe slicing gives the same per-expert
|
|
matrices as peft's reshape slicing."""
|
|
E, r, N = 4, 2, 16
|
|
peft_B = torch.randn(N, E * r)
|
|
smoe_B = peft_lora_B_to_scattermoe(peft_B, E, r)
|
|
|
|
peft_reshaped = peft_B.reshape(N, r, E)
|
|
for e in range(E):
|
|
torch.testing.assert_close(
|
|
smoe_B[:, e * r : (e + 1) * r],
|
|
peft_reshaped[:, :, e],
|
|
)
|
|
|
|
def test_lora_A_already_compatible(self):
|
|
"""lora_A layout is identical between peft and scattermoe."""
|
|
E, r, K = 4, 2, 16
|
|
lora_A = torch.randn(E * r, K)
|
|
peft_reshaped = lora_A.reshape(E, r, K)
|
|
for e in range(E):
|
|
torch.testing.assert_close(
|
|
lora_A[e * r : (e + 1) * r, :],
|
|
peft_reshaped[e],
|
|
)
|
|
|
|
def test_delta_weight_equivalence(self):
|
|
"""peft's einsum delta matches per-expert B @ A with converted layouts."""
|
|
E, r, K, N = 8, 4, 32, 64
|
|
peft_A = torch.randn(E * r, K)
|
|
peft_B = torch.randn(N, E * r)
|
|
scaling = 2.0
|
|
|
|
A_r = peft_A.reshape(E, r, K)
|
|
B_r = peft_B.reshape(N, r, E)
|
|
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
|
|
|
|
smoe_B = peft_lora_B_to_scattermoe(peft_B, E, r)
|
|
for e in range(E):
|
|
A_e = peft_A[e * r : (e + 1) * r, :]
|
|
B_e = smoe_B[:, e * r : (e + 1) * r]
|
|
delta_e = scaling * (B_e @ A_e)
|
|
torch.testing.assert_close(delta_e, delta_peft[e].T, atol=1e-5, rtol=1e-5)
|
|
|
|
def test_down_proj_conversion(self):
|
|
"""Verify peft_lora_to_scattermoe produces correct delta."""
|
|
E, r = 4, 2
|
|
hidden, inter = 32, 16
|
|
scaling = 2.0
|
|
|
|
peft_A = torch.randn(E * r, hidden)
|
|
peft_B = torch.randn(inter, E * r)
|
|
|
|
A_r = peft_A.reshape(E, r, hidden)
|
|
B_r = peft_B.reshape(inter, r, E)
|
|
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
|
|
|
|
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
|
|
for e in range(E):
|
|
A_e = smoe_A[e * r : (e + 1) * r, :]
|
|
B_e = smoe_B[:, e * r : (e + 1) * r]
|
|
delta_smoe_e = scaling * (B_e @ A_e)
|
|
torch.testing.assert_close(
|
|
delta_smoe_e, delta_peft[e], atol=1e-5, rtol=1e-5
|
|
)
|
|
|
|
def test_gate_up_proj_conversion(self):
|
|
"""Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like).
|
|
|
|
gate_up_proj param: [E, 2*inter, hidden].
|
|
peft: in_features=2*inter, out_features=hidden.
|
|
peft lora_A: [r*E, 2*inter], lora_B: [hidden, r*E].
|
|
|
|
scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter.
|
|
scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].
|
|
|
|
Uses non-square dims (hidden=32 != 2*inter=24) to catch A<->B swap bugs.
|
|
"""
|
|
E, r = 4, 2
|
|
hidden, inter = 32, 12 # 2*inter=24 != hidden=32
|
|
scaling = 2.0
|
|
|
|
# peft assigns: in_features=2*inter, out_features=hidden
|
|
peft_A = torch.randn(E * r, 2 * inter) # [r*E, in_features=2*inter]
|
|
peft_B = torch.randn(hidden, E * r) # [out_features=hidden, r*E]
|
|
|
|
# peft delta via einsum: "o r e, e r i -> e i o"
|
|
A_r = peft_A.reshape(E, r, 2 * inter)
|
|
B_r = peft_B.reshape(hidden, r, E)
|
|
delta_peft = torch.einsum("o r e, e r i -> e i o", B_r, A_r) * scaling
|
|
# delta_peft[e] has shape [in_features, out_features] = [2*inter, hidden]
|
|
# = param[e] shape [2*inter, hidden]
|
|
|
|
smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)
|
|
# smoe_A should be [r*E, K=hidden], smoe_B should be [N=2*inter, r*E]
|
|
assert smoe_A.shape == (E * r, hidden), (
|
|
f"Expected {(E * r, hidden)}, got {smoe_A.shape}"
|
|
)
|
|
assert smoe_B.shape == (2 * inter, E * r), (
|
|
f"Expected {(2 * inter, E * r)}, got {smoe_B.shape}"
|
|
)
|
|
|
|
for e in range(E):
|
|
A_e = smoe_A[e * r : (e + 1) * r, :] # [r, K=hidden]
|
|
B_e = smoe_B[:, e * r : (e + 1) * r] # [N=2*inter, r]
|
|
delta_smoe_e = scaling * (B_e @ A_e) # [2*inter, hidden]
|
|
# Should match peft delta which is [2*inter, hidden] = param[e]
|
|
torch.testing.assert_close(
|
|
delta_smoe_e, delta_peft[e], atol=1e-5, rtol=1e-5
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: peft weight extraction
|
|
# =============================================================================
|
|
|
|
|
|
class TestPeftLoRAWeightExtraction:
|
|
"""Test extracting peft LoRA weights for OLMoE."""
|
|
|
|
def test_peft_creates_correct_shapes(self):
|
|
config = make_olmoe_config(use_full=False)
|
|
E, r = config.num_experts, 4
|
|
|
|
model = MinimalOLMoEModel(config)
|
|
lora_config = LoraConfig(
|
|
r=r,
|
|
lora_alpha=16,
|
|
target_modules=[],
|
|
target_parameters=[
|
|
"gate.weight",
|
|
"experts.gate_up_proj",
|
|
"experts.down_proj",
|
|
],
|
|
bias="none",
|
|
)
|
|
peft_model = get_peft_model(model, lora_config)
|
|
trainable = {n: p for n, p in peft_model.named_parameters() if p.requires_grad}
|
|
|
|
# Gate router
|
|
assert trainable["base_model.model.moe.gate.lora_A.default.weight"].shape == (
|
|
r,
|
|
config.hidden_size,
|
|
)
|
|
assert trainable["base_model.model.moe.gate.lora_B.default.weight"].shape == (
|
|
E,
|
|
r,
|
|
)
|
|
|
|
# gate_up_proj [E, 2*inter, hidden]
|
|
# peft: in_features=2*inter (dim 1), out_features=hidden (dim 2)
|
|
assert trainable[
|
|
"base_model.model.moe.experts.base_layer.lora_A.default.weight"
|
|
].shape == (E * r, 2 * config.intermediate_size)
|
|
assert trainable[
|
|
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
|
|
].shape == (config.hidden_size, E * r)
|
|
|
|
# down_proj [E, hidden, inter]
|
|
# peft: in_features=hidden (dim 1), out_features=inter (dim 2)
|
|
assert trainable[
|
|
"base_model.model.moe.experts.lora_A.default.weight"
|
|
].shape == (E * r, config.hidden_size)
|
|
assert trainable[
|
|
"base_model.model.moe.experts.lora_B.default.weight"
|
|
].shape == (config.intermediate_size, E * r)
|
|
|
|
@requires_cuda
|
|
def test_peft_forward_runs(self):
|
|
"""Smoke test: peft model forward pass completes (needs CUDA for grouped_mm)."""
|
|
config = make_olmoe_config(use_full=False)
|
|
model = MinimalOLMoEModel(config)
|
|
lora_config = LoraConfig(
|
|
r=4,
|
|
lora_alpha=16,
|
|
target_modules=[],
|
|
target_parameters=[
|
|
"gate.weight",
|
|
"experts.gate_up_proj",
|
|
"experts.down_proj",
|
|
],
|
|
bias="none",
|
|
)
|
|
peft_model = get_peft_model(model, lora_config)
|
|
x = torch.randn(1, 4, config.hidden_size)
|
|
out = peft_model(x)
|
|
assert out.shape == x.shape
|
|
|
|
@pytest.mark.skipif(
|
|
not HAS_SCATTERMOE, reason="scattermoe_lora not importable (no triton)"
|
|
)
|
|
def test_unwrap_experts_lora(self):
|
|
"""Test that _unwrap_experts_lora correctly detects LoRA wrappers."""
|
|
config = make_olmoe_config(use_full=False)
|
|
model = MinimalOLMoEModel(config)
|
|
lora_config = LoraConfig(
|
|
r=4,
|
|
lora_alpha=16,
|
|
target_modules=[],
|
|
target_parameters=["experts.gate_up_proj", "experts.down_proj"],
|
|
bias="none",
|
|
)
|
|
peft_model = get_peft_model(model, lora_config)
|
|
base_moe = peft_model.base_model.model.moe
|
|
|
|
# Experts should be wrapped by ParamWrapper
|
|
experts, gup_lora, down_lora = _unwrap_experts_lora(base_moe.experts)
|
|
|
|
# Base experts should have the raw parameters
|
|
assert hasattr(experts, "gate_up_proj")
|
|
assert hasattr(experts, "down_proj")
|
|
|
|
# LoRA should be detected
|
|
assert gup_lora is not None, "gate_up_proj LoRA not detected"
|
|
assert down_lora is not None, "down_proj LoRA not detected"
|
|
|
|
# Check shapes (after peft->scattermoe conversion with A<->B swap)
|
|
# gate_up_proj W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter
|
|
E, r = config.num_experts, 4
|
|
gup_A, gup_B, gup_s = gup_lora
|
|
assert gup_A.shape == (E * r, config.hidden_size), (
|
|
f"gate_up_proj smoe_A: expected [r*E, K=hidden]={(E * r, config.hidden_size)}, "
|
|
f"got {gup_A.shape}"
|
|
)
|
|
assert gup_B.shape == (2 * config.intermediate_size, E * r), (
|
|
f"gate_up_proj smoe_B: expected [N=2*inter, r*E]="
|
|
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
|
|
)
|
|
|
|
# down_proj W = param.T = [E, inter, hidden], K=inter, N=hidden
|
|
down_A, down_B, down_s = down_lora
|
|
assert down_A.shape == (E * r, config.intermediate_size), (
|
|
f"down_proj smoe_A: expected [r*E, K=inter]={(E * r, config.intermediate_size)}, "
|
|
f"got {down_A.shape}"
|
|
)
|
|
assert down_B.shape == (config.hidden_size, E * r), (
|
|
f"down_proj smoe_B: expected [N=hidden, r*E]={(config.hidden_size, E * r)}, "
|
|
f"got {down_B.shape}"
|
|
)
|
|
|
|
def test_unwrap_no_lora(self):
|
|
"""Without peft, _unwrap_experts_lora returns no LoRA."""
|
|
config = make_olmoe_config(use_full=False)
|
|
moe = OlmoeSparseMoeBlock(config)
|
|
experts, gup_lora, down_lora = _unwrap_experts_lora(moe.experts)
|
|
assert gup_lora is None
|
|
assert down_lora is None
|
|
assert hasattr(experts, "gate_up_proj")
|
|
|
|
def test_unwrap_gate_lora(self):
|
|
"""Test that _unwrap_gate_lora detects LoRA on the router gate."""
|
|
config = make_olmoe_config(use_full=False)
|
|
model = MinimalOLMoEModel(config)
|
|
r = 4
|
|
lora_config = LoraConfig(
|
|
r=r,
|
|
lora_alpha=16,
|
|
target_modules=[],
|
|
target_parameters=["gate.weight"],
|
|
bias="none",
|
|
)
|
|
peft_model = get_peft_model(model, lora_config)
|
|
base_moe = peft_model.base_model.model.moe
|
|
|
|
# Set non-zero LoRA weights (peft initializes lora_B to zeros)
|
|
with torch.no_grad():
|
|
base_moe.gate.lora_B["default"].weight.normal_(0, 0.01)
|
|
|
|
base_gate, gate_weight, gate_delta = _unwrap_gate_lora(base_moe.gate)
|
|
|
|
# Base gate should be the original router
|
|
assert hasattr(base_gate, "top_k")
|
|
assert hasattr(base_gate, "num_experts")
|
|
assert base_gate.top_k == config.num_experts_per_tok
|
|
assert base_gate.num_experts == config.num_experts
|
|
|
|
# Gate weight should be the base weight (delta returned separately)
|
|
assert gate_weight.shape == (config.num_experts, config.hidden_size)
|
|
torch.testing.assert_close(gate_weight, base_gate.weight)
|
|
|
|
# Delta should be non-zero (LoRA was applied)
|
|
assert gate_delta is not None
|
|
assert gate_delta.shape == (config.num_experts, config.hidden_size)
|
|
assert gate_delta.abs().max() > 0, "Gate LoRA delta should be non-zero"
|
|
|
|
def test_unwrap_gate_no_lora(self):
|
|
"""Without peft, _unwrap_gate_lora returns the original gate."""
|
|
config = make_olmoe_config(use_full=False)
|
|
moe = OlmoeSparseMoeBlock(config)
|
|
base_gate, gate_weight, gate_delta = _unwrap_gate_lora(moe.gate)
|
|
assert base_gate is moe.gate
|
|
torch.testing.assert_close(gate_weight, moe.gate.weight)
|
|
assert gate_delta is None
|
|
|
|
def test_gate_lora_delta_matches_peft(self):
|
|
"""Verify _unwrap_gate_lora computes the same delta as peft."""
|
|
config = make_olmoe_config(use_full=False)
|
|
model = MinimalOLMoEModel(config)
|
|
r = 4
|
|
lora_alpha = 16
|
|
scaling = lora_alpha / r
|
|
lora_config = LoraConfig(
|
|
r=r,
|
|
lora_alpha=lora_alpha,
|
|
target_modules=[],
|
|
target_parameters=["gate.weight"],
|
|
bias="none",
|
|
)
|
|
peft_model = get_peft_model(model, lora_config)
|
|
base_moe = peft_model.base_model.model.moe
|
|
|
|
# Our unwrapped weight + delta
|
|
_, gate_weight, gate_delta = _unwrap_gate_lora(base_moe.gate)
|
|
|
|
# Manually compute expected delta
|
|
lora_A = base_moe.gate.lora_A["default"].weight # [r, hidden]
|
|
lora_B = base_moe.gate.lora_B["default"].weight # [E, r]
|
|
base_weight = base_moe.gate.base_layer.weight # [E, hidden]
|
|
expected_delta = scaling * (lora_B @ lora_A)
|
|
|
|
torch.testing.assert_close(gate_weight, base_weight)
|
|
torch.testing.assert_close(gate_delta, expected_delta)
|
|
# Combined should match the old behavior
|
|
torch.testing.assert_close(
|
|
gate_weight + gate_delta, base_weight + expected_delta
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: Base forward equivalence (no LoRA)
|
|
# =============================================================================
|
|
|
|
|
|
@requires_cuda
|
|
class TestOLMoEReferenceVsScatterMoE:
|
|
"""Base forward equivalence: per-expert reference vs ScatterMoE kernels."""
|
|
|
|
def test_small(self):
|
|
self._run(use_full=False, M=16)
|
|
|
|
@pytest.mark.slow
|
|
def test_full(self):
|
|
self._run(use_full=True, M=32)
|
|
|
|
def _run(self, use_full, M):
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora import (
|
|
flatten_sort_count,
|
|
parallel_linear,
|
|
)
|
|
|
|
config = make_olmoe_config(use_full=use_full)
|
|
torch.manual_seed(42)
|
|
moe = _init_expert_weights(OlmoeSparseMoeBlock(config)).cuda().float()
|
|
E, k = config.num_experts, config.num_experts_per_tok
|
|
|
|
x = torch.randn(1, M, config.hidden_size, device="cuda")
|
|
x_flat = x.view(-1, config.hidden_size)
|
|
|
|
with torch.no_grad():
|
|
# Shared routing for both paths
|
|
_, rw, sel = moe.gate(x_flat)
|
|
sei, ssi, eo = flatten_sort_count(sel, num_experts=E)
|
|
|
|
# Per-expert reference
|
|
ref_out = _reference_moe_forward(
|
|
x_flat,
|
|
moe.experts.gate_up_proj,
|
|
moe.experts.down_proj,
|
|
moe.experts.act_fn,
|
|
sel,
|
|
rw,
|
|
E,
|
|
).view(1, M, config.hidden_size)
|
|
|
|
# ScatterMoE kernel path
|
|
gup = parallel_linear(
|
|
x_flat,
|
|
moe.experts.gate_up_proj.transpose(2, 1),
|
|
k,
|
|
sei,
|
|
ssi,
|
|
eo,
|
|
grouped_in=False,
|
|
grouped_out=True,
|
|
)
|
|
g, u = gup.chunk(2, dim=-1)
|
|
h = moe.experts.act_fn(g) * u
|
|
|
|
smoe_out = parallel_linear(
|
|
h,
|
|
moe.experts.down_proj.transpose(2, 1),
|
|
1,
|
|
sei,
|
|
ssi,
|
|
eo,
|
|
grouped_in=True,
|
|
grouped_out=False,
|
|
gates=rw,
|
|
).view(1, M, config.hidden_size)
|
|
|
|
torch.testing.assert_close(smoe_out, ref_out, atol=1e-3, rtol=1e-3)
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: LoRA forward equivalence (peft vs scattermoe fused)
|
|
# =============================================================================
|
|
|
|
|
|
@requires_cuda
|
|
class TestOLMoEPeftLoRAForward:
|
|
"""Fused LoRA forward: peft merged-weight vs scattermoe_lora kernel."""
|
|
|
|
def test_small(self):
|
|
self._run(use_full=False, M=16, r=4)
|
|
|
|
@pytest.mark.slow
|
|
def test_full(self):
|
|
self._run(use_full=True, M=32, r=8)
|
|
|
|
def _run(self, use_full, M, r):
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora import (
|
|
flatten_sort_count,
|
|
parallel_linear_lora,
|
|
)
|
|
|
|
config = make_olmoe_config(use_full=use_full)
|
|
E, k = config.num_experts, config.num_experts_per_tok
|
|
lora_alpha = 16
|
|
scaling = lora_alpha / r
|
|
|
|
# Create peft model
|
|
model = MinimalOLMoEModel(config).cuda().float()
|
|
lora_config = LoraConfig(
|
|
r=r,
|
|
lora_alpha=lora_alpha,
|
|
target_modules=[],
|
|
target_parameters=["experts.gate_up_proj", "experts.down_proj"],
|
|
bias="none",
|
|
)
|
|
peft_model = get_peft_model(model, lora_config)
|
|
|
|
torch.manual_seed(42)
|
|
x = torch.randn(1, M, config.hidden_size, device="cuda")
|
|
|
|
# peft forward
|
|
with torch.no_grad():
|
|
peft_out = peft_model(x)
|
|
|
|
# Extract base weights and LoRA weights
|
|
base_moe = peft_model.base_model.model.moe
|
|
base_experts = base_moe.experts.base_layer.base_layer
|
|
gate_up_proj = base_experts.gate_up_proj
|
|
down_proj = base_experts.down_proj
|
|
act_fn = base_experts.act_fn
|
|
|
|
# gate_up_proj LoRA
|
|
gup_w = base_moe.experts.base_layer
|
|
peft_gup_A = gup_w.lora_A["default"].weight.detach()
|
|
peft_gup_B = gup_w.lora_B["default"].weight.detach()
|
|
smoe_gup_A, smoe_gup_B = peft_gate_up_lora_to_scattermoe(
|
|
peft_gup_A, peft_gup_B, E, r
|
|
)
|
|
|
|
# down_proj LoRA
|
|
down_w = base_moe.experts
|
|
peft_down_A = down_w.lora_A["default"].weight.detach()
|
|
peft_down_B = down_w.lora_B["default"].weight.detach()
|
|
smoe_down_A, smoe_down_B = peft_lora_to_scattermoe(
|
|
peft_down_A, peft_down_B, E, r
|
|
)
|
|
|
|
# ScatterMoE fused forward -- gate is NOT peft-wrapped, access directly
|
|
x_flat = x.view(-1, config.hidden_size)
|
|
|
|
with torch.no_grad():
|
|
_, rw, sel = base_moe.gate(x_flat)
|
|
sei, ssi, eo = flatten_sort_count(sel, num_experts=E)
|
|
|
|
gup = parallel_linear_lora(
|
|
x_flat,
|
|
gate_up_proj.transpose(2, 1),
|
|
k,
|
|
sei,
|
|
ssi,
|
|
eo,
|
|
lora_A=smoe_gup_A,
|
|
lora_B=smoe_gup_B,
|
|
scaling=scaling,
|
|
grouped_in=False,
|
|
grouped_out=True,
|
|
)
|
|
g, u = gup.chunk(2, dim=-1)
|
|
h = act_fn(g) * u
|
|
|
|
smoe_out = parallel_linear_lora(
|
|
h,
|
|
down_proj.transpose(2, 1),
|
|
1,
|
|
sei,
|
|
ssi,
|
|
eo,
|
|
lora_A=smoe_down_A,
|
|
lora_B=smoe_down_B,
|
|
scaling=scaling,
|
|
grouped_in=True,
|
|
grouped_out=False,
|
|
gates=rw,
|
|
).view(1, M, config.hidden_size)
|
|
|
|
torch.testing.assert_close(smoe_out, peft_out, atol=5e-3, rtol=5e-3)
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: Backward gradient correctness
|
|
# =============================================================================
|
|
|
|
|
|
@requires_cuda
|
|
class TestOLMoEPeftLoRABackward:
|
|
"""Backward gradients through scattermoe_lora vs pure-PyTorch reference."""
|
|
|
|
def test_small(self):
|
|
self._run(use_full=False, M=16, r=4)
|
|
|
|
def _run(self, use_full, M, r):
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora import (
|
|
flatten_sort_count,
|
|
parallel_linear_lora,
|
|
)
|
|
|
|
config = make_olmoe_config(use_full=use_full)
|
|
E, k = config.num_experts, config.num_experts_per_tok
|
|
lora_alpha = 16
|
|
scaling = lora_alpha / r
|
|
|
|
torch.manual_seed(42)
|
|
moe = _init_expert_weights(OlmoeSparseMoeBlock(config)).cuda().float()
|
|
x = torch.randn(1, M, config.hidden_size, device="cuda")
|
|
x_flat = x.view(-1, config.hidden_size)
|
|
gate_up_proj = moe.experts.gate_up_proj
|
|
down_proj = moe.experts.down_proj
|
|
|
|
# Create LoRA weights in scattermoe layout directly
|
|
gup_A = torch.randn(r * E, config.hidden_size, device="cuda") * 0.01
|
|
gup_B = torch.randn(2 * config.intermediate_size, r * E, device="cuda") * 0.01
|
|
down_A = torch.randn(r * E, config.intermediate_size, device="cuda") * 0.01
|
|
down_B = torch.randn(config.hidden_size, r * E, device="cuda") * 0.01
|
|
|
|
rw, sel = _get_routing(moe, x)
|
|
sei, ssi, eo = flatten_sort_count(sel, num_experts=E)
|
|
|
|
# --- Reference ---
|
|
gup_delta = _compute_delta_from_scattermoe_lora(
|
|
gup_A, gup_B, scaling, E, r, gate_up_proj.shape
|
|
)
|
|
down_delta = _compute_delta_from_scattermoe_lora(
|
|
down_A, down_B, scaling, E, r, down_proj.shape
|
|
)
|
|
|
|
x_ref = x_flat.clone().detach().requires_grad_(True)
|
|
ref_out = _reference_moe_forward_with_lora(
|
|
x_ref,
|
|
gate_up_proj,
|
|
down_proj,
|
|
moe.experts.act_fn,
|
|
sel,
|
|
rw,
|
|
E,
|
|
gup_delta,
|
|
down_delta,
|
|
)
|
|
ref_out.sum().backward()
|
|
|
|
# --- ScatterMoE fused path ---
|
|
x_smoe = x_flat.clone().detach().requires_grad_(True)
|
|
gup_A_s = gup_A.clone().requires_grad_(True)
|
|
gup_B_s = gup_B.clone().requires_grad_(True)
|
|
down_A_s = down_A.clone().requires_grad_(True)
|
|
down_B_s = down_B.clone().requires_grad_(True)
|
|
|
|
gup_out = parallel_linear_lora(
|
|
x_smoe,
|
|
gate_up_proj.transpose(2, 1),
|
|
k,
|
|
sei,
|
|
ssi,
|
|
eo,
|
|
lora_A=gup_A_s,
|
|
lora_B=gup_B_s,
|
|
scaling=scaling,
|
|
grouped_in=False,
|
|
grouped_out=True,
|
|
)
|
|
g, u = gup_out.chunk(2, dim=-1)
|
|
h = moe.experts.act_fn(g) * u
|
|
|
|
smoe_out = parallel_linear_lora(
|
|
h,
|
|
down_proj.transpose(2, 1),
|
|
1,
|
|
sei,
|
|
ssi,
|
|
eo,
|
|
lora_A=down_A_s,
|
|
lora_B=down_B_s,
|
|
scaling=scaling,
|
|
grouped_in=True,
|
|
grouped_out=False,
|
|
gates=rw,
|
|
)
|
|
smoe_out.sum().backward()
|
|
|
|
torch.testing.assert_close(
|
|
smoe_out.detach(),
|
|
ref_out.detach(),
|
|
atol=5e-3,
|
|
rtol=5e-3,
|
|
)
|
|
torch.testing.assert_close(
|
|
x_smoe.grad,
|
|
x_ref.grad,
|
|
atol=5e-2,
|
|
rtol=5e-2,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: kernelize() integration via LocalLayerRepository
|
|
# =============================================================================
|
|
|
|
|
|
@requires_cuda
|
|
class TestKernelizeIntegration:
|
|
"""Test the HF kernels library integration with LocalLayerRepository."""
|
|
|
|
@staticmethod
|
|
def _get_kernelize_imports():
|
|
"""Import kernels library components, skip if not available."""
|
|
try:
|
|
from kernels import (
|
|
LocalLayerRepository,
|
|
Mode,
|
|
kernelize,
|
|
register_kernel_mapping,
|
|
replace_kernel_forward_from_hub,
|
|
)
|
|
|
|
return (
|
|
LocalLayerRepository,
|
|
Mode,
|
|
register_kernel_mapping,
|
|
replace_kernel_forward_from_hub,
|
|
kernelize,
|
|
)
|
|
except ImportError:
|
|
pytest.skip("kernels library not installed")
|
|
|
|
@staticmethod
|
|
def _get_repo_path():
|
|
"""Get the path to scattermoe_lora within axolotl's plugin."""
|
|
return (
|
|
Path(__file__).parent.parent.parent.parent
|
|
/ "src"
|
|
/ "axolotl"
|
|
/ "integrations"
|
|
/ "kernels"
|
|
/ "libs"
|
|
/ "scattermoe_lora"
|
|
)
|
|
|
|
def _setup_kernels(
|
|
self,
|
|
LocalLayerRepository,
|
|
Mode,
|
|
register_kernel_mapping,
|
|
replace_kernel_forward_from_hub,
|
|
):
|
|
"""Register kernel mapping for tests."""
|
|
repo_path = self._get_repo_path()
|
|
local_repo = LocalLayerRepository(
|
|
repo_path=repo_path,
|
|
package_name="scattermoe_lora",
|
|
layer_name="HFScatterMoEGatedMLP",
|
|
)
|
|
|
|
replace_kernel_forward_from_hub(
|
|
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
|
|
)
|
|
register_kernel_mapping(
|
|
{
|
|
"HFScatterMoEParallelExperts": {
|
|
"cuda": {
|
|
Mode.TRAINING: local_repo,
|
|
Mode.INFERENCE: local_repo,
|
|
},
|
|
}
|
|
}
|
|
)
|
|
|
|
def test_base_forward_via_kernelize(self):
|
|
"""Kernelized OlmoeSparseMoeBlock (no LoRA) matches per-expert reference."""
|
|
(
|
|
LocalLayerRepository,
|
|
Mode,
|
|
register_kernel_mapping,
|
|
replace_kernel_forward_from_hub,
|
|
kernelize,
|
|
) = self._get_kernelize_imports()
|
|
|
|
config = make_olmoe_config(use_full=False)
|
|
E = config.num_experts
|
|
|
|
# Create model
|
|
torch.manual_seed(42)
|
|
moe = _init_expert_weights(OlmoeSparseMoeBlock(config)).cuda().float()
|
|
x = torch.randn(1, 8, config.hidden_size, device="cuda")
|
|
x_flat = x.view(-1, config.hidden_size)
|
|
|
|
# Compute reference BEFORE kernelizing
|
|
with torch.no_grad():
|
|
_, rw, sel = moe.gate(x_flat)
|
|
ref_out = _reference_moe_forward(
|
|
x_flat,
|
|
moe.experts.gate_up_proj,
|
|
moe.experts.down_proj,
|
|
moe.experts.act_fn,
|
|
sel,
|
|
rw,
|
|
E,
|
|
).view(1, 8, config.hidden_size)
|
|
|
|
# Set up kernel mapping
|
|
self._setup_kernels(
|
|
LocalLayerRepository,
|
|
Mode,
|
|
register_kernel_mapping,
|
|
replace_kernel_forward_from_hub,
|
|
)
|
|
|
|
# Kernelize the model
|
|
kernelize(moe, mode=Mode.TRAINING, device="cuda")
|
|
|
|
# Forward through kernelized model
|
|
with torch.no_grad():
|
|
kern_out = moe(x)
|
|
|
|
torch.testing.assert_close(kern_out, ref_out, atol=1e-3, rtol=1e-3)
|
|
|
|
def test_lora_forward_via_kernelize(self):
|
|
"""Kernelized OlmoeSparseMoeBlock with peft LoRA matches reference."""
|
|
(
|
|
LocalLayerRepository,
|
|
Mode,
|
|
register_kernel_mapping,
|
|
replace_kernel_forward_from_hub,
|
|
kernelize,
|
|
) = self._get_kernelize_imports()
|
|
|
|
config = make_olmoe_config(use_full=False)
|
|
r = 4
|
|
|
|
# Create peft model
|
|
torch.manual_seed(42)
|
|
model = MinimalOLMoEModel(config).cuda().float()
|
|
lora_config = LoraConfig(
|
|
r=r,
|
|
lora_alpha=16,
|
|
target_modules=[],
|
|
target_parameters=["experts.gate_up_proj", "experts.down_proj"],
|
|
bias="none",
|
|
)
|
|
peft_model = get_peft_model(model, lora_config)
|
|
|
|
x = torch.randn(1, 8, config.hidden_size, device="cuda")
|
|
|
|
# Reference: peft's own forward (uses _activate_lora context manager)
|
|
with torch.no_grad():
|
|
ref_out = peft_model(x)
|
|
|
|
# Set up kernel mapping
|
|
self._setup_kernels(
|
|
LocalLayerRepository,
|
|
Mode,
|
|
register_kernel_mapping,
|
|
replace_kernel_forward_from_hub,
|
|
)
|
|
|
|
# Kernelize the MoE block inside the peft model
|
|
base_moe = peft_model.base_model.model.moe
|
|
kernelize(base_moe, mode=Mode.TRAINING, device="cuda")
|
|
|
|
# Forward through kernelized peft model
|
|
with torch.no_grad():
|
|
kern_out = peft_model(x)
|
|
|
|
torch.testing.assert_close(kern_out, ref_out, atol=5e-3, rtol=5e-3)
|
|
|
|
def test_gate_lora_forward_via_kernelize(self):
|
|
"""Kernelized forward with gate LoRA matches peft reference."""
|
|
(
|
|
LocalLayerRepository,
|
|
Mode,
|
|
register_kernel_mapping,
|
|
replace_kernel_forward_from_hub,
|
|
kernelize,
|
|
) = self._get_kernelize_imports()
|
|
|
|
config = make_olmoe_config(use_full=False)
|
|
r = 4
|
|
|
|
# Create peft model with gate + experts LoRA
|
|
torch.manual_seed(42)
|
|
model = MinimalOLMoEModel(config).cuda().float()
|
|
lora_config = LoraConfig(
|
|
r=r,
|
|
lora_alpha=16,
|
|
target_modules=[],
|
|
target_parameters=[
|
|
"gate.weight",
|
|
"experts.gate_up_proj",
|
|
"experts.down_proj",
|
|
],
|
|
bias="none",
|
|
)
|
|
peft_model = get_peft_model(model, lora_config)
|
|
|
|
x = torch.randn(1, 8, config.hidden_size, device="cuda")
|
|
|
|
# Reference: peft's own forward
|
|
with torch.no_grad():
|
|
ref_out = peft_model(x)
|
|
|
|
# Set up kernel mapping
|
|
self._setup_kernels(
|
|
LocalLayerRepository,
|
|
Mode,
|
|
register_kernel_mapping,
|
|
replace_kernel_forward_from_hub,
|
|
)
|
|
|
|
# Kernelize the MoE block inside the peft model
|
|
base_moe = peft_model.base_model.model.moe
|
|
kernelize(base_moe, mode=Mode.TRAINING, device="cuda")
|
|
|
|
# Forward through kernelized peft model
|
|
with torch.no_grad():
|
|
kern_out = peft_model(x)
|
|
|
|
torch.testing.assert_close(kern_out, ref_out, atol=5e-3, rtol=5e-3)
|
|
|
|
|
|
# =============================================================================
|
|
# Tests: Shared expert handling
|
|
# =============================================================================
|
|
|
|
|
|
class TestSharedExpertHandling:
|
|
"""Test that HFScatterMoEGatedMLP.forward handles shared experts."""
|
|
|
|
@staticmethod
|
|
def _make_shared_expert_block(config):
|
|
"""Create an OlmoeSparseMoeBlock with a mock shared expert attached."""
|
|
moe = OlmoeSparseMoeBlock(config)
|
|
_init_expert_weights(moe)
|
|
|
|
hidden = config.hidden_size
|
|
inter = config.intermediate_size
|
|
|
|
# Attach a simple shared expert MLP (mimics Qwen2MoE structure)
|
|
class SharedExpertMLP(nn.Module):
|
|
def __init__(self, hidden_size, intermediate_size):
|
|
super().__init__()
|
|
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
|
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
|
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
|
self.act_fn = nn.SiLU()
|
|
|
|
def forward(self, x):
|
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
|
|
moe.shared_expert = SharedExpertMLP(hidden, inter)
|
|
moe.shared_expert_gate = nn.Linear(hidden, 1, bias=False)
|
|
|
|
return moe
|
|
|
|
def test_shared_expert_is_used(self):
|
|
"""Verify shared expert output affects final result."""
|
|
config = make_olmoe_config(use_full=False)
|
|
moe = self._make_shared_expert_block(config)
|
|
|
|
# Compute reference without shared expert
|
|
torch.manual_seed(42)
|
|
x = torch.randn(1, 4, config.hidden_size)
|
|
x_flat = x.view(-1, config.hidden_size)
|
|
|
|
with torch.no_grad():
|
|
# Shared expert contribution
|
|
shared_out = moe.shared_expert(x_flat)
|
|
gate_val = F.sigmoid(moe.shared_expert_gate(x_flat))
|
|
shared_contribution = shared_out * gate_val
|
|
|
|
# Verify shared expert produces non-zero output
|
|
assert shared_contribution.abs().max() > 0
|
|
|
|
@requires_cuda
|
|
def test_shared_expert_forward_via_kernelize(self):
|
|
"""Kernelized forward with shared expert matches manual reference."""
|
|
try:
|
|
from kernels import (
|
|
LocalLayerRepository,
|
|
Mode,
|
|
kernelize,
|
|
register_kernel_mapping,
|
|
replace_kernel_forward_from_hub,
|
|
)
|
|
except ImportError:
|
|
pytest.skip("kernels library not installed")
|
|
|
|
config = make_olmoe_config(use_full=False)
|
|
E = config.num_experts
|
|
|
|
torch.manual_seed(42)
|
|
moe = self._make_shared_expert_block(config).cuda().float()
|
|
x = torch.randn(1, 8, config.hidden_size, device="cuda")
|
|
x_flat = x.view(-1, config.hidden_size)
|
|
|
|
# Compute reference: per-expert + shared expert
|
|
with torch.no_grad():
|
|
_, rw, sel = moe.gate(x_flat)
|
|
|
|
expert_out = _reference_moe_forward(
|
|
x_flat,
|
|
moe.experts.gate_up_proj,
|
|
moe.experts.down_proj,
|
|
moe.experts.act_fn,
|
|
sel,
|
|
rw,
|
|
E,
|
|
)
|
|
shared_out = moe.shared_expert(x_flat)
|
|
gate_val = F.sigmoid(moe.shared_expert_gate(x_flat))
|
|
ref_out = (expert_out + shared_out * gate_val).view(
|
|
1, 8, config.hidden_size
|
|
)
|
|
|
|
# Kernelize
|
|
repo_path = (
|
|
Path(__file__).parent.parent.parent.parent
|
|
/ "src"
|
|
/ "axolotl"
|
|
/ "integrations"
|
|
/ "kernels"
|
|
/ "libs"
|
|
/ "scattermoe_lora"
|
|
)
|
|
local_repo = LocalLayerRepository(
|
|
repo_path=repo_path,
|
|
package_name="scattermoe_lora",
|
|
layer_name="HFScatterMoEGatedMLP",
|
|
)
|
|
|
|
replace_kernel_forward_from_hub(
|
|
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
|
|
)
|
|
register_kernel_mapping(
|
|
{
|
|
"HFScatterMoEParallelExperts": {
|
|
"cuda": {
|
|
Mode.TRAINING: local_repo,
|
|
Mode.INFERENCE: local_repo,
|
|
},
|
|
}
|
|
}
|
|
)
|
|
|
|
kernelize(moe, mode=Mode.TRAINING, device="cuda")
|
|
|
|
with torch.no_grad():
|
|
kern_out = moe(x)
|
|
|
|
torch.testing.assert_close(kern_out, ref_out, atol=1e-3, rtol=1e-3)
|