Scattermoe LoRA optimizations (#3513)

* optimize moe + lora

* more scattermoe optims

* selective dequant

* add correctness unit tests and benchmarks for scattermoe + lora

* handle base+lora split kernel for older moe models

* chore: lint

* fix casting for H200 and B200

* register pressure estimation and pruning for h200/b200

* use soft limit for pruning

* qkv patch for qwen3.5moe

* support text_model for qwen3.5 moe

* nesting of qwen3

* use udpated cce with zero3 support

* Fix decomposed backward for QKV and O projections

eliminates B @ A materialization in LoRA attention backward, replacing full [out, in] matmuls with two small [T, R] matmuls.
This commit is contained in:
Wing Lian
2026-03-19 23:07:42 -04:00
committed by GitHub
parent bb483ad4c4
commit 1fc86d5295
16 changed files with 1901 additions and 132 deletions

View File

@@ -0,0 +1,284 @@
"""Benchmark for ScatterMoE LoRA Triton kernels.
Measures forward, backward dX, and backward dA/dB kernels at common MoE
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
and full fwd+bwd autograd throughput.
Usage:
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
"""
import argparse
import gc
import time
from functools import partial
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
lora_ops,
ops as base_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
ScatterMoELoRA,
)
DEVICE = "cuda"
DTYPE = torch.bfloat16
WARMUP = 5
ITERS = 20
# ─── Model configs ──────────────────────────────────────────────────────────
BUILTIN_CONFIGS = {
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
"Qwen3-30B-A3B": (128, 2048, 768, 8),
"OLMoE-1B-7B": (64, 2048, 1024, 8),
"Mixtral-8x7B": (8, 4096, 14336, 2),
}
def _resolve_config(spec):
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
key = spec.lower().replace("/", "-")
for name, cfg in BUILTIN_CONFIGS.items():
if key in name.lower() or name.lower() in key:
return name, cfg
from transformers import AutoConfig
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
if callable(getattr(hf_cfg, "get_text_config", None)):
tc = hf_cfg.get_text_config()
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
hf_cfg = tc
hidden = hf_cfg.hidden_size
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
experts = (
getattr(hf_cfg, "num_experts", None)
or getattr(hf_cfg, "num_local_experts", None)
or getattr(hf_cfg, "n_routed_experts", None)
)
top_k = (
getattr(hf_cfg, "num_experts_per_tok", None)
or getattr(hf_cfg, "num_experts_per_token", None)
or 2
)
name = spec.split("/")[-1]
return name, (experts, hidden, inter, top_k)
# ─── Benchmark helpers ──────────────────────────────────────────────────────
def _clean():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def _bench(fn, warmup=WARMUP, iters=ITERS):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
times = []
for _ in range(iters):
torch.cuda.synchronize()
t0 = time.perf_counter()
fn()
torch.cuda.synchronize()
times.append((time.perf_counter() - t0) * 1000)
times.sort()
return times[len(times) // 2]
def _setup(num_experts, K, N, T, top_k, R):
torch.manual_seed(42)
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
logits = torch.randn(T, num_experts, device=DEVICE)
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
gx = base_ops.group(x, ssi, fan_out=top_k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
return lora_ops.scatter2scatter_lora(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
lora_A=lA,
lora_B=lB,
scaling=2.0,
)
def _call_base(x, W, sei, ssi, top_k):
return base_ops.scatter2scatter(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
)
def _call_dx(dy, W, sei, ssi, lA, lB):
return lora_ops.scatter2scatter_lora_dX(
DY=dy,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=1,
lora_A=lA,
lora_B=lB,
scaling=2.0,
dy_grouped=True,
dx_grouped=False,
)
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
return lora_ops.group_bwd_lora(
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=num_experts,
scaling=2.0,
)
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
parser.add_argument(
"--models",
"-m",
nargs="+",
help="Model names or HF IDs (default: all builtins)",
)
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
parser.add_argument("--seq-len", "-T", type=int, default=2048)
args = parser.parse_args()
T = args.seq_len
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"T={T}, ranks={args.ranks}\n")
if args.models:
configs = [_resolve_config(m) for m in args.models]
else:
configs = list(BUILTIN_CONFIGS.items())
for model_name, (num_experts, hidden, inter, top_k) in configs:
print(f"{'=' * 70}")
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
print(f"{'=' * 70}")
for R in args.ranks:
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
_clean()
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
num_experts, K, N, T, top_k, R
)
# Forward with LoRA (auto-dispatched: fused or split)
dispatch = (
"split"
if (
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
)
else "fused"
)
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
total = t_fwd + t_dx + t_bwd
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
print(
f" R={R:>2} {proj:<8} "
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
f"base={t_base:>6.2f}ms "
f"(+{overhead * 100:.0f}%) "
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
f"total={total:>6.2f}ms"
)
# Full autograd fwd+bwd with memory measurement
x_ag = x.clone().requires_grad_(True)
lA_ag = lA.clone().requires_grad_(True)
lB_ag = lB.clone().requires_grad_(True)
def _run_autograd(
_x=x_ag,
_W=W,
_k=top_k,
_sei=sei,
_ssi=ssi,
_eo=eo,
_lA=lA_ag,
_lB=lB_ag,
):
out = ScatterMoELoRA.apply(
_x,
_W,
_k,
_sei,
_ssi,
_eo,
_lA,
_lB,
2.0,
None,
None,
False,
False,
True,
False,
)
out.sum().backward()
_x.grad = None
_lA.grad = None
_lB.grad = None
t_full = _bench(_run_autograd)
_clean()
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
_run_autograd()
torch.cuda.synchronize()
mem_peak = torch.cuda.max_memory_allocated() - mem_before
print(
f" full_fwd_bwd={t_full:>6.2f}ms "
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
)
print()
if __name__ == "__main__":
main()

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
]
},
{

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
)

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
```
## Usage

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
)

View File

@@ -195,6 +195,36 @@ def _estimate_smem_usage(
_SMEM_SLACK = 10_000
def _estimate_register_pressure(
num_warps: int,
*tile_sizes: tuple[int, int],
) -> float:
"""Rough estimate of per-thread register footprint from live tile sizes.
This is a heuristic, NOT an accurate register count. Triton uses tensor
core MMA fragments that pack multiple elements per register, and can spill
to local memory when the hardware limit (255 regs/thread) is exceeded.
The estimate is used to prune only truly extreme configs that would cause
excessive spilling or compilation failures. The threshold is set high
(``_MAX_REGS_SOFT_LIMIT``) because the heuristic overestimates — it
doesn't account for MMA fragment packing. Configs like M=64,N=64,K=64
(est ~520) work fine in practice via spilling.
Returns estimated registers per thread.
"""
# Each thread in a warp holds ~1/32 of the tile elements
tile_regs = sum(r * c for r, c in tile_sizes) / 32
scalar_overhead = 40
return tile_regs + scalar_overhead
# Soft limit for register pressure pruning. Only prune configs with extreme
# tile products (e.g. M=128,K=256,N=256) that reliably crash on Blackwell.
# Moderate configs (M=64,N=64,K=64, est ~520) work via register spilling.
_MAX_REGS_SOFT_LIMIT = 1024
# =============================================================================
# Forward Kernel: scatter2scatter with fused LoRA
# =============================================================================
@@ -313,12 +343,11 @@ def _compute_expert_block_lora(
B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0
) # [BLOCK_N, BLOCK_R]
# Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16)
# Both operands must match; cast to float32 (accumulator type) for precision.
b_f32 = b.to(tl.float32)
# tl.dot requires non-float32 inputs (tensor cores); cast back to input dtype
b_inp = b.to(INPUT_DTYPE)
# (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N]
lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32)
lora_out = tl.dot(xa_acc.to(INPUT_DTYPE), tl.trans(b_inp), allow_tf32=allow_tf32)
acc += scaling * lora_out
return acc
@@ -327,20 +356,21 @@ def _compute_expert_block_lora(
def _scatter2scatter_lora_configs():
"""Generate forward kernel autotune configs.
Search space includes smaller tile sizes and fewer pipeline stages to
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
Search space includes BLOCK_M to allow trading token-tile size for
larger BLOCK_K/BLOCK_N tiles. On GPUs with ~99KB SMEM, BLOCK_M=128
forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128
(4× fewer inner-loop iterations).
Search space:
BLOCK_M: {32, 64, 128}
BLOCK_N: {32, 64, 128, 256}
BLOCK_K: {32, 64, 128}
num_warps: {4, 8}
num_stages: {3, 4, 5}
BLOCK_M is fixed at 128 (module-level constant, not autotuned in the
scatter2scatter pattern).
"""
configs = []
for block_n, block_k, warps, stages in product(
for block_m, block_n, block_k, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64, 128, 256], # BLOCK_N
[32, 64, 128], # BLOCK_K
[4, 8], # num_warps
@@ -348,7 +378,7 @@ def _scatter2scatter_lora_configs():
):
configs.append(
triton.Config(
{"BLOCK_N": block_n, "BLOCK_K": block_k},
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
num_stages=stages,
num_warps=warps,
)
@@ -357,7 +387,7 @@ def _scatter2scatter_lora_configs():
def _prune_fwd_configs(configs, named_args, **kwargs):
"""Prune forward configs based on SMEM capacity.
"""Prune forward configs based on SMEM capacity and register pressure.
The forward kernel inner loop loads three tiles per pipeline stage:
X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K].
@@ -373,23 +403,49 @@ def _prune_fwd_configs(configs, named_args, **kwargs):
scored = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_n = config.kwargs["BLOCK_N"]
block_k = config.kwargs["BLOCK_K"]
# Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k)
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k)
# A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop
smem_lora_loop = config.num_stages * block_r * block_k * 2
# B tile [BLOCK_N, BLOCK_R] loaded once in epilogue
smem_lora_epilogue = block_n * block_r * 2
smem = smem_base + smem_lora_loop + smem_lora_epilogue
# Register pressure: live tiles are acc[M,N], xa_acc[M,R],
# x[M,K], w[K,N], a[R,K], plus epilogue b[N,R]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_m, block_n), # acc
(block_m, block_r), # xa_acc
(block_m, block_k), # x tile
(block_k, block_n), # w tile
(block_r, block_k), # a tile
(block_n, block_r), # b tile (epilogue)
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_N"] * c.kwargs["BLOCK_K"]
),
)
]
@triton.autotune(
@@ -531,6 +587,89 @@ def _scatter2scatter_lora(
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
def _scatter2scatter_lora_split(
X: torch.Tensor,
W: torch.Tensor,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
k: int,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
scaling: float,
b: Optional[torch.Tensor] = None,
x_grouped: bool = False,
y_grouped: bool = False,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
because the base kernel runs at full speed without LoRA SMEM overhead,
and the LoRA matmuls (R=16) are tiny separate passes.
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
"""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
scatter2scatter,
)
E = W.size(0)
R = lora_A.size(0) // E
K = W.size(1)
N = W.size(2)
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
output = scatter2scatter(
X=X,
W=W,
b=b,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
x_grouped=x_grouped,
y_grouped=y_grouped,
out=out,
)
# 2. XA = X @ A^T (tiny: output is [M*k, R])
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
XA = scatter2scatter(
X=X,
W=W_A,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
x_grouped=x_grouped,
y_grouped=True,
)
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
# Reshape B: [N, R*E] → [E, R, N]
W_B = lora_B.T.reshape(E, R, N).contiguous()
Y_lora = scatter2scatter(
X=XA,
W=W_B,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
x_grouped=True,
y_grouped=y_grouped,
)
# 4. Y = Y_base + scaling * Y_lora
output.add_(Y_lora, alpha=scaling)
return output
# Threshold for switching from fused to split LoRA forward.
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
# loads dominate in the fused kernel's inner loop).
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
def scatter2scatter_lora(
X: torch.Tensor,
W: torch.Tensor,
@@ -546,7 +685,13 @@ def scatter2scatter_lora(
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
Automatically selects between:
- Fused kernel: single Triton kernel with LoRA in the inner loop.
Best for many small experts (E>=64, small K*N).
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
Best for few large experts (E<=32, large K*N like Mixtral).
Args:
X: Input [M, K] or [M*k, K] if x_grouped
@@ -565,12 +710,30 @@ def scatter2scatter_lora(
Returns:
Y: Output [M*k, N]
"""
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
E = W.size(0)
K = W.size(1)
N = W.size(2)
# Dispatch: split for few large experts, fused for many small experts
if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD:
return _scatter2scatter_lora_split(
X,
W,
sorted_expert_idxs,
sorted_scattered_idxs,
k,
lora_A,
lora_B,
scaling,
b,
x_grouped,
y_grouped,
out,
)
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
R = lora_A.size(0) // E
# Pad R to power of 2 for Triton tile size
@@ -610,11 +773,9 @@ def scatter2scatter_lora(
b_ptr,
stride_be,
stride_bn,
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
lora_A,
lora_A.stride(0),
lora_A.stride(1),
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
lora_B,
lora_B.stride(0),
lora_B.stride(1),
@@ -625,9 +786,8 @@ def scatter2scatter_lora(
K=K,
N=N,
E=E,
ACTUAL_R=R, # True LoRA rank for weight indexing
BLOCK_M=BLOCK_M,
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
ACC_TYPE=tl.float32,
scaling=scaling,
allow_tf32=ALLOW_TF32,
@@ -761,13 +921,13 @@ def _compute_expert_block_lora_dX(
+ (A_expert_offset + R_block)[:, None] * stride_ar
+ K_block[None, :] * stride_ak
)
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0)
# Cast to float32 for precision
a_f32 = a_e.to(tl.float32)
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to(
INPUT_DTYPE
)
# (DY @ B) @ A: [M, R] @ [R, K] -> [M, K]
lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32)
# tl.dot requires non-float32 inputs (tensor cores); cast accumulator back to input dtype
lora_dx = tl.dot(dy_b_acc.to(INPUT_DTYPE), a_e, allow_tf32=allow_tf32)
acc += scaling * lora_dx
return acc
@@ -779,17 +939,18 @@ def _scatter2scatter_lora_dX_configs():
The inner loop is over N (not K as in forward). The output dimension is K.
So BLOCK_K tiles the output and BLOCK_N tiles the reduction.
Search space includes smaller tile sizes and fewer pipeline stages to
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
BLOCK_M is now autotunable (was fixed at 128).
Search space:
BLOCK_M: {32, 64, 128} (token tile)
BLOCK_K: {32, 64, 128, 256} (output tile)
BLOCK_N: {32, 64, 128, 256} (reduction tile)
num_warps: {4, 8}
num_stages: {3, 4, 5}
"""
configs = []
for block_k, block_n, warps, stages in product(
for block_m, block_k, block_n, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64, 128, 256], # BLOCK_K (output dimension)
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
[4, 8], # num_warps
@@ -797,7 +958,7 @@ def _scatter2scatter_lora_dX_configs():
):
configs.append(
triton.Config(
{"BLOCK_K": block_k, "BLOCK_N": block_n},
{"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n},
num_stages=stages,
num_warps=warps,
)
@@ -806,7 +967,7 @@ def _scatter2scatter_lora_dX_configs():
def _prune_dX_configs(configs, named_args, **kwargs):
"""Prune backward dX configs based on SMEM capacity.
"""Prune backward dX configs based on SMEM capacity and register pressure.
The dX kernel inner loop loads three tiles per pipeline stage:
DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R].
@@ -822,23 +983,49 @@ def _prune_dX_configs(configs, named_args, **kwargs):
scored = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_k = config.kwargs["BLOCK_K"]
block_n = config.kwargs["BLOCK_N"]
# Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n)
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n)
# B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop
smem_lora_loop = config.num_stages * block_n * block_r * 2
# A tile [BLOCK_R, BLOCK_K] loaded once in epilogue
smem_lora_epilogue = block_r * block_k * 2
smem = smem_base + smem_lora_loop + smem_lora_epilogue
# Register pressure: live tiles are acc[M,K], dy_b_acc[M,R],
# dy[M,N], wt[N,K], b[N,R], plus epilogue a[R,K]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_m, block_k), # acc
(block_m, block_r), # dy_b_acc
(block_m, block_n), # dy tile
(block_n, block_k), # wt tile
(block_n, block_r), # b tile
(block_r, block_k), # a tile (epilogue)
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
),
)
]
@triton.autotune(
@@ -1067,7 +1254,7 @@ def scatter2scatter_lora_dX(
N=N,
E=E,
ACTUAL_R=R,
BLOCK_M=BLOCK_M,
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
BLOCK_R=BLOCK_R,
ACC_TYPE=tl.float32,
scaling=scaling,
@@ -1119,7 +1306,7 @@ def _group_bwd_lora_configs():
def _prune_bwd_lora_configs(configs, named_args, **kwargs):
"""Prune backward configs based on SMEM capacity.
"""Prune backward configs based on SMEM capacity and register pressure.
The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N]
in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R]
@@ -1138,14 +1325,40 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs):
# A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert
smem_lora = (block_r * block_k + block_n * block_r) * 2
smem = smem_base + smem_lora
# Register pressure: dA_acc[R,K], dB_acc[N,R], x[M,K], dy[M,N],
# a[R,K], b[N,R], xa[M,R], dy_b[M,R]
est_regs = _estimate_register_pressure(
config.num_warps,
(block_r, block_k), # dA_acc
(block_n, block_r), # dB_acc
(block_m, block_k), # x tile
(block_m, block_n), # dy tile
(block_r, block_k), # a tile
(block_n, block_r), # b tile
(block_m, block_r), # xa intermediate
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
scored.append((smem, config))
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
if pruned:
return pruned
# All configs exceed SMEM — return the one with smallest estimated usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
if scored:
# All surviving configs exceed SMEM — return the one with smallest usage
scored.sort(key=lambda x: x[0])
return [scored[0][1]]
# All configs pruned by register pressure — fall back to smallest tiles
return [
min(
configs,
key=lambda c: (
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
),
)
]
@triton.autotune(
@@ -1330,6 +1543,279 @@ def _group_bwd_lora(
)
def _group_bwd_split_configs():
"""Autotune configs for split dA/dB kernels."""
configs = []
for block_m, block_dim, warps, stages in product(
[32, 64, 128], # BLOCK_M (token tile)
[32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile)
[4, 8], # num_warps
[3, 4, 5], # num_stages
):
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_DIM": block_dim},
num_stages=stages,
num_warps=warps,
)
)
return configs
def _prune_split_configs(configs, named_args, **kwargs):
"""Prune split kernel configs based on SMEM capacity and register pressure."""
smem_cap = _get_smem_capacity()
block_r = named_args.get("BLOCK_R", 64)
# Fixed inner tile for reduction dimension
BLOCK_INNER = 64
pruned = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_dim = config.kwargs["BLOCK_DIM"]
# Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]
smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2
# LoRA weights held in registers: [INNER, R] or [R, DIM]
smem += (block_r * max(block_dim, BLOCK_INNER)) * 2
# Register pressure check
est_regs = _estimate_register_pressure(
config.num_warps,
(block_r, block_dim), # acc
(block_m, BLOCK_INNER), # input tile
(block_m, block_dim), # other tile
(block_r, BLOCK_INNER), # lora weight
)
if est_regs > _MAX_REGS_SOFT_LIMIT:
continue
if smem <= smem_cap - _SMEM_SLACK:
pruned.append(config)
if pruned:
return pruned
configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"])
return [configs[0]]
@triton.autotune(
configs=_group_bwd_split_configs(),
key=["M", "K", "N"],
prune_configs_by={"early_config_prune": _prune_split_configs},
)
@triton.heuristics(
{
"NO_DIM_MASK": lambda args: (
(args["K"] % args["BLOCK_DIM"]) == 0
if args["COMPUTE_DA"]
else (args["N"] % args["BLOCK_DIM"]) == 0
),
}
)
@triton.jit
def _group_bwd_lora_split(
# Data tensors (DY and X are always present)
DY_ptr,
stride_dym,
stride_dyn,
X_ptr,
stride_xm,
stride_xk,
# LoRA weight for the inner reduction (B for dA, A for dB)
LW_ptr,
stride_lw0,
stride_lw1,
# Output gradient tensor (dA or dB)
OUT_ptr,
stride_out0,
stride_out1,
# Expert offsets
expert_offsets_ptr,
# Dimensions
M,
K: tl.constexpr,
N: tl.constexpr,
ACTUAL_R: tl.constexpr,
BLOCK_R: tl.constexpr,
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
scaling,
# Mode flag
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
# Tile sizes
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_DIM_MASK: tl.constexpr,
):
"""
Unified split kernel for LoRA gradient computation.
When COMPUTE_DA=True:
dA[e] = scaling * (dY @ B[e])^T @ X → [R, K]
Grid: (E, cdiv(K, BLOCK_DIM))
- outer_ptr/stride = X (read [M, K_block])
- inner reduction over N using DY and B
- output shape [BLOCK_R, BLOCK_DIM]
When COMPUTE_DA=False:
dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R]
Grid: (E, cdiv(N, BLOCK_DIM))
- outer_ptr/stride = DY (read [M, N_block])
- inner reduction over K using X and A
- output shape [BLOCK_DIM, BLOCK_R]
No atomic adds — each (E, dim_block) pair is written by exactly one block.
"""
E_idx = tl.program_id(0)
dim_block_id = tl.program_id(1)
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
num_tokens = end_idx - start_idx
# Output dimension tile (K for dA, N for dB)
if COMPUTE_DA:
OUT_DIM: tl.constexpr = K # type: ignore[no-redef]
else:
OUT_DIM: tl.constexpr = N # type: ignore[no-redef]
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
dim_mask = dim_block < OUT_DIM
R_block = tl.arange(0, BLOCK_R)
R_mask = R_block < ACTUAL_R
lora_offset = E_idx * ACTUAL_R
# Output pointers — layout differs: dA is [R, K], dB is [N, R]
if COMPUTE_DA:
out_blk_ptrs = (
OUT_ptr
+ (lora_offset + R_block)[:, None] * stride_out0
+ dim_block[None, :] * stride_out1
)
out_mask = R_mask[:, None] & dim_mask[None, :]
else:
out_blk_ptrs = (
OUT_ptr
+ dim_block[:, None] * stride_out0
+ (lora_offset + R_block)[None, :] * stride_out1
)
out_mask = dim_mask[:, None] & R_mask[None, :]
if num_tokens > 0:
M_block = tl.arange(0, BLOCK_M)
INPUT_DTYPE = X_ptr.dtype.element_ty
BLOCK_INNER: tl.constexpr = 64
inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)
if COMPUTE_DA:
acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)
else:
acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)
M_iters = tl.cdiv(num_tokens, BLOCK_M)
for i in range(M_iters):
M_idx = start_idx + i * BLOCK_M + M_block
M_mask = M_idx < end_idx
if COMPUTE_DA:
# Load X[M, K_block] (the "outer" tensor for dA)
outer = tl.load(
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
mask=M_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < N
dy_tile = tl.load(
DY_ptr
+ M_idx[:, None] * stride_dym
+ inn_off[None, :] * stride_dyn,
mask=M_mask[:, None] & inn_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
lw_tile = tl.load(
LW_ptr
+ inn_off[:, None] * stride_lw0
+ (lora_offset + R_block)[None, :] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
acc += tl.dot(
tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32
)
else:
# Load DY[M, N_block] (the "outer" tensor for dB)
outer = tl.load(
DY_ptr
+ M_idx[:, None] * stride_dym
+ dim_block[None, :] * stride_dyn,
mask=M_mask[:, None] & dim_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
inner_range = tl.arange(0, BLOCK_INNER)
for j in range(inner_iters):
inn_off = j * BLOCK_INNER + inner_range
inn_mask = inn_off < K
x_tile = tl.load(
X_ptr
+ M_idx[:, None] * stride_xm
+ inn_off[None, :] * stride_xk,
mask=M_mask[:, None] & inn_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
# We want A[e]^T: [K, R], so load as [K_inner, R]
lw_tile = tl.load(
LW_ptr
+ (lora_offset + R_block)[None, :] * stride_lw0
+ inn_off[:, None] * stride_lw1,
mask=inn_mask[:, None] & R_mask[None, :],
other=0.0,
).to(INPUT_DTYPE)
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
acc += tl.dot(
tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32
)
tl.store(
out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask
)
else:
# Zero out this expert's slice — needed because output uses empty_like
if COMPUTE_DA:
tl.store(
out_blk_ptrs,
tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty),
mask=out_mask,
)
else:
tl.store(
out_blk_ptrs,
tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty),
mask=out_mask,
)
def group_bwd_lora(
DY: torch.Tensor,
X: torch.Tensor,
@@ -1344,6 +1830,9 @@ def group_bwd_lora(
"""
Compute LoRA gradients for A and B on expert-grouped data.
Uses split dA/dB kernels that eliminate atomic adds by giving each
(expert, output_block) pair its own thread block.
Args:
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
X: Input [M_total, K] (grouped by expert)
@@ -1361,19 +1850,46 @@ def group_bwd_lora(
K = X.size(1)
N = DY.size(1)
# Zero-init for atomic accumulation
dA = torch.zeros_like(lora_A)
dB = torch.zeros_like(lora_B)
# No zero-init needed: the split kernels write zeros for experts with
# zero routed tokens directly in the kernel (else branch).
dA = torch.empty_like(lora_A)
dB = torch.empty_like(lora_B)
BLOCK_R = _block_r_for_rank(R)
def grid(META):
return (
E * triton.cdiv(K, META["BLOCK_K"]),
triton.cdiv(N, META["BLOCK_N"]),
)
def grid_dA(META):
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
_group_bwd_lora[grid](
_group_bwd_lora_split[grid_dA](
DY,
DY.stride(0),
DY.stride(1),
X,
X.stride(0),
X.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
expert_offsets,
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
INNER_DIM=N,
scaling=scaling,
COMPUTE_DA=True,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)
def grid_dB(META):
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
_group_bwd_lora_split[grid_dB](
DY,
DY.stride(0),
DY.stride(1),
@@ -1383,12 +1899,6 @@ def group_bwd_lora(
lora_A,
lora_A.stride(0),
lora_A.stride(1),
lora_B,
lora_B.stride(0),
lora_B.stride(1),
dA,
dA.stride(0),
dA.stride(1),
dB,
dB.stride(0),
dB.stride(1),
@@ -1396,9 +1906,11 @@ def group_bwd_lora(
M=DY.size(0),
K=K,
N=N,
ACTUAL_R=R, # True LoRA rank
BLOCK_R=BLOCK_R, # Padded tile size
ACTUAL_R=R,
BLOCK_R=BLOCK_R,
INNER_DIM=K,
scaling=scaling,
COMPUTE_DA=False,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)

View File

@@ -489,20 +489,71 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Selective expert weight dequantization
# ====================================================================
# When experts are BnB-quantized (quantize_moe_experts), dequantize
# only the active experts instead of all E. This saves ~97% memory
# for the transient dequant buffer when few experts are active.
use_selective = (
getattr(self, "_use_selective_dequant", False)
and hasattr(experts, "parametrizations")
and "gate_up_proj" in experts.parametrizations
)
if use_selective:
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import (
get_active_experts,
remap_expert_indices,
selective_expert_weights,
selective_lora_weights,
)
active_experts = get_active_experts(sorted_expert_idxs, num_experts)
remapped_expert_idxs, compact_offsets = remap_expert_indices(
sorted_expert_idxs,
expert_offsets,
active_experts,
num_experts,
)
# Dequantize only active experts' weights
gate_up_W = selective_expert_weights(
experts,
"gate_up_proj",
active_experts,
).transpose(2, 1) # [num_active, hidden, 2*inter]
# Remap LoRA weights to match compact expert indices
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup_A, gup_B = selective_lora_weights(
gup_A,
gup_B,
active_experts,
num_experts,
)
gup_lora = (gup_A, gup_B, gup_scaling)
# Use remapped indices for ScatterMoE kernels
sei_gup = remapped_expert_idxs
eo_gup = compact_offsets
else:
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
sei_gup = sorted_expert_idxs
eo_gup = expert_offsets
# ====================================================================
# Gate + Up projection
# ====================================================================
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sei_gup,
sorted_scattered_idxs,
expert_offsets,
eo_gup,
lora_A=gup_A,
lora_B=gup_B,
scaling=gup_scaling,
@@ -516,9 +567,9 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sei_gup,
sorted_scattered_idxs,
expert_offsets,
eo_gup,
grouped_in=False,
grouped_out=True,
)
@@ -529,7 +580,29 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
# Down projection
# ====================================================================
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
if use_selective:
down_W = selective_expert_weights(
experts,
"down_proj",
active_experts,
).transpose(2, 1) # [num_active, inter, hidden]
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
down_A, down_B = selective_lora_weights(
down_A,
down_B,
active_experts,
num_experts,
)
down_lora = (down_A, down_B, down_scaling)
sei_down = remapped_expert_idxs
eo_down = compact_offsets
else:
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
sei_down = sorted_expert_idxs
eo_down = expert_offsets
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
@@ -537,9 +610,9 @@ class HFScatterMoEGatedMLP(nn.Module):
h,
down_W,
1,
sorted_expert_idxs,
sei_down,
sorted_scattered_idxs,
expert_offsets,
eo_down,
lora_A=down_A,
lora_B=down_B,
scaling=down_scaling,
@@ -554,9 +627,9 @@ class HFScatterMoEGatedMLP(nn.Module):
h,
down_W,
1,
sorted_expert_idxs,
sei_down,
sorted_scattered_idxs,
expert_offsets,
eo_down,
grouped_in=True,
grouped_out=False,
gates=routing_weights,

View File

@@ -0,0 +1,282 @@
"""
Selective Expert Dequantization
===============================
Instead of dequantizing all E expert weight matrices at once (which creates
a ~1 GB transient buffer for 256 experts), only dequantize the experts that
are actually routed to by the current batch's top-k selection.
For Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512):
- Full dequant: [256, 2048, 1024] = 1,074 MB per projection
- Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection
- Savings: ~97% memory reduction per layer
This module provides format-agnostic selective weight extraction:
- BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert
- bf16/fp32: direct indexing (no dequant needed)
- FP8: slice + cast
The ScatterMoE kernel itself doesn't change — we remap expert indices
from global (0..E-1) to compact (0..num_active-1) and pass the smaller
weight tensor.
"""
import torch
import torch.nn as nn
def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor:
"""Get sorted unique expert indices from the routing output.
Args:
sorted_expert_idxs: Expert assignments sorted by expert id [T*k]
E: Total number of experts
Returns:
active: Sorted unique expert indices [num_active]
"""
return torch.unique(sorted_expert_idxs)
def remap_expert_indices(
sorted_expert_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
active_experts: torch.Tensor,
E: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Remap global expert indices to compact indices.
Maps expert ids from [0..E-1] to [0..num_active-1], preserving the
sort order. Also compacts expert_offsets to only active experts.
Args:
sorted_expert_idxs: [T*k] expert ids in sorted order
expert_offsets: [E] cumulative token counts (original)
active_experts: [num_active] sorted unique expert ids
E: Total number of experts
Returns:
remapped_idxs: [T*k] expert ids in [0..num_active-1]
compact_offsets: [num_active] cumulative token counts
"""
# Build remap table: global_id -> compact_id
remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device)
remap[active_experts] = torch.arange(
len(active_experts), device=sorted_expert_idxs.device
)
remapped_idxs = remap[sorted_expert_idxs]
# Compact the expert_offsets: only keep active experts' cumulative counts
compact_offsets = expert_offsets[active_experts]
return remapped_idxs, compact_offsets
def _selective_dequant_bnb4(
raw_param: torch.Tensor,
quant_state,
active_experts: torch.Tensor,
expert_shape: tuple[int, int],
) -> torch.Tensor:
"""Dequantize only selected experts from BnB 4-bit packed data.
The raw parameter is a flattened 4-bit packed tensor. Each expert's
data is contiguous (stored in expert-major order), so we can gather
the packed data and absmax blocks for active experts, then dequantize
as one contiguous block.
Args:
raw_param: Flattened uint8 tensor of packed 4-bit weights
quant_state: BnB QuantState with absmax, blocksize, code, etc.
active_experts: [num_active] expert indices to dequantize
expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048))
Returns:
Dequantized weights [num_active, dim1, dim2] in original dtype
"""
import bitsandbytes.functional as F # noqa: N812
from bitsandbytes.functional import QuantState
expert_numel = expert_shape[0] * expert_shape[1]
packed_per_expert = expert_numel // 2 # 4-bit = 2 values per byte
blocks_per_expert = expert_numel // quant_state.blocksize
num_active = len(active_experts)
if blocks_per_expert == 0:
# Expert is smaller than one quantization block — blocks span across
# expert boundaries, so per-expert slicing isn't possible.
# Fallback: full dequantize + index.
full = F.dequantize_4bit(raw_param, quant_state)
E_total = full.numel() // expert_numel
return full.reshape(E_total, *expert_shape)[active_experts]
# Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass)
if quant_state.quant_type == "nf4" and raw_param.dtype == torch.uint8:
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import (
selective_dequant_nf4_triton,
)
# Handle nested (double) quantization: dequantize absmax first
# BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset
if quant_state.nested:
absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
else:
absmax = quant_state.absmax
return selective_dequant_nf4_triton(
packed_data=raw_param,
absmax=absmax,
active_experts=active_experts,
expert_shape=expert_shape,
blocksize=quant_state.blocksize,
dtype=quant_state.dtype,
codebook=quant_state.code,
)
# Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats)
raw_flat = raw_param.reshape(-1)
offsets_qt = (
active_experts.long()[:, None] * packed_per_expert
+ torch.arange(packed_per_expert, device=raw_param.device)[None, :]
).reshape(-1)
qt_gathered = raw_flat[offsets_qt]
offsets_abs = (
active_experts.long()[:, None] * blocks_per_expert
+ torch.arange(blocks_per_expert, device=raw_param.device)[None, :]
).reshape(-1)
if quant_state.nested:
full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
full_absmax += quant_state.offset
if full_absmax.dtype != torch.float32:
full_absmax = full_absmax.float()
absmax_gathered = full_absmax[offsets_abs]
else:
absmax_gathered = quant_state.absmax[offsets_abs]
qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered
gathered_qs = QuantState(
absmax=absmax_gathered,
shape=torch.Size([num_active * expert_numel]),
blocksize=quant_state.blocksize,
quant_type=quant_state.quant_type,
code=quant_state.code,
dtype=quant_state.dtype,
)
deq = F.dequantize_4bit(qt_gathered, gathered_qs)
return deq.reshape(num_active, *expert_shape)
def _selective_index_dense(
param: torch.Tensor,
active_experts: torch.Tensor,
) -> torch.Tensor:
"""Select experts from a dense (bf16/fp32) weight tensor.
Simple indexing — no dequantization needed.
"""
return param[active_experts]
def selective_expert_weights(
experts_module: nn.Module,
param_name: str,
active_experts: torch.Tensor,
) -> torch.Tensor:
"""Extract and dequantize only the active experts' weights.
Format-agnostic: dispatches based on whether the parameter is
BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32.
Args:
experts_module: The base experts module (e.g. Qwen3_5MoeExperts)
param_name: "gate_up_proj" or "down_proj"
active_experts: [num_active] sorted unique expert indices
Returns:
Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE
"""
# Check if the parameter is BnB-quantized via parametrize
if (
hasattr(experts_module, "parametrizations")
and param_name in experts_module.parametrizations
):
param_list = experts_module.parametrizations[param_name]
parametrization = param_list[0]
# BnB 4-bit parametrization
if hasattr(parametrization, "quant_state"):
# The raw quantized data is on the ParametrizationList, not the
# individual Bnb4bitParametrization module
raw_param = param_list.original
qs = parametrization.quant_state
# qs.shape is the original tensor shape before flattening.
# For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D).
orig_shape = qs.shape
if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3:
expert_shape = (orig_shape[1], orig_shape[2])
elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1:
# Flattened — need to infer from module attributes
E_total = getattr(experts_module, "num_experts", None)
if E_total is None:
E_total = int(active_experts.max().item()) + 1
expert_numel = orig_shape[0] // E_total
d2 = getattr(experts_module, "hidden_dim", None) or getattr(
experts_module, "intermediate_dim", None
)
if d2 and expert_numel % d2 == 0:
expert_shape = (expert_numel // d2, d2)
else:
full = getattr(experts_module, param_name)
return full[active_experts]
else:
full = getattr(experts_module, param_name)
return full[active_experts]
return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape)
# Dense parameter (bf16/fp32) — direct indexing
param = getattr(experts_module, param_name)
if param.dim() == 3:
return param[active_experts]
# Fallback: full access
return param
def selective_lora_weights(
lora_A: torch.Tensor,
lora_B: torch.Tensor,
active_experts: torch.Tensor,
E: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select LoRA A and B weights for only the active experts.
LoRA layout (scattermoe format):
A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r]
B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r]
Returns compact:
A: [r*num_active, K]
B: [N, r*num_active]
"""
R = lora_A.size(0) // E
# Vectorized gather: active_experts[:, None] * R + arange(R)[None, :]
row_idx = (
active_experts.long()[:, None] * R
+ torch.arange(R, device=lora_A.device)[None, :]
).reshape(-1)
compact_A = lora_A[row_idx] # [r*num_active, K]
compact_B = lora_B[:, row_idx] # [N, r*num_active]
return compact_A, compact_B

View File

@@ -0,0 +1,179 @@
"""
Triton kernel for fused selective expert gather + NF4 dequantization.
Instead of:
1. Gather packed uint8 data for active experts (memory copy)
2. Gather absmax for active experts (memory copy)
3. Call BnB dequantize_4bit CUDA kernel
This kernel does all three in one pass:
- Reads packed NF4 bytes from expert-strided positions
- Looks up the NF4 codebook
- Multiplies by the per-block absmax
- Writes bf16 output directly
This eliminates the intermediate gather buffer entirely.
"""
import torch
import triton
import triton.language as tl
# NF4 codebook (16 values, precomputed by BnB)
# These are the normalized float4 reconstruction values
NF4_CODEBOOK = [
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
]
@triton.jit
def _selective_dequant_nf4_kernel(
# Input: packed NF4 data (flattened, expert-major order)
packed_ptr,
# Input: absmax values (flattened, expert-major order)
absmax_ptr,
# Input: active expert indices
active_experts_ptr,
# Input: NF4 codebook (16 float values)
codebook_ptr,
# Output: dequantized bf16 weights [num_active, expert_numel]
out_ptr,
stride_out_e, # stride for expert dim in output
# Dimensions
num_active,
packed_per_expert, # expert_numel // 2
blocks_per_expert, # expert_numel // blocksize
blocksize: tl.constexpr,
# Tile size
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2)
):
"""
Each program processes BLOCK_SIZE elements from one expert.
Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE))
For each output element:
1. Compute which byte in packed data contains this element
2. Extract the 4-bit nibble (high or low)
3. Look up in NF4 codebook
4. Scale by absmax for this block
"""
expert_local_idx = tl.program_id(0) # which active expert (0..num_active-1)
block_id = tl.program_id(1) # which element block
# Load the global expert index
expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64)
expert_numel = packed_per_expert * 2 # 2 elements per packed byte
elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = elem_offset < expert_numel
# Each element is packed as: byte[i//2], low nibble for even i, high for odd i
byte_idx = elem_offset // 2
is_high = (elem_offset % 2) == 1
# Read packed bytes from the global expert's region
packed_global_offset = expert_global * packed_per_expert + byte_idx
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(
tl.int32
)
# Extract 4-bit nibble
# BnB packing: high nibble = even element, low nibble = odd element
nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF)
# NF4 codebook lookup
# Load all 16 codebook values (small, fits in registers)
# Use gather from codebook pointer
code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0)
# Load absmax for this element's quantization block
block_idx = elem_offset // blocksize
absmax_global_offset = expert_global * blocks_per_expert + block_idx
absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0)
# Dequantize: value = codebook[nibble] * absmax
result = code_val * absmax_val
# Store to output
out_offset = expert_local_idx * stride_out_e + elem_offset
tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask)
def selective_dequant_nf4_triton(
packed_data: torch.Tensor,
absmax: torch.Tensor,
active_experts: torch.Tensor,
expert_shape: tuple[int, int],
blocksize: int,
dtype: torch.dtype = torch.bfloat16,
codebook: torch.Tensor | None = None,
) -> torch.Tensor:
"""Fused selective gather + NF4 dequantization via Triton kernel.
Args:
packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1]
absmax: Per-block scaling factors [total_blocks]
active_experts: Sorted indices of experts to dequantize [num_active]
expert_shape: (dim1, dim2) per expert
blocksize: Quantization block size
dtype: Output dtype (default bf16)
codebook: NF4 lookup table [16] (uses default NF4 codebook if None)
Returns:
Dequantized weights [num_active, dim1, dim2]
"""
num_active = active_experts.shape[0]
expert_numel = expert_shape[0] * expert_shape[1]
packed_per_expert = expert_numel // 2
blocks_per_expert = expert_numel // blocksize
# Prepare codebook on device
if codebook is None:
codebook = torch.tensor(
NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device
)
else:
codebook = codebook.to(device=packed_data.device, dtype=torch.float32)
# Flatten inputs
packed_flat = packed_data.reshape(-1)
absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32
# Output buffer
out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device)
BLOCK_SIZE = 1024 # Process 1024 elements per thread block
grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE))
_selective_dequant_nf4_kernel[grid](
packed_flat,
absmax_flat,
active_experts,
codebook,
out,
out.stride(0),
num_active=num_active,
packed_per_expert=packed_per_expert,
blocks_per_expert=blocks_per_expert,
blocksize=blocksize,
BLOCK_SIZE=BLOCK_SIZE,
)
return out.reshape(num_active, *expert_shape)

View File

@@ -61,7 +61,16 @@ class KernelsPlugin(BasePlugin):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
# Prefer text backbone type for VLMs, but fall back to base type
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
if (
moe_model_type not in SPARSE_MOE_BLOCK
and cfg.model_config_type in SPARSE_MOE_BLOCK
):
moe_model_type = cfg.model_config_type
if cfg.use_scattermoe:
self._register_kernels()

View File

@@ -640,7 +640,9 @@ class LoRA_QKV(torch.autograd.Function):
del q_weight
del q_weight_t
if A_q is not None and B_q is not None:
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# Stay decomposed: dQ @ B^T gives [T, R], then [T, R] @ (s*A) gives [T, in]
# This is 65x fewer FLOPs than materializing B@A into [out, in]
grad_X.addmm_(torch.mm(q_grad, B_q_scaled), A_q_scaled)
# K path
k_weight_t = dequantize(k_weight, k_quant)
@@ -648,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function):
del k_weight
del k_weight_t
if A_k is not None and B_k is not None:
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
grad_X.addmm_(torch.mm(k_grad, B_k_scaled), A_k_scaled)
# V path
v_weight_t = dequantize(v_weight, v_quant)
@@ -656,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function):
del v_weight
del v_weight_t
if A_v is not None and B_v is not None:
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
grad_X.addmm_(torch.mm(v_grad, B_v_scaled), A_v_scaled)
# Transpose gradients if needed
if d_A_q is not None:
@@ -819,7 +821,8 @@ class LoRA_O(torch.autograd.Function):
del W
A, B = A.to(dtype), B.to(dtype)
dX += s * dY @ B @ A
# Stay decomposed: dY @ B gives [T, R], then [T, R] @ A gives [T, in]
dX.addmm_(torch.mm(dY, B), A, alpha=s)
# W, b, W_quant, A, B, s
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None

View File

@@ -505,6 +505,20 @@ class ModelLoader:
elif not is_ds_zero3:
self.model_kwargs["device_map"] = device_map
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
# so the actual VRAM usage is much less than bf16 estimates.
# When device_map is "auto", accelerate's infer_auto_device_map computes
# the device map at bf16 size (before quantization), causing it to offload
# layers to CPU, which BnB then rejects. Force single-GPU placement to
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
"auto",
None,
):
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
cur_device = get_device_type()
if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0"

View File

@@ -51,6 +51,29 @@ QKV_PATCHES = [
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
(
"""
query_states, gate = torch.chunk(
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states, gate = torch.chunk(
query_states.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip("\n"),
),
]
ORIGINAL_O_CODE = """
@@ -299,6 +322,8 @@ def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
if hasattr(pretrained_model, "language_model"):
return pretrained_model.language_model.layers
if hasattr(pretrained_model, "model"):
if hasattr(pretrained_model.model, "language_model"):
return pretrained_model.model.language_model.layers
return pretrained_model.model.layers
raise NotImplementedError(

View File

@@ -17,6 +17,8 @@ from transformers import (
class PytorchProfilerCallback(TrainerCallback):
"""
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
Also runs torch.profiler to produce a Chrome trace for timing analysis.
"""
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
@@ -26,9 +28,10 @@ class PytorchProfilerCallback(TrainerCallback):
if profiler_steps_start == 0:
# start recording memory allocations before everything is allocated, because if we start
# at the beginning of step 0, we won't have any memory allocations in the traces
torch.cuda.memory._record_memory_history(enabled="all")
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
profiler_steps_start = -1
self.profiler_steps_start = profiler_steps_start
self._profiler = None
def on_step_begin(
self,
@@ -38,7 +41,21 @@ class PytorchProfilerCallback(TrainerCallback):
**kwargs,
):
if state.global_step == self.profiler_steps_start:
torch.cuda.memory._record_memory_history(enabled="all")
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
# Start torch.profiler on the first profiled step
if state.global_step == max(self.profiler_steps_start, 0):
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
profile_memory=True,
with_stack=True,
)
profiler.__enter__()
self._profiler = profiler
def on_step_end(
self,
@@ -55,6 +72,13 @@ class PytorchProfilerCallback(TrainerCallback):
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
# Stop and export torch.profiler trace
if self._profiler is not None:
self._profiler.__exit__(None, None, None)
trace_path = Path(args.output_dir) / "profiler_trace.json"
self._profiler.export_chrome_trace(str(trace_path))
self._profiler = None
def on_train_end(
self,
args: TrainingArguments,
@@ -73,3 +97,9 @@ class PytorchProfilerCallback(TrainerCallback):
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
if self._profiler is not None:
self._profiler.__exit__(None, None, None)
trace_path = Path(args.output_dir) / "profiler_trace.json"
self._profiler.export_chrome_trace(str(trace_path))
self._profiler = None

View File

@@ -4,8 +4,7 @@ E2E tests for lora llama
import unittest
import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.datasets import load_datasets
from axolotl.train import train
@@ -68,51 +67,3 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@with_temp_dir
def test_lora_gptq_packed(self, temp_dir):
cfg = DictDefault(
{
"base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,
"gptq_disable_exllama": True,
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"max_steps": 20,
"save_steps": 0.5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -0,0 +1,407 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
Unit tests for ScatterMoE LoRA Triton kernels.
Tests correctness of:
- scatter2scatter_lora (forward)
- scatter2scatter_lora_dX (backward input gradient)
- group_bwd_lora (backward LoRA weight gradients via split dA/dB)
- ScatterMoELoRA autograd function (full forward + backward)
Each kernel is tested against a pure PyTorch per-expert-loop reference
implementation at multiple model shapes and LoRA ranks.
"""
import pytest
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
lora_ops,
ops as base_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
ScatterMoELoRA,
)
DEVICE = "cuda"
DTYPE = torch.bfloat16
def _requires_cuda():
return pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
)
pytestmark = _requires_cuda()
# ─── Helpers ─────────────────────────────────────────────────────────────────
def _setup(E, K, N, T, top_k, R, seed=42):
"""Create synthetic expert weights, LoRA, routing, and grouped inputs."""
torch.manual_seed(seed)
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02
lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01
lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01
logits = torch.randn(T, E, device=DEVICE)
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
sei, ssi, eo = flatten_sort_count(top_idx, E)
return x, W, lora_A, lora_B, sei, ssi, eo
def _reference_fwd(x, W, sei, ssi, eo, k, lora_A, lora_B, scaling, E):
"""Per-expert loop reference: Y = X@W + scaling*(X@A^T)@B^T."""
grouped_x = base_ops.group(x, ssi, fan_out=k)
M, N = grouped_x.size(0), W.size(2)
R = lora_A.size(0) // E
out = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)
for e in range(E):
s = eo[e - 1].item() if e > 0 else 0
end = eo[e].item()
if s == end:
continue
xe = grouped_x[s:end].float()
we = W[e].float()
ae = lora_A[e * R : (e + 1) * R].float()
be = lora_B[:, e * R : (e + 1) * R].float()
out[s:end] = (xe @ we + scaling * (xe @ ae.T) @ be.T).to(DTYPE)
result = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)
result[ssi] = out
return result
def _reference_dX(dy_grouped, W, sei, ssi, eo, lora_A, lora_B, scaling, E):
"""Per-expert loop reference: dX = dY@W^T + scaling*(dY@B)@A."""
M, K = dy_grouped.size(0), W.size(1)
R = lora_A.size(0) // E
out = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)
for e in range(E):
s = eo[e - 1].item() if e > 0 else 0
end = eo[e].item()
if s == end:
continue
dye = dy_grouped[s:end].float()
we = W[e].float()
ae = lora_A[e * R : (e + 1) * R].float()
be = lora_B[:, e * R : (e + 1) * R].float()
out[s:end] = (dye @ we.T + scaling * (dye @ be) @ ae).to(DTYPE)
result = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)
result[ssi] = out
return result
def _reference_bwd_lora(dy, grouped_x, lora_A, lora_B, eo, E, scaling):
"""Per-expert loop reference: dA, dB for LoRA weight gradients."""
R = lora_A.size(0) // E
dA = torch.zeros_like(lora_A)
dB = torch.zeros_like(lora_B)
for e in range(E):
s = eo[e - 1].item() if e > 0 else 0
end = eo[e].item()
if s == end:
continue
xe = grouped_x[s:end].float()
dye = dy[s:end].float()
ae = lora_A[e * R : (e + 1) * R].float()
be = lora_B[:, e * R : (e + 1) * R].float()
dA[e * R : (e + 1) * R] = (scaling * (dye @ be).T @ xe).to(DTYPE)
dB[:, e * R : (e + 1) * R] = (scaling * dye.T @ (xe @ ae.T)).to(DTYPE)
return dA, dB
# ─── Model shape configs ────────────────────────────────────────────────────
# (E, K, N, T, top_k, R, description)
CONFIGS_SMALL = [
(32, 128, 64, 64, 2, 4, "tiny"),
(64, 256, 128, 128, 4, 8, "small"),
]
CONFIGS_REAL = [
(256, 2048, 1024, 2048, 8, 16, "qwen35_gate_up"),
(256, 512, 2048, 2048, 8, 16, "qwen35_down"),
(64, 2048, 2048, 2048, 8, 16, "olmoe_gate_up"),
(128, 2048, 1536, 2048, 8, 16, "qwen3_gate_up"),
]
SCALING = 2.0
# ─── Forward tests ──────────────────────────────────────────────────────────
class TestScatter2ScatterLoRAForward:
"""Test scatter2scatter_lora forward kernel vs reference."""
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
def config(self, request):
return request.param
def test_matches_reference(self, config):
E, K, N, T, k, R, desc = config
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
kernel_out = lora_ops.scatter2scatter_lora(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
)
ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E)
err = (kernel_out.float() - ref_out.float()).abs().max().item()
assert err < 1.0, f"[{desc}] fwd max_err={err}"
def test_output_shape(self, config):
E, K, N, T, k, R, desc = config
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
out = lora_ops.scatter2scatter_lora(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
)
assert out.shape == (T * k, N)
assert out.dtype == DTYPE
# ─── Backward dX tests ──────────────────────────────────────────────────────
class TestScatter2ScatterLoRADX:
"""Test scatter2scatter_lora_dX backward kernel vs reference."""
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
def config(self, request):
return request.param
def test_matches_reference(self, config):
E, K, N, T, k, R, desc = config
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
gx = base_ops.group(x, ssi, fan_out=k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
kernel_dx = lora_ops.scatter2scatter_lora_dX(
DY=dy,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=1,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
dy_grouped=True,
dx_grouped=False,
)
ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E)
err = (kernel_dx.float() - ref_dx.float()).abs().max().item()
assert err < 1.0, f"[{desc}] dX max_err={err}"
# ─── Backward LoRA gradient tests ───────────────────────────────────────────
class TestGroupBwdLoRA:
"""Test group_bwd_lora (split dA/dB kernel) vs reference."""
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
def config(self, request):
return request.param
def test_matches_reference(self, config):
E, K, N, T, k, R, desc = config
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
gx = base_ops.group(x, ssi, fan_out=k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
kern_dA, kern_dB = lora_ops.group_bwd_lora(
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=E,
scaling=SCALING,
)
ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING)
# Use norm-relative error: bf16 accumulation order differs between
# kernel (tiled + different reduction order) and reference (per-expert
# fp32 loop), so max absolute error can be large on individual elements
# while the overall tensor is correct.
dA_norm_err = (
(kern_dA.float() - ref_dA.float()).norm() / (ref_dA.float().norm() + 1e-6)
).item()
dB_norm_err = (
(kern_dB.float() - ref_dB.float()).norm() / (ref_dB.float().norm() + 1e-6)
).item()
assert dA_norm_err < 0.01, f"[{desc}] dA norm_rel_err={dA_norm_err}"
assert dB_norm_err < 0.01, f"[{desc}] dB norm_rel_err={dB_norm_err}"
def test_zero_expert_tokens(self):
"""Experts with zero routed tokens produce zero gradients."""
E, K, N, R = 8, 64, 32, 4
torch.manual_seed(42)
# Route all tokens to expert 0 only
T, k = 16, 1
top_idx = torch.zeros(T, k, dtype=torch.long, device=DEVICE)
sei, ssi, eo = flatten_sort_count(top_idx, E)
gx = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
dy = torch.randn(T, N, device=DEVICE, dtype=DTYPE)
lA = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE)
lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE)
dA, dB = lora_ops.group_bwd_lora(
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=E,
scaling=2.0,
)
# Experts 1..7 should have zero gradients
for e in range(1, E):
assert dA[e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dA not zero"
assert dB[:, e * R : (e + 1) * R].abs().max() == 0, (
f"Expert {e} dB not zero"
)
# ─── Full autograd tests ────────────────────────────────────────────────────
class TestScatterMoELoRAAutograd:
"""Test full forward + backward through ScatterMoELoRA autograd function."""
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL[:2])
def config(self, request):
return request.param
def test_gradients_exist_and_finite(self, config):
E, K, N, T, k, R, desc = config
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
x = x.requires_grad_(True)
lA = lA.requires_grad_(True)
lB = lB.requires_grad_(True)
out = ScatterMoELoRA.apply(
x,
W,
k,
sei,
ssi,
eo,
lA,
lB,
SCALING,
None,
None,
False,
False,
True,
False,
)
out.sum().backward()
assert x.grad is not None, f"[{desc}] x.grad is None"
assert lA.grad is not None, f"[{desc}] lA.grad is None"
assert lB.grad is not None, f"[{desc}] lB.grad is None"
assert torch.isfinite(x.grad).all(), f"[{desc}] x.grad has non-finite"
assert torch.isfinite(lA.grad).all(), f"[{desc}] lA.grad has non-finite"
assert torch.isfinite(lB.grad).all(), f"[{desc}] lB.grad has non-finite"
assert x.grad.abs().sum() > 0, f"[{desc}] x.grad all zero"
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
def test_split_matches_fused(self):
"""Split dispatch (for few large experts) matches fused kernel."""
# Use a shape where split would be dispatched (large K*N, few E)
E, K, N, T, k, R = 8, 512, 1024, 128, 2, 16
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
# Force fused path
orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18
out_fused = lora_ops.scatter2scatter_lora(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
)
# Force split path
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0
out_split = lora_ops.scatter2scatter_lora(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
)
lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig
norm_err = (
(out_fused.float() - out_split.float()).norm()
/ (out_fused.float().norm() + 1e-6)
).item()
assert norm_err < 0.01, f"split vs fused norm_err={norm_err}"
def test_scaling_zero_gives_base_only(self):
"""With scaling=0.0, LoRA contribution vanishes. Output = X@W."""
E, K, N, T, k, R = 16, 64, 32, 32, 2, 4
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
out_lora = ScatterMoELoRA.apply(
x,
W,
k,
sei,
ssi,
eo,
lA,
lB,
0.0,
None,
None,
False,
False,
True,
False,
)
out_base = base_ops.scatter2scatter(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
)
err = (out_lora.float() - out_base.float()).abs().max().item()
assert err < 0.01, f"scaling=0 should match base: err={err}"