ScatterMoE LoRA support (#3410)
* scattermoe lora support * fsdp, bf16, dim fixes * expert weights aren't needed in save for bwd since they are frozen * use sonicmoe optim options * update save model from upstream * fixes per code review feedback and add tests * revert removal of CP fix * misc fixes
This commit is contained in:
@@ -18,7 +18,7 @@ datasets==4.5.0
|
||||
deepspeed>=0.18.3
|
||||
trl==0.28.0
|
||||
hf_xet==1.2.0
|
||||
kernels==0.11.5
|
||||
kernels==0.12.1
|
||||
|
||||
trackio>=0.16.1
|
||||
typing-extensions>=4.15.0
|
||||
|
||||
@@ -719,6 +719,8 @@ class AxolotlTrainer(
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||
|
||||
# fix for Context Parallel save
|
||||
if state_dict is None:
|
||||
state_dict = self.accelerator.get_state_dict(self.model)
|
||||
if state_dict is not None:
|
||||
@@ -726,6 +728,7 @@ class AxolotlTrainer(
|
||||
k: v.clone() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
supported_classes = (
|
||||
(PreTrainedModel,)
|
||||
if not is_peft_available()
|
||||
@@ -736,6 +739,7 @@ class AxolotlTrainer(
|
||||
if not isinstance(self.model, supported_classes):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if isinstance(
|
||||
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
|
||||
supported_classes,
|
||||
@@ -745,6 +749,7 @@ class AxolotlTrainer(
|
||||
).save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
@@ -756,11 +761,7 @@ class AxolotlTrainer(
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
self.model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
||||
|
||||
if self.processing_class is not None:
|
||||
self.processing_class.save_pretrained(output_dir)
|
||||
@@ -772,11 +773,7 @@ class AxolotlTrainer(
|
||||
LOG.info(
|
||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||
)
|
||||
save_jinja_files = True
|
||||
if self.axolotl_cfg:
|
||||
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
|
||||
self.data_collator.tokenizer.save_pretrained(
|
||||
output_dir, save_jinja_files=save_jinja_files
|
||||
)
|
||||
self.data_collator.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
@@ -33,3 +33,16 @@ class KernelsArgs(BaseModel):
|
||||
data["experts_implementation"] = "eager"
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def disable_mlp_kernel_scattermoe(cls, data):
|
||||
if data.get("use_scattermoe") is True:
|
||||
if data.get("lora_mlp_kernel") is True:
|
||||
LOG.warning(
|
||||
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
|
||||
)
|
||||
data["lora_mlp_kernel"] = False
|
||||
data["mlp_kernel"] = False
|
||||
|
||||
return data
|
||||
|
||||
0
src/axolotl/integrations/kernels/libs/__init__.py
Normal file
0
src/axolotl/integrations/kernels/libs/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
from . import layers
|
||||
from .lora_ops import ParallelExperts
|
||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
|
||||
|
||||
__all__ = [
|
||||
"layers",
|
||||
"ParallelExperts",
|
||||
"flatten_sort_count",
|
||||
"parallel_linear",
|
||||
"ScatterMoELoRA",
|
||||
"parallel_linear_lora",
|
||||
"lora_ops",
|
||||
]
|
||||
@@ -0,0 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
|
||||
# Adapted from https://github.com/shawntan/scattermoe
|
||||
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
|
||||
#
|
||||
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
from . import lora_ops, ops
|
||||
|
||||
__all__ = ["ops", "lora_ops"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,645 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://github.com/shawntan/scattermoe
|
||||
# Copyright (c) Shawn Tan and ScatterMoE Contributors
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
BLOCK_M = 128
|
||||
ALLOW_TF32 = True
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_expert_block(
|
||||
E_idx,
|
||||
E_mask,
|
||||
M_in_idx,
|
||||
N_block,
|
||||
N_mask,
|
||||
X_ptr,
|
||||
stride_xm,
|
||||
stride_xk,
|
||||
W_ptr,
|
||||
stride_we,
|
||||
stride_wk,
|
||||
stride_wn,
|
||||
K,
|
||||
acc,
|
||||
no_k_mask,
|
||||
BLOCK_K,
|
||||
allow_tf32=True,
|
||||
):
|
||||
K_block = tl.arange(0, BLOCK_K)
|
||||
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
|
||||
W_blk_ptrs = (
|
||||
W_ptr
|
||||
+ K_block[:, None] * stride_wk
|
||||
+ N_block[None, :] * stride_wn
|
||||
+ E_idx * stride_we
|
||||
)
|
||||
iters = tl.cdiv(K, BLOCK_K)
|
||||
|
||||
for K_block_id in range(iters):
|
||||
if no_k_mask:
|
||||
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
|
||||
w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
|
||||
else:
|
||||
K_mask = (K_block_id * BLOCK_K + K_block) < K
|
||||
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
|
||||
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
|
||||
|
||||
X_blk_ptrs += BLOCK_K * stride_xk
|
||||
W_blk_ptrs += BLOCK_K * stride_wk
|
||||
acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)
|
||||
return acc
|
||||
|
||||
|
||||
def _scatter2scatter_configs():
|
||||
return [
|
||||
triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4),
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_scatter2scatter_configs(),
|
||||
key=["M", "N", "K"],
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
|
||||
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _scatter2scatter(
|
||||
X_ptr,
|
||||
stride_xm: tl.constexpr,
|
||||
stride_xk: tl.constexpr,
|
||||
W_ptr,
|
||||
stride_we,
|
||||
stride_wk: tl.constexpr,
|
||||
stride_wn: tl.constexpr,
|
||||
Y_ptr,
|
||||
stride_ym: tl.constexpr,
|
||||
stride_yn: tl.constexpr,
|
||||
B_ptr,
|
||||
stride_be: tl.constexpr,
|
||||
stride_bn: tl.constexpr,
|
||||
grouped_idx_ptr,
|
||||
expert_idxs_ptr,
|
||||
# block_start_idx_ptr,
|
||||
FAN_OUT: tl.constexpr,
|
||||
M,
|
||||
K: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
E: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
# OUT_M,
|
||||
allow_tf32: tl.constexpr,
|
||||
x_grouped: tl.constexpr,
|
||||
y_grouped: tl.constexpr,
|
||||
NO_K_MASK: tl.constexpr,
|
||||
NO_N_MASK: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
|
||||
M_block_id = pid // N_BLOCK_COUNT
|
||||
N_block_id = pid % N_BLOCK_COUNT
|
||||
|
||||
M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
N_mask = N_block < N
|
||||
M_boundary_mask = M_block < (FAN_OUT * M)
|
||||
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)
|
||||
|
||||
no_k_mask = K % BLOCK_K == 0
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
E_first_idx = tl.min(E_idxs)
|
||||
E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
|
||||
M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)
|
||||
for E_idx in range(E_first_idx, E_last_idx + 1):
|
||||
E_mask = E_idxs == E_idx
|
||||
E_M_idx = M_idx
|
||||
if x_grouped:
|
||||
M_in_idx = M_block
|
||||
else:
|
||||
M_in_idx = E_M_idx // FAN_OUT
|
||||
acc = _compute_expert_block(
|
||||
E_idx,
|
||||
E_mask,
|
||||
M_in_idx,
|
||||
N_block,
|
||||
N_mask,
|
||||
X_ptr,
|
||||
stride_xm,
|
||||
stride_xk,
|
||||
W_ptr,
|
||||
stride_we,
|
||||
stride_wk,
|
||||
stride_wn,
|
||||
K,
|
||||
acc,
|
||||
no_k_mask,
|
||||
BLOCK_K,
|
||||
allow_tf32=allow_tf32,
|
||||
)
|
||||
|
||||
if B_ptr is not None:
|
||||
B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn
|
||||
acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
||||
|
||||
if y_grouped:
|
||||
M_out_idx = M_block
|
||||
else:
|
||||
M_out_idx = M_idx
|
||||
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
|
||||
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
||||
|
||||
|
||||
def scatter2scatter(
|
||||
X,
|
||||
W,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
k,
|
||||
b=None,
|
||||
x_grouped=False,
|
||||
y_grouped=False,
|
||||
out=None,
|
||||
):
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
# Pre-kernel setup
|
||||
y_dim = W.size(-1)
|
||||
L_scattered = sorted_expert_idxs.size(0)
|
||||
if out is None:
|
||||
output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
|
||||
else:
|
||||
assert out.size(0) == L_scattered and out.size(1) == y_dim
|
||||
output = out
|
||||
|
||||
scatter2scatter_compileable(
|
||||
output,
|
||||
W,
|
||||
X,
|
||||
k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
b,
|
||||
x_grouped,
|
||||
y_grouped,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"})
|
||||
def scatter2scatter_compileable(
|
||||
output: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
k: int,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
sorted_scattered_idxs: torch.Tensor,
|
||||
b: Optional[torch.Tensor],
|
||||
x_grouped: bool,
|
||||
y_grouped: bool,
|
||||
) -> None:
|
||||
def grid(META):
|
||||
grid_num = (
|
||||
triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"])
|
||||
* triton.cdiv(META["N"], META["BLOCK_N"]),
|
||||
)
|
||||
return grid_num
|
||||
|
||||
if b is None:
|
||||
b = None
|
||||
stride_be = stride_bn = 0
|
||||
else:
|
||||
stride_be, stride_bn = b.stride()
|
||||
|
||||
_scatter2scatter[grid](
|
||||
# X_ptr, stride_xm, stride_xk,
|
||||
X,
|
||||
X.stride(0),
|
||||
X.stride(1),
|
||||
# W_ptr, stride_we, stride_wk, stride_wn,
|
||||
W,
|
||||
W.stride(0),
|
||||
W.stride(1),
|
||||
W.stride(2),
|
||||
# Y_ptr, stride_ym, stride_yn,
|
||||
output,
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
# B_ptr, stride_be, stride_bn
|
||||
b,
|
||||
stride_be,
|
||||
stride_bn,
|
||||
grouped_idx_ptr=sorted_scattered_idxs,
|
||||
expert_idxs_ptr=sorted_expert_idxs,
|
||||
# block_start_idx_ptr=padded_block_idxs,
|
||||
FAN_OUT=k,
|
||||
M=X.size(0),
|
||||
K=X.size(1),
|
||||
N=output.size(1),
|
||||
E=W.size(0),
|
||||
BLOCK_M=BLOCK_M,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
x_grouped=x_grouped,
|
||||
y_grouped=y_grouped,
|
||||
)
|
||||
|
||||
|
||||
def _config_XtY():
|
||||
return [
|
||||
triton.Config(
|
||||
{"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def group_bwd_W(DY, X, expert_offsets, E, has_bias=False):
|
||||
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
|
||||
DW = DWt.permute(0, 2, 1)
|
||||
if has_bias:
|
||||
Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)
|
||||
else:
|
||||
Db = None
|
||||
groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)
|
||||
return DW, Db
|
||||
|
||||
|
||||
@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW", "Db"})
|
||||
def groupXtY_compileable(
|
||||
E: int,
|
||||
DW: torch.Tensor,
|
||||
Db: Optional[torch.Tensor],
|
||||
DY: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
) -> None:
|
||||
def grid(META):
|
||||
grid = (
|
||||
E * triton.cdiv(META["K"], META["BLOCK_K"]),
|
||||
triton.cdiv(META["N"], META["BLOCK_N"]),
|
||||
)
|
||||
return grid
|
||||
|
||||
if Db is None:
|
||||
stride_dbe = 0
|
||||
stride_dbn = 0
|
||||
else:
|
||||
stride_dbe, stride_dbn = Db.stride()
|
||||
|
||||
_groupXtY[grid](
|
||||
# DY_ptr, stride_dym, stride_dyk,
|
||||
DY,
|
||||
DY.stride(0),
|
||||
DY.stride(1),
|
||||
# X_ptr, stride_xm, stride_xn,
|
||||
X,
|
||||
X.stride(0),
|
||||
X.stride(1),
|
||||
# DW_ptr, stride_dwe, stride_dwk, stride_dwn,
|
||||
DW,
|
||||
DW.stride(0),
|
||||
DW.stride(1),
|
||||
DW.stride(2),
|
||||
# Db_ptr, stride_dwe, stride_dbn,
|
||||
Db,
|
||||
stride_dbe,
|
||||
stride_dbn,
|
||||
# expert_offsets_ptr,
|
||||
expert_offsets,
|
||||
# K: tl.constexpr, N: tl.constexpr,
|
||||
M=DY.size(0),
|
||||
N=DY.size(-1),
|
||||
K=X.size(-1),
|
||||
# ACC_TYPE: tl.constexpr,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_config_XtY(),
|
||||
key=["M", "N", "K"],
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
|
||||
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _groupXtY(
|
||||
DY_ptr,
|
||||
stride_dym,
|
||||
stride_dyk,
|
||||
X_ptr,
|
||||
stride_xm,
|
||||
stride_xn,
|
||||
DW_ptr,
|
||||
stride_dwe,
|
||||
stride_dwk,
|
||||
stride_dwn,
|
||||
Db_ptr,
|
||||
stride_dbe,
|
||||
stride_dbn,
|
||||
expert_offsets_ptr,
|
||||
M,
|
||||
K: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
NO_K_MASK: tl.constexpr,
|
||||
NO_N_MASK: tl.constexpr,
|
||||
):
|
||||
pid0 = tl.program_id(axis=0)
|
||||
pid1 = tl.program_id(axis=1)
|
||||
num0 = tl.num_programs(0)
|
||||
num1 = tl.num_programs(1)
|
||||
# pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
|
||||
pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
|
||||
|
||||
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
|
||||
E_idx = pid0 // K_BLOCK_COUNT
|
||||
K_block_id = pid0 % K_BLOCK_COUNT
|
||||
N_block_id = pid1
|
||||
|
||||
if E_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
|
||||
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
|
||||
|
||||
if end_idx > start_idx:
|
||||
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
|
||||
|
||||
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
K_mask = K_block < K
|
||||
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
|
||||
|
||||
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
N_mask = N_block < N
|
||||
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
M_idxs = M_block
|
||||
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
|
||||
dy_blk_ptrs = (
|
||||
DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
|
||||
)
|
||||
if (Db_ptr is not None) and (K_block_id == 0):
|
||||
_xty_and_bias(
|
||||
E_idx,
|
||||
start_idx,
|
||||
end_idx,
|
||||
M_block,
|
||||
K_block,
|
||||
K_mask,
|
||||
N_block,
|
||||
N_mask,
|
||||
dy_blk_ptrs,
|
||||
stride_dym,
|
||||
xt_blk_ptrs,
|
||||
stride_xm,
|
||||
DW_ptr,
|
||||
stride_dwe,
|
||||
stride_dwk,
|
||||
stride_dwn,
|
||||
Db_ptr,
|
||||
stride_dbe,
|
||||
stride_dbn,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
ACC_TYPE,
|
||||
allow_tf32,
|
||||
NO_K_MASK,
|
||||
NO_N_MASK,
|
||||
compute_bias=True,
|
||||
)
|
||||
else:
|
||||
_xty_and_bias(
|
||||
E_idx,
|
||||
start_idx,
|
||||
end_idx,
|
||||
M_block,
|
||||
K_block,
|
||||
K_mask,
|
||||
N_block,
|
||||
N_mask,
|
||||
dy_blk_ptrs,
|
||||
stride_dym,
|
||||
xt_blk_ptrs,
|
||||
stride_xm,
|
||||
DW_ptr,
|
||||
stride_dwe,
|
||||
stride_dwk,
|
||||
stride_dwn,
|
||||
Db_ptr,
|
||||
stride_dbe,
|
||||
stride_dbn,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
ACC_TYPE,
|
||||
allow_tf32,
|
||||
NO_K_MASK,
|
||||
NO_N_MASK,
|
||||
compute_bias=False,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _xty_and_bias(
|
||||
E_idx,
|
||||
start_idx,
|
||||
end_idx,
|
||||
M_block,
|
||||
K_block,
|
||||
K_mask,
|
||||
N_block,
|
||||
N_mask,
|
||||
dy_blk_ptrs,
|
||||
stride_dym,
|
||||
xt_blk_ptrs,
|
||||
stride_xm,
|
||||
DW_ptr,
|
||||
stride_dwe,
|
||||
stride_dwk,
|
||||
stride_dwn,
|
||||
Db_ptr,
|
||||
stride_dbe,
|
||||
stride_dbn,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
ACC_TYPE,
|
||||
allow_tf32,
|
||||
NO_K_MASK,
|
||||
NO_N_MASK,
|
||||
compute_bias: tl.constexpr,
|
||||
):
|
||||
if compute_bias:
|
||||
db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)
|
||||
else:
|
||||
db_acc = None
|
||||
|
||||
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
|
||||
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
|
||||
for i in range(0, iters):
|
||||
M_mask = (i * BLOCK_M + M_block) < end_idx
|
||||
if NO_K_MASK:
|
||||
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
|
||||
else:
|
||||
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
|
||||
if NO_N_MASK:
|
||||
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
|
||||
else:
|
||||
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
|
||||
|
||||
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
|
||||
|
||||
xt_blk_ptrs += BLOCK_M * stride_xm
|
||||
dy_blk_ptrs += BLOCK_M * stride_dym
|
||||
|
||||
if compute_bias:
|
||||
db_acc += tl.sum(dy, axis=0)
|
||||
|
||||
DW_blk_ptrs = (
|
||||
DW_ptr
|
||||
+ E_idx * stride_dwe
|
||||
+ K_block[:, None] * stride_dwk
|
||||
+ N_block[None, :] * stride_dwn
|
||||
)
|
||||
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
|
||||
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
|
||||
if compute_bias:
|
||||
Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn
|
||||
tl.store(Db_blk_ptrs, db_acc, mask=N_mask)
|
||||
|
||||
|
||||
def _config_grouping():
|
||||
return [
|
||||
triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
|
||||
]
|
||||
|
||||
|
||||
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
|
||||
N = sorted_expert_idxs.size(0)
|
||||
K = A.size(1)
|
||||
assert A.size(0) * fan_out == N
|
||||
if out is not None:
|
||||
Y = out
|
||||
else:
|
||||
Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
|
||||
group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)
|
||||
return Y
|
||||
|
||||
|
||||
@torch.library.custom_op("scattermoe::group", mutates_args={"Y"})
|
||||
def group_compileable(
|
||||
A: torch.Tensor,
|
||||
K: int,
|
||||
N: int,
|
||||
Y: torch.Tensor,
|
||||
coeff: Optional[torch.Tensor],
|
||||
has_coeff: bool,
|
||||
fan_out: int,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
) -> None:
|
||||
def grid(META):
|
||||
grid_num = (triton.cdiv(META["N"], META["BLOCK_N"]),)
|
||||
return grid_num
|
||||
|
||||
_group[grid](
|
||||
# A_ptr, stride_an, stride_ai,
|
||||
A,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
has_coeff,
|
||||
coeff,
|
||||
fan_out,
|
||||
# Y_ptr, stride_yn, stride_yk,
|
||||
Y,
|
||||
Y.stride(0),
|
||||
Y.stride(1),
|
||||
# grouped_idx_ptr,
|
||||
sorted_expert_idxs,
|
||||
# N: tl.constexpr, K: tl.constexpr,
|
||||
N,
|
||||
K,
|
||||
)
|
||||
|
||||
|
||||
@triton.autotune(configs=_config_grouping(), key=["K"])
|
||||
@triton.heuristics({"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0})
|
||||
@triton.jit
|
||||
def _group(
|
||||
src_ptr,
|
||||
stride_sn,
|
||||
stride_sk,
|
||||
has_coeff: tl.constexpr,
|
||||
coeff_ptr,
|
||||
FAN_OUT: tl.constexpr,
|
||||
tgt_ptr,
|
||||
stride_tn,
|
||||
stride_ti,
|
||||
grouped_idx_ptr,
|
||||
N,
|
||||
K: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
NO_K_MASK: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
N_block_id = pid
|
||||
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
N_mask = N_blk < N
|
||||
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
|
||||
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
|
||||
|
||||
K_blk = tl.arange(0, BLOCK_K)
|
||||
src_blk_ptrs = (
|
||||
src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
|
||||
)
|
||||
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
|
||||
|
||||
if has_coeff:
|
||||
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
|
||||
|
||||
iters = tl.cdiv(K, BLOCK_K)
|
||||
for i in range(0, iters):
|
||||
if NO_K_MASK or i < iters - 1:
|
||||
block = tl.load(src_blk_ptrs, mask=N_mask[:, None])
|
||||
if has_coeff:
|
||||
block *= c
|
||||
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
|
||||
|
||||
else:
|
||||
K_mask = (i * BLOCK_K + K_blk) < K
|
||||
mask = N_mask[:, None] & K_mask[None, :]
|
||||
block = tl.load(src_blk_ptrs, mask=mask)
|
||||
if has_coeff:
|
||||
block *= c
|
||||
tl.store(tgt_blk_ptrs, block, mask=mask)
|
||||
src_blk_ptrs += BLOCK_K * stride_sk
|
||||
tgt_blk_ptrs += BLOCK_K * stride_ti
|
||||
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://github.com/shawntan/scattermoe
|
||||
# Copyright (c) Shawn Tan and ScatterMoE Contributors
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _single2scatter(
|
||||
X_ptr,
|
||||
stride_xm,
|
||||
stride_xk,
|
||||
W_ptr,
|
||||
stride_we,
|
||||
stride_wk,
|
||||
stride_wn,
|
||||
Y_ptr,
|
||||
stride_ym,
|
||||
stride_yn,
|
||||
expert_idxs_ptr,
|
||||
FAN_OUT: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
E: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
):
|
||||
pid0 = tl.program_id(axis=0)
|
||||
pid1 = tl.program_id(axis=1)
|
||||
|
||||
N_block_id = pid0
|
||||
if FAN_OUT == 1:
|
||||
in_idx = pid1
|
||||
else:
|
||||
in_idx = 0
|
||||
out_idx = pid1
|
||||
|
||||
K_block = tl.arange(0, BLOCK_K)
|
||||
N_block = tl.max_contiguous(
|
||||
tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N),
|
||||
BLOCK_N,
|
||||
)
|
||||
E_idx = tl.load(expert_idxs_ptr + pid1)
|
||||
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
|
||||
W_blk_ptrs = (
|
||||
W_ptr
|
||||
+ E_idx * stride_we
|
||||
+ K_block[:, None] * stride_wk
|
||||
+ N_block[None, :] * stride_wn
|
||||
)
|
||||
N_mask = N_block < N
|
||||
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
|
||||
for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
K_mask = K_block < K
|
||||
x = tl.load(X_blk_ptrs, mask=K_mask[:, None], other=0.0)
|
||||
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0)
|
||||
acc += tl.sum(x * w, axis=0)[None, :]
|
||||
X_blk_ptrs += BLOCK_K * stride_xk
|
||||
W_blk_ptrs += BLOCK_K * stride_wk
|
||||
K_block += BLOCK_K
|
||||
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
|
||||
tl.store(Y_blk_ptrs, acc, mask=N_mask[None, :])
|
||||
|
||||
|
||||
def single2scatter(X, W, expert_idxs):
|
||||
E, xdim, ydim = W.size()
|
||||
k = expert_idxs.size(1)
|
||||
assert X.size(0) == k or X.size(0) == 1
|
||||
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
|
||||
BLOCK_N = 128
|
||||
BLOCK_K = 128
|
||||
grid = triton.cdiv(ydim, BLOCK_N), k
|
||||
_single2scatter[grid](
|
||||
X,
|
||||
X.stride(0),
|
||||
X.stride(1),
|
||||
W,
|
||||
W.stride(0),
|
||||
W.stride(1),
|
||||
W.stride(2),
|
||||
Y,
|
||||
Y.stride(0),
|
||||
Y.stride(1),
|
||||
expert_idxs,
|
||||
FAN_OUT=Y.size(0) // X.size(0),
|
||||
K=xdim,
|
||||
N=ydim,
|
||||
E=E,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_K=BLOCK_K,
|
||||
ACC_TYPE=tl.float32,
|
||||
)
|
||||
return Y
|
||||
439
src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py
Normal file
439
src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py
Normal file
@@ -0,0 +1,439 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
|
||||
# Adapted from https://github.com/shawntan/scattermoe
|
||||
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
|
||||
#
|
||||
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
ScatterMoE layer replacements for HuggingFace MoE architectures.
|
||||
|
||||
Provides drop-in forward replacements that use ScatterMoE kernels for
|
||||
acceleration. When used via the HF ``kernels`` library
|
||||
(``replace_kernel_forward_from_hub``), these classes replace the forward
|
||||
method of the original MoE block.
|
||||
|
||||
LoRA support
|
||||
------------
|
||||
When peft wraps parameters via ``target_parameters``, the ``self.experts``
|
||||
submodule becomes a chain of ``ParamWrapper`` objects and the ``self.gate``
|
||||
router may also become a ``ParamWrapper``. The ``HFScatterMoEGatedMLP``
|
||||
forward detects this and automatically:
|
||||
|
||||
1. Unwraps ``self.gate`` to the base router, applying gate LoRA delta
|
||||
2. Unwraps ``self.experts`` to the base ``OlmoeExperts`` module
|
||||
3. Extracts LoRA A/B weights and scaling from each wrapper
|
||||
4. Converts B layout from peft rank-major to scattermoe expert-major
|
||||
5. Routes to ``parallel_linear_lora`` for fused LoRA computation
|
||||
6. Passes through ``self.shared_expert`` / ``self.shared_expert_gate``
|
||||
(peft wraps their linear layers with standard LoRA, no special handling)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
|
||||
|
||||
# =============================================================================
|
||||
# LoRA layout conversion utilities (peft <-> scattermoe)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
|
||||
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
|
||||
expert-major ``[N, r*E]``.
|
||||
|
||||
peft reshapes B to ``[out, r, E]`` (rank-major).
|
||||
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
|
||||
"""
|
||||
N = peft_B.shape[0]
|
||||
return (
|
||||
peft_B.reshape(N, rank, num_experts)
|
||||
.permute(0, 2, 1)
|
||||
.contiguous()
|
||||
.reshape(N, num_experts * rank)
|
||||
)
|
||||
|
||||
|
||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
|
||||
|
||||
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
|
||||
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
|
||||
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
|
||||
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
|
||||
are swapped relative to scattermoe's convention.
|
||||
|
||||
peft gives:
|
||||
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
|
||||
|
||||
scattermoe needs:
|
||||
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
|
||||
|
||||
This function swaps A<->B and converts B from rank-major to expert-major.
|
||||
Uses vectorized tensor operations (no Python loop over experts).
|
||||
|
||||
Works for **both** gate_up_proj and down_proj since the transposition
|
||||
issue is the same for any parameter.
|
||||
"""
|
||||
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||
|
||||
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
|
||||
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
|
||||
|
||||
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
|
||||
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
|
||||
smoe_A = (
|
||||
peft_B_em.reshape(dim2, num_experts, rank)
|
||||
.permute(1, 2, 0)
|
||||
.contiguous()
|
||||
.reshape(rank * num_experts, dim2)
|
||||
)
|
||||
|
||||
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
|
||||
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
|
||||
smoe_B = (
|
||||
peft_A.reshape(num_experts, rank, dim1)
|
||||
.permute(2, 0, 1)
|
||||
.contiguous()
|
||||
.reshape(dim1, num_experts * rank)
|
||||
)
|
||||
|
||||
return smoe_A, smoe_B
|
||||
|
||||
|
||||
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
|
||||
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ParamWrapper unwrapping
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _unwrap_gate_lora(gate_module):
|
||||
"""Unwrap peft ``ParamWrapper`` on the router gate.
|
||||
|
||||
When peft targets ``gate.weight``, ``self.gate`` becomes::
|
||||
|
||||
ParamWrapper(weight)
|
||||
-> base_layer: OlmoeTopKRouter (the real module)
|
||||
|
||||
This function detects the wrapping and returns the base router, its
|
||||
weight tensor, and an optional LoRA delta tensor.
|
||||
|
||||
Returns:
|
||||
(base_gate, gate_weight, gate_lora_delta_or_None)
|
||||
|
||||
``base_gate`` is the original router module (with ``.top_k``,
|
||||
``.num_experts``, ``.norm_topk_prob``).
|
||||
``gate_weight`` is the base router weight (may be a DTensor under FSDP).
|
||||
``gate_lora_delta_or_None`` is the LoRA delta tensor if LoRA is active,
|
||||
else ``None``. Kept separate to avoid mixing DTensor + Tensor in an add.
|
||||
"""
|
||||
if hasattr(gate_module, "base_layer") and hasattr(gate_module, "lora_A"):
|
||||
base_gate = gate_module.base_layer
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module)
|
||||
if lora_A is not None:
|
||||
# gate weight: [num_experts, hidden_size]
|
||||
# lora_A: [r, hidden_size], lora_B: [num_experts, r]
|
||||
# delta = scaling * B @ A = [num_experts, hidden_size]
|
||||
delta = scaling * (lora_B @ lora_A)
|
||||
return base_gate, base_gate.weight, delta
|
||||
else:
|
||||
return base_gate, base_gate.weight, None
|
||||
else:
|
||||
# No wrapping — gate is the original module
|
||||
return gate_module, gate_module.weight, None
|
||||
|
||||
|
||||
def _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling):
|
||||
"""Convert peft LoRA weights to scattermoe layout."""
|
||||
smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank)
|
||||
return (smoe_A, smoe_B, scaling)
|
||||
|
||||
|
||||
def _unwrap_experts_lora(experts_module):
|
||||
"""Walk a peft ``ParamWrapper`` chain on ``self.experts``.
|
||||
|
||||
When peft targets ``experts.gate_up_proj`` and ``experts.down_proj`` via
|
||||
``target_parameters``, ``self.experts`` becomes a nested chain::
|
||||
|
||||
ParamWrapper(down_proj)
|
||||
-> base_layer: ParamWrapper(gate_up_proj)
|
||||
-> base_layer: OlmoeExperts (the real module)
|
||||
|
||||
This function walks the chain, collects LoRA params keyed by
|
||||
``parameter_name``, and returns the base experts module.
|
||||
|
||||
Returns:
|
||||
(base_experts, gup_lora, down_lora)
|
||||
|
||||
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
|
||||
A/B are already in scattermoe layout.
|
||||
"""
|
||||
# Collect ParamWrapper layers by their parameter_name
|
||||
wrappers = {}
|
||||
module = experts_module
|
||||
while hasattr(module, "base_layer") and hasattr(module, "lora_A"):
|
||||
param_name = getattr(module, "parameter_name", None)
|
||||
if param_name is not None:
|
||||
wrappers[param_name] = module
|
||||
module = module.base_layer
|
||||
|
||||
base_experts = module
|
||||
|
||||
if not wrappers:
|
||||
return base_experts, None, None
|
||||
|
||||
# Determine num_experts from base module
|
||||
num_experts = getattr(base_experts, "num_experts", None)
|
||||
if num_experts is None:
|
||||
# Fallback: infer from parameter shape
|
||||
gup = getattr(base_experts, "gate_up_proj", None)
|
||||
if gup is not None:
|
||||
num_experts = gup.shape[0]
|
||||
|
||||
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
|
||||
gup_lora = None
|
||||
gup_wrapper = wrappers.get("gate_up_proj")
|
||||
if gup_wrapper is not None:
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
|
||||
if lora_A is not None:
|
||||
rank = lora_A.shape[0] // num_experts
|
||||
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
|
||||
|
||||
# Extract down_proj LoRA (needs A<->B swap due to transposition)
|
||||
down_lora = None
|
||||
down_wrapper = wrappers.get("down_proj")
|
||||
if down_wrapper is not None:
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_wrapper)
|
||||
if lora_A is not None:
|
||||
rank = lora_A.shape[0] // num_experts
|
||||
down_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
|
||||
|
||||
return base_experts, gup_lora, down_lora
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Layer classes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ScatterMoEGatedMLP(nn.Module):
|
||||
def forward(self, layer_input):
|
||||
"""
|
||||
Forward pass of the mixture of experts layer.
|
||||
|
||||
Args:
|
||||
layer_input (Tensor):
|
||||
Input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor:
|
||||
Output tensor.
|
||||
"""
|
||||
bsz, length, emb_size = layer_input.size()
|
||||
layer_input = layer_input.reshape(-1, emb_size)
|
||||
# compute the top_k routing decision
|
||||
router_logits = self.router.layer(layer_input)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(
|
||||
routing_weights, self.router.top_k, dim=-1
|
||||
)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights = routing_weights.to(layer_input.dtype)
|
||||
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
|
||||
selected_experts, num_experts=self.router.num_experts
|
||||
)
|
||||
|
||||
# compute experts
|
||||
gates, h = parallel_linear(
|
||||
layer_input,
|
||||
self.input_linear.weight.transpose(2, 1),
|
||||
self.router.top_k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
grouped_in=False,
|
||||
grouped_out=True,
|
||||
).chunk(2, dim=-1)
|
||||
h = self.activation(gates) * h
|
||||
layer_output = parallel_linear(
|
||||
h,
|
||||
self.output_linear.weight.transpose(2, 1),
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
grouped_in=True,
|
||||
grouped_out=False,
|
||||
gates=routing_weights,
|
||||
)
|
||||
layer_output = layer_output.view(bsz, length, emb_size)
|
||||
return layer_output
|
||||
|
||||
|
||||
class HFScatterMoEGatedMLP(nn.Module):
|
||||
"""
|
||||
ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE).
|
||||
|
||||
Used as a kernel layer via the HF ``kernels`` library. The ``forward``
|
||||
method replaces the original ``OlmoeSparseMoeBlock.forward``.
|
||||
|
||||
Supports both full-parameter training and LoRA fine-tuning:
|
||||
|
||||
* **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel)
|
||||
* **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts
|
||||
adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(self: nn.Module, layer_input: torch.Tensor):
|
||||
"""
|
||||
Forward pass using ScatterMoE kernels.
|
||||
|
||||
Args:
|
||||
self: The MoeSparseMoeBlock module containing:
|
||||
- self.gate: Router (or peft ParamWrapper wrapping it)
|
||||
- self.experts: Experts module (or peft ParamWrapper chain)
|
||||
- self.shared_expert: Optional shared expert (e.g. Qwen2MoE)
|
||||
- self.shared_expert_gate: Optional shared expert gate
|
||||
layer_input: Input tensor [batch_size, seq_len, hidden_size]
|
||||
|
||||
Returns:
|
||||
Tensor: [batch_size, seq_len, hidden_size]
|
||||
"""
|
||||
batch_size, sequence_length, hidden_dim = layer_input.shape
|
||||
hidden_states_flat = layer_input.view(-1, hidden_dim)
|
||||
|
||||
# ====================================================================
|
||||
# Shared Expert (if present, e.g. Qwen2MoE)
|
||||
# ====================================================================
|
||||
# peft wraps individual linear layers inside shared_expert with
|
||||
# standard LoRA — calling forward() handles this transparently.
|
||||
if hasattr(self, "shared_expert") and self.shared_expert is not None:
|
||||
shared_expert_output = self.shared_expert(hidden_states_flat)
|
||||
# shared_expert_gate may also be peft-wrapped (standard LoRA
|
||||
# on nn.Linear), its forward() applies LoRA automatically.
|
||||
shared_expert_gate_output = F.sigmoid(
|
||||
self.shared_expert_gate(hidden_states_flat)
|
||||
)
|
||||
shared_expert_output = shared_expert_output * shared_expert_gate_output
|
||||
else:
|
||||
shared_expert_output = None
|
||||
|
||||
# ====================================================================
|
||||
# Router Computation (with optional gate LoRA)
|
||||
# ====================================================================
|
||||
base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)
|
||||
router_logits = F.linear(hidden_states_flat, gate_weight)
|
||||
if gate_lora_delta is not None:
|
||||
router_logits = router_logits + F.linear(
|
||||
hidden_states_flat, gate_lora_delta
|
||||
)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
|
||||
top_k = base_gate.top_k
|
||||
num_experts = base_gate.num_experts
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
||||
|
||||
if base_gate.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights = routing_weights.to(hidden_states_flat.dtype)
|
||||
|
||||
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
|
||||
selected_experts, num_experts=num_experts
|
||||
)
|
||||
|
||||
# ====================================================================
|
||||
# Detect LoRA (peft ParamWrapper) and extract adapter weights
|
||||
# ====================================================================
|
||||
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
|
||||
|
||||
# ====================================================================
|
||||
# Gate + Up projection
|
||||
# ====================================================================
|
||||
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
||||
|
||||
if gup_lora is not None:
|
||||
gup_A, gup_B, gup_scaling = gup_lora
|
||||
gup = parallel_linear_lora(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
lora_A=gup_A,
|
||||
lora_B=gup_B,
|
||||
scaling=gup_scaling,
|
||||
grouped_in=False,
|
||||
grouped_out=True,
|
||||
use_fused_dX=True,
|
||||
use_fused_gather=True,
|
||||
)
|
||||
else:
|
||||
gup = parallel_linear(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
grouped_in=False,
|
||||
grouped_out=True,
|
||||
)
|
||||
|
||||
gates, h = gup.chunk(2, dim=-1)
|
||||
h = experts.act_fn(gates) * h
|
||||
|
||||
# ====================================================================
|
||||
# Down projection
|
||||
# ====================================================================
|
||||
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
|
||||
|
||||
if down_lora is not None:
|
||||
down_A, down_B, down_scaling = down_lora
|
||||
expert_output = parallel_linear_lora(
|
||||
h,
|
||||
down_W,
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
lora_A=down_A,
|
||||
lora_B=down_B,
|
||||
scaling=down_scaling,
|
||||
gates=routing_weights,
|
||||
grouped_in=True,
|
||||
grouped_out=False,
|
||||
use_fused_dX=True,
|
||||
use_fused_gather=True,
|
||||
)
|
||||
else:
|
||||
expert_output = parallel_linear(
|
||||
h,
|
||||
down_W,
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
grouped_in=True,
|
||||
grouped_out=False,
|
||||
gates=routing_weights,
|
||||
)
|
||||
|
||||
# ====================================================================
|
||||
# Combine with shared expert and reshape
|
||||
# ====================================================================
|
||||
if shared_expert_output is not None:
|
||||
expert_output = expert_output + shared_expert_output
|
||||
|
||||
expert_output = expert_output.view(batch_size, sequence_length, hidden_dim)
|
||||
return expert_output
|
||||
@@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
ParallelExperts module with LoRA support.
|
||||
|
||||
Provides a drop-in replacement for ScatterMoE's ParallelExperts that
|
||||
uses the fused LoRA kernel when adapter weights are attached.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .parallel_linear_lora import parallel_linear_lora
|
||||
|
||||
|
||||
class ParallelExperts(nn.Module):
|
||||
"""
|
||||
Parallel Experts with fused LoRA support.
|
||||
|
||||
Drop-in replacement for the original ParallelExperts. When LoRA parameters
|
||||
are attached via set_lora(), the forward pass uses a fused kernel:
|
||||
Y = X @ W + scaling * (X @ A^T) @ B^T
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(num_experts, output_size))
|
||||
else:
|
||||
self.bias = None
|
||||
self.num_experts = num_experts
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self._lora_A: torch.Tensor | None = None
|
||||
self._lora_B: torch.Tensor | None = None
|
||||
self._lora_scaling: float | None = None
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
nn.init.normal_(self.weight, std=0.02)
|
||||
if self.bias is not None:
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
f"num_experts={self.num_experts}, "
|
||||
f"input_size={self.input_size}, "
|
||||
f"output_size={self.output_size}"
|
||||
)
|
||||
|
||||
def set_lora(self, lora_A: torch.Tensor, lora_B: torch.Tensor, scaling: float):
|
||||
"""Attach LoRA parameters for fused computation."""
|
||||
self._lora_A = lora_A
|
||||
self._lora_B = lora_B
|
||||
self._lora_scaling = scaling
|
||||
|
||||
def clear_lora(self):
|
||||
"""Remove LoRA parameters."""
|
||||
self._lora_A = None
|
||||
self._lora_B = None
|
||||
self._lora_scaling = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: torch.Tensor,
|
||||
k: int,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
sorted_scattered_idxs: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
gates: Optional[torch.Tensor] = None,
|
||||
grouped_in: bool = False,
|
||||
grouped_out: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return parallel_linear_lora(
|
||||
inputs,
|
||||
self.weight.permute(0, 2, 1), # [E, input, output]
|
||||
k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
lora_A=self._lora_A,
|
||||
lora_B=self._lora_B,
|
||||
scaling=self._lora_scaling if self._lora_scaling is not None else 1.0,
|
||||
expert_biases=self.bias,
|
||||
gates=gates,
|
||||
grouped_in=grouped_in,
|
||||
grouped_out=grouped_out,
|
||||
)
|
||||
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://github.com/shawntan/scattermoe
|
||||
# Copyright (c) Shawn Tan and ScatterMoE Contributors
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from . import kernels
|
||||
|
||||
|
||||
@torch.library.custom_op("scattermoe::bincount", mutates_args={})
|
||||
def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
|
||||
return x.bincount(minlength=minlength)
|
||||
|
||||
|
||||
@compileable_bincount.register_fake
|
||||
def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
|
||||
return torch.empty(minlength, dtype=torch.long, device=x.device)
|
||||
|
||||
|
||||
@torch.compile
|
||||
def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):
|
||||
with torch.no_grad():
|
||||
flattened_expert_idxs = expert_idxs.flatten()
|
||||
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
|
||||
expert_counts = compileable_bincount(
|
||||
flattened_expert_idxs, minlength=num_experts
|
||||
)
|
||||
expert_offsets = expert_counts.cumsum(-1)
|
||||
return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets
|
||||
|
||||
|
||||
class ParallelLinear(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
k: int,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
sorted_scattered_idxs: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
expert_biases: Optional[torch.Tensor] = None,
|
||||
gates: Optional[torch.Tensor] = None,
|
||||
grouped_in: bool = False,
|
||||
grouped_out: bool = False,
|
||||
):
|
||||
with torch.device(x.device):
|
||||
output = kernels.ops.scatter2scatter(
|
||||
X=x,
|
||||
W=expert_weights,
|
||||
b=expert_biases,
|
||||
k=k,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
x_grouped=grouped_in,
|
||||
y_grouped=grouped_out,
|
||||
)
|
||||
if gates is not None:
|
||||
output_expanded = output.view(
|
||||
gates.size(0), gates.size(1), output.size(-1)
|
||||
)
|
||||
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
|
||||
else:
|
||||
output_expanded = None
|
||||
|
||||
ctx.save_for_backward(
|
||||
x,
|
||||
expert_weights,
|
||||
expert_biases,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
gates,
|
||||
output_expanded,
|
||||
)
|
||||
ctx.grouped_in = grouped_in
|
||||
ctx.grouped_out = grouped_out
|
||||
ctx.k = k
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out: torch.Tensor):
|
||||
with torch.device(grad_out.device):
|
||||
(
|
||||
x,
|
||||
expert_weights,
|
||||
expert_biases,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
gates,
|
||||
output_expanded,
|
||||
) = ctx.saved_tensors
|
||||
k = ctx.k
|
||||
grouped_in = ctx.grouped_in
|
||||
grouped_out = ctx.grouped_out
|
||||
|
||||
if gates is not None:
|
||||
# calculate gates gradient
|
||||
# d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
|
||||
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
|
||||
gates_flat = gates.flatten()
|
||||
gate_fan = gates.size(1)
|
||||
grouped_grad_out = output_expanded.flatten(
|
||||
0, 1
|
||||
) # reuse expanded buffer later
|
||||
else:
|
||||
d_gates = None
|
||||
gates_flat = None
|
||||
gate_fan = 1
|
||||
grouped_grad_out = None
|
||||
|
||||
if grouped_out:
|
||||
grouped_grad_out = grad_out
|
||||
else:
|
||||
grouped_grad_out = kernels.ops.group(
|
||||
grad_out,
|
||||
sorted_scattered_idxs,
|
||||
fan_out=gate_fan,
|
||||
coeff=gates_flat,
|
||||
out=grouped_grad_out,
|
||||
)
|
||||
if grouped_in:
|
||||
grouped_x = x
|
||||
d_expanded_input = None
|
||||
else:
|
||||
grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
|
||||
d_expanded_input = grouped_x
|
||||
|
||||
d_weights, d_biases = kernels.ops.group_bwd_W(
|
||||
DY=grouped_grad_out,
|
||||
X=grouped_x,
|
||||
expert_offsets=expert_offsets,
|
||||
E=expert_weights.size(0),
|
||||
has_bias=expert_biases is not None,
|
||||
)
|
||||
|
||||
d_expanded_input = kernels.ops.scatter2scatter(
|
||||
X=grouped_grad_out,
|
||||
x_grouped=True,
|
||||
W=expert_weights.permute(0, 2, 1),
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=1,
|
||||
y_grouped=grouped_in,
|
||||
out=d_expanded_input, # Reuse grouped_x buffer
|
||||
)
|
||||
|
||||
if k == 1:
|
||||
d_input = d_expanded_input
|
||||
else:
|
||||
d_input = d_expanded_input.view(
|
||||
x.size(0), k, d_expanded_input.size(-1)
|
||||
).sum(-2)
|
||||
return (
|
||||
# x, expert_weights,
|
||||
d_input,
|
||||
d_weights,
|
||||
# k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
# bias, gates
|
||||
d_biases,
|
||||
d_gates,
|
||||
# grouped_in, grouped_out,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def parallel_linear(
|
||||
inputs,
|
||||
expert_weights,
|
||||
k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
expert_biases=None,
|
||||
gates=None,
|
||||
grouped_in=False,
|
||||
grouped_out=False,
|
||||
):
|
||||
results = ParallelLinear.apply(
|
||||
inputs,
|
||||
expert_weights,
|
||||
k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
expert_biases,
|
||||
gates,
|
||||
grouped_in,
|
||||
grouped_out,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
class ParallelExperts(nn.Module):
|
||||
def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(num_experts, output_size))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.num_experts = num_experts
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.reset_parameters()
|
||||
|
||||
def extra_repr(self):
|
||||
return "num_experts={}, input_size={}, output_size={}".format(
|
||||
self.num_experts, self.input_size, self.output_size
|
||||
)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
nn.init.normal_(self.weight, std=0.02)
|
||||
if self.bias is not None:
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs,
|
||||
k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
gates=None,
|
||||
grouped_in=False,
|
||||
grouped_out=False,
|
||||
):
|
||||
results = parallel_linear(
|
||||
inputs,
|
||||
self.weight.permute(0, 2, 1),
|
||||
k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
expert_biases=self.bias,
|
||||
gates=gates,
|
||||
grouped_in=grouped_in,
|
||||
grouped_out=grouped_out,
|
||||
)
|
||||
return results
|
||||
@@ -0,0 +1,480 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
ScatterMoE + LoRA Autograd Function
|
||||
====================================
|
||||
|
||||
Provides the autograd function and Python interface for fused ScatterMoE + LoRA.
|
||||
|
||||
Key design for LoRA training:
|
||||
- Expert weights W are FROZEN (no gradient computed for W).
|
||||
- Only LoRA adapter weights (A, B) receive gradients.
|
||||
- The input gradient dX is still computed (needed for upstream layers).
|
||||
- This avoids the expensive group_bwd_W computation entirely.
|
||||
|
||||
Forward:
|
||||
Y = X @ W + scaling * (X @ A^T) @ B^T
|
||||
|
||||
Backward (W frozen):
|
||||
dX = dY @ W^T + scaling * (dY @ B) @ A (via scatter2scatter for base, separate for LoRA)
|
||||
dA = scaling * (dY @ B)^T @ X (per-expert, on grouped data)
|
||||
dB = scaling * dY^T @ (X @ A^T) (per-expert, on grouped data)
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .kernels import ops as base_ops
|
||||
from .kernels.lora_ops import (
|
||||
group_bwd_lora,
|
||||
group_bwd_lora_fused,
|
||||
scatter2scatter_lora,
|
||||
scatter2scatter_lora_dX,
|
||||
)
|
||||
|
||||
|
||||
class ScatterMoELoRA(torch.autograd.Function):
|
||||
"""
|
||||
Autograd function for fused ScatterMoE + LoRA with frozen expert weights.
|
||||
|
||||
This function is optimized for the LoRA fine-tuning scenario where:
|
||||
- Expert weights W are frozen (requires_grad=False)
|
||||
- Only LoRA A and B matrices receive gradients
|
||||
- Input gradients are computed for upstream layer backprop
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
k: int,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
sorted_scattered_idxs: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
scaling: float,
|
||||
expert_biases: Optional[torch.Tensor] = None,
|
||||
gates: Optional[torch.Tensor] = None,
|
||||
grouped_in: bool = False,
|
||||
grouped_out: bool = False,
|
||||
use_fused_dX: bool = False,
|
||||
use_fused_gather: bool = False,
|
||||
):
|
||||
with torch.device(x.device):
|
||||
# Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T
|
||||
output = scatter2scatter_lora(
|
||||
X=x,
|
||||
W=expert_weights,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k,
|
||||
lora_A=lora_A,
|
||||
lora_B=lora_B,
|
||||
scaling=scaling,
|
||||
b=expert_biases,
|
||||
x_grouped=grouped_in,
|
||||
y_grouped=grouped_out,
|
||||
)
|
||||
|
||||
# Handle gating (weighted combination of top-k expert outputs)
|
||||
if gates is not None:
|
||||
output_expanded = output.view(
|
||||
gates.size(0), gates.size(1), output.size(-1)
|
||||
)
|
||||
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
|
||||
else:
|
||||
output_expanded = None
|
||||
|
||||
ctx.save_for_backward(
|
||||
x,
|
||||
lora_A,
|
||||
lora_B,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
gates,
|
||||
output_expanded,
|
||||
)
|
||||
# Store frozen weights as plain Python attributes instead of
|
||||
# save_for_backward. This avoids:
|
||||
# 1. Version-check conflicts with FSDP unshard/reshard
|
||||
# 2. Pinning all-gathered parameters via saved_tensors hooks
|
||||
# 3. Interfering with activation offloading pack/unpack hooks
|
||||
# Safe because expert_weights are frozen (requires_grad=False).
|
||||
ctx.expert_weights = expert_weights
|
||||
ctx.expert_biases = expert_biases
|
||||
ctx.grouped_in = grouped_in
|
||||
ctx.grouped_out = grouped_out
|
||||
ctx.k = k
|
||||
ctx.scaling = scaling
|
||||
ctx.use_fused_dX = use_fused_dX
|
||||
ctx.use_fused_gather = use_fused_gather
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out: torch.Tensor):
|
||||
with torch.device(grad_out.device):
|
||||
(
|
||||
x,
|
||||
lora_A,
|
||||
lora_B,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
gates,
|
||||
output_expanded,
|
||||
) = ctx.saved_tensors
|
||||
expert_weights = ctx.expert_weights
|
||||
|
||||
k = ctx.k
|
||||
scaling = ctx.scaling
|
||||
grouped_in = ctx.grouped_in
|
||||
grouped_out = ctx.grouped_out
|
||||
E = expert_weights.size(0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Gate gradients (if using top-k gating with routing weights)
|
||||
# ------------------------------------------------------------------
|
||||
if gates is not None:
|
||||
# d_gates[t, j] = output_expanded[t, j, :] . grad_out[t, :]
|
||||
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
|
||||
gates_flat = gates.flatten()
|
||||
gate_fan = gates.size(1)
|
||||
# Reuse output_expanded buffer for grouped_grad_out
|
||||
grouped_grad_out = output_expanded.flatten(0, 1)
|
||||
else:
|
||||
d_gates = None
|
||||
gates_flat = None
|
||||
gate_fan = 1
|
||||
grouped_grad_out = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LoRA gradients (dA, dB) and setup for dX
|
||||
# ------------------------------------------------------------------
|
||||
# Fused gather uses sorted_scattered_idxs for indirect X access
|
||||
# in the Triton kernel, avoiding the group(x) allocation.
|
||||
#
|
||||
# can_fuse_gather: X is ungrouped and not too large for scatter loads
|
||||
# - When gates is None and grouped_out=False: both DY and X ungrouped
|
||||
# - When grouped_out=True (gate_up_proj): DY already grouped, X ungrouped
|
||||
# -> use dy_grouped=True in the fused kernel
|
||||
M_total = sorted_scattered_idxs.size(0)
|
||||
K_dim = x.size(-1)
|
||||
N_dim = expert_weights.size(-1)
|
||||
fuse_gather_workload = M_total * max(K_dim, N_dim)
|
||||
_FUSE_GATHER_THRESHOLD = 2**24 # ~16M elements
|
||||
|
||||
can_fuse_gather = (
|
||||
ctx.use_fused_gather
|
||||
and not grouped_in # X must be ungrouped for scatter access
|
||||
and gates is None # gate coeff requires multiplicative gather
|
||||
and fuse_gather_workload < _FUSE_GATHER_THRESHOLD
|
||||
)
|
||||
|
||||
if can_fuse_gather:
|
||||
# ------------------------------------------------------------------
|
||||
# Fused path: skip group(x) entirely
|
||||
# ------------------------------------------------------------------
|
||||
d_expanded_input = None
|
||||
|
||||
d_lora_A, d_lora_B = group_bwd_lora_fused(
|
||||
DY=grad_out,
|
||||
X=x,
|
||||
lora_A=lora_A,
|
||||
lora_B=lora_B,
|
||||
expert_offsets=expert_offsets,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
E=E,
|
||||
k=k,
|
||||
scaling=scaling,
|
||||
dy_grouped=grouped_out,
|
||||
)
|
||||
|
||||
# Prepare grouped_grad_out for the dX path (needed by both
|
||||
# the fused dX kernel when grouped_out=True, and the non-fused path)
|
||||
if grouped_out:
|
||||
grouped_grad_out = grad_out
|
||||
elif not ctx.use_fused_dX:
|
||||
grouped_grad_out = base_ops.group(
|
||||
grad_out,
|
||||
sorted_scattered_idxs,
|
||||
fan_out=gate_fan,
|
||||
coeff=gates_flat,
|
||||
out=grouped_grad_out,
|
||||
)
|
||||
else:
|
||||
# ------------------------------------------------------------------
|
||||
# Original path: explicit group() calls
|
||||
# ------------------------------------------------------------------
|
||||
if grouped_out:
|
||||
grouped_grad_out = grad_out
|
||||
else:
|
||||
grouped_grad_out = base_ops.group(
|
||||
grad_out,
|
||||
sorted_scattered_idxs,
|
||||
fan_out=gate_fan,
|
||||
coeff=gates_flat,
|
||||
out=grouped_grad_out,
|
||||
)
|
||||
|
||||
if grouped_in:
|
||||
grouped_x = x
|
||||
d_expanded_input = None
|
||||
else:
|
||||
grouped_x = base_ops.group(x, sorted_scattered_idxs, fan_out=k)
|
||||
d_expanded_input = grouped_x # Will be overwritten; reuse buffer
|
||||
|
||||
d_lora_A, d_lora_B = group_bwd_lora(
|
||||
DY=grouped_grad_out,
|
||||
X=grouped_x,
|
||||
lora_A=lora_A,
|
||||
lora_B=lora_B,
|
||||
expert_offsets=expert_offsets,
|
||||
E=E,
|
||||
scaling=scaling,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Input gradient: dX = dY @ W^T + scaling * (dY @ B) @ A
|
||||
# ------------------------------------------------------------------
|
||||
if ctx.use_fused_dX:
|
||||
if can_fuse_gather and not grouped_out:
|
||||
# Fully fused: read ungrouped DY via scatter pattern
|
||||
d_expanded_input = scatter2scatter_lora_dX(
|
||||
DY=grad_out,
|
||||
W=expert_weights,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=1,
|
||||
lora_A=lora_A,
|
||||
lora_B=lora_B,
|
||||
scaling=scaling,
|
||||
dy_grouped=False,
|
||||
dx_grouped=grouped_in,
|
||||
out=d_expanded_input,
|
||||
)
|
||||
else:
|
||||
# Fused dX only: read from pre-grouped DY
|
||||
d_expanded_input = scatter2scatter_lora_dX(
|
||||
DY=grouped_grad_out,
|
||||
W=expert_weights,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=1,
|
||||
lora_A=lora_A,
|
||||
lora_B=lora_B,
|
||||
scaling=scaling,
|
||||
dy_grouped=True,
|
||||
dx_grouped=grouped_in,
|
||||
out=d_expanded_input,
|
||||
)
|
||||
else:
|
||||
# Original path: separate base scatter2scatter + LoRA Python loop
|
||||
d_expanded_input = base_ops.scatter2scatter(
|
||||
X=grouped_grad_out,
|
||||
x_grouped=True,
|
||||
W=expert_weights.permute(0, 2, 1), # [E, N, K]
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=1,
|
||||
y_grouped=grouped_in,
|
||||
out=d_expanded_input,
|
||||
)
|
||||
|
||||
# LoRA part: dX_lora = scaling * (dY @ B) @ A
|
||||
if scaling != 0.0:
|
||||
d_input_lora_grouped = _compute_lora_input_grad(
|
||||
grouped_grad_out,
|
||||
lora_A,
|
||||
lora_B,
|
||||
expert_offsets,
|
||||
E,
|
||||
scaling,
|
||||
)
|
||||
if grouped_in:
|
||||
d_expanded_input.add_(d_input_lora_grouped)
|
||||
else:
|
||||
# Scatter-add LoRA gradient directly into d_expanded_input.
|
||||
# Avoids allocating a zeros_like + add result
|
||||
d_expanded_input[sorted_scattered_idxs] += d_input_lora_grouped
|
||||
|
||||
# Reduce over top-k if k > 1
|
||||
if k == 1:
|
||||
d_input = d_expanded_input
|
||||
else:
|
||||
d_input = d_expanded_input.view(
|
||||
x.size(0), k, d_expanded_input.size(-1)
|
||||
).sum(-2)
|
||||
|
||||
# W is frozen during LoRA training -- skip weight gradient
|
||||
d_weights = (
|
||||
torch.zeros_like(expert_weights)
|
||||
if expert_weights.requires_grad
|
||||
else None
|
||||
)
|
||||
d_biases = None
|
||||
|
||||
return (
|
||||
d_input,
|
||||
d_weights,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, # k, sorted indices, offsets
|
||||
d_lora_A,
|
||||
d_lora_B,
|
||||
None, # lora_A, lora_B, scaling
|
||||
d_biases,
|
||||
d_gates,
|
||||
None,
|
||||
None, # grouped_in, grouped_out
|
||||
None, # use_fused_dX
|
||||
None, # use_fused_gather
|
||||
)
|
||||
|
||||
|
||||
def _compute_lora_input_grad(
|
||||
grouped_grad_out: torch.Tensor,
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
E: int,
|
||||
scaling: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the LoRA contribution to the input gradient:
|
||||
dX_lora = scaling * (dY @ B) @ A
|
||||
|
||||
Uses PyTorch ops on expert-grouped data.
|
||||
Each expert e: dX_e = scaling * (dY_e @ B_e) @ A_e
|
||||
"""
|
||||
R = lora_A.size(0) // E
|
||||
K = lora_A.size(1)
|
||||
M_total = grouped_grad_out.size(0)
|
||||
|
||||
d_input_lora = torch.zeros(
|
||||
(M_total, K), device=grouped_grad_out.device, dtype=grouped_grad_out.dtype
|
||||
)
|
||||
|
||||
compute_dtype = grouped_grad_out.dtype
|
||||
|
||||
prev_offset = 0
|
||||
for e in range(E):
|
||||
curr_offset = expert_offsets[e].item()
|
||||
if curr_offset > prev_offset:
|
||||
dy_e = grouped_grad_out[prev_offset:curr_offset] # [M_e, N]
|
||||
a_e = lora_A[e * R : (e + 1) * R, :].to(compute_dtype) # [r, K]
|
||||
b_e = lora_B[:, e * R : (e + 1) * R].to(compute_dtype) # [N, r]
|
||||
|
||||
# dX_e = scaling * (dY_e @ B_e) @ A_e
|
||||
dy_b = dy_e @ b_e # [M_e, r]
|
||||
dx_e = scaling * (dy_b @ a_e) # [M_e, K]
|
||||
d_input_lora[prev_offset:curr_offset] = dx_e
|
||||
|
||||
prev_offset = curr_offset
|
||||
|
||||
return d_input_lora
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper: Extract LoRA params from PEFT ParamWrapper
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_lora_params_from_wrapper(module) -> tuple:
|
||||
"""
|
||||
Extract LoRA parameters from a PEFT ParamWrapper.
|
||||
|
||||
Returns:
|
||||
(lora_A, lora_B, scaling) if LoRA is active, else (None, None, None)
|
||||
"""
|
||||
if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"):
|
||||
return None, None, None
|
||||
|
||||
active_adapters = getattr(module, "active_adapters", ["default"])
|
||||
if not active_adapters:
|
||||
return None, None, None
|
||||
|
||||
adapter_name = active_adapters[0]
|
||||
|
||||
lora_A_dict = getattr(module, "lora_A", {})
|
||||
lora_B_dict = getattr(module, "lora_B", {})
|
||||
scaling_dict = getattr(module, "scaling", {})
|
||||
|
||||
if adapter_name not in lora_A_dict:
|
||||
return None, None, None
|
||||
|
||||
lora_A = lora_A_dict[adapter_name].weight
|
||||
lora_B = lora_B_dict[adapter_name].weight
|
||||
scaling = scaling_dict[adapter_name]
|
||||
|
||||
return lora_A, lora_B, scaling
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Drop-in replacement for parallel_linear
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def parallel_linear_lora(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
k: int,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
sorted_scattered_idxs: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
lora_A: Optional[torch.Tensor] = None,
|
||||
lora_B: Optional[torch.Tensor] = None,
|
||||
scaling: float = 1.0,
|
||||
expert_biases: Optional[torch.Tensor] = None,
|
||||
gates: Optional[torch.Tensor] = None,
|
||||
grouped_in: bool = False,
|
||||
grouped_out: bool = False,
|
||||
use_fused_dX: bool = False,
|
||||
use_fused_gather: bool = False,
|
||||
):
|
||||
"""
|
||||
Drop-in replacement for parallel_linear that supports LoRA.
|
||||
|
||||
If lora_A and lora_B are provided, uses fused LoRA kernel.
|
||||
Otherwise falls back to standard scatter2scatter.
|
||||
"""
|
||||
if lora_A is not None and lora_B is not None:
|
||||
return ScatterMoELoRA.apply(
|
||||
inputs,
|
||||
expert_weights,
|
||||
k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
lora_A,
|
||||
lora_B,
|
||||
scaling,
|
||||
expert_biases,
|
||||
gates,
|
||||
grouped_in,
|
||||
grouped_out,
|
||||
use_fused_dX,
|
||||
use_fused_gather,
|
||||
)
|
||||
else:
|
||||
from .parallel_experts import ParallelLinear
|
||||
|
||||
return ParallelLinear.apply(
|
||||
inputs,
|
||||
expert_weights,
|
||||
k,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
expert_biases,
|
||||
gates,
|
||||
grouped_in,
|
||||
grouped_out,
|
||||
)
|
||||
@@ -1,5 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
from kernels import (
|
||||
LayerRepository,
|
||||
LocalLayerRepository,
|
||||
Mode,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
@@ -19,16 +21,19 @@ class KernelsPlugin(BasePlugin):
|
||||
self._kernelize_model(cfg.model_config_type)
|
||||
|
||||
def _register_kernels(self):
|
||||
plugin_root = Path(__file__).parent
|
||||
register_kernel_mapping(
|
||||
{
|
||||
"HFScatterMoEParallelExperts": {
|
||||
"cuda": {
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="axolotl-ai-co/scattermoe",
|
||||
Mode.TRAINING: LocalLayerRepository(
|
||||
repo_path=plugin_root / "libs" / "scattermoe_lora",
|
||||
package_name="scattermoe_lora",
|
||||
layer_name="HFScatterMoEGatedMLP",
|
||||
),
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="axolotl-ai-co/scattermoe",
|
||||
Mode.INFERENCE: LocalLayerRepository(
|
||||
repo_path=plugin_root / "libs" / "scattermoe_lora",
|
||||
package_name="scattermoe_lora",
|
||||
layer_name="HFScatterMoEGatedMLP",
|
||||
),
|
||||
},
|
||||
|
||||
@@ -329,7 +329,7 @@ class PatchManager:
|
||||
else:
|
||||
has_remote_code = False
|
||||
|
||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||
if has_remote_code and self.cfg.trust_remote_code is not None:
|
||||
# If explicitly set in YAML, prefer that
|
||||
has_remote_code = self.cfg.trust_remote_code
|
||||
|
||||
|
||||
@@ -54,15 +54,19 @@ class FileLockLoader:
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up ready flag when last process is done."""
|
||||
with FileLock(str(self.lock_file_path)):
|
||||
counter_content = self.counter_path.read_text().strip()
|
||||
count = int(counter_content) if counter_content else 0
|
||||
count -= 1
|
||||
try:
|
||||
with FileLock(str(self.lock_file_path)):
|
||||
counter_content = self.counter_path.read_text().strip()
|
||||
count = int(counter_content) if counter_content else 0
|
||||
count -= 1
|
||||
|
||||
if count <= 0:
|
||||
# Last process cleans everything up
|
||||
self.ready_flag_path.unlink(missing_ok=True)
|
||||
self.counter_path.unlink(missing_ok=True)
|
||||
else:
|
||||
# Still have active processes
|
||||
self.counter_path.write_text(str(count))
|
||||
if count <= 0:
|
||||
# Last process cleans everything up
|
||||
self.ready_flag_path.unlink(missing_ok=True)
|
||||
self.counter_path.unlink(missing_ok=True)
|
||||
else:
|
||||
# Still have active processes
|
||||
self.counter_path.write_text(str(count))
|
||||
except FileNotFoundError:
|
||||
# Lock file might have already been deleted by another process
|
||||
pass
|
||||
|
||||
323
tests/integrations/test_scattermoe_lora.py
Normal file
323
tests/integrations/test_scattermoe_lora.py
Normal file
@@ -0,0 +1,323 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
Unit tests for scattermoe-lora code-review fixes.
|
||||
|
||||
Tests cover:
|
||||
- KernelsArgs validator: disable_mlp_kernel_scattermoe
|
||||
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
|
||||
- ParallelExperts: scaling=0.0 not treated as falsy
|
||||
- single2scatter: non-aligned K/N dimensions
|
||||
- group_compileable: coeff=None accepted
|
||||
- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# ============================================================================
|
||||
# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestKernelsArgsValidator:
|
||||
"""Test that disable_mlp_kernel_scattermoe sets both flags correctly.
|
||||
|
||||
These tests call the validator classmethod directly on raw dicts,
|
||||
since lora_mlp_kernel / mlp_kernel are not declared model fields.
|
||||
"""
|
||||
|
||||
def test_disables_lora_mlp_kernel_when_scattermoe(self):
|
||||
"""lora_mlp_kernel=True gets set to False when use_scattermoe=True."""
|
||||
from axolotl.integrations.kernels.args import KernelsArgs
|
||||
|
||||
data = {
|
||||
"use_kernels": True,
|
||||
"use_scattermoe": True,
|
||||
"lora_mlp_kernel": True,
|
||||
}
|
||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
||||
assert result["lora_mlp_kernel"] is False
|
||||
assert result["mlp_kernel"] is False
|
||||
|
||||
def test_mlp_kernel_disabled_without_lora(self):
|
||||
"""Even without lora_mlp_kernel, mlp_kernel should be disabled."""
|
||||
from axolotl.integrations.kernels.args import KernelsArgs
|
||||
|
||||
data = {
|
||||
"use_kernels": True,
|
||||
"use_scattermoe": True,
|
||||
}
|
||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
||||
assert result["mlp_kernel"] is False
|
||||
# lora_mlp_kernel was not in data, should not be added
|
||||
assert "lora_mlp_kernel" not in result
|
||||
|
||||
def test_lora_mlp_kernel_false_unchanged(self):
|
||||
"""lora_mlp_kernel=False should stay False (no warning, no change)."""
|
||||
from axolotl.integrations.kernels.args import KernelsArgs
|
||||
|
||||
data = {
|
||||
"use_kernels": True,
|
||||
"use_scattermoe": True,
|
||||
"lora_mlp_kernel": False,
|
||||
}
|
||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
||||
assert result["lora_mlp_kernel"] is False
|
||||
|
||||
def test_no_change_when_scattermoe_disabled(self):
|
||||
"""When use_scattermoe is not True, nothing should be changed."""
|
||||
from axolotl.integrations.kernels.args import KernelsArgs
|
||||
|
||||
data = {
|
||||
"use_kernels": True,
|
||||
"use_scattermoe": False,
|
||||
"lora_mlp_kernel": True,
|
||||
}
|
||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
||||
assert result["lora_mlp_kernel"] is True
|
||||
|
||||
|
||||
class TestParallelExpertsScaling:
|
||||
"""Test that scaling=0.0 is preserved and not overridden to 1.0."""
|
||||
|
||||
def test_scaling_zero_preserved(self):
|
||||
"""scaling=0.0 should be passed as 0.0, not replaced with 1.0."""
|
||||
pytest.importorskip("triton")
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
|
||||
ParallelExperts,
|
||||
)
|
||||
|
||||
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
|
||||
pe.set_lora(
|
||||
lora_A=torch.randn(4, 4),
|
||||
lora_B=torch.randn(4, 4),
|
||||
scaling=0.0,
|
||||
)
|
||||
assert pe._lora_scaling == 0.0
|
||||
|
||||
# Patch parallel_linear_lora to capture the scaling arg
|
||||
with patch(
|
||||
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
|
||||
) as mock_pll:
|
||||
mock_pll.return_value = torch.randn(4, 4)
|
||||
# Create dummy routing tensors
|
||||
pe.forward(
|
||||
inputs=torch.randn(2, 4),
|
||||
k=1,
|
||||
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
|
||||
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
|
||||
expert_offsets=torch.tensor([2, 4]),
|
||||
)
|
||||
# Check that scaling=0.0 was passed, not 1.0
|
||||
call_kwargs = mock_pll.call_args
|
||||
assert (
|
||||
call_kwargs.kwargs.get("scaling") == 0.0
|
||||
or call_kwargs[1].get("scaling") == 0.0
|
||||
), f"Expected scaling=0.0 but got {call_kwargs}"
|
||||
|
||||
def test_scaling_none_defaults_to_one(self):
|
||||
"""scaling=None (no LoRA attached) should default to 1.0."""
|
||||
pytest.importorskip("triton")
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
|
||||
ParallelExperts,
|
||||
)
|
||||
|
||||
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
|
||||
# No set_lora called, so _lora_scaling is None
|
||||
|
||||
with patch(
|
||||
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
|
||||
) as mock_pll:
|
||||
mock_pll.return_value = torch.randn(4, 4)
|
||||
pe.forward(
|
||||
inputs=torch.randn(2, 4),
|
||||
k=1,
|
||||
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
|
||||
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
|
||||
expert_offsets=torch.tensor([2, 4]),
|
||||
)
|
||||
call_kwargs = mock_pll.call_args
|
||||
scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get(
|
||||
"scaling"
|
||||
)
|
||||
assert scaling_val == 1.0, (
|
||||
f"Expected scaling=1.0 for None but got {scaling_val}"
|
||||
)
|
||||
|
||||
def test_scaling_positive_preserved(self):
|
||||
"""Normal positive scaling should be preserved."""
|
||||
pytest.importorskip("triton")
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
|
||||
ParallelExperts,
|
||||
)
|
||||
|
||||
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
|
||||
pe.set_lora(
|
||||
lora_A=torch.randn(4, 4),
|
||||
lora_B=torch.randn(4, 4),
|
||||
scaling=0.5,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
|
||||
) as mock_pll:
|
||||
mock_pll.return_value = torch.randn(4, 4)
|
||||
pe.forward(
|
||||
inputs=torch.randn(2, 4),
|
||||
k=1,
|
||||
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
|
||||
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
|
||||
expert_offsets=torch.tensor([2, 4]),
|
||||
)
|
||||
call_kwargs = mock_pll.call_args
|
||||
scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get(
|
||||
"scaling"
|
||||
)
|
||||
assert scaling_val == 0.5
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 4. single2scatter: non-aligned K/N dimensions (GPU only)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
class TestSingle2ScatterBounds:
|
||||
"""Test single2scatter with non-aligned dimensions."""
|
||||
|
||||
def test_non_aligned_k(self):
|
||||
"""K not a multiple of BLOCK_K should produce correct results."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
|
||||
single2scatter,
|
||||
)
|
||||
|
||||
E, K, N = 2, 100, 128 # K=100 not a multiple of 128
|
||||
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
|
||||
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
|
||||
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
|
||||
|
||||
Y = single2scatter(X, W, expert_idxs)
|
||||
assert Y.shape == (2, N)
|
||||
|
||||
# Verify against manual computation
|
||||
Y_ref_0 = X[0] @ W[0]
|
||||
Y_ref_1 = X[0] @ W[1]
|
||||
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_non_aligned_n(self):
|
||||
"""N not a multiple of BLOCK_N should produce correct results."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
|
||||
single2scatter,
|
||||
)
|
||||
|
||||
E, K, N = 2, 128, 100 # N=100 not a multiple of 128
|
||||
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
|
||||
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
|
||||
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
|
||||
|
||||
Y = single2scatter(X, W, expert_idxs)
|
||||
assert Y.shape == (2, N)
|
||||
|
||||
Y_ref_0 = X[0] @ W[0]
|
||||
Y_ref_1 = X[0] @ W[1]
|
||||
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_non_aligned_both(self):
|
||||
"""Both K and N not aligned should produce correct results."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
|
||||
single2scatter,
|
||||
)
|
||||
|
||||
E, K, N = 2, 100, 100 # Neither aligned to 128
|
||||
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
|
||||
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
|
||||
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
|
||||
|
||||
Y = single2scatter(X, W, expert_idxs)
|
||||
assert Y.shape == (2, N)
|
||||
|
||||
Y_ref_0 = X[0] @ W[0]
|
||||
Y_ref_1 = X[0] @ W[1]
|
||||
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 5. group_compileable: coeff=None accepted
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
class TestGroupCoeffNone:
|
||||
"""Test that group() works with coeff=None."""
|
||||
|
||||
def test_group_with_none_coeff(self):
|
||||
"""group() should accept coeff=None without errors."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group
|
||||
|
||||
M, K = 4, 32
|
||||
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
|
||||
sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long)
|
||||
|
||||
# This should not raise a TypeError
|
||||
Y = group(A, sorted_expert_idxs, coeff=None, fan_out=1)
|
||||
assert Y.shape == (M, K)
|
||||
|
||||
def test_group_with_coeff(self):
|
||||
"""group() should also work with actual coeff values."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group
|
||||
|
||||
M, K = 4, 32
|
||||
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
|
||||
sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long)
|
||||
coeff = torch.ones(M, device="cuda", dtype=torch.float32) * 0.5
|
||||
|
||||
Y = group(A, sorted_expert_idxs, coeff=coeff, fan_out=1)
|
||||
assert Y.shape == (M, K)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 6. Layer return value contracts
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestLayerReturnValues:
|
||||
"""Test that layer forward methods return the correct types."""
|
||||
|
||||
def test_hf_scatter_moe_returns_single_tensor(self):
|
||||
"""HFScatterMoEGatedMLP.forward should return a single tensor, not a tuple."""
|
||||
pytest.importorskip("triton")
|
||||
# Verify the forward method signature and return annotation
|
||||
import inspect
|
||||
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
HFScatterMoEGatedMLP,
|
||||
)
|
||||
|
||||
sig = inspect.signature(HFScatterMoEGatedMLP.forward)
|
||||
# It's a staticmethod taking (self, layer_input)
|
||||
params = list(sig.parameters.keys())
|
||||
assert "self" in params
|
||||
assert "layer_input" in params
|
||||
|
||||
def test_scatter_moe_gated_mlp_docstring_no_router_logits(self):
|
||||
"""ScatterMoEGatedMLP.forward docstring should not mention router logits as return."""
|
||||
pytest.importorskip("triton")
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
ScatterMoEGatedMLP,
|
||||
)
|
||||
|
||||
docstring = ScatterMoEGatedMLP.forward.__doc__
|
||||
assert docstring is not None
|
||||
# The docstring should mention output tensor but NOT router logits
|
||||
assert "Output tensor" in docstring or "output tensor" in docstring.lower()
|
||||
assert "Router logits" not in docstring, (
|
||||
"Docstring should not mention 'Router logits' in Returns section"
|
||||
)
|
||||
Reference in New Issue
Block a user