Compare commits
8 Commits
torch-211-
...
scattermoe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42922f8f8b | ||
|
|
7041592ca7 | ||
|
|
fec0c3a99e | ||
|
|
31d8d068bb | ||
|
|
66fea258c7 | ||
|
|
07ff389be8 | ||
|
|
2dcca15f65 | ||
|
|
c5db90aa3f |
284
benchmarks/bench_scattermoe_lora.py
Normal file
284
benchmarks/bench_scattermoe_lora.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""Benchmark for ScatterMoE LoRA Triton kernels.
|
||||
|
||||
Measures forward, backward dX, and backward dA/dB kernels at common MoE
|
||||
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
|
||||
and full fwd+bwd autograd throughput.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||
lora_ops,
|
||||
ops as base_ops,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||
flatten_sort_count,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
||||
ScatterMoELoRA,
|
||||
)
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
WARMUP = 5
|
||||
ITERS = 20
|
||||
|
||||
# ─── Model configs ──────────────────────────────────────────────────────────
|
||||
|
||||
BUILTIN_CONFIGS = {
|
||||
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
|
||||
"Qwen3-30B-A3B": (128, 2048, 768, 8),
|
||||
"OLMoE-1B-7B": (64, 2048, 1024, 8),
|
||||
"Mixtral-8x7B": (8, 4096, 14336, 2),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_config(spec):
|
||||
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
|
||||
key = spec.lower().replace("/", "-")
|
||||
for name, cfg in BUILTIN_CONFIGS.items():
|
||||
if key in name.lower() or name.lower() in key:
|
||||
return name, cfg
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
|
||||
if callable(getattr(hf_cfg, "get_text_config", None)):
|
||||
tc = hf_cfg.get_text_config()
|
||||
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
|
||||
hf_cfg = tc
|
||||
hidden = hf_cfg.hidden_size
|
||||
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
|
||||
experts = (
|
||||
getattr(hf_cfg, "num_experts", None)
|
||||
or getattr(hf_cfg, "num_local_experts", None)
|
||||
or getattr(hf_cfg, "n_routed_experts", None)
|
||||
)
|
||||
top_k = (
|
||||
getattr(hf_cfg, "num_experts_per_tok", None)
|
||||
or getattr(hf_cfg, "num_experts_per_token", None)
|
||||
or 2
|
||||
)
|
||||
name = spec.split("/")[-1]
|
||||
return name, (experts, hidden, inter, top_k)
|
||||
|
||||
|
||||
# ─── Benchmark helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _clean():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def _bench(fn, warmup=WARMUP, iters=ITERS):
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times = []
|
||||
for _ in range(iters):
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - t0) * 1000)
|
||||
times.sort()
|
||||
return times[len(times) // 2]
|
||||
|
||||
|
||||
def _setup(num_experts, K, N, T, top_k, R):
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
logits = torch.randn(T, num_experts, device=DEVICE)
|
||||
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
|
||||
gx = base_ops.group(x, ssi, fan_out=top_k)
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
|
||||
|
||||
|
||||
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
|
||||
|
||||
|
||||
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
def _call_base(x, W, sei, ssi, top_k):
|
||||
return base_ops.scatter2scatter(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def _call_dx(dy, W, sei, ssi, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora_dX(
|
||||
DY=dy,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=1,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
dy_grouped=True,
|
||||
dx_grouped=False,
|
||||
)
|
||||
|
||||
|
||||
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
|
||||
return lora_ops.group_bwd_lora(
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=num_experts,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
"-m",
|
||||
nargs="+",
|
||||
help="Model names or HF IDs (default: all builtins)",
|
||||
)
|
||||
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
|
||||
parser.add_argument("--seq-len", "-T", type=int, default=2048)
|
||||
args = parser.parse_args()
|
||||
|
||||
T = args.seq_len
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
print(f"T={T}, ranks={args.ranks}\n")
|
||||
|
||||
if args.models:
|
||||
configs = [_resolve_config(m) for m in args.models]
|
||||
else:
|
||||
configs = list(BUILTIN_CONFIGS.items())
|
||||
|
||||
for model_name, (num_experts, hidden, inter, top_k) in configs:
|
||||
print(f"{'=' * 70}")
|
||||
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
for R in args.ranks:
|
||||
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
|
||||
_clean()
|
||||
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
|
||||
num_experts, K, N, T, top_k, R
|
||||
)
|
||||
|
||||
# Forward with LoRA (auto-dispatched: fused or split)
|
||||
dispatch = (
|
||||
"split"
|
||||
if (
|
||||
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||
)
|
||||
else "fused"
|
||||
)
|
||||
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
|
||||
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
|
||||
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
|
||||
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
|
||||
|
||||
total = t_fwd + t_dx + t_bwd
|
||||
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
||||
|
||||
print(
|
||||
f" R={R:>2} {proj:<8} "
|
||||
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
||||
f"base={t_base:>6.2f}ms "
|
||||
f"(+{overhead * 100:.0f}%) "
|
||||
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||
f"total={total:>6.2f}ms"
|
||||
)
|
||||
|
||||
# Full autograd fwd+bwd with memory measurement
|
||||
x_ag = x.clone().requires_grad_(True)
|
||||
lA_ag = lA.clone().requires_grad_(True)
|
||||
lB_ag = lB.clone().requires_grad_(True)
|
||||
|
||||
def _run_autograd(
|
||||
_x=x_ag,
|
||||
_W=W,
|
||||
_k=top_k,
|
||||
_sei=sei,
|
||||
_ssi=ssi,
|
||||
_eo=eo,
|
||||
_lA=lA_ag,
|
||||
_lB=lB_ag,
|
||||
):
|
||||
out = ScatterMoELoRA.apply(
|
||||
_x,
|
||||
_W,
|
||||
_k,
|
||||
_sei,
|
||||
_ssi,
|
||||
_eo,
|
||||
_lA,
|
||||
_lB,
|
||||
2.0,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out.sum().backward()
|
||||
_x.grad = None
|
||||
_lA.grad = None
|
||||
_lB.grad = None
|
||||
|
||||
t_full = _bench(_run_autograd)
|
||||
|
||||
_clean()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_before = torch.cuda.memory_allocated()
|
||||
_run_autograd()
|
||||
torch.cuda.synchronize()
|
||||
mem_peak = torch.cuda.max_memory_allocated() - mem_before
|
||||
|
||||
print(
|
||||
f" full_fwd_bwd={t_full:>6.2f}ms "
|
||||
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -195,6 +195,30 @@ def _estimate_smem_usage(
|
||||
_SMEM_SLACK = 10_000
|
||||
|
||||
|
||||
def _estimate_register_pressure(
|
||||
num_warps: int,
|
||||
*tile_sizes: tuple[int, int],
|
||||
) -> float:
|
||||
"""Estimate per-thread register count from live tile sizes.
|
||||
|
||||
Each tile of shape (rows, cols) requires rows*cols elements distributed
|
||||
across 32 threads per warp, but each thread in the warp holds a fragment.
|
||||
For Triton GEMM-style kernels, the register footprint per thread is
|
||||
approximately sum(rows * cols) / 32 for each live tile, plus ~40 for
|
||||
scalar overhead (loop counters, pointers, masks, etc.).
|
||||
|
||||
Returns estimated registers per thread.
|
||||
"""
|
||||
# Each thread in a warp holds 1/32 of the tile elements
|
||||
tile_regs = sum(r * c for r, c in tile_sizes) / 32
|
||||
scalar_overhead = 40
|
||||
return tile_regs + scalar_overhead
|
||||
|
||||
|
||||
# Maximum registers per thread on NVIDIA GPUs
|
||||
_MAX_REGS_PER_THREAD = 255
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Forward Kernel: scatter2scatter with fused LoRA
|
||||
# =============================================================================
|
||||
@@ -313,12 +337,11 @@ def _compute_expert_block_lora(
|
||||
B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0
|
||||
) # [BLOCK_N, BLOCK_R]
|
||||
|
||||
# Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16)
|
||||
# Both operands must match; cast to float32 (accumulator type) for precision.
|
||||
b_f32 = b.to(tl.float32)
|
||||
# tl.dot requires non-float32 inputs (tensor cores); cast back to input dtype
|
||||
b_inp = b.to(INPUT_DTYPE)
|
||||
|
||||
# (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N]
|
||||
lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32)
|
||||
lora_out = tl.dot(xa_acc.to(INPUT_DTYPE), tl.trans(b_inp), allow_tf32=allow_tf32)
|
||||
|
||||
acc += scaling * lora_out
|
||||
return acc
|
||||
@@ -327,20 +350,21 @@ def _compute_expert_block_lora(
|
||||
def _scatter2scatter_lora_configs():
|
||||
"""Generate forward kernel autotune configs.
|
||||
|
||||
Search space includes smaller tile sizes and fewer pipeline stages to
|
||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
||||
Search space includes BLOCK_M to allow trading token-tile size for
|
||||
larger BLOCK_K/BLOCK_N tiles. On GPUs with ~99KB SMEM, BLOCK_M=128
|
||||
forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128
|
||||
(4× fewer inner-loop iterations).
|
||||
|
||||
Search space:
|
||||
BLOCK_M: {32, 64, 128}
|
||||
BLOCK_N: {32, 64, 128, 256}
|
||||
BLOCK_K: {32, 64, 128}
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
|
||||
BLOCK_M is fixed at 128 (module-level constant, not autotuned in the
|
||||
scatter2scatter pattern).
|
||||
"""
|
||||
configs = []
|
||||
for block_n, block_k, warps, stages in product(
|
||||
for block_m, block_n, block_k, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64, 128, 256], # BLOCK_N
|
||||
[32, 64, 128], # BLOCK_K
|
||||
[4, 8], # num_warps
|
||||
@@ -348,7 +372,7 @@ def _scatter2scatter_lora_configs():
|
||||
):
|
||||
configs.append(
|
||||
triton.Config(
|
||||
{"BLOCK_N": block_n, "BLOCK_K": block_k},
|
||||
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
|
||||
num_stages=stages,
|
||||
num_warps=warps,
|
||||
)
|
||||
@@ -357,7 +381,7 @@ def _scatter2scatter_lora_configs():
|
||||
|
||||
|
||||
def _prune_fwd_configs(configs, named_args, **kwargs):
|
||||
"""Prune forward configs based on SMEM capacity.
|
||||
"""Prune forward configs based on SMEM capacity and register pressure.
|
||||
|
||||
The forward kernel inner loop loads three tiles per pipeline stage:
|
||||
X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K].
|
||||
@@ -373,23 +397,49 @@ def _prune_fwd_configs(configs, named_args, **kwargs):
|
||||
|
||||
scored = []
|
||||
for config in configs:
|
||||
block_m = config.kwargs["BLOCK_M"]
|
||||
block_n = config.kwargs["BLOCK_N"]
|
||||
block_k = config.kwargs["BLOCK_K"]
|
||||
# Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N
|
||||
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k)
|
||||
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k)
|
||||
# A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop
|
||||
smem_lora_loop = config.num_stages * block_r * block_k * 2
|
||||
# B tile [BLOCK_N, BLOCK_R] loaded once in epilogue
|
||||
smem_lora_epilogue = block_n * block_r * 2
|
||||
smem = smem_base + smem_lora_loop + smem_lora_epilogue
|
||||
|
||||
# Register pressure: live tiles are acc[M,N], xa_acc[M,R],
|
||||
# x[M,K], w[K,N], a[R,K], plus epilogue b[N,R]
|
||||
est_regs = _estimate_register_pressure(
|
||||
config.num_warps,
|
||||
(block_m, block_n), # acc
|
||||
(block_m, block_r), # xa_acc
|
||||
(block_m, block_k), # x tile
|
||||
(block_k, block_n), # w tile
|
||||
(block_r, block_k), # a tile
|
||||
(block_n, block_r), # b tile (epilogue)
|
||||
)
|
||||
if est_regs > _MAX_REGS_PER_THREAD:
|
||||
continue
|
||||
|
||||
scored.append((smem, config))
|
||||
|
||||
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
||||
if pruned:
|
||||
return pruned
|
||||
# All configs exceed SMEM — return the one with smallest estimated usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
if scored:
|
||||
# All surviving configs exceed SMEM — return the one with smallest usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
# All configs pruned by register pressure — fall back to smallest tiles
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: (
|
||||
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_N"] * c.kwargs["BLOCK_K"]
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
@@ -531,6 +581,89 @@ def _scatter2scatter_lora(
|
||||
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
||||
|
||||
|
||||
def _scatter2scatter_lora_split(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
sorted_scattered_idxs: torch.Tensor,
|
||||
k: int,
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
scaling: float,
|
||||
b: Optional[torch.Tensor] = None,
|
||||
x_grouped: bool = False,
|
||||
y_grouped: bool = False,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
|
||||
|
||||
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
|
||||
because the base kernel runs at full speed without LoRA SMEM overhead,
|
||||
and the LoRA matmuls (R=16) are tiny separate passes.
|
||||
|
||||
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
|
||||
"""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
|
||||
scatter2scatter,
|
||||
)
|
||||
|
||||
E = W.size(0)
|
||||
R = lora_A.size(0) // E
|
||||
K = W.size(1)
|
||||
N = W.size(2)
|
||||
|
||||
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
|
||||
output = scatter2scatter(
|
||||
X=X,
|
||||
W=W,
|
||||
b=b,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k,
|
||||
x_grouped=x_grouped,
|
||||
y_grouped=y_grouped,
|
||||
out=out,
|
||||
)
|
||||
|
||||
# 2. XA = X @ A^T (tiny: output is [M*k, R])
|
||||
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
|
||||
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
|
||||
XA = scatter2scatter(
|
||||
X=X,
|
||||
W=W_A,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k,
|
||||
x_grouped=x_grouped,
|
||||
y_grouped=True,
|
||||
)
|
||||
|
||||
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
|
||||
# Reshape B: [N, R*E] → [E, R, N]
|
||||
W_B = lora_B.T.reshape(E, R, N).contiguous()
|
||||
Y_lora = scatter2scatter(
|
||||
X=XA,
|
||||
W=W_B,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=1,
|
||||
x_grouped=True,
|
||||
y_grouped=y_grouped,
|
||||
)
|
||||
|
||||
# 4. Y = Y_base + scaling * Y_lora
|
||||
output.add_(Y_lora, alpha=scaling)
|
||||
return output
|
||||
|
||||
|
||||
# Threshold for switching from fused to split LoRA forward.
|
||||
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
|
||||
# loads dominate in the fused kernel's inner loop).
|
||||
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
|
||||
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
|
||||
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
|
||||
|
||||
|
||||
def scatter2scatter_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
@@ -546,7 +679,13 @@ def scatter2scatter_lora(
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
||||
Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
||||
|
||||
Automatically selects between:
|
||||
- Fused kernel: single Triton kernel with LoRA in the inner loop.
|
||||
Best for many small experts (E>=64, small K*N).
|
||||
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
|
||||
Best for few large experts (E<=32, large K*N like Mixtral).
|
||||
|
||||
Args:
|
||||
X: Input [M, K] or [M*k, K] if x_grouped
|
||||
@@ -565,12 +704,30 @@ def scatter2scatter_lora(
|
||||
Returns:
|
||||
Y: Output [M*k, N]
|
||||
"""
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
|
||||
E = W.size(0)
|
||||
K = W.size(1)
|
||||
N = W.size(2)
|
||||
|
||||
# Dispatch: split for few large experts, fused for many small experts
|
||||
if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD:
|
||||
return _scatter2scatter_lora_split(
|
||||
X,
|
||||
W,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
k,
|
||||
lora_A,
|
||||
lora_B,
|
||||
scaling,
|
||||
b,
|
||||
x_grouped,
|
||||
y_grouped,
|
||||
out,
|
||||
)
|
||||
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
|
||||
R = lora_A.size(0) // E
|
||||
|
||||
# Pad R to power of 2 for Triton tile size
|
||||
@@ -610,11 +767,9 @@ def scatter2scatter_lora(
|
||||
b_ptr,
|
||||
stride_be,
|
||||
stride_bn,
|
||||
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
|
||||
lora_A,
|
||||
lora_A.stride(0),
|
||||
lora_A.stride(1),
|
||||
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
|
||||
lora_B,
|
||||
lora_B.stride(0),
|
||||
lora_B.stride(1),
|
||||
@@ -625,9 +780,8 @@ def scatter2scatter_lora(
|
||||
K=K,
|
||||
N=N,
|
||||
E=E,
|
||||
ACTUAL_R=R, # True LoRA rank for weight indexing
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
ACC_TYPE=tl.float32,
|
||||
scaling=scaling,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
@@ -761,13 +915,13 @@ def _compute_expert_block_lora_dX(
|
||||
+ (A_expert_offset + R_block)[:, None] * stride_ar
|
||||
+ K_block[None, :] * stride_ak
|
||||
)
|
||||
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0)
|
||||
|
||||
# Cast to float32 for precision
|
||||
a_f32 = a_e.to(tl.float32)
|
||||
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to(
|
||||
INPUT_DTYPE
|
||||
)
|
||||
|
||||
# (DY @ B) @ A: [M, R] @ [R, K] -> [M, K]
|
||||
lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32)
|
||||
# tl.dot requires non-float32 inputs (tensor cores); cast accumulator back to input dtype
|
||||
lora_dx = tl.dot(dy_b_acc.to(INPUT_DTYPE), a_e, allow_tf32=allow_tf32)
|
||||
|
||||
acc += scaling * lora_dx
|
||||
return acc
|
||||
@@ -779,17 +933,18 @@ def _scatter2scatter_lora_dX_configs():
|
||||
The inner loop is over N (not K as in forward). The output dimension is K.
|
||||
So BLOCK_K tiles the output and BLOCK_N tiles the reduction.
|
||||
|
||||
Search space includes smaller tile sizes and fewer pipeline stages to
|
||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
||||
BLOCK_M is now autotunable (was fixed at 128).
|
||||
|
||||
Search space:
|
||||
BLOCK_M: {32, 64, 128} (token tile)
|
||||
BLOCK_K: {32, 64, 128, 256} (output tile)
|
||||
BLOCK_N: {32, 64, 128, 256} (reduction tile)
|
||||
num_warps: {4, 8}
|
||||
num_stages: {3, 4, 5}
|
||||
"""
|
||||
configs = []
|
||||
for block_k, block_n, warps, stages in product(
|
||||
for block_m, block_k, block_n, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M
|
||||
[32, 64, 128, 256], # BLOCK_K (output dimension)
|
||||
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
|
||||
[4, 8], # num_warps
|
||||
@@ -797,7 +952,7 @@ def _scatter2scatter_lora_dX_configs():
|
||||
):
|
||||
configs.append(
|
||||
triton.Config(
|
||||
{"BLOCK_K": block_k, "BLOCK_N": block_n},
|
||||
{"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n},
|
||||
num_stages=stages,
|
||||
num_warps=warps,
|
||||
)
|
||||
@@ -806,7 +961,7 @@ def _scatter2scatter_lora_dX_configs():
|
||||
|
||||
|
||||
def _prune_dX_configs(configs, named_args, **kwargs):
|
||||
"""Prune backward dX configs based on SMEM capacity.
|
||||
"""Prune backward dX configs based on SMEM capacity and register pressure.
|
||||
|
||||
The dX kernel inner loop loads three tiles per pipeline stage:
|
||||
DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R].
|
||||
@@ -822,23 +977,49 @@ def _prune_dX_configs(configs, named_args, **kwargs):
|
||||
|
||||
scored = []
|
||||
for config in configs:
|
||||
block_m = config.kwargs["BLOCK_M"]
|
||||
block_k = config.kwargs["BLOCK_K"]
|
||||
block_n = config.kwargs["BLOCK_N"]
|
||||
# Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K
|
||||
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n)
|
||||
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n)
|
||||
# B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop
|
||||
smem_lora_loop = config.num_stages * block_n * block_r * 2
|
||||
# A tile [BLOCK_R, BLOCK_K] loaded once in epilogue
|
||||
smem_lora_epilogue = block_r * block_k * 2
|
||||
smem = smem_base + smem_lora_loop + smem_lora_epilogue
|
||||
|
||||
# Register pressure: live tiles are acc[M,K], dy_b_acc[M,R],
|
||||
# dy[M,N], wt[N,K], b[N,R], plus epilogue a[R,K]
|
||||
est_regs = _estimate_register_pressure(
|
||||
config.num_warps,
|
||||
(block_m, block_k), # acc
|
||||
(block_m, block_r), # dy_b_acc
|
||||
(block_m, block_n), # dy tile
|
||||
(block_n, block_k), # wt tile
|
||||
(block_n, block_r), # b tile
|
||||
(block_r, block_k), # a tile (epilogue)
|
||||
)
|
||||
if est_regs > _MAX_REGS_PER_THREAD:
|
||||
continue
|
||||
|
||||
scored.append((smem, config))
|
||||
|
||||
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
||||
if pruned:
|
||||
return pruned
|
||||
# All configs exceed SMEM — return the one with smallest estimated usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
if scored:
|
||||
# All surviving configs exceed SMEM — return the one with smallest usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
# All configs pruned by register pressure — fall back to smallest tiles
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: (
|
||||
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
@@ -1067,7 +1248,7 @@ def scatter2scatter_lora_dX(
|
||||
N=N,
|
||||
E=E,
|
||||
ACTUAL_R=R,
|
||||
BLOCK_M=BLOCK_M,
|
||||
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
|
||||
BLOCK_R=BLOCK_R,
|
||||
ACC_TYPE=tl.float32,
|
||||
scaling=scaling,
|
||||
@@ -1119,7 +1300,7 @@ def _group_bwd_lora_configs():
|
||||
|
||||
|
||||
def _prune_bwd_lora_configs(configs, named_args, **kwargs):
|
||||
"""Prune backward configs based on SMEM capacity.
|
||||
"""Prune backward configs based on SMEM capacity and register pressure.
|
||||
|
||||
The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N]
|
||||
in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R]
|
||||
@@ -1138,14 +1319,40 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs):
|
||||
# A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert
|
||||
smem_lora = (block_r * block_k + block_n * block_r) * 2
|
||||
smem = smem_base + smem_lora
|
||||
|
||||
# Register pressure: dA_acc[R,K], dB_acc[N,R], x[M,K], dy[M,N],
|
||||
# a[R,K], b[N,R], xa[M,R], dy_b[M,R]
|
||||
est_regs = _estimate_register_pressure(
|
||||
config.num_warps,
|
||||
(block_r, block_k), # dA_acc
|
||||
(block_n, block_r), # dB_acc
|
||||
(block_m, block_k), # x tile
|
||||
(block_m, block_n), # dy tile
|
||||
(block_r, block_k), # a tile
|
||||
(block_n, block_r), # b tile
|
||||
(block_m, block_r), # xa intermediate
|
||||
)
|
||||
if est_regs > _MAX_REGS_PER_THREAD:
|
||||
continue
|
||||
|
||||
scored.append((smem, config))
|
||||
|
||||
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
||||
if pruned:
|
||||
return pruned
|
||||
# All configs exceed SMEM — return the one with smallest estimated usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
if scored:
|
||||
# All surviving configs exceed SMEM — return the one with smallest usage
|
||||
scored.sort(key=lambda x: x[0])
|
||||
return [scored[0][1]]
|
||||
# All configs pruned by register pressure — fall back to smallest tiles
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: (
|
||||
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
@@ -1330,6 +1537,279 @@ def _group_bwd_lora(
|
||||
)
|
||||
|
||||
|
||||
def _group_bwd_split_configs():
|
||||
"""Autotune configs for split dA/dB kernels."""
|
||||
configs = []
|
||||
for block_m, block_dim, warps, stages in product(
|
||||
[32, 64, 128], # BLOCK_M (token tile)
|
||||
[32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile)
|
||||
[4, 8], # num_warps
|
||||
[3, 4, 5], # num_stages
|
||||
):
|
||||
configs.append(
|
||||
triton.Config(
|
||||
{"BLOCK_M": block_m, "BLOCK_DIM": block_dim},
|
||||
num_stages=stages,
|
||||
num_warps=warps,
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def _prune_split_configs(configs, named_args, **kwargs):
|
||||
"""Prune split kernel configs based on SMEM capacity and register pressure."""
|
||||
smem_cap = _get_smem_capacity()
|
||||
block_r = named_args.get("BLOCK_R", 64)
|
||||
|
||||
# Fixed inner tile for reduction dimension
|
||||
BLOCK_INNER = 64
|
||||
|
||||
pruned = []
|
||||
for config in configs:
|
||||
block_m = config.kwargs["BLOCK_M"]
|
||||
block_dim = config.kwargs["BLOCK_DIM"]
|
||||
# Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]
|
||||
smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2
|
||||
# LoRA weights held in registers: [INNER, R] or [R, DIM]
|
||||
smem += (block_r * max(block_dim, BLOCK_INNER)) * 2
|
||||
|
||||
# Register pressure check
|
||||
est_regs = _estimate_register_pressure(
|
||||
config.num_warps,
|
||||
(block_r, block_dim), # acc
|
||||
(block_m, BLOCK_INNER), # input tile
|
||||
(block_m, block_dim), # other tile
|
||||
(block_r, BLOCK_INNER), # lora weight
|
||||
)
|
||||
if est_regs > _MAX_REGS_PER_THREAD:
|
||||
continue
|
||||
|
||||
if smem <= smem_cap - _SMEM_SLACK:
|
||||
pruned.append(config)
|
||||
|
||||
if pruned:
|
||||
return pruned
|
||||
configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"])
|
||||
return [configs[0]]
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=_group_bwd_split_configs(),
|
||||
key=["M", "K", "N"],
|
||||
prune_configs_by={"early_config_prune": _prune_split_configs},
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
"NO_DIM_MASK": lambda args: (
|
||||
(args["K"] % args["BLOCK_DIM"]) == 0
|
||||
if args["COMPUTE_DA"]
|
||||
else (args["N"] % args["BLOCK_DIM"]) == 0
|
||||
),
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _group_bwd_lora_split(
|
||||
# Data tensors (DY and X are always present)
|
||||
DY_ptr,
|
||||
stride_dym,
|
||||
stride_dyn,
|
||||
X_ptr,
|
||||
stride_xm,
|
||||
stride_xk,
|
||||
# LoRA weight for the inner reduction (B for dA, A for dB)
|
||||
LW_ptr,
|
||||
stride_lw0,
|
||||
stride_lw1,
|
||||
# Output gradient tensor (dA or dB)
|
||||
OUT_ptr,
|
||||
stride_out0,
|
||||
stride_out1,
|
||||
# Expert offsets
|
||||
expert_offsets_ptr,
|
||||
# Dimensions
|
||||
M,
|
||||
K: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
ACTUAL_R: tl.constexpr,
|
||||
BLOCK_R: tl.constexpr,
|
||||
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
|
||||
scaling,
|
||||
# Mode flag
|
||||
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
|
||||
# Tile sizes
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DIM: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
NO_DIM_MASK: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Unified split kernel for LoRA gradient computation.
|
||||
|
||||
When COMPUTE_DA=True:
|
||||
dA[e] = scaling * (dY @ B[e])^T @ X → [R, K]
|
||||
Grid: (E, cdiv(K, BLOCK_DIM))
|
||||
- outer_ptr/stride = X (read [M, K_block])
|
||||
- inner reduction over N using DY and B
|
||||
- output shape [BLOCK_R, BLOCK_DIM]
|
||||
|
||||
When COMPUTE_DA=False:
|
||||
dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R]
|
||||
Grid: (E, cdiv(N, BLOCK_DIM))
|
||||
- outer_ptr/stride = DY (read [M, N_block])
|
||||
- inner reduction over K using X and A
|
||||
- output shape [BLOCK_DIM, BLOCK_R]
|
||||
|
||||
No atomic adds — each (E, dim_block) pair is written by exactly one block.
|
||||
"""
|
||||
E_idx = tl.program_id(0)
|
||||
dim_block_id = tl.program_id(1)
|
||||
|
||||
if E_idx == 0:
|
||||
start_idx = 0
|
||||
else:
|
||||
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
|
||||
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
# Output dimension tile (K for dA, N for dB)
|
||||
if COMPUTE_DA:
|
||||
OUT_DIM: tl.constexpr = K # type: ignore[no-redef]
|
||||
else:
|
||||
OUT_DIM: tl.constexpr = N # type: ignore[no-redef]
|
||||
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
||||
dim_mask = dim_block < OUT_DIM
|
||||
R_block = tl.arange(0, BLOCK_R)
|
||||
R_mask = R_block < ACTUAL_R
|
||||
lora_offset = E_idx * ACTUAL_R
|
||||
|
||||
# Output pointers — layout differs: dA is [R, K], dB is [N, R]
|
||||
if COMPUTE_DA:
|
||||
out_blk_ptrs = (
|
||||
OUT_ptr
|
||||
+ (lora_offset + R_block)[:, None] * stride_out0
|
||||
+ dim_block[None, :] * stride_out1
|
||||
)
|
||||
out_mask = R_mask[:, None] & dim_mask[None, :]
|
||||
else:
|
||||
out_blk_ptrs = (
|
||||
OUT_ptr
|
||||
+ dim_block[:, None] * stride_out0
|
||||
+ (lora_offset + R_block)[None, :] * stride_out1
|
||||
)
|
||||
out_mask = dim_mask[:, None] & R_mask[None, :]
|
||||
|
||||
if num_tokens > 0:
|
||||
M_block = tl.arange(0, BLOCK_M)
|
||||
INPUT_DTYPE = X_ptr.dtype.element_ty
|
||||
BLOCK_INNER: tl.constexpr = 64
|
||||
inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)
|
||||
|
||||
if COMPUTE_DA:
|
||||
acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)
|
||||
else:
|
||||
acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)
|
||||
|
||||
M_iters = tl.cdiv(num_tokens, BLOCK_M)
|
||||
for i in range(M_iters):
|
||||
M_idx = start_idx + i * BLOCK_M + M_block
|
||||
M_mask = M_idx < end_idx
|
||||
|
||||
if COMPUTE_DA:
|
||||
# Load X[M, K_block] (the "outer" tensor for dA)
|
||||
outer = tl.load(
|
||||
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
|
||||
mask=M_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
|
||||
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
|
||||
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
|
||||
inner_range = tl.arange(0, BLOCK_INNER)
|
||||
for j in range(inner_iters):
|
||||
inn_off = j * BLOCK_INNER + inner_range
|
||||
inn_mask = inn_off < N
|
||||
|
||||
dy_tile = tl.load(
|
||||
DY_ptr
|
||||
+ M_idx[:, None] * stride_dym
|
||||
+ inn_off[None, :] * stride_dyn,
|
||||
mask=M_mask[:, None] & inn_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
|
||||
lw_tile = tl.load(
|
||||
LW_ptr
|
||||
+ inn_off[:, None] * stride_lw0
|
||||
+ (lora_offset + R_block)[None, :] * stride_lw1,
|
||||
mask=inn_mask[:, None] & R_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
|
||||
|
||||
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
|
||||
acc += tl.dot(
|
||||
tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32
|
||||
)
|
||||
else:
|
||||
# Load DY[M, N_block] (the "outer" tensor for dB)
|
||||
outer = tl.load(
|
||||
DY_ptr
|
||||
+ M_idx[:, None] * stride_dym
|
||||
+ dim_block[None, :] * stride_dyn,
|
||||
mask=M_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
|
||||
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
|
||||
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
|
||||
inner_range = tl.arange(0, BLOCK_INNER)
|
||||
for j in range(inner_iters):
|
||||
inn_off = j * BLOCK_INNER + inner_range
|
||||
inn_mask = inn_off < K
|
||||
|
||||
x_tile = tl.load(
|
||||
X_ptr
|
||||
+ M_idx[:, None] * stride_xm
|
||||
+ inn_off[None, :] * stride_xk,
|
||||
mask=M_mask[:, None] & inn_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
|
||||
# We want A[e]^T: [K, R], so load as [K_inner, R]
|
||||
lw_tile = tl.load(
|
||||
LW_ptr
|
||||
+ (lora_offset + R_block)[None, :] * stride_lw0
|
||||
+ inn_off[:, None] * stride_lw1,
|
||||
mask=inn_mask[:, None] & R_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
|
||||
|
||||
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
|
||||
acc += tl.dot(
|
||||
tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32
|
||||
)
|
||||
|
||||
tl.store(
|
||||
out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask
|
||||
)
|
||||
else:
|
||||
# Zero out this expert's slice — needed because output uses empty_like
|
||||
if COMPUTE_DA:
|
||||
tl.store(
|
||||
out_blk_ptrs,
|
||||
tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty),
|
||||
mask=out_mask,
|
||||
)
|
||||
else:
|
||||
tl.store(
|
||||
out_blk_ptrs,
|
||||
tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty),
|
||||
mask=out_mask,
|
||||
)
|
||||
|
||||
|
||||
def group_bwd_lora(
|
||||
DY: torch.Tensor,
|
||||
X: torch.Tensor,
|
||||
@@ -1344,6 +1824,9 @@ def group_bwd_lora(
|
||||
"""
|
||||
Compute LoRA gradients for A and B on expert-grouped data.
|
||||
|
||||
Uses split dA/dB kernels that eliminate atomic adds by giving each
|
||||
(expert, output_block) pair its own thread block.
|
||||
|
||||
Args:
|
||||
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
|
||||
X: Input [M_total, K] (grouped by expert)
|
||||
@@ -1361,19 +1844,46 @@ def group_bwd_lora(
|
||||
K = X.size(1)
|
||||
N = DY.size(1)
|
||||
|
||||
# Zero-init for atomic accumulation
|
||||
dA = torch.zeros_like(lora_A)
|
||||
dB = torch.zeros_like(lora_B)
|
||||
# No zero-init needed: the split kernels write zeros for experts with
|
||||
# zero routed tokens directly in the kernel (else branch).
|
||||
dA = torch.empty_like(lora_A)
|
||||
dB = torch.empty_like(lora_B)
|
||||
|
||||
BLOCK_R = _block_r_for_rank(R)
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
E * triton.cdiv(K, META["BLOCK_K"]),
|
||||
triton.cdiv(N, META["BLOCK_N"]),
|
||||
)
|
||||
def grid_dA(META):
|
||||
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
|
||||
|
||||
_group_bwd_lora[grid](
|
||||
_group_bwd_lora_split[grid_dA](
|
||||
DY,
|
||||
DY.stride(0),
|
||||
DY.stride(1),
|
||||
X,
|
||||
X.stride(0),
|
||||
X.stride(1),
|
||||
lora_B,
|
||||
lora_B.stride(0),
|
||||
lora_B.stride(1),
|
||||
dA,
|
||||
dA.stride(0),
|
||||
dA.stride(1),
|
||||
expert_offsets,
|
||||
M=DY.size(0),
|
||||
K=K,
|
||||
N=N,
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
INNER_DIM=N,
|
||||
scaling=scaling,
|
||||
COMPUTE_DA=True,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
)
|
||||
|
||||
def grid_dB(META):
|
||||
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
|
||||
|
||||
_group_bwd_lora_split[grid_dB](
|
||||
DY,
|
||||
DY.stride(0),
|
||||
DY.stride(1),
|
||||
@@ -1383,12 +1893,6 @@ def group_bwd_lora(
|
||||
lora_A,
|
||||
lora_A.stride(0),
|
||||
lora_A.stride(1),
|
||||
lora_B,
|
||||
lora_B.stride(0),
|
||||
lora_B.stride(1),
|
||||
dA,
|
||||
dA.stride(0),
|
||||
dA.stride(1),
|
||||
dB,
|
||||
dB.stride(0),
|
||||
dB.stride(1),
|
||||
@@ -1396,9 +1900,11 @@ def group_bwd_lora(
|
||||
M=DY.size(0),
|
||||
K=K,
|
||||
N=N,
|
||||
ACTUAL_R=R, # True LoRA rank
|
||||
BLOCK_R=BLOCK_R, # Padded tile size
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
INNER_DIM=K,
|
||||
scaling=scaling,
|
||||
COMPUTE_DA=False,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
)
|
||||
|
||||
@@ -489,20 +489,71 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
# ====================================================================
|
||||
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
|
||||
|
||||
# ====================================================================
|
||||
# Selective expert weight dequantization
|
||||
# ====================================================================
|
||||
# When experts are BnB-quantized (quantize_moe_experts), dequantize
|
||||
# only the active experts instead of all E. This saves ~97% memory
|
||||
# for the transient dequant buffer when few experts are active.
|
||||
use_selective = (
|
||||
getattr(self, "_use_selective_dequant", False)
|
||||
and hasattr(experts, "parametrizations")
|
||||
and "gate_up_proj" in experts.parametrizations
|
||||
)
|
||||
|
||||
if use_selective:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import (
|
||||
get_active_experts,
|
||||
remap_expert_indices,
|
||||
selective_expert_weights,
|
||||
selective_lora_weights,
|
||||
)
|
||||
|
||||
active_experts = get_active_experts(sorted_expert_idxs, num_experts)
|
||||
remapped_expert_idxs, compact_offsets = remap_expert_indices(
|
||||
sorted_expert_idxs,
|
||||
expert_offsets,
|
||||
active_experts,
|
||||
num_experts,
|
||||
)
|
||||
# Dequantize only active experts' weights
|
||||
gate_up_W = selective_expert_weights(
|
||||
experts,
|
||||
"gate_up_proj",
|
||||
active_experts,
|
||||
).transpose(2, 1) # [num_active, hidden, 2*inter]
|
||||
|
||||
# Remap LoRA weights to match compact expert indices
|
||||
if gup_lora is not None:
|
||||
gup_A, gup_B, gup_scaling = gup_lora
|
||||
gup_A, gup_B = selective_lora_weights(
|
||||
gup_A,
|
||||
gup_B,
|
||||
active_experts,
|
||||
num_experts,
|
||||
)
|
||||
gup_lora = (gup_A, gup_B, gup_scaling)
|
||||
|
||||
# Use remapped indices for ScatterMoE kernels
|
||||
sei_gup = remapped_expert_idxs
|
||||
eo_gup = compact_offsets
|
||||
else:
|
||||
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
||||
sei_gup = sorted_expert_idxs
|
||||
eo_gup = expert_offsets
|
||||
|
||||
# ====================================================================
|
||||
# Gate + Up projection
|
||||
# ====================================================================
|
||||
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
||||
|
||||
if gup_lora is not None:
|
||||
gup_A, gup_B, gup_scaling = gup_lora
|
||||
gup = parallel_linear_lora(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
eo_gup,
|
||||
lora_A=gup_A,
|
||||
lora_B=gup_B,
|
||||
scaling=gup_scaling,
|
||||
@@ -516,9 +567,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
top_k,
|
||||
sorted_expert_idxs,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
eo_gup,
|
||||
grouped_in=False,
|
||||
grouped_out=True,
|
||||
)
|
||||
@@ -529,7 +580,29 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
# ====================================================================
|
||||
# Down projection
|
||||
# ====================================================================
|
||||
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
|
||||
if use_selective:
|
||||
down_W = selective_expert_weights(
|
||||
experts,
|
||||
"down_proj",
|
||||
active_experts,
|
||||
).transpose(2, 1) # [num_active, inter, hidden]
|
||||
|
||||
if down_lora is not None:
|
||||
down_A, down_B, down_scaling = down_lora
|
||||
down_A, down_B = selective_lora_weights(
|
||||
down_A,
|
||||
down_B,
|
||||
active_experts,
|
||||
num_experts,
|
||||
)
|
||||
down_lora = (down_A, down_B, down_scaling)
|
||||
|
||||
sei_down = remapped_expert_idxs
|
||||
eo_down = compact_offsets
|
||||
else:
|
||||
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
|
||||
sei_down = sorted_expert_idxs
|
||||
eo_down = expert_offsets
|
||||
|
||||
if down_lora is not None:
|
||||
down_A, down_B, down_scaling = down_lora
|
||||
@@ -537,9 +610,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
h,
|
||||
down_W,
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sei_down,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
eo_down,
|
||||
lora_A=down_A,
|
||||
lora_B=down_B,
|
||||
scaling=down_scaling,
|
||||
@@ -554,9 +627,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
h,
|
||||
down_W,
|
||||
1,
|
||||
sorted_expert_idxs,
|
||||
sei_down,
|
||||
sorted_scattered_idxs,
|
||||
expert_offsets,
|
||||
eo_down,
|
||||
grouped_in=True,
|
||||
grouped_out=False,
|
||||
gates=routing_weights,
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Selective Expert Dequantization
|
||||
===============================
|
||||
|
||||
Instead of dequantizing all E expert weight matrices at once (which creates
|
||||
a ~1 GB transient buffer for 256 experts), only dequantize the experts that
|
||||
are actually routed to by the current batch's top-k selection.
|
||||
|
||||
For Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512):
|
||||
- Full dequant: [256, 2048, 1024] = 1,074 MB per projection
|
||||
- Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection
|
||||
- Savings: ~97% memory reduction per layer
|
||||
|
||||
This module provides format-agnostic selective weight extraction:
|
||||
- BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert
|
||||
- bf16/fp32: direct indexing (no dequant needed)
|
||||
- FP8: slice + cast
|
||||
|
||||
The ScatterMoE kernel itself doesn't change — we remap expert indices
|
||||
from global (0..E-1) to compact (0..num_active-1) and pass the smaller
|
||||
weight tensor.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor:
|
||||
"""Get sorted unique expert indices from the routing output.
|
||||
|
||||
Args:
|
||||
sorted_expert_idxs: Expert assignments sorted by expert id [T*k]
|
||||
E: Total number of experts
|
||||
|
||||
Returns:
|
||||
active: Sorted unique expert indices [num_active]
|
||||
"""
|
||||
return torch.unique(sorted_expert_idxs)
|
||||
|
||||
|
||||
def remap_expert_indices(
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
active_experts: torch.Tensor,
|
||||
E: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Remap global expert indices to compact indices.
|
||||
|
||||
Maps expert ids from [0..E-1] to [0..num_active-1], preserving the
|
||||
sort order. Also compacts expert_offsets to only active experts.
|
||||
|
||||
Args:
|
||||
sorted_expert_idxs: [T*k] expert ids in sorted order
|
||||
expert_offsets: [E] cumulative token counts (original)
|
||||
active_experts: [num_active] sorted unique expert ids
|
||||
E: Total number of experts
|
||||
|
||||
Returns:
|
||||
remapped_idxs: [T*k] expert ids in [0..num_active-1]
|
||||
compact_offsets: [num_active] cumulative token counts
|
||||
"""
|
||||
# Build remap table: global_id -> compact_id
|
||||
remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device)
|
||||
remap[active_experts] = torch.arange(
|
||||
len(active_experts), device=sorted_expert_idxs.device
|
||||
)
|
||||
|
||||
remapped_idxs = remap[sorted_expert_idxs]
|
||||
|
||||
# Compact the expert_offsets: only keep active experts' cumulative counts
|
||||
compact_offsets = expert_offsets[active_experts]
|
||||
|
||||
return remapped_idxs, compact_offsets
|
||||
|
||||
|
||||
def _selective_dequant_bnb4(
|
||||
raw_param: torch.Tensor,
|
||||
quant_state,
|
||||
active_experts: torch.Tensor,
|
||||
expert_shape: tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""Dequantize only selected experts from BnB 4-bit packed data.
|
||||
|
||||
The raw parameter is a flattened 4-bit packed tensor. Each expert's
|
||||
data is contiguous (stored in expert-major order), so we can gather
|
||||
the packed data and absmax blocks for active experts, then dequantize
|
||||
as one contiguous block.
|
||||
|
||||
Args:
|
||||
raw_param: Flattened uint8 tensor of packed 4-bit weights
|
||||
quant_state: BnB QuantState with absmax, blocksize, code, etc.
|
||||
active_experts: [num_active] expert indices to dequantize
|
||||
expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048))
|
||||
|
||||
Returns:
|
||||
Dequantized weights [num_active, dim1, dim2] in original dtype
|
||||
"""
|
||||
import bitsandbytes.functional as F # noqa: N812
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
expert_numel = expert_shape[0] * expert_shape[1]
|
||||
packed_per_expert = expert_numel // 2 # 4-bit = 2 values per byte
|
||||
blocks_per_expert = expert_numel // quant_state.blocksize
|
||||
num_active = len(active_experts)
|
||||
|
||||
if blocks_per_expert == 0:
|
||||
# Expert is smaller than one quantization block — blocks span across
|
||||
# expert boundaries, so per-expert slicing isn't possible.
|
||||
# Fallback: full dequantize + index.
|
||||
full = F.dequantize_4bit(raw_param, quant_state)
|
||||
E_total = full.numel() // expert_numel
|
||||
return full.reshape(E_total, *expert_shape)[active_experts]
|
||||
|
||||
# Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass)
|
||||
if quant_state.quant_type == "nf4" and raw_param.dtype == torch.uint8:
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import (
|
||||
selective_dequant_nf4_triton,
|
||||
)
|
||||
|
||||
# Handle nested (double) quantization: dequantize absmax first
|
||||
# BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset
|
||||
if quant_state.nested:
|
||||
absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||
absmax += quant_state.offset
|
||||
if absmax.dtype != torch.float32:
|
||||
absmax = absmax.float()
|
||||
else:
|
||||
absmax = quant_state.absmax
|
||||
|
||||
return selective_dequant_nf4_triton(
|
||||
packed_data=raw_param,
|
||||
absmax=absmax,
|
||||
active_experts=active_experts,
|
||||
expert_shape=expert_shape,
|
||||
blocksize=quant_state.blocksize,
|
||||
dtype=quant_state.dtype,
|
||||
codebook=quant_state.code,
|
||||
)
|
||||
|
||||
# Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats)
|
||||
raw_flat = raw_param.reshape(-1)
|
||||
|
||||
offsets_qt = (
|
||||
active_experts.long()[:, None] * packed_per_expert
|
||||
+ torch.arange(packed_per_expert, device=raw_param.device)[None, :]
|
||||
).reshape(-1)
|
||||
qt_gathered = raw_flat[offsets_qt]
|
||||
|
||||
offsets_abs = (
|
||||
active_experts.long()[:, None] * blocks_per_expert
|
||||
+ torch.arange(blocks_per_expert, device=raw_param.device)[None, :]
|
||||
).reshape(-1)
|
||||
|
||||
if quant_state.nested:
|
||||
full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||
full_absmax += quant_state.offset
|
||||
if full_absmax.dtype != torch.float32:
|
||||
full_absmax = full_absmax.float()
|
||||
absmax_gathered = full_absmax[offsets_abs]
|
||||
else:
|
||||
absmax_gathered = quant_state.absmax[offsets_abs]
|
||||
|
||||
qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered
|
||||
|
||||
gathered_qs = QuantState(
|
||||
absmax=absmax_gathered,
|
||||
shape=torch.Size([num_active * expert_numel]),
|
||||
blocksize=quant_state.blocksize,
|
||||
quant_type=quant_state.quant_type,
|
||||
code=quant_state.code,
|
||||
dtype=quant_state.dtype,
|
||||
)
|
||||
|
||||
deq = F.dequantize_4bit(qt_gathered, gathered_qs)
|
||||
return deq.reshape(num_active, *expert_shape)
|
||||
|
||||
|
||||
def _selective_index_dense(
|
||||
param: torch.Tensor,
|
||||
active_experts: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Select experts from a dense (bf16/fp32) weight tensor.
|
||||
|
||||
Simple indexing — no dequantization needed.
|
||||
"""
|
||||
return param[active_experts]
|
||||
|
||||
|
||||
def selective_expert_weights(
|
||||
experts_module: nn.Module,
|
||||
param_name: str,
|
||||
active_experts: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract and dequantize only the active experts' weights.
|
||||
|
||||
Format-agnostic: dispatches based on whether the parameter is
|
||||
BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32.
|
||||
|
||||
Args:
|
||||
experts_module: The base experts module (e.g. Qwen3_5MoeExperts)
|
||||
param_name: "gate_up_proj" or "down_proj"
|
||||
active_experts: [num_active] sorted unique expert indices
|
||||
|
||||
Returns:
|
||||
Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE
|
||||
"""
|
||||
# Check if the parameter is BnB-quantized via parametrize
|
||||
if (
|
||||
hasattr(experts_module, "parametrizations")
|
||||
and param_name in experts_module.parametrizations
|
||||
):
|
||||
param_list = experts_module.parametrizations[param_name]
|
||||
parametrization = param_list[0]
|
||||
|
||||
# BnB 4-bit parametrization
|
||||
if hasattr(parametrization, "quant_state"):
|
||||
# The raw quantized data is on the ParametrizationList, not the
|
||||
# individual Bnb4bitParametrization module
|
||||
raw_param = param_list.original
|
||||
qs = parametrization.quant_state
|
||||
# qs.shape is the original tensor shape before flattening.
|
||||
# For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D).
|
||||
orig_shape = qs.shape
|
||||
if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3:
|
||||
expert_shape = (orig_shape[1], orig_shape[2])
|
||||
elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1:
|
||||
# Flattened — need to infer from module attributes
|
||||
E_total = getattr(experts_module, "num_experts", None)
|
||||
if E_total is None:
|
||||
E_total = int(active_experts.max().item()) + 1
|
||||
expert_numel = orig_shape[0] // E_total
|
||||
d2 = getattr(experts_module, "hidden_dim", None) or getattr(
|
||||
experts_module, "intermediate_dim", None
|
||||
)
|
||||
if d2 and expert_numel % d2 == 0:
|
||||
expert_shape = (expert_numel // d2, d2)
|
||||
else:
|
||||
full = getattr(experts_module, param_name)
|
||||
return full[active_experts]
|
||||
else:
|
||||
full = getattr(experts_module, param_name)
|
||||
return full[active_experts]
|
||||
|
||||
return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape)
|
||||
|
||||
# Dense parameter (bf16/fp32) — direct indexing
|
||||
param = getattr(experts_module, param_name)
|
||||
if param.dim() == 3:
|
||||
return param[active_experts]
|
||||
|
||||
# Fallback: full access
|
||||
return param
|
||||
|
||||
|
||||
def selective_lora_weights(
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
active_experts: torch.Tensor,
|
||||
E: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Select LoRA A and B weights for only the active experts.
|
||||
|
||||
LoRA layout (scattermoe format):
|
||||
A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r]
|
||||
B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r]
|
||||
|
||||
Returns compact:
|
||||
A: [r*num_active, K]
|
||||
B: [N, r*num_active]
|
||||
"""
|
||||
R = lora_A.size(0) // E
|
||||
|
||||
# Vectorized gather: active_experts[:, None] * R + arange(R)[None, :]
|
||||
row_idx = (
|
||||
active_experts.long()[:, None] * R
|
||||
+ torch.arange(R, device=lora_A.device)[None, :]
|
||||
).reshape(-1)
|
||||
|
||||
compact_A = lora_A[row_idx] # [r*num_active, K]
|
||||
compact_B = lora_B[:, row_idx] # [N, r*num_active]
|
||||
|
||||
return compact_A, compact_B
|
||||
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Triton kernel for fused selective expert gather + NF4 dequantization.
|
||||
|
||||
Instead of:
|
||||
1. Gather packed uint8 data for active experts (memory copy)
|
||||
2. Gather absmax for active experts (memory copy)
|
||||
3. Call BnB dequantize_4bit CUDA kernel
|
||||
|
||||
This kernel does all three in one pass:
|
||||
- Reads packed NF4 bytes from expert-strided positions
|
||||
- Looks up the NF4 codebook
|
||||
- Multiplies by the per-block absmax
|
||||
- Writes bf16 output directly
|
||||
|
||||
This eliminates the intermediate gather buffer entirely.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# NF4 codebook (16 values, precomputed by BnB)
|
||||
# These are the normalized float4 reconstruction values
|
||||
NF4_CODEBOOK = [
|
||||
-1.0,
|
||||
-0.6961928009986877,
|
||||
-0.5250730514526367,
|
||||
-0.39491748809814453,
|
||||
-0.28444138169288635,
|
||||
-0.18477343022823334,
|
||||
-0.09105003625154495,
|
||||
0.0,
|
||||
0.07958029955625534,
|
||||
0.16093020141124725,
|
||||
0.24611230194568634,
|
||||
0.33791524171829224,
|
||||
0.44070982933044434,
|
||||
0.5626170039176941,
|
||||
0.7229568362236023,
|
||||
1.0,
|
||||
]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _selective_dequant_nf4_kernel(
|
||||
# Input: packed NF4 data (flattened, expert-major order)
|
||||
packed_ptr,
|
||||
# Input: absmax values (flattened, expert-major order)
|
||||
absmax_ptr,
|
||||
# Input: active expert indices
|
||||
active_experts_ptr,
|
||||
# Input: NF4 codebook (16 float values)
|
||||
codebook_ptr,
|
||||
# Output: dequantized bf16 weights [num_active, expert_numel]
|
||||
out_ptr,
|
||||
stride_out_e, # stride for expert dim in output
|
||||
# Dimensions
|
||||
num_active,
|
||||
packed_per_expert, # expert_numel // 2
|
||||
blocks_per_expert, # expert_numel // blocksize
|
||||
blocksize: tl.constexpr,
|
||||
# Tile size
|
||||
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2)
|
||||
):
|
||||
"""
|
||||
Each program processes BLOCK_SIZE elements from one expert.
|
||||
|
||||
Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE))
|
||||
|
||||
For each output element:
|
||||
1. Compute which byte in packed data contains this element
|
||||
2. Extract the 4-bit nibble (high or low)
|
||||
3. Look up in NF4 codebook
|
||||
4. Scale by absmax for this block
|
||||
"""
|
||||
expert_local_idx = tl.program_id(0) # which active expert (0..num_active-1)
|
||||
block_id = tl.program_id(1) # which element block
|
||||
|
||||
# Load the global expert index
|
||||
expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64)
|
||||
|
||||
expert_numel = packed_per_expert * 2 # 2 elements per packed byte
|
||||
elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = elem_offset < expert_numel
|
||||
|
||||
# Each element is packed as: byte[i//2], low nibble for even i, high for odd i
|
||||
byte_idx = elem_offset // 2
|
||||
is_high = (elem_offset % 2) == 1
|
||||
|
||||
# Read packed bytes from the global expert's region
|
||||
packed_global_offset = expert_global * packed_per_expert + byte_idx
|
||||
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(
|
||||
tl.int32
|
||||
)
|
||||
|
||||
# Extract 4-bit nibble
|
||||
# BnB packing: high nibble = even element, low nibble = odd element
|
||||
nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF)
|
||||
|
||||
# NF4 codebook lookup
|
||||
# Load all 16 codebook values (small, fits in registers)
|
||||
# Use gather from codebook pointer
|
||||
code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0)
|
||||
|
||||
# Load absmax for this element's quantization block
|
||||
block_idx = elem_offset // blocksize
|
||||
absmax_global_offset = expert_global * blocks_per_expert + block_idx
|
||||
absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0)
|
||||
|
||||
# Dequantize: value = codebook[nibble] * absmax
|
||||
result = code_val * absmax_val
|
||||
|
||||
# Store to output
|
||||
out_offset = expert_local_idx * stride_out_e + elem_offset
|
||||
tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
def selective_dequant_nf4_triton(
|
||||
packed_data: torch.Tensor,
|
||||
absmax: torch.Tensor,
|
||||
active_experts: torch.Tensor,
|
||||
expert_shape: tuple[int, int],
|
||||
blocksize: int,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
codebook: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Fused selective gather + NF4 dequantization via Triton kernel.
|
||||
|
||||
Args:
|
||||
packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1]
|
||||
absmax: Per-block scaling factors [total_blocks]
|
||||
active_experts: Sorted indices of experts to dequantize [num_active]
|
||||
expert_shape: (dim1, dim2) per expert
|
||||
blocksize: Quantization block size
|
||||
dtype: Output dtype (default bf16)
|
||||
codebook: NF4 lookup table [16] (uses default NF4 codebook if None)
|
||||
|
||||
Returns:
|
||||
Dequantized weights [num_active, dim1, dim2]
|
||||
"""
|
||||
num_active = active_experts.shape[0]
|
||||
expert_numel = expert_shape[0] * expert_shape[1]
|
||||
packed_per_expert = expert_numel // 2
|
||||
blocks_per_expert = expert_numel // blocksize
|
||||
|
||||
# Prepare codebook on device
|
||||
if codebook is None:
|
||||
codebook = torch.tensor(
|
||||
NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device
|
||||
)
|
||||
else:
|
||||
codebook = codebook.to(device=packed_data.device, dtype=torch.float32)
|
||||
|
||||
# Flatten inputs
|
||||
packed_flat = packed_data.reshape(-1)
|
||||
absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32
|
||||
|
||||
# Output buffer
|
||||
out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device)
|
||||
|
||||
BLOCK_SIZE = 1024 # Process 1024 elements per thread block
|
||||
|
||||
grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE))
|
||||
|
||||
_selective_dequant_nf4_kernel[grid](
|
||||
packed_flat,
|
||||
absmax_flat,
|
||||
active_experts,
|
||||
codebook,
|
||||
out,
|
||||
out.stride(0),
|
||||
num_active=num_active,
|
||||
packed_per_expert=packed_per_expert,
|
||||
blocks_per_expert=blocks_per_expert,
|
||||
blocksize=blocksize,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return out.reshape(num_active, *expert_shape)
|
||||
@@ -61,7 +61,16 @@ class KernelsPlugin(BasePlugin):
|
||||
return "axolotl.integrations.kernels.KernelsArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
|
||||
|
||||
# Prefer text backbone type for VLMs, but fall back to base type
|
||||
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
|
||||
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
|
||||
if (
|
||||
moe_model_type not in SPARSE_MOE_BLOCK
|
||||
and cfg.model_config_type in SPARSE_MOE_BLOCK
|
||||
):
|
||||
moe_model_type = cfg.model_config_type
|
||||
|
||||
if cfg.use_scattermoe:
|
||||
self._register_kernels()
|
||||
|
||||
@@ -505,6 +505,20 @@ class ModelLoader:
|
||||
elif not is_ds_zero3:
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
|
||||
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
|
||||
# so the actual VRAM usage is much less than bf16 estimates.
|
||||
# When device_map is "auto", accelerate's infer_auto_device_map computes
|
||||
# the device map at bf16 size (before quantization), causing it to offload
|
||||
# layers to CPU, which BnB then rejects. Force single-GPU placement to
|
||||
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
|
||||
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
|
||||
"auto",
|
||||
None,
|
||||
):
|
||||
self.model_kwargs["device_map"] = {
|
||||
"": int(os.environ.get("LOCAL_RANK", 0))
|
||||
}
|
||||
|
||||
cur_device = get_device_type()
|
||||
if "mps" in str(cur_device):
|
||||
self.model_kwargs["device_map"] = "mps:0"
|
||||
|
||||
@@ -17,6 +17,8 @@ from transformers import (
|
||||
class PytorchProfilerCallback(TrainerCallback):
|
||||
"""
|
||||
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
|
||||
|
||||
Also runs torch.profiler to produce a Chrome trace for timing analysis.
|
||||
"""
|
||||
|
||||
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
|
||||
@@ -26,9 +28,10 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
if profiler_steps_start == 0:
|
||||
# start recording memory allocations before everything is allocated, because if we start
|
||||
# at the beginning of step 0, we won't have any memory allocations in the traces
|
||||
torch.cuda.memory._record_memory_history(enabled="all")
|
||||
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
|
||||
profiler_steps_start = -1
|
||||
self.profiler_steps_start = profiler_steps_start
|
||||
self._profiler = None
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
@@ -38,7 +41,21 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
**kwargs,
|
||||
):
|
||||
if state.global_step == self.profiler_steps_start:
|
||||
torch.cuda.memory._record_memory_history(enabled="all")
|
||||
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
|
||||
|
||||
# Start torch.profiler on the first profiled step
|
||||
if state.global_step == max(self.profiler_steps_start, 0):
|
||||
profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
)
|
||||
profiler.__enter__()
|
||||
self._profiler = profiler
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
@@ -55,6 +72,13 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
# tell CUDA to stop recording memory allocations now
|
||||
torch.cuda.memory._record_memory_history(enabled=None)
|
||||
|
||||
# Stop and export torch.profiler trace
|
||||
if self._profiler is not None:
|
||||
self._profiler.__exit__(None, None, None)
|
||||
trace_path = Path(args.output_dir) / "profiler_trace.json"
|
||||
self._profiler.export_chrome_trace(str(trace_path))
|
||||
self._profiler = None
|
||||
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
@@ -73,3 +97,9 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
|
||||
# tell CUDA to stop recording memory allocations now
|
||||
torch.cuda.memory._record_memory_history(enabled=None)
|
||||
|
||||
if self._profiler is not None:
|
||||
self._profiler.__exit__(None, None, None)
|
||||
trace_path = Path(args.output_dir) / "profiler_trace.json"
|
||||
self._profiler.export_chrome_trace(str(trace_path))
|
||||
self._profiler = None
|
||||
|
||||
407
tests/integrations/test_scattermoe_lora_kernels.py
Normal file
407
tests/integrations/test_scattermoe_lora_kernels.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Axolotl AI
|
||||
# Licensed under the Apache License, Version 2.0
|
||||
|
||||
"""
|
||||
Unit tests for ScatterMoE LoRA Triton kernels.
|
||||
|
||||
Tests correctness of:
|
||||
- scatter2scatter_lora (forward)
|
||||
- scatter2scatter_lora_dX (backward input gradient)
|
||||
- group_bwd_lora (backward LoRA weight gradients via split dA/dB)
|
||||
- ScatterMoELoRA autograd function (full forward + backward)
|
||||
|
||||
Each kernel is tested against a pure PyTorch per-expert-loop reference
|
||||
implementation at multiple model shapes and LoRA ranks.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||
lora_ops,
|
||||
ops as base_ops,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||
flatten_sort_count,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
||||
ScatterMoELoRA,
|
||||
)
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
def _requires_cuda():
|
||||
return pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="CUDA not available"
|
||||
)
|
||||
|
||||
|
||||
pytestmark = _requires_cuda()
|
||||
|
||||
|
||||
# ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _setup(E, K, N, T, top_k, R, seed=42):
|
||||
"""Create synthetic expert weights, LoRA, routing, and grouped inputs."""
|
||||
torch.manual_seed(seed)
|
||||
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||
W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||
lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
logits = torch.randn(T, E, device=DEVICE)
|
||||
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||
sei, ssi, eo = flatten_sort_count(top_idx, E)
|
||||
return x, W, lora_A, lora_B, sei, ssi, eo
|
||||
|
||||
|
||||
def _reference_fwd(x, W, sei, ssi, eo, k, lora_A, lora_B, scaling, E):
|
||||
"""Per-expert loop reference: Y = X@W + scaling*(X@A^T)@B^T."""
|
||||
grouped_x = base_ops.group(x, ssi, fan_out=k)
|
||||
M, N = grouped_x.size(0), W.size(2)
|
||||
R = lora_A.size(0) // E
|
||||
out = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)
|
||||
for e in range(E):
|
||||
s = eo[e - 1].item() if e > 0 else 0
|
||||
end = eo[e].item()
|
||||
if s == end:
|
||||
continue
|
||||
xe = grouped_x[s:end].float()
|
||||
we = W[e].float()
|
||||
ae = lora_A[e * R : (e + 1) * R].float()
|
||||
be = lora_B[:, e * R : (e + 1) * R].float()
|
||||
out[s:end] = (xe @ we + scaling * (xe @ ae.T) @ be.T).to(DTYPE)
|
||||
result = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)
|
||||
result[ssi] = out
|
||||
return result
|
||||
|
||||
|
||||
def _reference_dX(dy_grouped, W, sei, ssi, eo, lora_A, lora_B, scaling, E):
|
||||
"""Per-expert loop reference: dX = dY@W^T + scaling*(dY@B)@A."""
|
||||
M, K = dy_grouped.size(0), W.size(1)
|
||||
R = lora_A.size(0) // E
|
||||
out = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)
|
||||
for e in range(E):
|
||||
s = eo[e - 1].item() if e > 0 else 0
|
||||
end = eo[e].item()
|
||||
if s == end:
|
||||
continue
|
||||
dye = dy_grouped[s:end].float()
|
||||
we = W[e].float()
|
||||
ae = lora_A[e * R : (e + 1) * R].float()
|
||||
be = lora_B[:, e * R : (e + 1) * R].float()
|
||||
out[s:end] = (dye @ we.T + scaling * (dye @ be) @ ae).to(DTYPE)
|
||||
result = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)
|
||||
result[ssi] = out
|
||||
return result
|
||||
|
||||
|
||||
def _reference_bwd_lora(dy, grouped_x, lora_A, lora_B, eo, E, scaling):
|
||||
"""Per-expert loop reference: dA, dB for LoRA weight gradients."""
|
||||
R = lora_A.size(0) // E
|
||||
dA = torch.zeros_like(lora_A)
|
||||
dB = torch.zeros_like(lora_B)
|
||||
for e in range(E):
|
||||
s = eo[e - 1].item() if e > 0 else 0
|
||||
end = eo[e].item()
|
||||
if s == end:
|
||||
continue
|
||||
xe = grouped_x[s:end].float()
|
||||
dye = dy[s:end].float()
|
||||
ae = lora_A[e * R : (e + 1) * R].float()
|
||||
be = lora_B[:, e * R : (e + 1) * R].float()
|
||||
dA[e * R : (e + 1) * R] = (scaling * (dye @ be).T @ xe).to(DTYPE)
|
||||
dB[:, e * R : (e + 1) * R] = (scaling * dye.T @ (xe @ ae.T)).to(DTYPE)
|
||||
return dA, dB
|
||||
|
||||
|
||||
# ─── Model shape configs ────────────────────────────────────────────────────
|
||||
|
||||
# (E, K, N, T, top_k, R, description)
|
||||
CONFIGS_SMALL = [
|
||||
(32, 128, 64, 64, 2, 4, "tiny"),
|
||||
(64, 256, 128, 128, 4, 8, "small"),
|
||||
]
|
||||
|
||||
CONFIGS_REAL = [
|
||||
(256, 2048, 1024, 2048, 8, 16, "qwen35_gate_up"),
|
||||
(256, 512, 2048, 2048, 8, 16, "qwen35_down"),
|
||||
(64, 2048, 2048, 2048, 8, 16, "olmoe_gate_up"),
|
||||
(128, 2048, 1536, 2048, 8, 16, "qwen3_gate_up"),
|
||||
]
|
||||
|
||||
SCALING = 2.0
|
||||
|
||||
|
||||
# ─── Forward tests ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestScatter2ScatterLoRAForward:
|
||||
"""Test scatter2scatter_lora forward kernel vs reference."""
|
||||
|
||||
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
|
||||
def config(self, request):
|
||||
return request.param
|
||||
|
||||
def test_matches_reference(self, config):
|
||||
E, K, N, T, k, R, desc = config
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
kernel_out = lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
)
|
||||
ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E)
|
||||
|
||||
err = (kernel_out.float() - ref_out.float()).abs().max().item()
|
||||
assert err < 1.0, f"[{desc}] fwd max_err={err}"
|
||||
|
||||
def test_output_shape(self, config):
|
||||
E, K, N, T, k, R, desc = config
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
out = lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
)
|
||||
assert out.shape == (T * k, N)
|
||||
assert out.dtype == DTYPE
|
||||
|
||||
|
||||
# ─── Backward dX tests ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestScatter2ScatterLoRADX:
|
||||
"""Test scatter2scatter_lora_dX backward kernel vs reference."""
|
||||
|
||||
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
|
||||
def config(self, request):
|
||||
return request.param
|
||||
|
||||
def test_matches_reference(self, config):
|
||||
E, K, N, T, k, R, desc = config
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
gx = base_ops.group(x, ssi, fan_out=k)
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
|
||||
kernel_dx = lora_ops.scatter2scatter_lora_dX(
|
||||
DY=dy,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=1,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
dy_grouped=True,
|
||||
dx_grouped=False,
|
||||
)
|
||||
ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E)
|
||||
|
||||
err = (kernel_dx.float() - ref_dx.float()).abs().max().item()
|
||||
assert err < 1.0, f"[{desc}] dX max_err={err}"
|
||||
|
||||
|
||||
# ─── Backward LoRA gradient tests ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGroupBwdLoRA:
|
||||
"""Test group_bwd_lora (split dA/dB kernel) vs reference."""
|
||||
|
||||
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
|
||||
def config(self, request):
|
||||
return request.param
|
||||
|
||||
def test_matches_reference(self, config):
|
||||
E, K, N, T, k, R, desc = config
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
gx = base_ops.group(x, ssi, fan_out=k)
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
|
||||
kern_dA, kern_dB = lora_ops.group_bwd_lora(
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=E,
|
||||
scaling=SCALING,
|
||||
)
|
||||
ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING)
|
||||
|
||||
# Use norm-relative error: bf16 accumulation order differs between
|
||||
# kernel (tiled + different reduction order) and reference (per-expert
|
||||
# fp32 loop), so max absolute error can be large on individual elements
|
||||
# while the overall tensor is correct.
|
||||
dA_norm_err = (
|
||||
(kern_dA.float() - ref_dA.float()).norm() / (ref_dA.float().norm() + 1e-6)
|
||||
).item()
|
||||
dB_norm_err = (
|
||||
(kern_dB.float() - ref_dB.float()).norm() / (ref_dB.float().norm() + 1e-6)
|
||||
).item()
|
||||
assert dA_norm_err < 0.01, f"[{desc}] dA norm_rel_err={dA_norm_err}"
|
||||
assert dB_norm_err < 0.01, f"[{desc}] dB norm_rel_err={dB_norm_err}"
|
||||
|
||||
def test_zero_expert_tokens(self):
|
||||
"""Experts with zero routed tokens produce zero gradients."""
|
||||
E, K, N, R = 8, 64, 32, 4
|
||||
torch.manual_seed(42)
|
||||
# Route all tokens to expert 0 only
|
||||
T, k = 16, 1
|
||||
top_idx = torch.zeros(T, k, dtype=torch.long, device=DEVICE)
|
||||
sei, ssi, eo = flatten_sort_count(top_idx, E)
|
||||
gx = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||
dy = torch.randn(T, N, device=DEVICE, dtype=DTYPE)
|
||||
lA = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE)
|
||||
lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE)
|
||||
|
||||
dA, dB = lora_ops.group_bwd_lora(
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=E,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
# Experts 1..7 should have zero gradients
|
||||
for e in range(1, E):
|
||||
assert dA[e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dA not zero"
|
||||
assert dB[:, e * R : (e + 1) * R].abs().max() == 0, (
|
||||
f"Expert {e} dB not zero"
|
||||
)
|
||||
|
||||
|
||||
# ─── Full autograd tests ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestScatterMoELoRAAutograd:
|
||||
"""Test full forward + backward through ScatterMoELoRA autograd function."""
|
||||
|
||||
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL[:2])
|
||||
def config(self, request):
|
||||
return request.param
|
||||
|
||||
def test_gradients_exist_and_finite(self, config):
|
||||
E, K, N, T, k, R, desc = config
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
x = x.requires_grad_(True)
|
||||
lA = lA.requires_grad_(True)
|
||||
lB = lB.requires_grad_(True)
|
||||
|
||||
out = ScatterMoELoRA.apply(
|
||||
x,
|
||||
W,
|
||||
k,
|
||||
sei,
|
||||
ssi,
|
||||
eo,
|
||||
lA,
|
||||
lB,
|
||||
SCALING,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out.sum().backward()
|
||||
|
||||
assert x.grad is not None, f"[{desc}] x.grad is None"
|
||||
assert lA.grad is not None, f"[{desc}] lA.grad is None"
|
||||
assert lB.grad is not None, f"[{desc}] lB.grad is None"
|
||||
assert torch.isfinite(x.grad).all(), f"[{desc}] x.grad has non-finite"
|
||||
assert torch.isfinite(lA.grad).all(), f"[{desc}] lA.grad has non-finite"
|
||||
assert torch.isfinite(lB.grad).all(), f"[{desc}] lB.grad has non-finite"
|
||||
assert x.grad.abs().sum() > 0, f"[{desc}] x.grad all zero"
|
||||
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
|
||||
|
||||
def test_split_matches_fused(self):
|
||||
"""Split dispatch (for few large experts) matches fused kernel."""
|
||||
# Use a shape where split would be dispatched (large K*N, few E)
|
||||
E, K, N, T, k, R = 8, 512, 1024, 128, 2, 16
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
# Force fused path
|
||||
orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18
|
||||
out_fused = lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
)
|
||||
|
||||
# Force split path
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0
|
||||
out_split = lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
)
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig
|
||||
|
||||
norm_err = (
|
||||
(out_fused.float() - out_split.float()).norm()
|
||||
/ (out_fused.float().norm() + 1e-6)
|
||||
).item()
|
||||
assert norm_err < 0.01, f"split vs fused norm_err={norm_err}"
|
||||
|
||||
def test_scaling_zero_gives_base_only(self):
|
||||
"""With scaling=0.0, LoRA contribution vanishes. Output = X@W."""
|
||||
E, K, N, T, k, R = 16, 64, 32, 32, 2, 4
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
out_lora = ScatterMoELoRA.apply(
|
||||
x,
|
||||
W,
|
||||
k,
|
||||
sei,
|
||||
ssi,
|
||||
eo,
|
||||
lA,
|
||||
lB,
|
||||
0.0,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out_base = base_ops.scatter2scatter(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
)
|
||||
err = (out_lora.float() - out_base.float()).abs().max().item()
|
||||
assert err < 0.01, f"scaling=0 should match base: err={err}"
|
||||
Reference in New Issue
Block a user