chore: lint

This commit is contained in:
Wing Lian
2026-03-19 07:27:23 +00:00
parent 31d8d068bb
commit fec0c3a99e
8 changed files with 443 additions and 191 deletions

View File

@@ -12,14 +12,14 @@ Usage:
import argparse
import gc
import statistics
import time
from functools import partial
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
ops as base_ops,
lora_ops,
ops as base_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
@@ -36,7 +36,7 @@ ITERS = 20
# ─── Model configs ──────────────────────────────────────────────────────────
BUILTIN_CONFIGS = {
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
"Qwen3-30B-A3B": (128, 2048, 768, 8),
"OLMoE-1B-7B": (64, 2048, 1024, 8),
"Mixtral-8x7B": (8, 4096, 14336, 2),
@@ -50,26 +50,32 @@ def _resolve_config(spec):
if key in name.lower() or name.lower() in key:
return name, cfg
# Try HuggingFace AutoConfig
from transformers import AutoConfig
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
if callable(getattr(hf_cfg, "get_text_config", None)):
tc = hf_cfg.get_text_config()
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
hf_cfg = tc
H = hf_cfg.hidden_size
I = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
E = (getattr(hf_cfg, "num_experts", None)
or getattr(hf_cfg, "num_local_experts", None)
or getattr(hf_cfg, "n_routed_experts", None))
k = (getattr(hf_cfg, "num_experts_per_tok", None)
or getattr(hf_cfg, "num_experts_per_token", None) or 2)
hidden = hf_cfg.hidden_size
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
experts = (
getattr(hf_cfg, "num_experts", None)
or getattr(hf_cfg, "num_local_experts", None)
or getattr(hf_cfg, "n_routed_experts", None)
)
top_k = (
getattr(hf_cfg, "num_experts_per_tok", None)
or getattr(hf_cfg, "num_experts_per_token", None)
or 2
)
name = spec.split("/")[-1]
return name, (E, H, I, k)
return name, (experts, hidden, inter, top_k)
# ─── Benchmark helpers ──────────────────────────────────────────────────────
def _clean():
gc.collect()
torch.cuda.empty_cache()
@@ -87,29 +93,88 @@ def _bench(fn, warmup=WARMUP, iters=ITERS):
fn()
torch.cuda.synchronize()
times.append((time.perf_counter() - t0) * 1000)
return statistics.median(times)
times.sort()
return times[len(times) // 2]
def _setup(E, K, N, T, top_k, R):
def _setup(num_experts, K, N, T, top_k, R):
torch.manual_seed(42)
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02
lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01
lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01
logits = torch.randn(T, E, device=DEVICE)
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
logits = torch.randn(T, num_experts, device=DEVICE)
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
sei, ssi, eo = flatten_sort_count(top_idx, E)
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
gx = base_ops.group(x, ssi, fan_out=top_k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
return lora_ops.scatter2scatter_lora(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
lora_A=lA,
lora_B=lB,
scaling=2.0,
)
def _call_base(x, W, sei, ssi, top_k):
return base_ops.scatter2scatter(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
)
def _call_dx(dy, W, sei, ssi, lA, lB):
return lora_ops.scatter2scatter_lora_dX(
DY=dy,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=1,
lora_A=lA,
lora_B=lB,
scaling=2.0,
dy_grouped=True,
dx_grouped=False,
)
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
return lora_ops.group_bwd_lora(
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=num_experts,
scaling=2.0,
)
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
parser.add_argument("--models", "-m", nargs="+",
help="Model names or HF IDs (default: all builtins)")
parser.add_argument(
"--models",
"-m",
nargs="+",
help="Model names or HF IDs (default: all builtins)",
)
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
parser.add_argument("--seq-len", "-T", type=int, default=2048)
args = parser.parse_args()
@@ -122,73 +187,84 @@ def main():
configs = [_resolve_config(m) for m in args.models]
else:
configs = list(BUILTIN_CONFIGS.items())
configs = [(n, c) for n, c in configs]
for model_name, (E, H, I, k) in configs:
for model_name, (num_experts, hidden, inter, top_k) in configs:
print(f"{'=' * 70}")
print(f" {model_name}: E={E}, H={H}, I={I}, k={k}")
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
print(f"{'=' * 70}")
for R in args.ranks:
for proj, K, N in [("gate_up", H, 2 * I), ("down", I, H)]:
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
_clean()
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R)
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
num_experts, K, N, T, top_k, R
)
# Forward with LoRA (auto-dispatched: fused or split)
dispatch = "split" if (E <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD) else "fused"
t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=2.0,
))
# Forward without LoRA (base)
t_base = _bench(lambda: base_ops.scatter2scatter(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k,
))
# Backward dX
t_dx = _bench(lambda: lora_ops.scatter2scatter_lora_dX(
DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=1, lora_A=lA, lora_B=lB, scaling=2.0,
dy_grouped=True, dx_grouped=False,
))
# Backward dA/dB
t_bwd = _bench(lambda: lora_ops.group_bwd_lora(
DY=dy, X=gx, lora_A=lA, lora_B=lB,
expert_offsets=eo, E=E, scaling=2.0,
))
dispatch = (
"split"
if (
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
)
else "fused"
)
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
total = t_fwd + t_dx + t_bwd
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
print(f" R={R:>2} {proj:<8} "
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
f"base={t_base:>6.2f}ms "
f"(+{overhead*100:.0f}%) "
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
f"total={total:>6.2f}ms")
print(
f" R={R:>2} {proj:<8} "
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
f"base={t_base:>6.2f}ms "
f"(+{overhead * 100:.0f}%) "
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
f"total={total:>6.2f}ms"
)
# Full autograd fwd+bwd
# Full autograd fwd+bwd with memory measurement
x_ag = x.clone().requires_grad_(True)
lA_ag = lA.clone().requires_grad_(True)
lB_ag = lB.clone().requires_grad_(True)
def _run_autograd():
def _run_autograd(
_x=x_ag,
_W=W,
_k=top_k,
_sei=sei,
_ssi=ssi,
_eo=eo,
_lA=lA_ag,
_lB=lB_ag,
):
out = ScatterMoELoRA.apply(
x_ag, W, k, sei, ssi, eo,
lA_ag, lB_ag, 2.0,
None, None, False, False, True, False,
_x,
_W,
_k,
_sei,
_ssi,
_eo,
_lA,
_lB,
2.0,
None,
None,
False,
False,
True,
False,
)
out.sum().backward()
x_ag.grad = None
lA_ag.grad = None
lB_ag.grad = None
_x.grad = None
_lA.grad = None
_lB.grad = None
t_full = _bench(_run_autograd)
# Memory measurement
_clean()
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
@@ -196,8 +272,10 @@ def main():
torch.cuda.synchronize()
mem_peak = torch.cuda.max_memory_allocated() - mem_before
print(f" full_fwd_bwd={t_full:>6.2f}ms "
f"peak_delta={mem_peak/1e6:>6.1f}MB")
print(
f" full_fwd_bwd={t_full:>6.2f}ms "
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
)
print()

View File

@@ -566,30 +566,41 @@ def _scatter2scatter_lora_split(
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
output = scatter2scatter(
X=X, W=W, b=b,
X=X,
W=W,
b=b,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k, x_grouped=x_grouped, y_grouped=y_grouped, out=out,
k=k,
x_grouped=x_grouped,
y_grouped=y_grouped,
out=out,
)
# 2. XA = X @ A^T (tiny: output is [M*k, R])
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
XA = scatter2scatter(
X=X, W=W_A,
X=X,
W=W_A,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k, x_grouped=x_grouped, y_grouped=True,
k=k,
x_grouped=x_grouped,
y_grouped=True,
)
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
# Reshape B: [N, R*E] → [E, R, N]
W_B = lora_B.T.reshape(E, R, N).contiguous()
Y_lora = scatter2scatter(
X=XA, W=W_B,
X=XA,
W=W_B,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1, x_grouped=True, y_grouped=y_grouped,
k=1,
x_grouped=True,
y_grouped=y_grouped,
)
# 4. Y = Y_base + scaling * Y_lora
@@ -650,13 +661,20 @@ def scatter2scatter_lora(
N = W.size(2)
# Dispatch: split for few large experts, fused for many small experts
if (
E <= _SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= _SPLIT_LORA_FWD_THRESHOLD
):
if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD:
return _scatter2scatter_lora_split(
X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
lora_A, lora_B, scaling, b, x_grouped, y_grouped, out,
X,
W,
sorted_expert_idxs,
sorted_scattered_idxs,
k,
lora_A,
lora_B,
scaling,
b,
x_grouped,
y_grouped,
out,
)
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
@@ -1443,7 +1461,6 @@ def _prune_split_configs(configs, named_args, **kwargs):
"""Prune split kernel configs based on SMEM capacity."""
smem_cap = _get_smem_capacity()
block_r = named_args.get("BLOCK_R", 64)
inner_dim = named_args.get("INNER_DIM", 2048)
# Fixed inner tile for reduction dimension
BLOCK_INNER = 64
@@ -1470,33 +1487,47 @@ def _prune_split_configs(configs, named_args, **kwargs):
key=["M", "K", "N"],
prune_configs_by={"early_config_prune": _prune_split_configs},
)
@triton.heuristics({
"NO_DIM_MASK": lambda args: (
(args["K"] % args["BLOCK_DIM"]) == 0
if args["COMPUTE_DA"]
else (args["N"] % args["BLOCK_DIM"]) == 0
),
})
@triton.heuristics(
{
"NO_DIM_MASK": lambda args: (
(args["K"] % args["BLOCK_DIM"]) == 0
if args["COMPUTE_DA"]
else (args["N"] % args["BLOCK_DIM"]) == 0
),
}
)
@triton.jit
def _group_bwd_lora_split(
# Data tensors (DY and X are always present)
DY_ptr, stride_dym, stride_dyn,
X_ptr, stride_xm, stride_xk,
DY_ptr,
stride_dym,
stride_dyn,
X_ptr,
stride_xm,
stride_xk,
# LoRA weight for the inner reduction (B for dA, A for dB)
LW_ptr, stride_lw0, stride_lw1,
LW_ptr,
stride_lw0,
stride_lw1,
# Output gradient tensor (dA or dB)
OUT_ptr, stride_out0, stride_out1,
OUT_ptr,
stride_out0,
stride_out1,
# Expert offsets
expert_offsets_ptr,
# Dimensions
M, K: tl.constexpr, N: tl.constexpr,
ACTUAL_R: tl.constexpr, BLOCK_R: tl.constexpr,
M,
K: tl.constexpr,
N: tl.constexpr,
ACTUAL_R: tl.constexpr,
BLOCK_R: tl.constexpr,
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
scaling,
# Mode flag
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
# Tile sizes
BLOCK_M: tl.constexpr, BLOCK_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_DIM_MASK: tl.constexpr,
@@ -1532,9 +1563,9 @@ def _group_bwd_lora_split(
# Output dimension tile (K for dA, N for dB)
if COMPUTE_DA:
OUT_DIM: tl.constexpr = K
OUT_DIM: tl.constexpr = K # type: ignore[no-redef]
else:
OUT_DIM: tl.constexpr = N
OUT_DIM: tl.constexpr = N # type: ignore[no-redef]
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
dim_mask = dim_block < OUT_DIM
R_block = tl.arange(0, BLOCK_R)
@@ -1577,7 +1608,8 @@ def _group_bwd_lora_split(
# Load X[M, K_block] (the "outer" tensor for dA)
outer = tl.load(
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
mask=M_mask[:, None] & dim_mask[None, :], other=0.0
mask=M_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
@@ -1588,23 +1620,34 @@ def _group_bwd_lora_split(
inn_mask = inn_off < N
dy_tile = tl.load(
DY_ptr + M_idx[:, None] * stride_dym + inn_off[None, :] * stride_dyn,
mask=M_mask[:, None] & inn_mask[None, :], other=0.0
DY_ptr
+ M_idx[:, None] * stride_dym
+ inn_off[None, :] * stride_dyn,
mask=M_mask[:, None] & inn_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
lw_tile = tl.load(
LW_ptr + inn_off[:, None] * stride_lw0 + (lora_offset + R_block)[None, :] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :], other=0.0
LW_ptr
+ inn_off[:, None] * stride_lw0
+ (lora_offset + R_block)[None, :] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
acc += tl.dot(tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32)
acc += tl.dot(
tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32
)
else:
# Load DY[M, N_block] (the "outer" tensor for dB)
outer = tl.load(
DY_ptr + M_idx[:, None] * stride_dym + dim_block[None, :] * stride_dyn,
mask=M_mask[:, None] & dim_mask[None, :], other=0.0
DY_ptr
+ M_idx[:, None] * stride_dym
+ dim_block[None, :] * stride_dyn,
mask=M_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
@@ -1615,27 +1658,45 @@ def _group_bwd_lora_split(
inn_mask = inn_off < K
x_tile = tl.load(
X_ptr + M_idx[:, None] * stride_xm + inn_off[None, :] * stride_xk,
mask=M_mask[:, None] & inn_mask[None, :], other=0.0
X_ptr
+ M_idx[:, None] * stride_xm
+ inn_off[None, :] * stride_xk,
mask=M_mask[:, None] & inn_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
# We want A[e]^T: [K, R], so load as [K_inner, R]
lw_tile = tl.load(
LW_ptr + (lora_offset + R_block)[None, :] * stride_lw0 + inn_off[:, None] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :], other=0.0
LW_ptr
+ (lora_offset + R_block)[None, :] * stride_lw0
+ inn_off[:, None] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
acc += tl.dot(tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32)
acc += tl.dot(
tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32
)
tl.store(out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask)
tl.store(
out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask
)
else:
# Zero out this expert's slice — needed because output uses empty_like
if COMPUTE_DA:
tl.store(out_blk_ptrs, tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty), mask=out_mask)
tl.store(
out_blk_ptrs,
tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty),
mask=out_mask,
)
else:
tl.store(out_blk_ptrs, tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty), mask=out_mask)
tl.store(
out_blk_ptrs,
tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty),
mask=out_mask,
)
def group_bwd_lora(
@@ -1683,34 +1744,58 @@ def group_bwd_lora(
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
_group_bwd_lora_split[grid_dA](
DY, DY.stride(0), DY.stride(1),
X, X.stride(0), X.stride(1),
lora_B, lora_B.stride(0), lora_B.stride(1),
dA, dA.stride(0), dA.stride(1),
DY,
DY.stride(0),
DY.stride(1),
X,
X.stride(0),
X.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
expert_offsets,
M=DY.size(0), K=K, N=N,
ACTUAL_R=R, BLOCK_R=BLOCK_R,
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
INNER_DIM=N,
scaling=scaling,
COMPUTE_DA=True,
ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)
def grid_dB(META):
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
_group_bwd_lora_split[grid_dB](
DY, DY.stride(0), DY.stride(1),
X, X.stride(0), X.stride(1),
lora_A, lora_A.stride(0), lora_A.stride(1),
dB, dB.stride(0), dB.stride(1),
DY,
DY.stride(0),
DY.stride(1),
X,
X.stride(0),
X.stride(1),
lora_A,
lora_A.stride(0),
lora_A.stride(1),
dB,
dB.stride(0),
dB.stride(1),
expert_offsets,
M=DY.size(0), K=K, N=N,
ACTUAL_R=R, BLOCK_R=BLOCK_R,
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
INNER_DIM=K,
scaling=scaling,
COMPUTE_DA=False,
ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)
return dA, dB

View File

@@ -511,20 +511,26 @@ class HFScatterMoEGatedMLP(nn.Module):
active_experts = get_active_experts(sorted_expert_idxs, num_experts)
remapped_expert_idxs, compact_offsets = remap_expert_indices(
sorted_expert_idxs, expert_offsets, active_experts, num_experts,
sorted_expert_idxs,
expert_offsets,
active_experts,
num_experts,
)
num_active = len(active_experts)
# Dequantize only active experts' weights
gate_up_W = selective_expert_weights(
experts, "gate_up_proj", active_experts,
experts,
"gate_up_proj",
active_experts,
).transpose(2, 1) # [num_active, hidden, 2*inter]
# Remap LoRA weights to match compact expert indices
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup_A, gup_B = selective_lora_weights(
gup_A, gup_B, active_experts, num_experts,
gup_A,
gup_B,
active_experts,
num_experts,
)
gup_lora = (gup_A, gup_B, gup_scaling)
@@ -576,13 +582,18 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
if use_selective:
down_W = selective_expert_weights(
experts, "down_proj", active_experts,
experts,
"down_proj",
active_experts,
).transpose(2, 1) # [num_active, inter, hidden]
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
down_A, down_B = selective_lora_weights(
down_A, down_B, active_experts, num_experts,
down_A,
down_B,
active_experts,
num_experts,
)
down_lora = (down_A, down_B, down_scaling)

View File

@@ -21,8 +21,6 @@ from global (0..E-1) to compact (0..num_active-1) and pass the smaller
weight tensor.
"""
from typing import Optional
import torch
import torch.nn as nn
@@ -79,7 +77,7 @@ def _selective_dequant_bnb4(
raw_param: torch.Tensor,
quant_state,
active_experts: torch.Tensor,
expert_shape: tuple[int, ...],
expert_shape: tuple[int, int],
) -> torch.Tensor:
"""Dequantize only selected experts from BnB 4-bit packed data.
@@ -231,7 +229,9 @@ def selective_expert_weights(
if E_total is None:
E_total = int(active_experts.max().item()) + 1
expert_numel = orig_shape[0] // E_total
d2 = getattr(experts_module, "hidden_dim", None) or getattr(experts_module, "intermediate_dim", None)
d2 = getattr(experts_module, "hidden_dim", None) or getattr(
experts_module, "intermediate_dim", None
)
if d2 and expert_numel % d2 == 0:
expert_shape = (expert_numel // d2, d2)
else:
@@ -241,9 +241,7 @@ def selective_expert_weights(
full = getattr(experts_module, param_name)
return full[active_experts]
return _selective_dequant_bnb4(
raw_param, qs, active_experts, expert_shape
)
return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape)
# Dense parameter (bf16/fp32) — direct indexing
param = getattr(experts_module, param_name)

View File

@@ -19,15 +19,25 @@ import torch
import triton
import triton.language as tl
# NF4 codebook (16 values, precomputed by BnB)
# These are the normalized float4 reconstruction values
NF4_CODEBOOK = [
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
0.07958029955625534, 0.16093020141124725, 0.24611230194568634,
0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
0.7229568362236023, 1.0,
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
]
@@ -46,8 +56,8 @@ def _selective_dequant_nf4_kernel(
stride_out_e, # stride for expert dim in output
# Dimensions
num_active,
packed_per_expert, # expert_numel // 2
blocks_per_expert, # expert_numel // blocksize
packed_per_expert, # expert_numel // 2
blocks_per_expert, # expert_numel // blocksize
blocksize: tl.constexpr,
# Tile size
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2)
@@ -79,7 +89,9 @@ def _selective_dequant_nf4_kernel(
# Read packed bytes from the global expert's region
packed_global_offset = expert_global * packed_per_expert + byte_idx
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(tl.int32)
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(
tl.int32
)
# Extract 4-bit nibble
# BnB packing: high nibble = even element, low nibble = odd element
@@ -133,8 +145,9 @@ def selective_dequant_nf4_triton(
# Prepare codebook on device
if codebook is None:
codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32,
device=packed_data.device)
codebook = torch.tensor(
NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device
)
else:
codebook = codebook.to(device=packed_data.device, dtype=torch.float32)
@@ -143,8 +156,7 @@ def selective_dequant_nf4_triton(
absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32
# Output buffer
out = torch.empty(num_active, expert_numel, dtype=dtype,
device=packed_data.device)
out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device)
BLOCK_SIZE = 1024 # Process 1024 elements per thread block

View File

@@ -66,7 +66,10 @@ class KernelsPlugin(BasePlugin):
# Prefer text backbone type for VLMs, but fall back to base type
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
if moe_model_type not in SPARSE_MOE_BLOCK and cfg.model_config_type in SPARSE_MOE_BLOCK:
if (
moe_model_type not in SPARSE_MOE_BLOCK
and cfg.model_config_type in SPARSE_MOE_BLOCK
):
moe_model_type = cfg.model_config_type
if cfg.use_scattermoe:

View File

@@ -28,9 +28,7 @@ class PytorchProfilerCallback(TrainerCallback):
if profiler_steps_start == 0:
# start recording memory allocations before everything is allocated, because if we start
# at the beginning of step 0, we won't have any memory allocations in the traces
torch.cuda.memory._record_memory_history(
enabled="all", stacks="all"
)
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
profiler_steps_start = -1
self.profiler_steps_start = profiler_steps_start
self._profiler = None
@@ -43,13 +41,11 @@ class PytorchProfilerCallback(TrainerCallback):
**kwargs,
):
if state.global_step == self.profiler_steps_start:
torch.cuda.memory._record_memory_history(
enabled="all", stacks="all"
)
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
# Start torch.profiler on the first profiled step
if state.global_step == max(self.profiler_steps_start, 0):
self._profiler = torch.profiler.profile(
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
@@ -58,7 +54,8 @@ class PytorchProfilerCallback(TrainerCallback):
profile_memory=True,
with_stack=True,
)
self._profiler.__enter__()
profiler.__enter__()
self._profiler = profiler
def on_step_end(
self,

View File

@@ -19,8 +19,8 @@ import pytest
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
ops as base_ops,
lora_ops,
ops as base_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
@@ -151,8 +151,14 @@ class TestScatter2ScatterLoRAForward:
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
kernel_out = lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
)
ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E)
@@ -164,8 +170,14 @@ class TestScatter2ScatterLoRAForward:
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
out = lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
)
assert out.shape == (T * k, N)
assert out.dtype == DTYPE
@@ -188,9 +200,16 @@ class TestScatter2ScatterLoRADX:
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
kernel_dx = lora_ops.scatter2scatter_lora_dX(
DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=1, lora_A=lA, lora_B=lB, scaling=SCALING,
dy_grouped=True, dx_grouped=False,
DY=dy,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=1,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
dy_grouped=True,
dx_grouped=False,
)
ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E)
@@ -215,8 +234,13 @@ class TestGroupBwdLoRA:
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
kern_dA, kern_dB = lora_ops.group_bwd_lora(
DY=dy, X=gx, lora_A=lA, lora_B=lB,
expert_offsets=eo, E=E, scaling=SCALING,
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=E,
scaling=SCALING,
)
ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING)
@@ -225,12 +249,10 @@ class TestGroupBwdLoRA:
# fp32 loop), so max absolute error can be large on individual elements
# while the overall tensor is correct.
dA_norm_err = (
(kern_dA.float() - ref_dA.float()).norm()
/ (ref_dA.float().norm() + 1e-6)
(kern_dA.float() - ref_dA.float()).norm() / (ref_dA.float().norm() + 1e-6)
).item()
dB_norm_err = (
(kern_dB.float() - ref_dB.float()).norm()
/ (ref_dB.float().norm() + 1e-6)
(kern_dB.float() - ref_dB.float()).norm() / (ref_dB.float().norm() + 1e-6)
).item()
assert dA_norm_err < 0.01, f"[{desc}] dA norm_rel_err={dA_norm_err}"
assert dB_norm_err < 0.01, f"[{desc}] dB norm_rel_err={dB_norm_err}"
@@ -249,14 +271,21 @@ class TestGroupBwdLoRA:
lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE)
dA, dB = lora_ops.group_bwd_lora(
DY=dy, X=gx, lora_A=lA, lora_B=lB,
expert_offsets=eo, E=E, scaling=2.0,
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=E,
scaling=2.0,
)
# Experts 1..7 should have zero gradients
for e in range(1, E):
assert dA[e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dA not zero"
assert dB[:, e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dB not zero"
assert dB[:, e * R : (e + 1) * R].abs().max() == 0, (
f"Expert {e} dB not zero"
)
# ─── Full autograd tests ────────────────────────────────────────────────────
@@ -278,9 +307,21 @@ class TestScatterMoELoRAAutograd:
lB = lB.requires_grad_(True)
out = ScatterMoELoRA.apply(
x, W, k, sei, ssi, eo,
lA, lB, SCALING,
None, None, False, False, True, False,
x,
W,
k,
sei,
ssi,
eo,
lA,
lB,
SCALING,
None,
None,
False,
False,
True,
False,
)
out.sum().backward()
@@ -293,7 +334,6 @@ class TestScatterMoELoRAAutograd:
assert x.grad.abs().sum() > 0, f"[{desc}] x.grad all zero"
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
def test_split_matches_fused(self):
"""Split dispatch (for few large experts) matches fused kernel."""
# Use a shape where split would be dispatched (large K*N, few E)
@@ -304,15 +344,27 @@ class TestScatterMoELoRAAutograd:
orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18
out_fused = lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
)
# Force split path
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0
out_split = lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
)
lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig
@@ -328,12 +380,28 @@ class TestScatterMoELoRAAutograd:
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
out_lora = ScatterMoELoRA.apply(
x, W, k, sei, ssi, eo,
lA, lB, 0.0,
None, None, False, False, True, False,
x,
W,
k,
sei,
ssi,
eo,
lA,
lB,
0.0,
None,
None,
False,
False,
True,
False,
)
out_base = base_ops.scatter2scatter(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k,
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
)
err = (out_lora.float() - out_base.float()).abs().max().item()
assert err < 0.01, f"scaling=0 should match base: err={err}"