diff --git a/benchmarks/bench_scattermoe_lora.py b/benchmarks/bench_scattermoe_lora.py new file mode 100644 index 000000000..0fb3ba68c --- /dev/null +++ b/benchmarks/bench_scattermoe_lora.py @@ -0,0 +1,284 @@ +"""Benchmark for ScatterMoE LoRA Triton kernels. + +Measures forward, backward dX, and backward dA/dB kernels at common MoE +model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter, +and full fwd+bwd autograd throughput. + +Usage: + CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py + CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64 + CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B +""" + +import argparse +import gc +import time +from functools import partial + +import torch + +from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import ( + lora_ops, + ops as base_ops, +) +from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import ( + flatten_sort_count, +) +from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import ( + ScatterMoELoRA, +) + +DEVICE = "cuda" +DTYPE = torch.bfloat16 +WARMUP = 5 +ITERS = 20 + +# ─── Model configs ────────────────────────────────────────────────────────── + +BUILTIN_CONFIGS = { + "Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k + "Qwen3-30B-A3B": (128, 2048, 768, 8), + "OLMoE-1B-7B": (64, 2048, 1024, 8), + "Mixtral-8x7B": (8, 4096, 14336, 2), +} + + +def _resolve_config(spec): + """Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs.""" + key = spec.lower().replace("/", "-") + for name, cfg in BUILTIN_CONFIGS.items(): + if key in name.lower() or name.lower() in key: + return name, cfg + + from transformers import AutoConfig + + hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True) + if callable(getattr(hf_cfg, "get_text_config", None)): + tc = hf_cfg.get_text_config() + if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type: + hf_cfg = tc + hidden = hf_cfg.hidden_size + inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size + experts = ( + getattr(hf_cfg, "num_experts", None) + or getattr(hf_cfg, "num_local_experts", None) + or getattr(hf_cfg, "n_routed_experts", None) + ) + top_k = ( + getattr(hf_cfg, "num_experts_per_tok", None) + or getattr(hf_cfg, "num_experts_per_token", None) + or 2 + ) + name = spec.split("/")[-1] + return name, (experts, hidden, inter, top_k) + + +# ─── Benchmark helpers ────────────────────────────────────────────────────── + + +def _clean(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def _bench(fn, warmup=WARMUP, iters=ITERS): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + times = [] + for _ in range(iters): + torch.cuda.synchronize() + t0 = time.perf_counter() + fn() + torch.cuda.synchronize() + times.append((time.perf_counter() - t0) * 1000) + times.sort() + return times[len(times) // 2] + + +def _setup(num_experts, K, N, T, top_k, R): + torch.manual_seed(42) + x = torch.randn(T, K, device=DEVICE, dtype=DTYPE) + W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02 + lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01 + lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01 + logits = torch.randn(T, num_experts, device=DEVICE) + _, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1) + sei, ssi, eo = flatten_sort_count(top_idx, num_experts) + gx = base_ops.group(x, ssi, fan_out=top_k) + dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE) + return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy + + +# ─── Kernel wrappers (avoid B023 loop-variable capture) ────────────────────── + + +def _call_fwd(x, W, sei, ssi, top_k, lA, lB): + return lora_ops.scatter2scatter_lora( + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=top_k, + lora_A=lA, + lora_B=lB, + scaling=2.0, + ) + + +def _call_base(x, W, sei, ssi, top_k): + return base_ops.scatter2scatter( + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=top_k, + ) + + +def _call_dx(dy, W, sei, ssi, lA, lB): + return lora_ops.scatter2scatter_lora_dX( + DY=dy, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=1, + lora_A=lA, + lora_B=lB, + scaling=2.0, + dy_grouped=True, + dx_grouped=False, + ) + + +def _call_bwd(dy, gx, lA, lB, eo, num_experts): + return lora_ops.group_bwd_lora( + DY=dy, + X=gx, + lora_A=lA, + lora_B=lB, + expert_offsets=eo, + E=num_experts, + scaling=2.0, + ) + + +# ─── Main ──────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark") + parser.add_argument( + "--models", + "-m", + nargs="+", + help="Model names or HF IDs (default: all builtins)", + ) + parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64]) + parser.add_argument("--seq-len", "-T", type=int, default=2048) + args = parser.parse_args() + + T = args.seq_len + print(f"GPU: {torch.cuda.get_device_name()}") + print(f"T={T}, ranks={args.ranks}\n") + + if args.models: + configs = [_resolve_config(m) for m in args.models] + else: + configs = list(BUILTIN_CONFIGS.items()) + + for model_name, (num_experts, hidden, inter, top_k) in configs: + print(f"{'=' * 70}") + print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}") + print(f"{'=' * 70}") + + for R in args.ranks: + for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]: + _clean() + x, W, lA, lB, sei, ssi, eo, gx, dy = _setup( + num_experts, K, N, T, top_k, R + ) + + # Forward with LoRA (auto-dispatched: fused or split) + dispatch = ( + "split" + if ( + num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS + and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD + ) + else "fused" + ) + t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB)) + t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k)) + t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB)) + t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts)) + + total = t_fwd + t_dx + t_bwd + overhead = t_fwd / t_base - 1 if t_base > 0 else 0 + + print( + f" R={R:>2} {proj:<8} " + f"fwd={t_fwd:>6.2f}ms [{dispatch}] " + f"base={t_base:>6.2f}ms " + f"(+{overhead * 100:.0f}%) " + f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms " + f"total={total:>6.2f}ms" + ) + + # Full autograd fwd+bwd with memory measurement + x_ag = x.clone().requires_grad_(True) + lA_ag = lA.clone().requires_grad_(True) + lB_ag = lB.clone().requires_grad_(True) + + def _run_autograd( + _x=x_ag, + _W=W, + _k=top_k, + _sei=sei, + _ssi=ssi, + _eo=eo, + _lA=lA_ag, + _lB=lB_ag, + ): + out = ScatterMoELoRA.apply( + _x, + _W, + _k, + _sei, + _ssi, + _eo, + _lA, + _lB, + 2.0, + None, + None, + False, + False, + True, + False, + ) + out.sum().backward() + _x.grad = None + _lA.grad = None + _lB.grad = None + + t_full = _bench(_run_autograd) + + _clean() + torch.cuda.reset_peak_memory_stats() + mem_before = torch.cuda.memory_allocated() + _run_autograd() + torch.cuda.synchronize() + mem_peak = torch.cuda.max_memory_allocated() - mem_before + + print( + f" full_fwd_bwd={t_full:>6.2f}ms " + f"peak_delta={mem_peak / 1e6:>6.1f}MB" + ) + + print() + + +if __name__ == "__main__": + main() diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 49a45cdc6..7a9feaa03 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "# This step can take ~5-10 minutes to install dependencies\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", - "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe\"" + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\"" ] }, { diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index 771a5adb2..bd92a3630 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"' ) diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 5a3a73d34..220fb4d2b 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6" ``` ## Usage diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 808aff662..758c5406c 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -35,7 +35,7 @@ LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install Axolotl's fork of cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`' ) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py index 5d47c2040..16f6da73b 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py @@ -195,6 +195,36 @@ def _estimate_smem_usage( _SMEM_SLACK = 10_000 +def _estimate_register_pressure( + num_warps: int, + *tile_sizes: tuple[int, int], +) -> float: + """Rough estimate of per-thread register footprint from live tile sizes. + + This is a heuristic, NOT an accurate register count. Triton uses tensor + core MMA fragments that pack multiple elements per register, and can spill + to local memory when the hardware limit (255 regs/thread) is exceeded. + + The estimate is used to prune only truly extreme configs that would cause + excessive spilling or compilation failures. The threshold is set high + (``_MAX_REGS_SOFT_LIMIT``) because the heuristic overestimates — it + doesn't account for MMA fragment packing. Configs like M=64,N=64,K=64 + (est ~520) work fine in practice via spilling. + + Returns estimated registers per thread. + """ + # Each thread in a warp holds ~1/32 of the tile elements + tile_regs = sum(r * c for r, c in tile_sizes) / 32 + scalar_overhead = 40 + return tile_regs + scalar_overhead + + +# Soft limit for register pressure pruning. Only prune configs with extreme +# tile products (e.g. M=128,K=256,N=256) that reliably crash on Blackwell. +# Moderate configs (M=64,N=64,K=64, est ~520) work via register spilling. +_MAX_REGS_SOFT_LIMIT = 1024 + + # ============================================================================= # Forward Kernel: scatter2scatter with fused LoRA # ============================================================================= @@ -313,12 +343,11 @@ def _compute_expert_block_lora( B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0 ) # [BLOCK_N, BLOCK_R] - # Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16) - # Both operands must match; cast to float32 (accumulator type) for precision. - b_f32 = b.to(tl.float32) + # tl.dot requires non-float32 inputs (tensor cores); cast back to input dtype + b_inp = b.to(INPUT_DTYPE) # (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N] - lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32) + lora_out = tl.dot(xa_acc.to(INPUT_DTYPE), tl.trans(b_inp), allow_tf32=allow_tf32) acc += scaling * lora_out return acc @@ -327,20 +356,21 @@ def _compute_expert_block_lora( def _scatter2scatter_lora_configs(): """Generate forward kernel autotune configs. - Search space includes smaller tile sizes and fewer pipeline stages to - support GPUs with limited shared memory (e.g. ~99KB on some GPUs). + Search space includes BLOCK_M to allow trading token-tile size for + larger BLOCK_K/BLOCK_N tiles. On GPUs with ~99KB SMEM, BLOCK_M=128 + forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128 + (4× fewer inner-loop iterations). Search space: + BLOCK_M: {32, 64, 128} BLOCK_N: {32, 64, 128, 256} BLOCK_K: {32, 64, 128} num_warps: {4, 8} num_stages: {3, 4, 5} - - BLOCK_M is fixed at 128 (module-level constant, not autotuned in the - scatter2scatter pattern). """ configs = [] - for block_n, block_k, warps, stages in product( + for block_m, block_n, block_k, warps, stages in product( + [32, 64, 128], # BLOCK_M [32, 64, 128, 256], # BLOCK_N [32, 64, 128], # BLOCK_K [4, 8], # num_warps @@ -348,7 +378,7 @@ def _scatter2scatter_lora_configs(): ): configs.append( triton.Config( - {"BLOCK_N": block_n, "BLOCK_K": block_k}, + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k}, num_stages=stages, num_warps=warps, ) @@ -357,7 +387,7 @@ def _scatter2scatter_lora_configs(): def _prune_fwd_configs(configs, named_args, **kwargs): - """Prune forward configs based on SMEM capacity. + """Prune forward configs based on SMEM capacity and register pressure. The forward kernel inner loop loads three tiles per pipeline stage: X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K]. @@ -373,23 +403,49 @@ def _prune_fwd_configs(configs, named_args, **kwargs): scored = [] for config in configs: + block_m = config.kwargs["BLOCK_M"] block_n = config.kwargs["BLOCK_N"] block_k = config.kwargs["BLOCK_K"] # Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N - smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k) + smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k) # A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop smem_lora_loop = config.num_stages * block_r * block_k * 2 # B tile [BLOCK_N, BLOCK_R] loaded once in epilogue smem_lora_epilogue = block_n * block_r * 2 smem = smem_base + smem_lora_loop + smem_lora_epilogue + + # Register pressure: live tiles are acc[M,N], xa_acc[M,R], + # x[M,K], w[K,N], a[R,K], plus epilogue b[N,R] + est_regs = _estimate_register_pressure( + config.num_warps, + (block_m, block_n), # acc + (block_m, block_r), # xa_acc + (block_m, block_k), # x tile + (block_k, block_n), # w tile + (block_r, block_k), # a tile + (block_n, block_r), # b tile (epilogue) + ) + if est_regs > _MAX_REGS_SOFT_LIMIT: + continue + scored.append((smem, config)) pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK] if pruned: return pruned - # All configs exceed SMEM — return the one with smallest estimated usage - scored.sort(key=lambda x: x[0]) - return [scored[0][1]] + if scored: + # All surviving configs exceed SMEM — return the one with smallest usage + scored.sort(key=lambda x: x[0]) + return [scored[0][1]] + # All configs pruned by register pressure — fall back to smallest tiles + return [ + min( + configs, + key=lambda c: ( + c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_N"] * c.kwargs["BLOCK_K"] + ), + ) + ] @triton.autotune( @@ -531,6 +587,89 @@ def _scatter2scatter_lora( tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :]) +def _scatter2scatter_lora_split( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + k: int, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + scaling: float, + b: Optional[torch.Tensor] = None, + x_grouped: bool = False, + y_grouped: bool = False, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel. + + Faster for models with few large experts (e.g. Mixtral E=8, I=14336) + because the base kernel runs at full speed without LoRA SMEM overhead, + and the LoRA matmuls (R=16) are tiny separate passes. + + Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T) + """ + from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import ( + scatter2scatter, + ) + + E = W.size(0) + R = lora_A.size(0) // E + K = W.size(1) + N = W.size(2) + + # 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes) + output = scatter2scatter( + X=X, + W=W, + b=b, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=k, + x_grouped=x_grouped, + y_grouped=y_grouped, + out=out, + ) + + # 2. XA = X @ A^T (tiny: output is [M*k, R]) + # Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter) + W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous() + XA = scatter2scatter( + X=X, + W=W_A, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=k, + x_grouped=x_grouped, + y_grouped=True, + ) + + # 3. Y_lora = XA @ B^T (R is tiny, so this is very fast) + # Reshape B: [N, R*E] → [E, R, N] + W_B = lora_B.T.reshape(E, R, N).contiguous() + Y_lora = scatter2scatter( + X=XA, + W=W_B, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, + x_grouped=True, + y_grouped=y_grouped, + ) + + # 4. Y = Y_base + scaling * Y_lora + output.add_(Y_lora, alpha=scaling) + return output + + +# Threshold for switching from fused to split LoRA forward. +# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile +# loads dominate in the fused kernel's inner loop). +# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE). +_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N +_SPLIT_LORA_FWD_MAX_EXPERTS = 32 + + def scatter2scatter_lora( X: torch.Tensor, W: torch.Tensor, @@ -546,7 +685,13 @@ def scatter2scatter_lora( out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e] + Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e] + + Automatically selects between: + - Fused kernel: single Triton kernel with LoRA in the inner loop. + Best for many small experts (E>=64, small K*N). + - Split dispatch: 3 separate scatter2scatter calls (base + XA + lora). + Best for few large experts (E<=32, large K*N like Mixtral). Args: X: Input [M, K] or [M*k, K] if x_grouped @@ -565,12 +710,30 @@ def scatter2scatter_lora( Returns: Y: Output [M*k, N] """ - assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) - assert sorted_scattered_idxs.size(0) == X.size(0) * k - E = W.size(0) K = W.size(1) N = W.size(2) + + # Dispatch: split for few large experts, fused for many small experts + if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD: + return _scatter2scatter_lora_split( + X, + W, + sorted_expert_idxs, + sorted_scattered_idxs, + k, + lora_A, + lora_B, + scaling, + b, + x_grouped, + y_grouped, + out, + ) + + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + R = lora_A.size(0) // E # Pad R to power of 2 for Triton tile size @@ -610,11 +773,9 @@ def scatter2scatter_lora( b_ptr, stride_be, stride_bn, - # A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride lora_A, lora_A.stride(0), lora_A.stride(1), - # B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride lora_B, lora_B.stride(0), lora_B.stride(1), @@ -625,9 +786,8 @@ def scatter2scatter_lora( K=K, N=N, E=E, - ACTUAL_R=R, # True LoRA rank for weight indexing - BLOCK_M=BLOCK_M, - BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16) + ACTUAL_R=R, + BLOCK_R=BLOCK_R, ACC_TYPE=tl.float32, scaling=scaling, allow_tf32=ALLOW_TF32, @@ -761,13 +921,13 @@ def _compute_expert_block_lora_dX( + (A_expert_offset + R_block)[:, None] * stride_ar + K_block[None, :] * stride_ak ) - a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0) - - # Cast to float32 for precision - a_f32 = a_e.to(tl.float32) + a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to( + INPUT_DTYPE + ) # (DY @ B) @ A: [M, R] @ [R, K] -> [M, K] - lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32) + # tl.dot requires non-float32 inputs (tensor cores); cast accumulator back to input dtype + lora_dx = tl.dot(dy_b_acc.to(INPUT_DTYPE), a_e, allow_tf32=allow_tf32) acc += scaling * lora_dx return acc @@ -779,17 +939,18 @@ def _scatter2scatter_lora_dX_configs(): The inner loop is over N (not K as in forward). The output dimension is K. So BLOCK_K tiles the output and BLOCK_N tiles the reduction. - Search space includes smaller tile sizes and fewer pipeline stages to - support GPUs with limited shared memory (e.g. ~99KB on some GPUs). + BLOCK_M is now autotunable (was fixed at 128). Search space: + BLOCK_M: {32, 64, 128} (token tile) BLOCK_K: {32, 64, 128, 256} (output tile) BLOCK_N: {32, 64, 128, 256} (reduction tile) num_warps: {4, 8} num_stages: {3, 4, 5} """ configs = [] - for block_k, block_n, warps, stages in product( + for block_m, block_k, block_n, warps, stages in product( + [32, 64, 128], # BLOCK_M [32, 64, 128, 256], # BLOCK_K (output dimension) [32, 64, 128, 256], # BLOCK_N (reduction dimension) [4, 8], # num_warps @@ -797,7 +958,7 @@ def _scatter2scatter_lora_dX_configs(): ): configs.append( triton.Config( - {"BLOCK_K": block_k, "BLOCK_N": block_n}, + {"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n}, num_stages=stages, num_warps=warps, ) @@ -806,7 +967,7 @@ def _scatter2scatter_lora_dX_configs(): def _prune_dX_configs(configs, named_args, **kwargs): - """Prune backward dX configs based on SMEM capacity. + """Prune backward dX configs based on SMEM capacity and register pressure. The dX kernel inner loop loads three tiles per pipeline stage: DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R]. @@ -822,23 +983,49 @@ def _prune_dX_configs(configs, named_args, **kwargs): scored = [] for config in configs: + block_m = config.kwargs["BLOCK_M"] block_k = config.kwargs["BLOCK_K"] block_n = config.kwargs["BLOCK_N"] # Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K - smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n) + smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n) # B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop smem_lora_loop = config.num_stages * block_n * block_r * 2 # A tile [BLOCK_R, BLOCK_K] loaded once in epilogue smem_lora_epilogue = block_r * block_k * 2 smem = smem_base + smem_lora_loop + smem_lora_epilogue + + # Register pressure: live tiles are acc[M,K], dy_b_acc[M,R], + # dy[M,N], wt[N,K], b[N,R], plus epilogue a[R,K] + est_regs = _estimate_register_pressure( + config.num_warps, + (block_m, block_k), # acc + (block_m, block_r), # dy_b_acc + (block_m, block_n), # dy tile + (block_n, block_k), # wt tile + (block_n, block_r), # b tile + (block_r, block_k), # a tile (epilogue) + ) + if est_regs > _MAX_REGS_SOFT_LIMIT: + continue + scored.append((smem, config)) pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK] if pruned: return pruned - # All configs exceed SMEM — return the one with smallest estimated usage - scored.sort(key=lambda x: x[0]) - return [scored[0][1]] + if scored: + # All surviving configs exceed SMEM — return the one with smallest usage + scored.sort(key=lambda x: x[0]) + return [scored[0][1]] + # All configs pruned by register pressure — fall back to smallest tiles + return [ + min( + configs, + key=lambda c: ( + c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"] + ), + ) + ] @triton.autotune( @@ -1067,7 +1254,7 @@ def scatter2scatter_lora_dX( N=N, E=E, ACTUAL_R=R, - BLOCK_M=BLOCK_M, + # BLOCK_M is autotuned (injected by triton.autotune from Config kwargs) BLOCK_R=BLOCK_R, ACC_TYPE=tl.float32, scaling=scaling, @@ -1119,7 +1306,7 @@ def _group_bwd_lora_configs(): def _prune_bwd_lora_configs(configs, named_args, **kwargs): - """Prune backward configs based on SMEM capacity. + """Prune backward configs based on SMEM capacity and register pressure. The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N] in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] @@ -1138,14 +1325,40 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs): # A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert smem_lora = (block_r * block_k + block_n * block_r) * 2 smem = smem_base + smem_lora + + # Register pressure: dA_acc[R,K], dB_acc[N,R], x[M,K], dy[M,N], + # a[R,K], b[N,R], xa[M,R], dy_b[M,R] + est_regs = _estimate_register_pressure( + config.num_warps, + (block_r, block_k), # dA_acc + (block_n, block_r), # dB_acc + (block_m, block_k), # x tile + (block_m, block_n), # dy tile + (block_r, block_k), # a tile + (block_n, block_r), # b tile + (block_m, block_r), # xa intermediate + ) + if est_regs > _MAX_REGS_SOFT_LIMIT: + continue + scored.append((smem, config)) pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK] if pruned: return pruned - # All configs exceed SMEM — return the one with smallest estimated usage - scored.sort(key=lambda x: x[0]) - return [scored[0][1]] + if scored: + # All surviving configs exceed SMEM — return the one with smallest usage + scored.sort(key=lambda x: x[0]) + return [scored[0][1]] + # All configs pruned by register pressure — fall back to smallest tiles + return [ + min( + configs, + key=lambda c: ( + c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"] + ), + ) + ] @triton.autotune( @@ -1330,6 +1543,279 @@ def _group_bwd_lora( ) +def _group_bwd_split_configs(): + """Autotune configs for split dA/dB kernels.""" + configs = [] + for block_m, block_dim, warps, stages in product( + [32, 64, 128], # BLOCK_M (token tile) + [32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile) + [4, 8], # num_warps + [3, 4, 5], # num_stages + ): + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_DIM": block_dim}, + num_stages=stages, + num_warps=warps, + ) + ) + return configs + + +def _prune_split_configs(configs, named_args, **kwargs): + """Prune split kernel configs based on SMEM capacity and register pressure.""" + smem_cap = _get_smem_capacity() + block_r = named_args.get("BLOCK_R", 64) + + # Fixed inner tile for reduction dimension + BLOCK_INNER = 64 + + pruned = [] + for config in configs: + block_m = config.kwargs["BLOCK_M"] + block_dim = config.kwargs["BLOCK_DIM"] + # Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM] + smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2 + # LoRA weights held in registers: [INNER, R] or [R, DIM] + smem += (block_r * max(block_dim, BLOCK_INNER)) * 2 + + # Register pressure check + est_regs = _estimate_register_pressure( + config.num_warps, + (block_r, block_dim), # acc + (block_m, BLOCK_INNER), # input tile + (block_m, block_dim), # other tile + (block_r, BLOCK_INNER), # lora weight + ) + if est_regs > _MAX_REGS_SOFT_LIMIT: + continue + + if smem <= smem_cap - _SMEM_SLACK: + pruned.append(config) + + if pruned: + return pruned + configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"]) + return [configs[0]] + + +@triton.autotune( + configs=_group_bwd_split_configs(), + key=["M", "K", "N"], + prune_configs_by={"early_config_prune": _prune_split_configs}, +) +@triton.heuristics( + { + "NO_DIM_MASK": lambda args: ( + (args["K"] % args["BLOCK_DIM"]) == 0 + if args["COMPUTE_DA"] + else (args["N"] % args["BLOCK_DIM"]) == 0 + ), + } +) +@triton.jit +def _group_bwd_lora_split( + # Data tensors (DY and X are always present) + DY_ptr, + stride_dym, + stride_dyn, + X_ptr, + stride_xm, + stride_xk, + # LoRA weight for the inner reduction (B for dA, A for dB) + LW_ptr, + stride_lw0, + stride_lw1, + # Output gradient tensor (dA or dB) + OUT_ptr, + stride_out0, + stride_out1, + # Expert offsets + expert_offsets_ptr, + # Dimensions + M, + K: tl.constexpr, + N: tl.constexpr, + ACTUAL_R: tl.constexpr, + BLOCK_R: tl.constexpr, + INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB) + scaling, + # Mode flag + COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB + # Tile sizes + BLOCK_M: tl.constexpr, + BLOCK_DIM: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + NO_DIM_MASK: tl.constexpr, +): + """ + Unified split kernel for LoRA gradient computation. + + When COMPUTE_DA=True: + dA[e] = scaling * (dY @ B[e])^T @ X → [R, K] + Grid: (E, cdiv(K, BLOCK_DIM)) + - outer_ptr/stride = X (read [M, K_block]) + - inner reduction over N using DY and B + - output shape [BLOCK_R, BLOCK_DIM] + + When COMPUTE_DA=False: + dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R] + Grid: (E, cdiv(N, BLOCK_DIM)) + - outer_ptr/stride = DY (read [M, N_block]) + - inner reduction over K using X and A + - output shape [BLOCK_DIM, BLOCK_R] + + No atomic adds — each (E, dim_block) pair is written by exactly one block. + """ + E_idx = tl.program_id(0) + dim_block_id = tl.program_id(1) + + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + num_tokens = end_idx - start_idx + + # Output dimension tile (K for dA, N for dB) + if COMPUTE_DA: + OUT_DIM: tl.constexpr = K # type: ignore[no-redef] + else: + OUT_DIM: tl.constexpr = N # type: ignore[no-redef] + dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM) + dim_mask = dim_block < OUT_DIM + R_block = tl.arange(0, BLOCK_R) + R_mask = R_block < ACTUAL_R + lora_offset = E_idx * ACTUAL_R + + # Output pointers — layout differs: dA is [R, K], dB is [N, R] + if COMPUTE_DA: + out_blk_ptrs = ( + OUT_ptr + + (lora_offset + R_block)[:, None] * stride_out0 + + dim_block[None, :] * stride_out1 + ) + out_mask = R_mask[:, None] & dim_mask[None, :] + else: + out_blk_ptrs = ( + OUT_ptr + + dim_block[:, None] * stride_out0 + + (lora_offset + R_block)[None, :] * stride_out1 + ) + out_mask = dim_mask[:, None] & R_mask[None, :] + + if num_tokens > 0: + M_block = tl.arange(0, BLOCK_M) + INPUT_DTYPE = X_ptr.dtype.element_ty + BLOCK_INNER: tl.constexpr = 64 + inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER) + + if COMPUTE_DA: + acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE) + else: + acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE) + + M_iters = tl.cdiv(num_tokens, BLOCK_M) + for i in range(M_iters): + M_idx = start_idx + i * BLOCK_M + M_block + M_mask = M_idx < end_idx + + if COMPUTE_DA: + # Load X[M, K_block] (the "outer" tensor for dA) + outer = tl.load( + X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk, + mask=M_mask[:, None] & dim_mask[None, :], + other=0.0, + ).to(INPUT_DTYPE) + + # Reduce DY[M, :] @ B[e][:, R] over N → [M, R] + reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE) + inner_range = tl.arange(0, BLOCK_INNER) + for j in range(inner_iters): + inn_off = j * BLOCK_INNER + inner_range + inn_mask = inn_off < N + + dy_tile = tl.load( + DY_ptr + + M_idx[:, None] * stride_dym + + inn_off[None, :] * stride_dyn, + mask=M_mask[:, None] & inn_mask[None, :], + other=0.0, + ).to(INPUT_DTYPE) + # B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride + lw_tile = tl.load( + LW_ptr + + inn_off[:, None] * stride_lw0 + + (lora_offset + R_block)[None, :] * stride_lw1, + mask=inn_mask[:, None] & R_mask[None, :], + other=0.0, + ).to(INPUT_DTYPE) + reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32) + + # dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block] + acc += tl.dot( + tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32 + ) + else: + # Load DY[M, N_block] (the "outer" tensor for dB) + outer = tl.load( + DY_ptr + + M_idx[:, None] * stride_dym + + dim_block[None, :] * stride_dyn, + mask=M_mask[:, None] & dim_mask[None, :], + other=0.0, + ).to(INPUT_DTYPE) + + # Reduce X[M, :] @ A[e][:, :].T over K → [M, R] + reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE) + inner_range = tl.arange(0, BLOCK_INNER) + for j in range(inner_iters): + inn_off = j * BLOCK_INNER + inner_range + inn_mask = inn_off < K + + x_tile = tl.load( + X_ptr + + M_idx[:, None] * stride_xm + + inn_off[None, :] * stride_xk, + mask=M_mask[:, None] & inn_mask[None, :], + other=0.0, + ).to(INPUT_DTYPE) + # A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride + # We want A[e]^T: [K, R], so load as [K_inner, R] + lw_tile = tl.load( + LW_ptr + + (lora_offset + R_block)[None, :] * stride_lw0 + + inn_off[:, None] * stride_lw1, + mask=inn_mask[:, None] & R_mask[None, :], + other=0.0, + ).to(INPUT_DTYPE) + reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32) + + # dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R] + acc += tl.dot( + tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32 + ) + + tl.store( + out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask + ) + else: + # Zero out this expert's slice — needed because output uses empty_like + if COMPUTE_DA: + tl.store( + out_blk_ptrs, + tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty), + mask=out_mask, + ) + else: + tl.store( + out_blk_ptrs, + tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty), + mask=out_mask, + ) + + def group_bwd_lora( DY: torch.Tensor, X: torch.Tensor, @@ -1344,6 +1830,9 @@ def group_bwd_lora( """ Compute LoRA gradients for A and B on expert-grouped data. + Uses split dA/dB kernels that eliminate atomic adds by giving each + (expert, output_block) pair its own thread block. + Args: DY: Gradient w.r.t. output [M_total, N] (grouped by expert) X: Input [M_total, K] (grouped by expert) @@ -1361,19 +1850,46 @@ def group_bwd_lora( K = X.size(1) N = DY.size(1) - # Zero-init for atomic accumulation - dA = torch.zeros_like(lora_A) - dB = torch.zeros_like(lora_B) + # No zero-init needed: the split kernels write zeros for experts with + # zero routed tokens directly in the kernel (else branch). + dA = torch.empty_like(lora_A) + dB = torch.empty_like(lora_B) BLOCK_R = _block_r_for_rank(R) - def grid(META): - return ( - E * triton.cdiv(K, META["BLOCK_K"]), - triton.cdiv(N, META["BLOCK_N"]), - ) + def grid_dA(META): + return (E, triton.cdiv(K, META["BLOCK_DIM"])) - _group_bwd_lora[grid]( + _group_bwd_lora_split[grid_dA]( + DY, + DY.stride(0), + DY.stride(1), + X, + X.stride(0), + X.stride(1), + lora_B, + lora_B.stride(0), + lora_B.stride(1), + dA, + dA.stride(0), + dA.stride(1), + expert_offsets, + M=DY.size(0), + K=K, + N=N, + ACTUAL_R=R, + BLOCK_R=BLOCK_R, + INNER_DIM=N, + scaling=scaling, + COMPUTE_DA=True, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32, + ) + + def grid_dB(META): + return (E, triton.cdiv(N, META["BLOCK_DIM"])) + + _group_bwd_lora_split[grid_dB]( DY, DY.stride(0), DY.stride(1), @@ -1383,12 +1899,6 @@ def group_bwd_lora( lora_A, lora_A.stride(0), lora_A.stride(1), - lora_B, - lora_B.stride(0), - lora_B.stride(1), - dA, - dA.stride(0), - dA.stride(1), dB, dB.stride(0), dB.stride(1), @@ -1396,9 +1906,11 @@ def group_bwd_lora( M=DY.size(0), K=K, N=N, - ACTUAL_R=R, # True LoRA rank - BLOCK_R=BLOCK_R, # Padded tile size + ACTUAL_R=R, + BLOCK_R=BLOCK_R, + INNER_DIM=K, scaling=scaling, + COMPUTE_DA=False, ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32, ) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index 5125e8801..c6c01e255 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -489,20 +489,71 @@ class HFScatterMoEGatedMLP(nn.Module): # ==================================================================== experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts) + # ==================================================================== + # Selective expert weight dequantization + # ==================================================================== + # When experts are BnB-quantized (quantize_moe_experts), dequantize + # only the active experts instead of all E. This saves ~97% memory + # for the transient dequant buffer when few experts are active. + use_selective = ( + getattr(self, "_use_selective_dequant", False) + and hasattr(experts, "parametrizations") + and "gate_up_proj" in experts.parametrizations + ) + + if use_selective: + from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import ( + get_active_experts, + remap_expert_indices, + selective_expert_weights, + selective_lora_weights, + ) + + active_experts = get_active_experts(sorted_expert_idxs, num_experts) + remapped_expert_idxs, compact_offsets = remap_expert_indices( + sorted_expert_idxs, + expert_offsets, + active_experts, + num_experts, + ) + # Dequantize only active experts' weights + gate_up_W = selective_expert_weights( + experts, + "gate_up_proj", + active_experts, + ).transpose(2, 1) # [num_active, hidden, 2*inter] + + # Remap LoRA weights to match compact expert indices + if gup_lora is not None: + gup_A, gup_B, gup_scaling = gup_lora + gup_A, gup_B = selective_lora_weights( + gup_A, + gup_B, + active_experts, + num_experts, + ) + gup_lora = (gup_A, gup_B, gup_scaling) + + # Use remapped indices for ScatterMoE kernels + sei_gup = remapped_expert_idxs + eo_gup = compact_offsets + else: + gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter] + sei_gup = sorted_expert_idxs + eo_gup = expert_offsets + # ==================================================================== # Gate + Up projection # ==================================================================== - gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter] - if gup_lora is not None: gup_A, gup_B, gup_scaling = gup_lora gup = parallel_linear_lora( hidden_states_flat, gate_up_W, top_k, - sorted_expert_idxs, + sei_gup, sorted_scattered_idxs, - expert_offsets, + eo_gup, lora_A=gup_A, lora_B=gup_B, scaling=gup_scaling, @@ -516,9 +567,9 @@ class HFScatterMoEGatedMLP(nn.Module): hidden_states_flat, gate_up_W, top_k, - sorted_expert_idxs, + sei_gup, sorted_scattered_idxs, - expert_offsets, + eo_gup, grouped_in=False, grouped_out=True, ) @@ -529,7 +580,29 @@ class HFScatterMoEGatedMLP(nn.Module): # ==================================================================== # Down projection # ==================================================================== - down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden] + if use_selective: + down_W = selective_expert_weights( + experts, + "down_proj", + active_experts, + ).transpose(2, 1) # [num_active, inter, hidden] + + if down_lora is not None: + down_A, down_B, down_scaling = down_lora + down_A, down_B = selective_lora_weights( + down_A, + down_B, + active_experts, + num_experts, + ) + down_lora = (down_A, down_B, down_scaling) + + sei_down = remapped_expert_idxs + eo_down = compact_offsets + else: + down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden] + sei_down = sorted_expert_idxs + eo_down = expert_offsets if down_lora is not None: down_A, down_B, down_scaling = down_lora @@ -537,9 +610,9 @@ class HFScatterMoEGatedMLP(nn.Module): h, down_W, 1, - sorted_expert_idxs, + sei_down, sorted_scattered_idxs, - expert_offsets, + eo_down, lora_A=down_A, lora_B=down_B, scaling=down_scaling, @@ -554,9 +627,9 @@ class HFScatterMoEGatedMLP(nn.Module): h, down_W, 1, - sorted_expert_idxs, + sei_down, sorted_scattered_idxs, - expert_offsets, + eo_down, grouped_in=True, grouped_out=False, gates=routing_weights, diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py new file mode 100644 index 000000000..1df8b2f68 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py @@ -0,0 +1,282 @@ +""" +Selective Expert Dequantization +=============================== + +Instead of dequantizing all E expert weight matrices at once (which creates +a ~1 GB transient buffer for 256 experts), only dequantize the experts that +are actually routed to by the current batch's top-k selection. + +For Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512): + - Full dequant: [256, 2048, 1024] = 1,074 MB per projection + - Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection + - Savings: ~97% memory reduction per layer + +This module provides format-agnostic selective weight extraction: + - BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert + - bf16/fp32: direct indexing (no dequant needed) + - FP8: slice + cast + +The ScatterMoE kernel itself doesn't change — we remap expert indices +from global (0..E-1) to compact (0..num_active-1) and pass the smaller +weight tensor. +""" + +import torch +import torch.nn as nn + + +def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor: + """Get sorted unique expert indices from the routing output. + + Args: + sorted_expert_idxs: Expert assignments sorted by expert id [T*k] + E: Total number of experts + + Returns: + active: Sorted unique expert indices [num_active] + """ + return torch.unique(sorted_expert_idxs) + + +def remap_expert_indices( + sorted_expert_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + active_experts: torch.Tensor, + E: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Remap global expert indices to compact indices. + + Maps expert ids from [0..E-1] to [0..num_active-1], preserving the + sort order. Also compacts expert_offsets to only active experts. + + Args: + sorted_expert_idxs: [T*k] expert ids in sorted order + expert_offsets: [E] cumulative token counts (original) + active_experts: [num_active] sorted unique expert ids + E: Total number of experts + + Returns: + remapped_idxs: [T*k] expert ids in [0..num_active-1] + compact_offsets: [num_active] cumulative token counts + """ + # Build remap table: global_id -> compact_id + remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device) + remap[active_experts] = torch.arange( + len(active_experts), device=sorted_expert_idxs.device + ) + + remapped_idxs = remap[sorted_expert_idxs] + + # Compact the expert_offsets: only keep active experts' cumulative counts + compact_offsets = expert_offsets[active_experts] + + return remapped_idxs, compact_offsets + + +def _selective_dequant_bnb4( + raw_param: torch.Tensor, + quant_state, + active_experts: torch.Tensor, + expert_shape: tuple[int, int], +) -> torch.Tensor: + """Dequantize only selected experts from BnB 4-bit packed data. + + The raw parameter is a flattened 4-bit packed tensor. Each expert's + data is contiguous (stored in expert-major order), so we can gather + the packed data and absmax blocks for active experts, then dequantize + as one contiguous block. + + Args: + raw_param: Flattened uint8 tensor of packed 4-bit weights + quant_state: BnB QuantState with absmax, blocksize, code, etc. + active_experts: [num_active] expert indices to dequantize + expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048)) + + Returns: + Dequantized weights [num_active, dim1, dim2] in original dtype + """ + import bitsandbytes.functional as F # noqa: N812 + from bitsandbytes.functional import QuantState + + expert_numel = expert_shape[0] * expert_shape[1] + packed_per_expert = expert_numel // 2 # 4-bit = 2 values per byte + blocks_per_expert = expert_numel // quant_state.blocksize + num_active = len(active_experts) + + if blocks_per_expert == 0: + # Expert is smaller than one quantization block — blocks span across + # expert boundaries, so per-expert slicing isn't possible. + # Fallback: full dequantize + index. + full = F.dequantize_4bit(raw_param, quant_state) + E_total = full.numel() // expert_numel + return full.reshape(E_total, *expert_shape)[active_experts] + + # Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass) + if quant_state.quant_type == "nf4" and raw_param.dtype == torch.uint8: + from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import ( + selective_dequant_nf4_triton, + ) + + # Handle nested (double) quantization: dequantize absmax first + # BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset + if quant_state.nested: + absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + else: + absmax = quant_state.absmax + + return selective_dequant_nf4_triton( + packed_data=raw_param, + absmax=absmax, + active_experts=active_experts, + expert_shape=expert_shape, + blocksize=quant_state.blocksize, + dtype=quant_state.dtype, + codebook=quant_state.code, + ) + + # Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats) + raw_flat = raw_param.reshape(-1) + + offsets_qt = ( + active_experts.long()[:, None] * packed_per_expert + + torch.arange(packed_per_expert, device=raw_param.device)[None, :] + ).reshape(-1) + qt_gathered = raw_flat[offsets_qt] + + offsets_abs = ( + active_experts.long()[:, None] * blocks_per_expert + + torch.arange(blocks_per_expert, device=raw_param.device)[None, :] + ).reshape(-1) + + if quant_state.nested: + full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2) + full_absmax += quant_state.offset + if full_absmax.dtype != torch.float32: + full_absmax = full_absmax.float() + absmax_gathered = full_absmax[offsets_abs] + else: + absmax_gathered = quant_state.absmax[offsets_abs] + + qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered + + gathered_qs = QuantState( + absmax=absmax_gathered, + shape=torch.Size([num_active * expert_numel]), + blocksize=quant_state.blocksize, + quant_type=quant_state.quant_type, + code=quant_state.code, + dtype=quant_state.dtype, + ) + + deq = F.dequantize_4bit(qt_gathered, gathered_qs) + return deq.reshape(num_active, *expert_shape) + + +def _selective_index_dense( + param: torch.Tensor, + active_experts: torch.Tensor, +) -> torch.Tensor: + """Select experts from a dense (bf16/fp32) weight tensor. + + Simple indexing — no dequantization needed. + """ + return param[active_experts] + + +def selective_expert_weights( + experts_module: nn.Module, + param_name: str, + active_experts: torch.Tensor, +) -> torch.Tensor: + """Extract and dequantize only the active experts' weights. + + Format-agnostic: dispatches based on whether the parameter is + BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32. + + Args: + experts_module: The base experts module (e.g. Qwen3_5MoeExperts) + param_name: "gate_up_proj" or "down_proj" + active_experts: [num_active] sorted unique expert indices + + Returns: + Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE + """ + # Check if the parameter is BnB-quantized via parametrize + if ( + hasattr(experts_module, "parametrizations") + and param_name in experts_module.parametrizations + ): + param_list = experts_module.parametrizations[param_name] + parametrization = param_list[0] + + # BnB 4-bit parametrization + if hasattr(parametrization, "quant_state"): + # The raw quantized data is on the ParametrizationList, not the + # individual Bnb4bitParametrization module + raw_param = param_list.original + qs = parametrization.quant_state + # qs.shape is the original tensor shape before flattening. + # For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D). + orig_shape = qs.shape + if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3: + expert_shape = (orig_shape[1], orig_shape[2]) + elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1: + # Flattened — need to infer from module attributes + E_total = getattr(experts_module, "num_experts", None) + if E_total is None: + E_total = int(active_experts.max().item()) + 1 + expert_numel = orig_shape[0] // E_total + d2 = getattr(experts_module, "hidden_dim", None) or getattr( + experts_module, "intermediate_dim", None + ) + if d2 and expert_numel % d2 == 0: + expert_shape = (expert_numel // d2, d2) + else: + full = getattr(experts_module, param_name) + return full[active_experts] + else: + full = getattr(experts_module, param_name) + return full[active_experts] + + return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape) + + # Dense parameter (bf16/fp32) — direct indexing + param = getattr(experts_module, param_name) + if param.dim() == 3: + return param[active_experts] + + # Fallback: full access + return param + + +def selective_lora_weights( + lora_A: torch.Tensor, + lora_B: torch.Tensor, + active_experts: torch.Tensor, + E: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Select LoRA A and B weights for only the active experts. + + LoRA layout (scattermoe format): + A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r] + B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r] + + Returns compact: + A: [r*num_active, K] + B: [N, r*num_active] + """ + R = lora_A.size(0) // E + + # Vectorized gather: active_experts[:, None] * R + arange(R)[None, :] + row_idx = ( + active_experts.long()[:, None] * R + + torch.arange(R, device=lora_A.device)[None, :] + ).reshape(-1) + + compact_A = lora_A[row_idx] # [r*num_active, K] + compact_B = lora_B[:, row_idx] # [N, r*num_active] + + return compact_A, compact_B diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant_kernel.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant_kernel.py new file mode 100644 index 000000000..aa9f0278a --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant_kernel.py @@ -0,0 +1,179 @@ +""" +Triton kernel for fused selective expert gather + NF4 dequantization. + +Instead of: + 1. Gather packed uint8 data for active experts (memory copy) + 2. Gather absmax for active experts (memory copy) + 3. Call BnB dequantize_4bit CUDA kernel + +This kernel does all three in one pass: + - Reads packed NF4 bytes from expert-strided positions + - Looks up the NF4 codebook + - Multiplies by the per-block absmax + - Writes bf16 output directly + +This eliminates the intermediate gather buffer entirely. +""" + +import torch +import triton +import triton.language as tl + +# NF4 codebook (16 values, precomputed by BnB) +# These are the normalized float4 reconstruction values +NF4_CODEBOOK = [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, +] + + +@triton.jit +def _selective_dequant_nf4_kernel( + # Input: packed NF4 data (flattened, expert-major order) + packed_ptr, + # Input: absmax values (flattened, expert-major order) + absmax_ptr, + # Input: active expert indices + active_experts_ptr, + # Input: NF4 codebook (16 float values) + codebook_ptr, + # Output: dequantized bf16 weights [num_active, expert_numel] + out_ptr, + stride_out_e, # stride for expert dim in output + # Dimensions + num_active, + packed_per_expert, # expert_numel // 2 + blocks_per_expert, # expert_numel // blocksize + blocksize: tl.constexpr, + # Tile size + BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2) +): + """ + Each program processes BLOCK_SIZE elements from one expert. + + Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE)) + + For each output element: + 1. Compute which byte in packed data contains this element + 2. Extract the 4-bit nibble (high or low) + 3. Look up in NF4 codebook + 4. Scale by absmax for this block + """ + expert_local_idx = tl.program_id(0) # which active expert (0..num_active-1) + block_id = tl.program_id(1) # which element block + + # Load the global expert index + expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64) + + expert_numel = packed_per_expert * 2 # 2 elements per packed byte + elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = elem_offset < expert_numel + + # Each element is packed as: byte[i//2], low nibble for even i, high for odd i + byte_idx = elem_offset // 2 + is_high = (elem_offset % 2) == 1 + + # Read packed bytes from the global expert's region + packed_global_offset = expert_global * packed_per_expert + byte_idx + packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to( + tl.int32 + ) + + # Extract 4-bit nibble + # BnB packing: high nibble = even element, low nibble = odd element + nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF) + + # NF4 codebook lookup + # Load all 16 codebook values (small, fits in registers) + # Use gather from codebook pointer + code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0) + + # Load absmax for this element's quantization block + block_idx = elem_offset // blocksize + absmax_global_offset = expert_global * blocks_per_expert + block_idx + absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0) + + # Dequantize: value = codebook[nibble] * absmax + result = code_val * absmax_val + + # Store to output + out_offset = expert_local_idx * stride_out_e + elem_offset + tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask) + + +def selective_dequant_nf4_triton( + packed_data: torch.Tensor, + absmax: torch.Tensor, + active_experts: torch.Tensor, + expert_shape: tuple[int, int], + blocksize: int, + dtype: torch.dtype = torch.bfloat16, + codebook: torch.Tensor | None = None, +) -> torch.Tensor: + """Fused selective gather + NF4 dequantization via Triton kernel. + + Args: + packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1] + absmax: Per-block scaling factors [total_blocks] + active_experts: Sorted indices of experts to dequantize [num_active] + expert_shape: (dim1, dim2) per expert + blocksize: Quantization block size + dtype: Output dtype (default bf16) + codebook: NF4 lookup table [16] (uses default NF4 codebook if None) + + Returns: + Dequantized weights [num_active, dim1, dim2] + """ + num_active = active_experts.shape[0] + expert_numel = expert_shape[0] * expert_shape[1] + packed_per_expert = expert_numel // 2 + blocks_per_expert = expert_numel // blocksize + + # Prepare codebook on device + if codebook is None: + codebook = torch.tensor( + NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device + ) + else: + codebook = codebook.to(device=packed_data.device, dtype=torch.float32) + + # Flatten inputs + packed_flat = packed_data.reshape(-1) + absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32 + + # Output buffer + out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device) + + BLOCK_SIZE = 1024 # Process 1024 elements per thread block + + grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE)) + + _selective_dequant_nf4_kernel[grid]( + packed_flat, + absmax_flat, + active_experts, + codebook, + out, + out.stride(0), + num_active=num_active, + packed_per_expert=packed_per_expert, + blocks_per_expert=blocks_per_expert, + blocksize=blocksize, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return out.reshape(num_active, *expert_shape) diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index 351db5ef2..939bdb790 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -61,7 +61,16 @@ class KernelsPlugin(BasePlugin): return "axolotl.integrations.kernels.KernelsArgs" def pre_model_load(self, cfg): + from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK + + # Prefer text backbone type for VLMs, but fall back to base type + # when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text) moe_model_type = cfg.model_config_type_text or cfg.model_config_type + if ( + moe_model_type not in SPARSE_MOE_BLOCK + and cfg.model_config_type in SPARSE_MOE_BLOCK + ): + moe_model_type = cfg.model_config_type if cfg.use_scattermoe: self._register_kernels() diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py index 28ef75acc..9dc66a918 100644 --- a/src/axolotl/kernels/lora.py +++ b/src/axolotl/kernels/lora.py @@ -640,7 +640,9 @@ class LoRA_QKV(torch.autograd.Function): del q_weight del q_weight_t if A_q is not None and B_q is not None: - grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled)) + # Stay decomposed: dQ @ B^T gives [T, R], then [T, R] @ (s*A) gives [T, in] + # This is 65x fewer FLOPs than materializing B@A into [out, in] + grad_X.addmm_(torch.mm(q_grad, B_q_scaled), A_q_scaled) # K path k_weight_t = dequantize(k_weight, k_quant) @@ -648,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function): del k_weight del k_weight_t if A_k is not None and B_k is not None: - grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled)) + grad_X.addmm_(torch.mm(k_grad, B_k_scaled), A_k_scaled) # V path v_weight_t = dequantize(v_weight, v_quant) @@ -656,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function): del v_weight del v_weight_t if A_v is not None and B_v is not None: - grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled)) + grad_X.addmm_(torch.mm(v_grad, B_v_scaled), A_v_scaled) # Transpose gradients if needed if d_A_q is not None: @@ -819,7 +821,8 @@ class LoRA_O(torch.autograd.Function): del W A, B = A.to(dtype), B.to(dtype) - dX += s * dY @ B @ A + # Stay decomposed: dY @ B gives [T, R], then [T, R] @ A gives [T, in] + dX.addmm_(torch.mm(dY, B), A, alpha=s) # W, b, W_quant, A, B, s return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 37c112337..dd3f4ddfa 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -505,6 +505,20 @@ class ModelLoader: elif not is_ds_zero3: self.model_kwargs["device_map"] = device_map + # quantize_moe_experts quantizes expert weights on-the-fly during loading, + # so the actual VRAM usage is much less than bf16 estimates. + # When device_map is "auto", accelerate's infer_auto_device_map computes + # the device map at bf16 size (before quantization), causing it to offload + # layers to CPU, which BnB then rejects. Force single-GPU placement to + # prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single). + if getattr(self.cfg, "quantize_moe_experts", False) and device_map in ( + "auto", + None, + ): + self.model_kwargs["device_map"] = { + "": int(os.environ.get("LOCAL_RANK", 0)) + } + cur_device = get_device_type() if "mps" in str(cur_device): self.model_kwargs["device_map"] = "mps:0" diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 2972c6285..44be5267d 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -51,6 +51,29 @@ QKV_PATCHES = [ value_states = value_states.view(hidden_shape).transpose(1, 2) """.lstrip("\n"), ), + ( + """ + query_states, gate = torch.chunk( + self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 + ) + gate = gate.reshape(*input_shape, -1) + + query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) +""".lstrip("\n"), + """ + query_states, key_states, value_states = self.apply_qkv(hidden_states) + query_states, gate = torch.chunk( + query_states.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1 + ) + gate = gate.reshape(*input_shape, -1) + + query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) +""".lstrip("\n"), + ), ] ORIGINAL_O_CODE = """ @@ -299,6 +322,8 @@ def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]: if hasattr(pretrained_model, "language_model"): return pretrained_model.language_model.layers if hasattr(pretrained_model, "model"): + if hasattr(pretrained_model.model, "language_model"): + return pretrained_model.model.language_model.layers return pretrained_model.model.layers raise NotImplementedError( diff --git a/src/axolotl/utils/callbacks/profiler.py b/src/axolotl/utils/callbacks/profiler.py index 2cf5e0f4f..49e8c5388 100644 --- a/src/axolotl/utils/callbacks/profiler.py +++ b/src/axolotl/utils/callbacks/profiler.py @@ -17,6 +17,8 @@ from transformers import ( class PytorchProfilerCallback(TrainerCallback): """ PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps. + + Also runs torch.profiler to produce a Chrome trace for timing analysis. """ def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0): @@ -26,9 +28,10 @@ class PytorchProfilerCallback(TrainerCallback): if profiler_steps_start == 0: # start recording memory allocations before everything is allocated, because if we start # at the beginning of step 0, we won't have any memory allocations in the traces - torch.cuda.memory._record_memory_history(enabled="all") + torch.cuda.memory._record_memory_history(enabled="all", stacks="all") profiler_steps_start = -1 self.profiler_steps_start = profiler_steps_start + self._profiler = None def on_step_begin( self, @@ -38,7 +41,21 @@ class PytorchProfilerCallback(TrainerCallback): **kwargs, ): if state.global_step == self.profiler_steps_start: - torch.cuda.memory._record_memory_history(enabled="all") + torch.cuda.memory._record_memory_history(enabled="all", stacks="all") + + # Start torch.profiler on the first profiled step + if state.global_step == max(self.profiler_steps_start, 0): + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + profile_memory=True, + with_stack=True, + ) + profiler.__enter__() + self._profiler = profiler def on_step_end( self, @@ -55,6 +72,13 @@ class PytorchProfilerCallback(TrainerCallback): # tell CUDA to stop recording memory allocations now torch.cuda.memory._record_memory_history(enabled=None) + # Stop and export torch.profiler trace + if self._profiler is not None: + self._profiler.__exit__(None, None, None) + trace_path = Path(args.output_dir) / "profiler_trace.json" + self._profiler.export_chrome_trace(str(trace_path)) + self._profiler = None + def on_train_end( self, args: TrainingArguments, @@ -73,3 +97,9 @@ class PytorchProfilerCallback(TrainerCallback): # tell CUDA to stop recording memory allocations now torch.cuda.memory._record_memory_history(enabled=None) + + if self._profiler is not None: + self._profiler.__exit__(None, None, None) + trace_path = Path(args.output_dir) / "profiler_trace.json" + self._profiler.export_chrome_trace(str(trace_path)) + self._profiler = None diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index 1833c750b..5033cabc9 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -4,8 +4,7 @@ E2E tests for lora llama import unittest -import pytest -from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available +from transformers.utils import is_torch_bf16_gpu_available from axolotl.common.datasets import load_datasets from axolotl.train import train @@ -68,51 +67,3 @@ class TestLoraLlama(unittest.TestCase): train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) - - @pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available") - @with_temp_dir - def test_lora_gptq_packed(self, temp_dir): - cfg = DictDefault( - { - "base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ", - "model_type": "AutoModelForCausalLM", - "tokenizer_type": "AutoTokenizer", - "sequence_len": 1024, - "sample_packing": True, - "flash_attention": True, - "load_in_8bit": True, - "adapter": "lora", - "gptq": True, - "gptq_disable_exllama": True, - "lora_r": 32, - "lora_alpha": 64, - "lora_dropout": 0.05, - "lora_target_linear": True, - "val_set_size": 0.02, - "special_tokens": { - "pad_token": "<|endoftext|>", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - ], - "num_epochs": 2, - "max_steps": 20, - "save_steps": 0.5, - "micro_batch_size": 8, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch_fused", - "lr_scheduler": "cosine", - "save_first_step": False, - } - ) - cfg = validate_config(cfg) - normalize_config(cfg) - dataset_meta = load_datasets(cfg=cfg) - - train(cfg=cfg, dataset_meta=dataset_meta) - check_model_output_exists(temp_dir, cfg) diff --git a/tests/integrations/test_scattermoe_lora_kernels.py b/tests/integrations/test_scattermoe_lora_kernels.py new file mode 100644 index 000000000..fc783fa1d --- /dev/null +++ b/tests/integrations/test_scattermoe_lora_kernels.py @@ -0,0 +1,407 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +Unit tests for ScatterMoE LoRA Triton kernels. + +Tests correctness of: + - scatter2scatter_lora (forward) + - scatter2scatter_lora_dX (backward input gradient) + - group_bwd_lora (backward LoRA weight gradients via split dA/dB) + - ScatterMoELoRA autograd function (full forward + backward) + +Each kernel is tested against a pure PyTorch per-expert-loop reference +implementation at multiple model shapes and LoRA ranks. +""" + +import pytest +import torch + +from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import ( + lora_ops, + ops as base_ops, +) +from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import ( + flatten_sort_count, +) +from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import ( + ScatterMoELoRA, +) + +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +def _requires_cuda(): + return pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ) + + +pytestmark = _requires_cuda() + + +# ─── Helpers ───────────────────────────────────────────────────────────────── + + +def _setup(E, K, N, T, top_k, R, seed=42): + """Create synthetic expert weights, LoRA, routing, and grouped inputs.""" + torch.manual_seed(seed) + x = torch.randn(T, K, device=DEVICE, dtype=DTYPE) + W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02 + lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01 + lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01 + logits = torch.randn(T, E, device=DEVICE) + _, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1) + sei, ssi, eo = flatten_sort_count(top_idx, E) + return x, W, lora_A, lora_B, sei, ssi, eo + + +def _reference_fwd(x, W, sei, ssi, eo, k, lora_A, lora_B, scaling, E): + """Per-expert loop reference: Y = X@W + scaling*(X@A^T)@B^T.""" + grouped_x = base_ops.group(x, ssi, fan_out=k) + M, N = grouped_x.size(0), W.size(2) + R = lora_A.size(0) // E + out = torch.zeros(M, N, device=DEVICE, dtype=DTYPE) + for e in range(E): + s = eo[e - 1].item() if e > 0 else 0 + end = eo[e].item() + if s == end: + continue + xe = grouped_x[s:end].float() + we = W[e].float() + ae = lora_A[e * R : (e + 1) * R].float() + be = lora_B[:, e * R : (e + 1) * R].float() + out[s:end] = (xe @ we + scaling * (xe @ ae.T) @ be.T).to(DTYPE) + result = torch.zeros(M, N, device=DEVICE, dtype=DTYPE) + result[ssi] = out + return result + + +def _reference_dX(dy_grouped, W, sei, ssi, eo, lora_A, lora_B, scaling, E): + """Per-expert loop reference: dX = dY@W^T + scaling*(dY@B)@A.""" + M, K = dy_grouped.size(0), W.size(1) + R = lora_A.size(0) // E + out = torch.zeros(M, K, device=DEVICE, dtype=DTYPE) + for e in range(E): + s = eo[e - 1].item() if e > 0 else 0 + end = eo[e].item() + if s == end: + continue + dye = dy_grouped[s:end].float() + we = W[e].float() + ae = lora_A[e * R : (e + 1) * R].float() + be = lora_B[:, e * R : (e + 1) * R].float() + out[s:end] = (dye @ we.T + scaling * (dye @ be) @ ae).to(DTYPE) + result = torch.zeros(M, K, device=DEVICE, dtype=DTYPE) + result[ssi] = out + return result + + +def _reference_bwd_lora(dy, grouped_x, lora_A, lora_B, eo, E, scaling): + """Per-expert loop reference: dA, dB for LoRA weight gradients.""" + R = lora_A.size(0) // E + dA = torch.zeros_like(lora_A) + dB = torch.zeros_like(lora_B) + for e in range(E): + s = eo[e - 1].item() if e > 0 else 0 + end = eo[e].item() + if s == end: + continue + xe = grouped_x[s:end].float() + dye = dy[s:end].float() + ae = lora_A[e * R : (e + 1) * R].float() + be = lora_B[:, e * R : (e + 1) * R].float() + dA[e * R : (e + 1) * R] = (scaling * (dye @ be).T @ xe).to(DTYPE) + dB[:, e * R : (e + 1) * R] = (scaling * dye.T @ (xe @ ae.T)).to(DTYPE) + return dA, dB + + +# ─── Model shape configs ──────────────────────────────────────────────────── + +# (E, K, N, T, top_k, R, description) +CONFIGS_SMALL = [ + (32, 128, 64, 64, 2, 4, "tiny"), + (64, 256, 128, 128, 4, 8, "small"), +] + +CONFIGS_REAL = [ + (256, 2048, 1024, 2048, 8, 16, "qwen35_gate_up"), + (256, 512, 2048, 2048, 8, 16, "qwen35_down"), + (64, 2048, 2048, 2048, 8, 16, "olmoe_gate_up"), + (128, 2048, 1536, 2048, 8, 16, "qwen3_gate_up"), +] + +SCALING = 2.0 + + +# ─── Forward tests ────────────────────────────────────────────────────────── + + +class TestScatter2ScatterLoRAForward: + """Test scatter2scatter_lora forward kernel vs reference.""" + + @pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL) + def config(self, request): + return request.param + + def test_matches_reference(self, config): + E, K, N, T, k, R, desc = config + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + + kernel_out = lora_ops.scatter2scatter_lora( + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=k, + lora_A=lA, + lora_B=lB, + scaling=SCALING, + ) + ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E) + + err = (kernel_out.float() - ref_out.float()).abs().max().item() + assert err < 1.0, f"[{desc}] fwd max_err={err}" + + def test_output_shape(self, config): + E, K, N, T, k, R, desc = config + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + + out = lora_ops.scatter2scatter_lora( + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=k, + lora_A=lA, + lora_B=lB, + scaling=SCALING, + ) + assert out.shape == (T * k, N) + assert out.dtype == DTYPE + + +# ─── Backward dX tests ────────────────────────────────────────────────────── + + +class TestScatter2ScatterLoRADX: + """Test scatter2scatter_lora_dX backward kernel vs reference.""" + + @pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL) + def config(self, request): + return request.param + + def test_matches_reference(self, config): + E, K, N, T, k, R, desc = config + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + gx = base_ops.group(x, ssi, fan_out=k) + dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE) + + kernel_dx = lora_ops.scatter2scatter_lora_dX( + DY=dy, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=1, + lora_A=lA, + lora_B=lB, + scaling=SCALING, + dy_grouped=True, + dx_grouped=False, + ) + ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E) + + err = (kernel_dx.float() - ref_dx.float()).abs().max().item() + assert err < 1.0, f"[{desc}] dX max_err={err}" + + +# ─── Backward LoRA gradient tests ─────────────────────────────────────────── + + +class TestGroupBwdLoRA: + """Test group_bwd_lora (split dA/dB kernel) vs reference.""" + + @pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL) + def config(self, request): + return request.param + + def test_matches_reference(self, config): + E, K, N, T, k, R, desc = config + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + gx = base_ops.group(x, ssi, fan_out=k) + dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE) + + kern_dA, kern_dB = lora_ops.group_bwd_lora( + DY=dy, + X=gx, + lora_A=lA, + lora_B=lB, + expert_offsets=eo, + E=E, + scaling=SCALING, + ) + ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING) + + # Use norm-relative error: bf16 accumulation order differs between + # kernel (tiled + different reduction order) and reference (per-expert + # fp32 loop), so max absolute error can be large on individual elements + # while the overall tensor is correct. + dA_norm_err = ( + (kern_dA.float() - ref_dA.float()).norm() / (ref_dA.float().norm() + 1e-6) + ).item() + dB_norm_err = ( + (kern_dB.float() - ref_dB.float()).norm() / (ref_dB.float().norm() + 1e-6) + ).item() + assert dA_norm_err < 0.01, f"[{desc}] dA norm_rel_err={dA_norm_err}" + assert dB_norm_err < 0.01, f"[{desc}] dB norm_rel_err={dB_norm_err}" + + def test_zero_expert_tokens(self): + """Experts with zero routed tokens produce zero gradients.""" + E, K, N, R = 8, 64, 32, 4 + torch.manual_seed(42) + # Route all tokens to expert 0 only + T, k = 16, 1 + top_idx = torch.zeros(T, k, dtype=torch.long, device=DEVICE) + sei, ssi, eo = flatten_sort_count(top_idx, E) + gx = torch.randn(T, K, device=DEVICE, dtype=DTYPE) + dy = torch.randn(T, N, device=DEVICE, dtype=DTYPE) + lA = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) + lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) + + dA, dB = lora_ops.group_bwd_lora( + DY=dy, + X=gx, + lora_A=lA, + lora_B=lB, + expert_offsets=eo, + E=E, + scaling=2.0, + ) + + # Experts 1..7 should have zero gradients + for e in range(1, E): + assert dA[e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dA not zero" + assert dB[:, e * R : (e + 1) * R].abs().max() == 0, ( + f"Expert {e} dB not zero" + ) + + +# ─── Full autograd tests ──────────────────────────────────────────────────── + + +class TestScatterMoELoRAAutograd: + """Test full forward + backward through ScatterMoELoRA autograd function.""" + + @pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL[:2]) + def config(self, request): + return request.param + + def test_gradients_exist_and_finite(self, config): + E, K, N, T, k, R, desc = config + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + + x = x.requires_grad_(True) + lA = lA.requires_grad_(True) + lB = lB.requires_grad_(True) + + out = ScatterMoELoRA.apply( + x, + W, + k, + sei, + ssi, + eo, + lA, + lB, + SCALING, + None, + None, + False, + False, + True, + False, + ) + out.sum().backward() + + assert x.grad is not None, f"[{desc}] x.grad is None" + assert lA.grad is not None, f"[{desc}] lA.grad is None" + assert lB.grad is not None, f"[{desc}] lB.grad is None" + assert torch.isfinite(x.grad).all(), f"[{desc}] x.grad has non-finite" + assert torch.isfinite(lA.grad).all(), f"[{desc}] lA.grad has non-finite" + assert torch.isfinite(lB.grad).all(), f"[{desc}] lB.grad has non-finite" + assert x.grad.abs().sum() > 0, f"[{desc}] x.grad all zero" + assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero" + + def test_split_matches_fused(self): + """Split dispatch (for few large experts) matches fused kernel.""" + # Use a shape where split would be dispatched (large K*N, few E) + E, K, N, T, k, R = 8, 512, 1024, 128, 2, 16 + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + + # Force fused path + orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD + lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18 + out_fused = lora_ops.scatter2scatter_lora( + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=k, + lora_A=lA, + lora_B=lB, + scaling=SCALING, + ) + + # Force split path + lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0 + out_split = lora_ops.scatter2scatter_lora( + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=k, + lora_A=lA, + lora_B=lB, + scaling=SCALING, + ) + lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig + + norm_err = ( + (out_fused.float() - out_split.float()).norm() + / (out_fused.float().norm() + 1e-6) + ).item() + assert norm_err < 0.01, f"split vs fused norm_err={norm_err}" + + def test_scaling_zero_gives_base_only(self): + """With scaling=0.0, LoRA contribution vanishes. Output = X@W.""" + E, K, N, T, k, R = 16, 64, 32, 32, 2, 4 + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + + out_lora = ScatterMoELoRA.apply( + x, + W, + k, + sei, + ssi, + eo, + lA, + lB, + 0.0, + None, + None, + False, + False, + True, + False, + ) + out_base = base_ops.scatter2scatter( + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=k, + ) + err = (out_lora.float() - out_base.float()).abs().max().item() + assert err < 0.01, f"scaling=0 should match base: err={err}"