diff --git a/requirements.txt b/requirements.txt index 09b1f625b..710e24d71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ datasets==4.5.0 deepspeed>=0.18.3 trl==0.28.0 hf_xet==1.2.0 -kernels==0.11.5 +kernels==0.12.1 trackio>=0.16.1 typing-extensions>=4.15.0 diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 77e7b573b..414abeb4d 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -719,6 +719,8 @@ class AxolotlTrainer( output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) LOG.info(f"Saving model checkpoint to {output_dir}") + + # fix for Context Parallel save if state_dict is None: state_dict = self.accelerator.get_state_dict(self.model) if state_dict is not None: @@ -726,6 +728,7 @@ class AxolotlTrainer( k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() } + supported_classes = ( (PreTrainedModel,) if not is_peft_available() @@ -736,6 +739,7 @@ class AxolotlTrainer( if not isinstance(self.model, supported_classes): if state_dict is None: state_dict = self.model.state_dict() + if isinstance( self.accelerator.unwrap_model(self.model, keep_torch_compile=False), supported_classes, @@ -745,6 +749,7 @@ class AxolotlTrainer( ).save_pretrained( output_dir, state_dict=state_dict, + is_main_process=self.accelerator.is_main_process, ) else: LOG.info( @@ -756,11 +761,7 @@ class AxolotlTrainer( metadata={"format": "pt"}, ) else: - self.model.save_pretrained( - output_dir, - state_dict=state_dict, - is_main_process=self.accelerator.is_main_process, - ) + self.model.save_pretrained(output_dir, state_dict=state_dict) if self.processing_class is not None: self.processing_class.save_pretrained(output_dir) @@ -772,11 +773,7 @@ class AxolotlTrainer( LOG.info( "Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`" ) - save_jinja_files = True - if self.axolotl_cfg: - save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files - self.data_collator.tokenizer.save_pretrained( - output_dir, save_jinja_files=save_jinja_files - ) + self.data_collator.tokenizer.save_pretrained(output_dir) + # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) diff --git a/src/axolotl/integrations/kernels/args.py b/src/axolotl/integrations/kernels/args.py index 66d6b6d53..e8cf7208a 100644 --- a/src/axolotl/integrations/kernels/args.py +++ b/src/axolotl/integrations/kernels/args.py @@ -33,3 +33,16 @@ class KernelsArgs(BaseModel): data["experts_implementation"] = "eager" return data + + @model_validator(mode="before") + @classmethod + def disable_mlp_kernel_scattermoe(cls, data): + if data.get("use_scattermoe") is True: + if data.get("lora_mlp_kernel") is True: + LOG.warning( + "Disabling lora_mlp_kernel when using scattermoe due to compatibility issues." + ) + data["lora_mlp_kernel"] = False + data["mlp_kernel"] = False + + return data diff --git a/src/axolotl/integrations/kernels/libs/__init__.py b/src/axolotl/integrations/kernels/libs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py new file mode 100644 index 000000000..f5148634e --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +from . import layers +from .lora_ops import ParallelExperts +from .parallel_experts import flatten_sort_count, parallel_linear +from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora + +__all__ = [ + "layers", + "ParallelExperts", + "flatten_sort_count", + "parallel_linear", + "ScatterMoELoRA", + "parallel_linear_lora", + "lora_ops", +] diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/__init__.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/__init__.py new file mode 100644 index 000000000..eb502db71 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors +# Adapted from https://github.com/shawntan/scattermoe +# See https://github.com/shawntan/scattermoe/blob/main/LICENSE +# +# Modifications and LoRA adaptation Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +from . import lora_ops, ops + +__all__ = ["ops", "lora_ops"] diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py new file mode 100644 index 000000000..5d47c2040 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py @@ -0,0 +1,1731 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +Fused ScatterMoE + LoRA Triton Kernels +======================================= + +Provides fused forward and backward kernels for ScatterMoE with LoRA adapters. + +Forward: Y = X @ W + scaling * (X @ A^T) @ B^T +Backward (LoRA training, W frozen): + - dX = dY @ W^T + scaling * (dY @ B) @ A (input gradient) + - dA = scaling * (dY @ B)^T @ X (LoRA A gradient) + - dB = scaling * dY^T @ (X @ A^T) (LoRA B gradient) + +LoRA weight layout (from PEFT ParamWrapper): + - A: [r*E, K] -- for expert e, rows [e*r : (e+1)*r] give A_e of shape [r, K] + - B: [N, r*E] -- for expert e, cols [e*r : (e+1)*r] give B_e of shape [N, r] + +Key design decisions: + - The forward kernel fuses X@W and X@A^T in the same K-loop for data reuse on X, + then computes (X@A^T) @ B^T in the epilogue. + - The backward dA/dB kernel operates on grouped (expert-contiguous) data and + iterates over tokens per expert, accumulating gradients in registers. + - R (LoRA rank) is a tl.constexpr, allowing tl.arange(0, R). We pad R to a + power-of-2 for Triton tile compatibility; typical ranks (4, 8, 16, 32, 64) + already satisfy this. +""" + +from itertools import product +from typing import Optional + +import torch +import triton +import triton.language as tl + +# ============================================================================= +# Configuration +# ============================================================================= + +BLOCK_M = 128 +ALLOW_TF32 = True + + +def _next_power_of_2(n: int) -> int: + """Round up to next power of 2.""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + return n + 1 + + +# Triton tl.dot requires minimum tile dimensions of 16 on modern GPUs. +MIN_TRITON_DOT_SIZE = 16 + + +def _block_r_for_rank(r: int) -> int: + """Compute BLOCK_R: next power-of-2 >= max(r, MIN_TRITON_DOT_SIZE).""" + return _next_power_of_2(max(r, MIN_TRITON_DOT_SIZE)) + + +# ============================================================================= +# Token Rounding: pad expert counts to BLOCK_M multiples +# ============================================================================= + + +def round_expert_counts( + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + E: int, + block_m: int = BLOCK_M, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Pad each expert's token count to a multiple of block_m to eliminate + partial-tile waste in the backward kernel. + + Padding is done by duplicating the last valid token index for each expert. + The kernel's M_mask = M_idx < real_end_idx masks these padding entries, so + correctness is preserved (they contribute 0 to the accumulation via other=0.0). + + This only helps the backward dA/dB kernel where per-expert iteration is + explicit. The forward scatter2scatter kernel handles partial tiles via masking. + + Args: + sorted_expert_idxs: Expert assignments sorted [M*k] + sorted_scattered_idxs: Original indices sorted [M*k] + expert_offsets: Cumulative token counts per expert [E] + E: Number of experts + block_m: Block size for token dimension (default: BLOCK_M) + + Returns: + padded_expert_idxs: [M_padded] expert assignments with padding + padded_scattered_idxs: [M_padded] original indices with padding + padded_offsets: [E] cumulative padded counts (for kernel iteration range) + real_offsets: [E] original cumulative counts (for M_mask in kernel) + """ + device = sorted_expert_idxs.device + + # Compute per-expert counts + counts = torch.zeros(E, dtype=torch.int64, device=device) + prev = 0 + for e in range(E): + curr = expert_offsets[e].item() + counts[e] = curr - prev + prev = curr + + # Round up each count to multiple of block_m + padded_counts = ((counts + block_m - 1) // block_m) * block_m + # Experts with 0 tokens stay at 0 + padded_counts = torch.where( + counts > 0, padded_counts, torch.zeros_like(padded_counts) + ) + total_padded = padded_counts.sum().item() + + padded_expert_idxs = torch.empty( + total_padded, dtype=sorted_expert_idxs.dtype, device=device + ) + padded_scattered_idxs = torch.empty( + total_padded, dtype=sorted_scattered_idxs.dtype, device=device + ) + + src_offset = 0 + dst_offset = 0 + for e in range(E): + count = counts[e].item() + padded_count = padded_counts[e].item() + + if count > 0: + # Copy original tokens + padded_expert_idxs[dst_offset : dst_offset + count] = sorted_expert_idxs[ + src_offset : src_offset + count + ] + padded_scattered_idxs[dst_offset : dst_offset + count] = ( + sorted_scattered_idxs[src_offset : src_offset + count] + ) + + # Pad with last valid token (masked out by kernel via M_mask) + if padded_count > count: + padded_expert_idxs[dst_offset + count : dst_offset + padded_count] = ( + sorted_expert_idxs[src_offset + count - 1] + ) + padded_scattered_idxs[ + dst_offset + count : dst_offset + padded_count + ] = sorted_scattered_idxs[src_offset + count - 1] + + src_offset += count + dst_offset += padded_count + + # Padded offsets: cumulative padded counts (for iteration range in kernel) + padded_offsets = padded_counts.cumsum(-1).to(expert_offsets.dtype) + # Real offsets: original cumulative counts (for M_mask in kernel) + real_offsets = expert_offsets.clone() + + return padded_expert_idxs, padded_scattered_idxs, padded_offsets, real_offsets + + +# ============================================================================= +# Autotuning: SMEM estimation and config pruning +# ============================================================================= + +_SMEM_CAPACITY: int | None = None + + +def _get_smem_capacity() -> int: + """Get device shared memory capacity (bytes). Cached after first call.""" + global _SMEM_CAPACITY + if _SMEM_CAPACITY is None: + props = triton.runtime.driver.active.utils.get_device_properties( + torch.cuda.current_device() + ) + _SMEM_CAPACITY = props["max_shared_mem"] + return _SMEM_CAPACITY + + +def _estimate_smem_usage( + num_stages: int, BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, dtype_bytes: int = 2 +) -> int: + """Estimate shared memory in bytes for a GEMM-style tile. + + Formula: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N + Multiply by dtype_bytes (2 for fp16/bf16). + """ + return ( + num_stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N + ) * dtype_bytes + + +# Conservative margin (bytes) subtracted from SMEM capacity to account for +# estimation inaccuracies and kernel overhead (registers spilled to SMEM, etc.) +_SMEM_SLACK = 10_000 + + +# ============================================================================= +# Forward Kernel: scatter2scatter with fused LoRA +# ============================================================================= + + +@triton.jit +def _compute_expert_block_lora( + E_idx, + E_mask, + M_in_idx, + N_block, + N_mask, + # Base weight + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + # LoRA weights + A_ptr, + stride_ar, + stride_ak, # A: [r*E, K], stride_ar = stride for r*E dim, stride_ak = stride for K dim + B_ptr, + stride_bn, + stride_br, # B: [N, r*E], stride_bn = stride for N dim, stride_br = stride for r*E dim + # Dimensions + K, + ACTUAL_R: tl.constexpr, # True LoRA rank (for indexing into weight arrays) + acc, + no_k_mask, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_R: tl.constexpr, # Padded tile size >= max(ACTUAL_R, 16) + scaling, + allow_tf32: tl.constexpr, +): + """ + Compute Y_block = X_block @ W_e + scaling * (X_block @ A_e^T) @ B_e^T + + for tokens in this M-block assigned to expert E_idx. + + ACTUAL_R is the true LoRA rank used for indexing into A[e*r:(e+1)*r, :]. + BLOCK_R >= ACTUAL_R is the padded tile dimension (must be >= 16 for tl.dot). + When BLOCK_R > ACTUAL_R, loads are masked on the R dimension. + """ + K_block = tl.arange(0, BLOCK_K) + R_block = tl.arange(0, BLOCK_R) + R_mask = R_block < ACTUAL_R # Mask for padding when BLOCK_R > ACTUAL_R + + # Base weight pointers: W[E_idx, :, :] is [K, N], load [BLOCK_K, BLOCK_N] + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = ( + W_ptr + + E_idx * stride_we + + K_block[:, None] * stride_wk + + N_block[None, :] * stride_wn + ) + + # LoRA A pointers: A[e*ACTUAL_R:(e+1)*ACTUAL_R, :] for expert e, shape [r, K] + A_expert_offset = E_idx * ACTUAL_R + A_blk_ptrs = ( + A_ptr + + (A_expert_offset + R_block)[:, None] * stride_ar + + K_block[None, :] * stride_ak + ) + + iters = tl.cdiv(K, BLOCK_K) + + # Accumulator for X @ A^T: [BLOCK_M, BLOCK_R] + xa_acc = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32) + + # Determine the input element type for consistent casting. + # Masked tl.load with other=0.0 can upcast bf16->fp32 in some Triton versions, + # causing dtype mismatches in tl.dot. We cast all tiles to the same type. + INPUT_DTYPE = X_ptr.dtype.element_ty + + for i in range(iters): + if no_k_mask: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None], other=0.0).to(INPUT_DTYPE) + w = tl.load(W_blk_ptrs, mask=N_mask[None, :], other=0.0).to(INPUT_DTYPE) + a = tl.load(A_blk_ptrs, mask=R_mask[:, None], other=0.0).to(INPUT_DTYPE) + else: + K_mask = (i * BLOCK_K + K_block) < K + x = tl.load( + X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + w = tl.load( + W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + a = tl.load( + A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + + # Base: acc += X @ W ([M, K] @ [K, N] -> [M, N]) + acc += tl.dot(x, w, allow_tf32=allow_tf32).to(tl.float32) + + # LoRA: xa_acc += X @ A^T ([M, K] @ [K, R] -> [M, R]) + xa_acc += tl.dot(x, tl.trans(a), allow_tf32=allow_tf32).to(tl.float32) + + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + A_blk_ptrs += BLOCK_K * stride_ak + + # Epilogue: load B[e] and compute (X @ A^T) @ B^T + # B[e] is B[:, e*ACTUAL_R:(e+1)*ACTUAL_R], shape [N, r]. Load [BLOCK_N, BLOCK_R]. + B_expert_offset = E_idx * ACTUAL_R + B_blk_ptrs = ( + B_ptr + + N_block[:, None] * stride_bn + + (B_expert_offset + R_block)[None, :] * stride_br + ) + b = tl.load( + B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0 + ) # [BLOCK_N, BLOCK_R] + + # Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16) + # Both operands must match; cast to float32 (accumulator type) for precision. + b_f32 = b.to(tl.float32) + + # (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N] + lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32) + + acc += scaling * lora_out + return acc + + +def _scatter2scatter_lora_configs(): + """Generate forward kernel autotune configs. + + Search space includes smaller tile sizes and fewer pipeline stages to + support GPUs with limited shared memory (e.g. ~99KB on some GPUs). + + Search space: + BLOCK_N: {32, 64, 128, 256} + BLOCK_K: {32, 64, 128} + num_warps: {4, 8} + num_stages: {3, 4, 5} + + BLOCK_M is fixed at 128 (module-level constant, not autotuned in the + scatter2scatter pattern). + """ + configs = [] + for block_n, block_k, warps, stages in product( + [32, 64, 128, 256], # BLOCK_N + [32, 64, 128], # BLOCK_K + [4, 8], # num_warps + [3, 4, 5], # num_stages + ): + configs.append( + triton.Config( + {"BLOCK_N": block_n, "BLOCK_K": block_k}, + num_stages=stages, + num_warps=warps, + ) + ) + return configs + + +def _prune_fwd_configs(configs, named_args, **kwargs): + """Prune forward configs based on SMEM capacity. + + The forward kernel inner loop loads three tiles per pipeline stage: + X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K]. + The base estimate only accounts for X and W. We add: + - A tile [BLOCK_R, BLOCK_K] per pipeline stage (loaded in the inner loop) + - B tile [BLOCK_N, BLOCK_R] loaded once in the epilogue + - Extra headroom for compiler overhead (register spills, metadata) + """ + smem_cap = _get_smem_capacity() + + # Get BLOCK_R from named_args if available, else assume worst case + block_r = named_args.get("BLOCK_R", 64) + + scored = [] + for config in configs: + block_n = config.kwargs["BLOCK_N"] + block_k = config.kwargs["BLOCK_K"] + # Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N + smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k) + # A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop + smem_lora_loop = config.num_stages * block_r * block_k * 2 + # B tile [BLOCK_N, BLOCK_R] loaded once in epilogue + smem_lora_epilogue = block_n * block_r * 2 + smem = smem_base + smem_lora_loop + smem_lora_epilogue + scored.append((smem, config)) + + pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK] + if pruned: + return pruned + # All configs exceed SMEM — return the one with smallest estimated usage + scored.sort(key=lambda x: x[0]) + return [scored[0][1]] + + +@triton.autotune( + configs=_scatter2scatter_lora_configs(), + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": _prune_fwd_configs}, +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _scatter2scatter_lora( + # Input/Output + X_ptr, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + W_ptr, + stride_we, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + Y_ptr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, + # Bias + Bias_ptr, + stride_bias_e: tl.constexpr, + stride_bias_n: tl.constexpr, + # LoRA weights + LA_ptr, + stride_la_r, + stride_la_k, # A: [r*E, K] + LB_ptr, + stride_lb_n, + stride_lb_r, # B: [N, r*E] + # Routing + grouped_idx_ptr, + expert_idxs_ptr, + # Dimensions + FAN_OUT: tl.constexpr, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + ACTUAL_R: tl.constexpr, # True LoRA rank (for weight indexing) + # Block sizes + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_R: tl.constexpr, # Padded tile size >= max(ACTUAL_R, 16) + # Config + ACC_TYPE: tl.constexpr, + scaling, + allow_tf32: tl.constexpr, + x_grouped: tl.constexpr, + y_grouped: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, +): + """ + Fused scatter2scatter with LoRA: Y = X @ W + scaling * (X @ A^T) @ B^T + bias + """ + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + M_block_id = pid // N_BLOCK_COUNT + N_block_id = pid % N_BLOCK_COUNT + + M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M) + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + M_boundary_mask = M_block < (FAN_OUT * M) + + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E) + + no_k_mask = NO_K_MASK + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + E_first_idx = tl.min(E_idxs) + E_last_idx = tl.minimum(tl.max(E_idxs), E - 1) + M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32) + + for E_idx in range(E_first_idx, E_last_idx + 1): + E_mask = E_idxs == E_idx + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = M_idx // FAN_OUT + + acc = _compute_expert_block_lora( + E_idx, + E_mask, + M_in_idx, + N_block, + N_mask, + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + LA_ptr, + stride_la_r, + stride_la_k, + LB_ptr, + stride_lb_n, + stride_lb_r, + K, + ACTUAL_R, + acc, + no_k_mask, + BLOCK_M, + BLOCK_K, + BLOCK_N, + BLOCK_R, + scaling, + allow_tf32=allow_tf32, + ) + + # Add bias if present + if Bias_ptr is not None: + B_blk_ptrs = ( + Bias_ptr + + E_idxs[:, None] * stride_bias_e + + N_block[None, :] * stride_bias_n + ) + acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :]) + + # Store output + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :]) + + +def scatter2scatter_lora( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + k: int, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + scaling: float, + b: Optional[torch.Tensor] = None, + x_grouped: bool = False, + y_grouped: bool = False, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e] + + Args: + X: Input [M, K] or [M*k, K] if x_grouped + W: Expert weights [E, K, N] + sorted_expert_idxs: Expert assignments sorted [M*k] + sorted_scattered_idxs: Original indices sorted [M*k] + k: Fan-out (top-k) + lora_A: LoRA A weights [r*E, K] + lora_B: LoRA B weights [N, r*E] + scaling: LoRA scaling factor (alpha/r) + b: Optional bias [E, N] + x_grouped: Input pre-grouped by expert + y_grouped: Keep output grouped + out: Optional pre-allocated output buffer + + Returns: + Y: Output [M*k, N] + """ + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + + E = W.size(0) + K = W.size(1) + N = W.size(2) + R = lora_A.size(0) // E + + # Pad R to power of 2 for Triton tile size + BLOCK_R = _block_r_for_rank(R) + + L_scattered = sorted_expert_idxs.size(0) + + if out is None: + output = torch.empty((L_scattered, N), device=X.device, dtype=X.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == N + output = out + + def grid(META): + return ( + triton.cdiv(L_scattered, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ) + + if b is None: + stride_be = stride_bn = 0 + b_ptr = None + else: + stride_be, stride_bn = b.stride() + b_ptr = b + + _scatter2scatter_lora[grid]( + X, + X.stride(0), + X.stride(1), + W, + W.stride(0), + W.stride(1), + W.stride(2), + output, + output.stride(0), + output.stride(1), + b_ptr, + stride_be, + stride_bn, + # A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride + lora_A, + lora_A.stride(0), + lora_A.stride(1), + # B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride + lora_B, + lora_B.stride(0), + lora_B.stride(1), + sorted_scattered_idxs, + sorted_expert_idxs, + FAN_OUT=k, + M=X.size(0), + K=K, + N=N, + E=E, + ACTUAL_R=R, # True LoRA rank for weight indexing + BLOCK_M=BLOCK_M, + BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16) + ACC_TYPE=tl.float32, + scaling=scaling, + allow_tf32=ALLOW_TF32, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + return output + + +# ============================================================================= +# Backward Kernel: Fused dX = dY @ W^T + scaling * (dY @ B) @ A +# ============================================================================= + + +@triton.jit +def _compute_expert_block_lora_dX( + E_idx, + E_mask, + M_in_idx, + K_block, + K_mask, + # Input: DY (gradient w.r.t. output) + DY_ptr, + stride_dym, + stride_dyn, + # Base weight W^T: we load W[e] as [K, N] and index as W^T[e] = [N, K] + W_ptr, + stride_we, + stride_wk, + stride_wn, + # LoRA weights + A_ptr, + stride_ar, + stride_ak, # A: [r*E, K] + B_ptr, + stride_bn, + stride_br, # B: [N, r*E] + # Dimensions + N, + ACTUAL_R: tl.constexpr, + acc, + no_n_mask, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_R: tl.constexpr, + scaling, + allow_tf32: tl.constexpr, +): + """ + Compute dX_block = DY_block @ W_e^T + scaling * (DY_block @ B_e) @ A_e + + for tokens in this M-block assigned to expert E_idx. + + Inner loop over N dimension (reduction dim for dY @ W^T and dY @ B). + Output dimension is K. + Epilogue computes (dY @ B) @ A. + + Transpose mapping from forward: + Forward: X@W (K-loop), X@A^T (K-loop), (X@A^T)@B^T (epilogue) + Backward: DY@W^T (N-loop), DY@B (N-loop), (DY@B)@A (epilogue) + """ + N_block = tl.arange(0, BLOCK_N) + R_block = tl.arange(0, BLOCK_R) + R_mask = R_block < ACTUAL_R + + # DY pointers: DY is [M_total, N], load [BLOCK_M, BLOCK_N] + DY_blk_ptrs = ( + DY_ptr + M_in_idx[:, None] * stride_dym + N_block[None, :] * stride_dyn + ) + + # W^T pointers: W[e] is [K, N], W^T[e] is [N, K]. We load W^T as [BLOCK_N, BLOCK_K]. + # W stored as [E, K, N], so W^T[e][n, k] = W[e][k, n] = W_ptr + e*stride_we + k*stride_wk + n*stride_wn + # As [BLOCK_N, BLOCK_K] tile: row=n, col=k + WT_blk_ptrs = ( + W_ptr + + E_idx * stride_we + + N_block[:, None] * stride_wn # row = n dimension + + K_block[None, :] * stride_wk + ) # col = k dimension + + # B pointers: B[e] is B[:, e*R:(e+1)*R], shape [N, R]. Load [BLOCK_N, BLOCK_R]. + B_expert_offset = E_idx * ACTUAL_R + B_blk_ptrs = ( + B_ptr + + N_block[:, None] * stride_bn + + (B_expert_offset + R_block)[None, :] * stride_br + ) + + iters = tl.cdiv(N, BLOCK_N) + + # Accumulator for DY @ B: [BLOCK_M, BLOCK_R] + dy_b_acc = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32) + + # Determine the input element type for consistent casting. + INPUT_DTYPE = DY_ptr.dtype.element_ty + + for i in range(iters): + if no_n_mask: + dy = tl.load(DY_blk_ptrs, mask=E_mask[:, None], other=0.0).to(INPUT_DTYPE) + wt = tl.load(WT_blk_ptrs, mask=K_mask[None, :], other=0.0).to(INPUT_DTYPE) + b = tl.load(B_blk_ptrs, mask=R_mask[None, :], other=0.0).to(INPUT_DTYPE) + else: + N_mask_iter = (i * BLOCK_N + N_block) < N + dy = tl.load( + DY_blk_ptrs, mask=E_mask[:, None] & N_mask_iter[None, :], other=0.0 + ).to(INPUT_DTYPE) + wt = tl.load( + WT_blk_ptrs, mask=N_mask_iter[:, None] & K_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + b = tl.load( + B_blk_ptrs, mask=N_mask_iter[:, None] & R_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + + # Base: acc += DY @ W^T ([M, N] @ [N, K] -> [M, K]) + acc += tl.dot(dy, wt, allow_tf32=allow_tf32).to(tl.float32) + + # LoRA: dy_b_acc += DY @ B ([M, N] @ [N, R] -> [M, R]) + dy_b_acc += tl.dot(dy, b, allow_tf32=allow_tf32).to(tl.float32) + + DY_blk_ptrs += BLOCK_N * stride_dyn + WT_blk_ptrs += BLOCK_N * stride_wn + B_blk_ptrs += BLOCK_N * stride_bn + + # Epilogue: load A[e] and compute (DY @ B) @ A + # A[e] is A[e*R:(e+1)*R, :], shape [R, K]. Load [BLOCK_R, BLOCK_K]. + A_expert_offset = E_idx * ACTUAL_R + A_blk_ptrs = ( + A_ptr + + (A_expert_offset + R_block)[:, None] * stride_ar + + K_block[None, :] * stride_ak + ) + a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0) + + # Cast to float32 for precision + a_f32 = a_e.to(tl.float32) + + # (DY @ B) @ A: [M, R] @ [R, K] -> [M, K] + lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32) + + acc += scaling * lora_dx + return acc + + +def _scatter2scatter_lora_dX_configs(): + """Generate backward dX kernel autotune configs. + + The inner loop is over N (not K as in forward). The output dimension is K. + So BLOCK_K tiles the output and BLOCK_N tiles the reduction. + + Search space includes smaller tile sizes and fewer pipeline stages to + support GPUs with limited shared memory (e.g. ~99KB on some GPUs). + + Search space: + BLOCK_K: {32, 64, 128, 256} (output tile) + BLOCK_N: {32, 64, 128, 256} (reduction tile) + num_warps: {4, 8} + num_stages: {3, 4, 5} + """ + configs = [] + for block_k, block_n, warps, stages in product( + [32, 64, 128, 256], # BLOCK_K (output dimension) + [32, 64, 128, 256], # BLOCK_N (reduction dimension) + [4, 8], # num_warps + [3, 4, 5], # num_stages + ): + configs.append( + triton.Config( + {"BLOCK_K": block_k, "BLOCK_N": block_n}, + num_stages=stages, + num_warps=warps, + ) + ) + return configs + + +def _prune_dX_configs(configs, named_args, **kwargs): + """Prune backward dX configs based on SMEM capacity. + + The dX kernel inner loop loads three tiles per pipeline stage: + DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R]. + The base estimate only accounts for DY and W^T. We add: + - B tile [BLOCK_N, BLOCK_R] per pipeline stage (loaded in the inner loop) + - A tile [BLOCK_R, BLOCK_K] loaded once in the epilogue + - Extra headroom for compiler overhead (register spills, metadata) + """ + smem_cap = _get_smem_capacity() + + # Get BLOCK_R from named_args if available, else assume worst case + block_r = named_args.get("BLOCK_R", 64) + + scored = [] + for config in configs: + block_k = config.kwargs["BLOCK_K"] + block_n = config.kwargs["BLOCK_N"] + # Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K + smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n) + # B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop + smem_lora_loop = config.num_stages * block_n * block_r * 2 + # A tile [BLOCK_R, BLOCK_K] loaded once in epilogue + smem_lora_epilogue = block_r * block_k * 2 + smem = smem_base + smem_lora_loop + smem_lora_epilogue + scored.append((smem, config)) + + pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK] + if pruned: + return pruned + # All configs exceed SMEM — return the one with smallest estimated usage + scored.sort(key=lambda x: x[0]) + return [scored[0][1]] + + +@triton.autotune( + configs=_scatter2scatter_lora_dX_configs(), + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": _prune_dX_configs}, +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _scatter2scatter_lora_dX( + # Input: DY (gradient w.r.t. output, grouped) + DY_ptr, + stride_dym: tl.constexpr, + stride_dyn: tl.constexpr, + # Base weight: W [E, K, N] (we compute DY @ W^T) + W_ptr, + stride_we, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + # Output: dX + DX_ptr, + stride_dxm: tl.constexpr, + stride_dxk: tl.constexpr, + # LoRA weights + LA_ptr, + stride_la_r, + stride_la_k, # A: [r*E, K] + LB_ptr, + stride_lb_n, + stride_lb_r, # B: [N, r*E] + # Routing + grouped_idx_ptr, + expert_idxs_ptr, + # Dimensions + FAN_OUT: tl.constexpr, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + ACTUAL_R: tl.constexpr, + # Block sizes + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_R: tl.constexpr, + # Config + ACC_TYPE: tl.constexpr, + scaling, + allow_tf32: tl.constexpr, + dy_grouped: tl.constexpr, + dx_grouped: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, +): + """ + Fused backward dX = DY @ W^T + scaling * (DY @ B) @ A + + DY is in expert-grouped order (x_grouped=True). + dX is output in ungrouped or grouped order based on dx_grouped. + + Grid: (cdiv(M_total, BLOCK_M) * cdiv(K, BLOCK_K),) + """ + pid = tl.program_id(axis=0) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + M_block_id = pid // K_BLOCK_COUNT + K_block_id = pid % K_BLOCK_COUNT + + M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M) + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + M_boundary_mask = M_block < (FAN_OUT * M) + + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E) + + no_n_mask = NO_N_MASK + + acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=ACC_TYPE) + + E_first_idx = tl.min(E_idxs) + E_last_idx = tl.minimum(tl.max(E_idxs), E - 1) + M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32) + + for E_idx in range(E_first_idx, E_last_idx + 1): + E_mask = E_idxs == E_idx + if dy_grouped: + M_in_idx = M_block + else: + M_in_idx = M_idx // FAN_OUT + + acc = _compute_expert_block_lora_dX( + E_idx, + E_mask, + M_in_idx, + K_block, + K_mask, + DY_ptr, + stride_dym, + stride_dyn, + W_ptr, + stride_we, + stride_wk, + stride_wn, + LA_ptr, + stride_la_r, + stride_la_k, + LB_ptr, + stride_lb_n, + stride_lb_r, + N, + ACTUAL_R, + acc, + no_n_mask, + BLOCK_M, + BLOCK_N, + BLOCK_K, + BLOCK_R, + scaling, + allow_tf32=allow_tf32, + ) + + # Store output + if dx_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + DX_blk_ptrs = DX_ptr + ( + M_out_idx[:, None] * stride_dxm + K_block[None, :] * stride_dxk + ) + tl.store(DX_blk_ptrs, acc, mask=M_boundary_mask[:, None] & K_mask[None, :]) + + +def scatter2scatter_lora_dX( + DY: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + k: int, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + scaling: float, + dy_grouped: bool = True, + dx_grouped: bool = False, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Fused backward dX = DY @ W^T + scaling * (DY @ B) @ A + + Replaces the separate: + 1. base_ops.scatter2scatter(DY, W^T, x_grouped=True, ...) + 2. _compute_lora_input_grad(DY, A, B, ...) + + Args: + DY: Gradient w.r.t. output [M*k, N] (grouped by expert) + W: Expert weights [E, K, N] (NOT transposed — kernel handles W^T internally) + sorted_expert_idxs: Expert assignments sorted [M*k] + sorted_scattered_idxs: Original indices sorted [M*k] + k: Fan-out (top-k) + lora_A: LoRA A weights [r*E, K] + lora_B: LoRA B weights [N, r*E] + scaling: LoRA scaling factor + dy_grouped: Whether DY is in grouped (expert-sorted) order (default True) + dx_grouped: Whether to output dX in grouped order (default False) + out: Optional pre-allocated output buffer + + Returns: + dX: Input gradient [M*k, K] + """ + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + + E = W.size(0) + K = W.size(1) + N = W.size(2) + R = lora_A.size(0) // E + + BLOCK_R = _block_r_for_rank(R) + + L_scattered = sorted_expert_idxs.size(0) + + # M for the kernel is DY.size(0) when dy_grouped, else the original M + if dy_grouped: + M = DY.size(0) + fan_out = 1 # DY is already expanded + else: + M = DY.size(0) + fan_out = k + + if out is None: + output = torch.empty((L_scattered, K), device=DY.device, dtype=DY.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == K + output = out + + def grid(META): + return ( + triton.cdiv(L_scattered, META["BLOCK_M"]) * triton.cdiv(K, META["BLOCK_K"]), + ) + + _scatter2scatter_lora_dX[grid]( + DY, + DY.stride(0), + DY.stride(1), + W, + W.stride(0), + W.stride(1), + W.stride(2), + output, + output.stride(0), + output.stride(1), + lora_A, + lora_A.stride(0), + lora_A.stride(1), + lora_B, + lora_B.stride(0), + lora_B.stride(1), + sorted_scattered_idxs, + sorted_expert_idxs, + FAN_OUT=fan_out, + M=M, + K=K, + N=N, + E=E, + ACTUAL_R=R, + BLOCK_M=BLOCK_M, + BLOCK_R=BLOCK_R, + ACC_TYPE=tl.float32, + scaling=scaling, + allow_tf32=ALLOW_TF32, + dy_grouped=dy_grouped, + dx_grouped=dx_grouped, + ) + + return output + + +# ============================================================================= +# Backward Kernel: LoRA gradient computation (dA, dB) +# ============================================================================= + + +def _group_bwd_lora_configs(): + """Generate backward (dA/dB) kernel autotune configs. + + Search space includes smaller tile sizes and fewer pipeline stages to + support GPUs with limited shared memory (e.g. ~99KB on some GPUs). + + Search space: + BLOCK_M: {32, 64, 128, 256} (token-loop tile) + BLOCK_K: {32, 64, 128, 256} + BLOCK_N: {32, 64, 128, 256} + num_warps: {4, 8} + num_stages: {3, 4, 5} + + The backward kernel also uses BLOCK_R (from LoRA rank), but that is + determined by the rank and not autotunable. + """ + configs = [] + for block_m, block_k, block_n, warps, stages in product( + [32, 64, 128, 256], # BLOCK_M + [32, 64, 128, 256], # BLOCK_K + [32, 64, 128, 256], # BLOCK_N + [4, 8], # num_warps + [3, 4, 5], # num_stages + ): + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n}, + num_stages=stages, + num_warps=warps, + ) + ) + return configs + + +def _prune_bwd_lora_configs(configs, named_args, **kwargs): + """Prune backward configs based on SMEM capacity. + + The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N] + in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] + for the full expert. We estimate SMEM based on the dominant terms. + """ + smem_cap = _get_smem_capacity() + block_r = named_args.get("BLOCK_R", 64) + + scored = [] + for config in configs: + block_m = config.kwargs["BLOCK_M"] + block_k = config.kwargs["BLOCK_K"] + block_n = config.kwargs["BLOCK_N"] + # Inner loop loads X[M,K] and DY[M,N], pipeline over M iterations + smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k) + # A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert + smem_lora = (block_r * block_k + block_n * block_r) * 2 + smem = smem_base + smem_lora + scored.append((smem, config)) + + pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK] + if pruned: + return pruned + # All configs exceed SMEM — return the one with smallest estimated usage + scored.sort(key=lambda x: x[0]) + return [scored[0][1]] + + +@triton.autotune( + configs=_group_bwd_lora_configs(), + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": _prune_bwd_lora_configs}, + reset_to_zero=["DLA_ptr", "DLB_ptr"], +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _group_bwd_lora( + # Inputs + DY_ptr, + stride_dym, + stride_dyn, + X_ptr, + stride_xm, + stride_xk, + # LoRA weights (needed for cross-terms) + LA_ptr, + stride_la_r, + stride_la_k, # A: [r*E, K] + LB_ptr, + stride_lb_n, + stride_lb_r, # B: [N, r*E] + # Gradient outputs + DLA_ptr, + stride_dla_r, + stride_dla_k, + DLB_ptr, + stride_dlb_n, + stride_dlb_r, + # Expert offsets + expert_offsets_ptr, + # Dimensions + M, + K: tl.constexpr, + N: tl.constexpr, + ACTUAL_R: tl.constexpr, # True LoRA rank (for weight indexing) + BLOCK_R: tl.constexpr, # Padded tile size >= max(ACTUAL_R, 16) + scaling, + # Block sizes + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, +): + """ + Compute LoRA gradients for each expert on grouped data. + + Grid: (E * cdiv(K, BLOCK_K), cdiv(N, BLOCK_N)) + + For expert e: + dA[e] = scaling * (dY @ B[e])^T @ X -> [r, K], accumulate over M tokens + dB[e] = scaling * dY^T @ (X @ A[e]^T) -> [N, r], accumulate over M tokens + + ACTUAL_R is the true LoRA rank. BLOCK_R >= ACTUAL_R is padded for tl.dot min size. + """ + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + # Get expert's token range from cumulative offsets + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + num_tokens = end_idx - start_idx + + if num_tokens > 0: + M_block = tl.arange(0, BLOCK_M) + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + R_block = tl.arange(0, BLOCK_R) + R_mask = R_block < ACTUAL_R # Mask for padding + + lora_offset = E_idx * ACTUAL_R + + # Determine input element type for consistent casting. + INPUT_DTYPE = X_ptr.dtype.element_ty + + # Load B[e]: [BLOCK_N, BLOCK_R] (masked on R and N, other=0 for padding) + B_blk_ptrs = ( + LB_ptr + + N_block[:, None] * stride_lb_n + + (lora_offset + R_block)[None, :] * stride_lb_r + ) + b_e = tl.load(B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0).to( + INPUT_DTYPE + ) + + # Load A[e]: [BLOCK_R, BLOCK_K] (masked on R and K, other=0 for padding) + A_blk_ptrs = ( + LA_ptr + + (lora_offset + R_block)[:, None] * stride_la_r + + K_block[None, :] * stride_la_k + ) + a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to( + INPUT_DTYPE + ) + + # Accumulators + dA_acc = tl.zeros((BLOCK_R, BLOCK_K), dtype=ACC_TYPE) + dB_acc = tl.zeros((BLOCK_N, BLOCK_R), dtype=ACC_TYPE) + + iters = tl.cdiv(num_tokens, BLOCK_M) + for i in range(iters): + M_idx = start_idx + i * BLOCK_M + M_block + M_mask = M_idx < end_idx + + # Load X: [BLOCK_M, BLOCK_K] + X_blk_ptrs = ( + X_ptr + M_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + ) + x = tl.load( + X_blk_ptrs, mask=M_mask[:, None] & K_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + + # Load dY: [BLOCK_M, BLOCK_N] + DY_blk_ptrs = ( + DY_ptr + M_idx[:, None] * stride_dym + N_block[None, :] * stride_dyn + ) + dy = tl.load( + DY_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + + # X @ A[e]^T: [M, K] @ [K, R] -> [M, R] + xa = tl.dot(x, tl.trans(a_e), allow_tf32=allow_tf32) + + # dY @ B[e]: [M, N] @ [N, R] -> [M, R] + dy_b = tl.dot(dy, b_e, allow_tf32=allow_tf32) + + # Cast intermediates to input dtype for subsequent tl.dot calls + # (tl.dot requires both operands to have the same dtype) + dy_b_cast = dy_b.to(INPUT_DTYPE) + xa_cast = xa.to(INPUT_DTYPE) + + # dA += (dY @ B)^T @ X: [R, M] @ [M, K] -> [R, K] + dA_acc += tl.dot(tl.trans(dy_b_cast), x, allow_tf32=allow_tf32) + + # dB += dY^T @ (X @ A^T): [N, M] @ [M, R] -> [N, R] + dB_acc += tl.dot(tl.trans(dy), xa_cast, allow_tf32=allow_tf32) + + # Store dA with scaling (atomic add since multiple N_blocks contribute) + # Only store the actual R rows, not the padded ones + DLA_blk_ptrs = ( + DLA_ptr + + (lora_offset + R_block)[:, None] * stride_dla_r + + K_block[None, :] * stride_dla_k + ) + tl.atomic_add( + DLA_blk_ptrs, + (dA_acc * scaling).to(DLA_ptr.dtype.element_ty), + mask=R_mask[:, None] & K_mask[None, :], + ) + + # Store dB with scaling (atomic add since multiple K_blocks contribute) + DLB_blk_ptrs = ( + DLB_ptr + + N_block[:, None] * stride_dlb_n + + (lora_offset + R_block)[None, :] * stride_dlb_r + ) + tl.atomic_add( + DLB_blk_ptrs, + (dB_acc * scaling).to(DLB_ptr.dtype.element_ty), + mask=N_mask[:, None] & R_mask[None, :], + ) + + +def group_bwd_lora( + DY: torch.Tensor, + X: torch.Tensor, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + expert_offsets: torch.Tensor, + E: int, + scaling: float, + sorted_scattered_idxs: Optional[torch.Tensor] = None, + k: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute LoRA gradients for A and B on expert-grouped data. + + Args: + DY: Gradient w.r.t. output [M_total, N] (grouped by expert) + X: Input [M_total, K] (grouped by expert) + lora_A: LoRA A weights [r*E, K] + lora_B: LoRA B weights [N, r*E] + expert_offsets: Cumulative token counts per expert [E] + E: Number of experts + scaling: LoRA scaling factor + + Returns: + dA: Gradient for A [r*E, K] + dB: Gradient for B [N, r*E] + """ + R = lora_A.size(0) // E + K = X.size(1) + N = DY.size(1) + + # Zero-init for atomic accumulation + dA = torch.zeros_like(lora_A) + dB = torch.zeros_like(lora_B) + + BLOCK_R = _block_r_for_rank(R) + + def grid(META): + return ( + E * triton.cdiv(K, META["BLOCK_K"]), + triton.cdiv(N, META["BLOCK_N"]), + ) + + _group_bwd_lora[grid]( + DY, + DY.stride(0), + DY.stride(1), + X, + X.stride(0), + X.stride(1), + lora_A, + lora_A.stride(0), + lora_A.stride(1), + lora_B, + lora_B.stride(0), + lora_B.stride(1), + dA, + dA.stride(0), + dA.stride(1), + dB, + dB.stride(0), + dB.stride(1), + expert_offsets, + M=DY.size(0), + K=K, + N=N, + ACTUAL_R=R, # True LoRA rank + BLOCK_R=BLOCK_R, # Padded tile size + scaling=scaling, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32, + ) + + return dA, dB + + +# ============================================================================= +# Backward Kernel: Fused gather + LoRA gradient (dA, dB) — eliminates group() +# ============================================================================= + + +@triton.autotune( + configs=_group_bwd_lora_configs(), + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": _prune_bwd_lora_configs}, + reset_to_zero=["DLA_ptr", "DLB_ptr"], +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _group_bwd_lora_fused( + # Inputs (ungrouped or grouped) + DY_ptr, + stride_dym, + stride_dyn, + X_ptr, + stride_xm, + stride_xk, + # Scatter indices for gather-on-load + sorted_scattered_idxs_ptr, + FAN_OUT: tl.constexpr, + # LoRA weights (needed for cross-terms) + LA_ptr, + stride_la_r, + stride_la_k, # A: [r*E, K] + LB_ptr, + stride_lb_n, + stride_lb_r, # B: [N, r*E] + # Gradient outputs + DLA_ptr, + stride_dla_r, + stride_dla_k, + DLB_ptr, + stride_dlb_n, + stride_dlb_r, + # Expert offsets + expert_offsets_ptr, + # Real expert offsets (for M_mask when using token rounding, else same as expert_offsets_ptr) + real_expert_offsets_ptr, + # Dimensions + M, + K: tl.constexpr, + N: tl.constexpr, + ACTUAL_R: tl.constexpr, + BLOCK_R: tl.constexpr, + scaling, + # Block sizes + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, + # Whether DY is already in grouped (expert-sorted) order + dy_grouped: tl.constexpr = False, +): + """ + Fused gather + LoRA gradient computation. Same as _group_bwd_lora but + reads X from ungrouped buffers using sorted_scattered_idxs for indirect + indexing, eliminating the need for a separate group(X) call. + + When dy_grouped=False (default): both X and DY are read via indirect + indexing through sorted_scattered_idxs. This eliminates both group() + calls entirely. + + When dy_grouped=True: DY is already in grouped order (e.g. gate_up_proj + backward where grouped_out=True) and is read directly. Only X uses + indirect indexing. This avoids the group(X) allocation while + still supporting the grouped DY case. + + Grid: (E * cdiv(K, BLOCK_K), cdiv(N, BLOCK_N)) + + For expert e: + dA[e] = scaling * (dY @ B[e])^T @ X -> [r, K] + dB[e] = scaling * dY^T @ (X @ A[e]^T) -> [N, r] + + Supports token rounding: expert_offsets_ptr gives the iteration range + (padded to BLOCK_M multiples), real_expert_offsets_ptr gives the real + token count for M_mask (to exclude padding tokens). + """ + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + # Get expert's token range from cumulative offsets + # start_idx/end_idx from expert_offsets_ptr: iteration range (possibly padded) + # real_end_idx from real_expert_offsets_ptr: for M_mask (real token count) + if E_idx == 0: + start_idx = 0 + real_start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + real_start_idx = tl.load(real_expert_offsets_ptr + E_idx - 1).to(tl.int32) + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + real_end_idx = tl.load(real_expert_offsets_ptr + E_idx).to(tl.int32) + num_tokens = end_idx - start_idx + + if num_tokens > 0: + M_block = tl.arange(0, BLOCK_M) + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + R_block = tl.arange(0, BLOCK_R) + R_mask = R_block < ACTUAL_R + + lora_offset = E_idx * ACTUAL_R + + # Determine input element type for consistent casting. + INPUT_DTYPE = X_ptr.dtype.element_ty + + # Load B[e] and A[e] — same as non-fused kernel + B_blk_ptrs = ( + LB_ptr + + N_block[:, None] * stride_lb_n + + (lora_offset + R_block)[None, :] * stride_lb_r + ) + b_e = tl.load(B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0).to( + INPUT_DTYPE + ) + + A_blk_ptrs = ( + LA_ptr + + (lora_offset + R_block)[:, None] * stride_la_r + + K_block[None, :] * stride_la_k + ) + a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to( + INPUT_DTYPE + ) + + # Accumulators + dA_acc = tl.zeros((BLOCK_R, BLOCK_K), dtype=ACC_TYPE) + dB_acc = tl.zeros((BLOCK_N, BLOCK_R), dtype=ACC_TYPE) + + real_num_tokens = real_end_idx - real_start_idx + iters = tl.cdiv(num_tokens, BLOCK_M) + for i in range(iters): + M_idx = start_idx + i * BLOCK_M + M_block + # Use real token count for masking (excludes padding tokens) + M_local = i * BLOCK_M + M_block + M_mask = M_local < real_num_tokens + + # Fused gather: load scatter indices for indirect X access + scatter_idx = tl.load( + sorted_scattered_idxs_ptr + M_idx, mask=M_mask, other=0 + ).to(tl.int32) + X_token_idx = scatter_idx // FAN_OUT # X is [M, K], not expanded by k + + # Load X via indirect index: [BLOCK_M, BLOCK_K] + X_blk_ptrs = ( + X_ptr + X_token_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + ) + x = tl.load( + X_blk_ptrs, mask=M_mask[:, None] & K_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + + # Load DY: indirect via scatter_idx when ungrouped, direct via M_idx when grouped + if dy_grouped: + DY_blk_ptrs = ( + DY_ptr + M_idx[:, None] * stride_dym + N_block[None, :] * stride_dyn + ) + else: + DY_blk_ptrs = ( + DY_ptr + + scatter_idx[:, None] * stride_dym + + N_block[None, :] * stride_dyn + ) + dy = tl.load( + DY_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + + # X @ A[e]^T: [M, K] @ [K, R] -> [M, R] + xa = tl.dot(x, tl.trans(a_e), allow_tf32=allow_tf32) + + # dY @ B[e]: [M, N] @ [N, R] -> [M, R] + dy_b = tl.dot(dy, b_e, allow_tf32=allow_tf32) + + dy_b_cast = dy_b.to(INPUT_DTYPE) + xa_cast = xa.to(INPUT_DTYPE) + + # dA += (dY @ B)^T @ X: [R, M] @ [M, K] -> [R, K] + dA_acc += tl.dot(tl.trans(dy_b_cast), x, allow_tf32=allow_tf32) + + # dB += dY^T @ (X @ A^T): [N, M] @ [M, R] -> [N, R] + dB_acc += tl.dot(tl.trans(dy), xa_cast, allow_tf32=allow_tf32) + + # Store dA with scaling (atomic add since multiple N_blocks contribute) + DLA_blk_ptrs = ( + DLA_ptr + + (lora_offset + R_block)[:, None] * stride_dla_r + + K_block[None, :] * stride_dla_k + ) + tl.atomic_add( + DLA_blk_ptrs, + (dA_acc * scaling).to(DLA_ptr.dtype.element_ty), + mask=R_mask[:, None] & K_mask[None, :], + ) + + # Store dB with scaling (atomic add since multiple K_blocks contribute) + DLB_blk_ptrs = ( + DLB_ptr + + N_block[:, None] * stride_dlb_n + + (lora_offset + R_block)[None, :] * stride_dlb_r + ) + tl.atomic_add( + DLB_blk_ptrs, + (dB_acc * scaling).to(DLB_ptr.dtype.element_ty), + mask=N_mask[:, None] & R_mask[None, :], + ) + + +def group_bwd_lora_fused( + DY: torch.Tensor, + X: torch.Tensor, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + expert_offsets: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + E: int, + k: int, + scaling: float, + real_expert_offsets: Optional[torch.Tensor] = None, + dy_grouped: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Fused gather + LoRA gradient computation. Same result as + group(X) + group(DY) + group_bwd_lora(DY, X, ...) but without + the intermediate grouped buffers. + + Args: + DY: Gradient w.r.t. output [M*k, N]. + If dy_grouped=False: ungrouped (original token order), read via + indirect indexing through sorted_scattered_idxs. + If dy_grouped=True: already in grouped (expert-sorted) order, + read directly. + X: Input [M, K] (ungrouped, original token order). Always read via + indirect indexing through sorted_scattered_idxs. + lora_A: LoRA A weights [r*E, K] + lora_B: LoRA B weights [N, r*E] + expert_offsets: Cumulative token counts per expert [E] + (or padded offsets if using token rounding) + sorted_scattered_idxs: Maps grouped position -> original position [M*k] + (or padded version if using token rounding) + E: Number of experts + k: Fan-out (top-k) + scaling: LoRA scaling factor + real_expert_offsets: Original cumulative counts for M_mask when using + token rounding. If None, expert_offsets is used for both. + dy_grouped: Whether DY is already in grouped order (default False). + When True, avoids indirect indexing for DY, used for gate_up_proj + backward where grouped_out=True. + + Returns: + dA: Gradient for A [r*E, K] + dB: Gradient for B [N, r*E] + """ + R = lora_A.size(0) // E + K = X.size(1) + N = DY.size(1) + + # Zero-init for atomic accumulation + dA = torch.zeros_like(lora_A) + dB = torch.zeros_like(lora_B) + + BLOCK_R = _block_r_for_rank(R) + + if real_expert_offsets is None: + real_expert_offsets = expert_offsets + + def grid(META): + return ( + E * triton.cdiv(K, META["BLOCK_K"]), + triton.cdiv(N, META["BLOCK_N"]), + ) + + _group_bwd_lora_fused[grid]( + DY, + DY.stride(0), + DY.stride(1), + X, + X.stride(0), + X.stride(1), + sorted_scattered_idxs, + FAN_OUT=k, + LA_ptr=lora_A, + stride_la_r=lora_A.stride(0), + stride_la_k=lora_A.stride(1), + LB_ptr=lora_B, + stride_lb_n=lora_B.stride(0), + stride_lb_r=lora_B.stride(1), + DLA_ptr=dA, + stride_dla_r=dA.stride(0), + stride_dla_k=dA.stride(1), + DLB_ptr=dB, + stride_dlb_n=dB.stride(0), + stride_dlb_r=dB.stride(1), + expert_offsets_ptr=expert_offsets, + real_expert_offsets_ptr=real_expert_offsets, + M=sorted_scattered_idxs.size(0), + K=K, + N=N, + ACTUAL_R=R, + BLOCK_R=BLOCK_R, + scaling=scaling, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32, + dy_grouped=dy_grouped, + ) + + return dA, dB diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py new file mode 100644 index 000000000..6aa432770 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/ops.py @@ -0,0 +1,645 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/shawntan/scattermoe +# Copyright (c) Shawn Tan and ScatterMoE Contributors +# Licensed under the Apache License, Version 2.0 +# See https://github.com/shawntan/scattermoe/blob/main/LICENSE + +from typing import Optional + +import torch +import triton +import triton.language as tl + +BLOCK_M = 128 +ALLOW_TF32 = True + + +@triton.jit +def _compute_expert_block( + E_idx, + E_mask, + M_in_idx, + N_block, + N_mask, + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + K, + acc, + no_k_mask, + BLOCK_K, + allow_tf32=True, +): + K_block = tl.arange(0, BLOCK_K) + X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk + W_blk_ptrs = ( + W_ptr + + K_block[:, None] * stride_wk + + N_block[None, :] * stride_wn + + E_idx * stride_we + ) + iters = tl.cdiv(K, BLOCK_K) + + for K_block_id in range(iters): + if no_k_mask: + x = tl.load(X_blk_ptrs, mask=E_mask[:, None]) + w = tl.load(W_blk_ptrs, mask=N_mask[None, :]) + else: + K_mask = (K_block_id * BLOCK_K + K_block) < K + x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :]) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :]) + + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + acc = tl.dot(x, w, acc, allow_tf32=allow_tf32) + return acc + + +def _scatter2scatter_configs(): + return [ + triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4), + ] + + +@triton.autotune( + configs=_scatter2scatter_configs(), + key=["M", "N", "K"], +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _scatter2scatter( + X_ptr, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + W_ptr, + stride_we, + stride_wk: tl.constexpr, + stride_wn: tl.constexpr, + Y_ptr, + stride_ym: tl.constexpr, + stride_yn: tl.constexpr, + B_ptr, + stride_be: tl.constexpr, + stride_bn: tl.constexpr, + grouped_idx_ptr, + expert_idxs_ptr, + # block_start_idx_ptr, + FAN_OUT: tl.constexpr, + M, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + # OUT_M, + allow_tf32: tl.constexpr, + x_grouped: tl.constexpr, + y_grouped: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, +): + pid = tl.program_id(axis=0) + + N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N) + M_block_id = pid // N_BLOCK_COUNT + N_block_id = pid % N_BLOCK_COUNT + + M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M) + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + M_boundary_mask = M_block < (FAN_OUT * M) + E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E) + + no_k_mask = K % BLOCK_K == 0 + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + E_first_idx = tl.min(E_idxs) + E_last_idx = tl.minimum(tl.max(E_idxs), E - 1) + M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32) + for E_idx in range(E_first_idx, E_last_idx + 1): + E_mask = E_idxs == E_idx + E_M_idx = M_idx + if x_grouped: + M_in_idx = M_block + else: + M_in_idx = E_M_idx // FAN_OUT + acc = _compute_expert_block( + E_idx, + E_mask, + M_in_idx, + N_block, + N_mask, + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + K, + acc, + no_k_mask, + BLOCK_K, + allow_tf32=allow_tf32, + ) + + if B_ptr is not None: + B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn + acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :]) + + if y_grouped: + M_out_idx = M_block + else: + M_out_idx = M_idx + Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn) + tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :]) + + +def scatter2scatter( + X, + W, + sorted_expert_idxs, + sorted_scattered_idxs, + k, + b=None, + x_grouped=False, + y_grouped=False, + out=None, +): + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + # Pre-kernel setup + y_dim = W.size(-1) + L_scattered = sorted_expert_idxs.size(0) + if out is None: + output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype) + else: + assert out.size(0) == L_scattered and out.size(1) == y_dim + output = out + + scatter2scatter_compileable( + output, + W, + X, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + b, + x_grouped, + y_grouped, + ) + return output + + +@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"}) +def scatter2scatter_compileable( + output: torch.Tensor, + W: torch.Tensor, + X: torch.Tensor, + k: int, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + b: Optional[torch.Tensor], + x_grouped: bool, + y_grouped: bool, +) -> None: + def grid(META): + grid_num = ( + triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"]) + * triton.cdiv(META["N"], META["BLOCK_N"]), + ) + return grid_num + + if b is None: + b = None + stride_be = stride_bn = 0 + else: + stride_be, stride_bn = b.stride() + + _scatter2scatter[grid]( + # X_ptr, stride_xm, stride_xk, + X, + X.stride(0), + X.stride(1), + # W_ptr, stride_we, stride_wk, stride_wn, + W, + W.stride(0), + W.stride(1), + W.stride(2), + # Y_ptr, stride_ym, stride_yn, + output, + output.stride(0), + output.stride(1), + # B_ptr, stride_be, stride_bn + b, + stride_be, + stride_bn, + grouped_idx_ptr=sorted_scattered_idxs, + expert_idxs_ptr=sorted_expert_idxs, + # block_start_idx_ptr=padded_block_idxs, + FAN_OUT=k, + M=X.size(0), + K=X.size(1), + N=output.size(1), + E=W.size(0), + BLOCK_M=BLOCK_M, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32, + x_grouped=x_grouped, + y_grouped=y_grouped, + ) + + +def _config_XtY(): + return [ + triton.Config( + {"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4 + ), + ] + + +def group_bwd_W(DY, X, expert_offsets, E, has_bias=False): + DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype) + DW = DWt.permute(0, 2, 1) + if has_bias: + Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype) + else: + Db = None + groupXtY_compileable(E, DW, Db, DY, X, expert_offsets) + return DW, Db + + +@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW", "Db"}) +def groupXtY_compileable( + E: int, + DW: torch.Tensor, + Db: Optional[torch.Tensor], + DY: torch.Tensor, + X: torch.Tensor, + expert_offsets: torch.Tensor, +) -> None: + def grid(META): + grid = ( + E * triton.cdiv(META["K"], META["BLOCK_K"]), + triton.cdiv(META["N"], META["BLOCK_N"]), + ) + return grid + + if Db is None: + stride_dbe = 0 + stride_dbn = 0 + else: + stride_dbe, stride_dbn = Db.stride() + + _groupXtY[grid]( + # DY_ptr, stride_dym, stride_dyk, + DY, + DY.stride(0), + DY.stride(1), + # X_ptr, stride_xm, stride_xn, + X, + X.stride(0), + X.stride(1), + # DW_ptr, stride_dwe, stride_dwk, stride_dwn, + DW, + DW.stride(0), + DW.stride(1), + DW.stride(2), + # Db_ptr, stride_dwe, stride_dbn, + Db, + stride_dbe, + stride_dbn, + # expert_offsets_ptr, + expert_offsets, + # K: tl.constexpr, N: tl.constexpr, + M=DY.size(0), + N=DY.size(-1), + K=X.size(-1), + # ACC_TYPE: tl.constexpr, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32, + ) + + +@triton.autotune( + configs=_config_XtY(), + key=["M", "N", "K"], +) +@triton.heuristics( + { + "NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0, + "NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _groupXtY( + DY_ptr, + stride_dym, + stride_dyk, + X_ptr, + stride_xm, + stride_xn, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + Db_ptr, + stride_dbe, + stride_dbn, + expert_offsets_ptr, + M, + K: tl.constexpr, + N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + NO_K_MASK: tl.constexpr, + NO_N_MASK: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + num0 = tl.num_programs(0) + num1 = tl.num_programs(1) + # pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128) + pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4) + + K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K) + E_idx = pid0 // K_BLOCK_COUNT + K_block_id = pid0 % K_BLOCK_COUNT + N_block_id = pid1 + + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + + if end_idx > start_idx: + M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M) + + K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + K_mask = K_block < K + K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K) + + N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_block < N + N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N) + + M_idxs = M_block + xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm + dy_blk_ptrs = ( + DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk + ) + if (Db_ptr is not None) and (K_block_id == 0): + _xty_and_bias( + E_idx, + start_idx, + end_idx, + M_block, + K_block, + K_mask, + N_block, + N_mask, + dy_blk_ptrs, + stride_dym, + xt_blk_ptrs, + stride_xm, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + Db_ptr, + stride_dbe, + stride_dbn, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ACC_TYPE, + allow_tf32, + NO_K_MASK, + NO_N_MASK, + compute_bias=True, + ) + else: + _xty_and_bias( + E_idx, + start_idx, + end_idx, + M_block, + K_block, + K_mask, + N_block, + N_mask, + dy_blk_ptrs, + stride_dym, + xt_blk_ptrs, + stride_xm, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + Db_ptr, + stride_dbe, + stride_dbn, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ACC_TYPE, + allow_tf32, + NO_K_MASK, + NO_N_MASK, + compute_bias=False, + ) + + +@triton.jit +def _xty_and_bias( + E_idx, + start_idx, + end_idx, + M_block, + K_block, + K_mask, + N_block, + N_mask, + dy_blk_ptrs, + stride_dym, + xt_blk_ptrs, + stride_xm, + DW_ptr, + stride_dwe, + stride_dwk, + stride_dwn, + Db_ptr, + stride_dbe, + stride_dbn, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ACC_TYPE, + allow_tf32, + NO_K_MASK, + NO_N_MASK, + compute_bias: tl.constexpr, +): + if compute_bias: + db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE) + else: + db_acc = None + + acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE) + iters = tl.cdiv(end_idx - start_idx, BLOCK_M) + for i in range(0, iters): + M_mask = (i * BLOCK_M + M_block) < end_idx + if NO_K_MASK: + xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :]) + else: + xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :]) + if NO_N_MASK: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None]) + else: + dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :]) + + acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32) + + xt_blk_ptrs += BLOCK_M * stride_xm + dy_blk_ptrs += BLOCK_M * stride_dym + + if compute_bias: + db_acc += tl.sum(dy, axis=0) + + DW_blk_ptrs = ( + DW_ptr + + E_idx * stride_dwe + + K_block[:, None] * stride_dwk + + N_block[None, :] * stride_dwn + ) + acc = acc.to(DW_blk_ptrs.dtype.element_ty) + tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :]) + if compute_bias: + Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn + tl.store(Db_blk_ptrs, db_acc, mask=N_mask) + + +def _config_grouping(): + return [ + triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4), + ] + + +def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None): + N = sorted_expert_idxs.size(0) + K = A.size(1) + assert A.size(0) * fan_out == N + if out is not None: + Y = out + else: + Y = torch.empty((N, K), dtype=A.dtype, device=A.device) + group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs) + return Y + + +@torch.library.custom_op("scattermoe::group", mutates_args={"Y"}) +def group_compileable( + A: torch.Tensor, + K: int, + N: int, + Y: torch.Tensor, + coeff: Optional[torch.Tensor], + has_coeff: bool, + fan_out: int, + sorted_expert_idxs: torch.Tensor, +) -> None: + def grid(META): + grid_num = (triton.cdiv(META["N"], META["BLOCK_N"]),) + return grid_num + + _group[grid]( + # A_ptr, stride_an, stride_ai, + A, + A.stride(0), + A.stride(1), + has_coeff, + coeff, + fan_out, + # Y_ptr, stride_yn, stride_yk, + Y, + Y.stride(0), + Y.stride(1), + # grouped_idx_ptr, + sorted_expert_idxs, + # N: tl.constexpr, K: tl.constexpr, + N, + K, + ) + + +@triton.autotune(configs=_config_grouping(), key=["K"]) +@triton.heuristics({"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0}) +@triton.jit +def _group( + src_ptr, + stride_sn, + stride_sk, + has_coeff: tl.constexpr, + coeff_ptr, + FAN_OUT: tl.constexpr, + tgt_ptr, + stride_tn, + stride_ti, + grouped_idx_ptr, + N, + K: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + NO_K_MASK: tl.constexpr, +): + pid = tl.program_id(axis=0) + + N_block_id = pid + N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N) + N_mask = N_blk < N + N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N) + N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0) + + K_blk = tl.arange(0, BLOCK_K) + src_blk_ptrs = ( + src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk + ) + tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti + + if has_coeff: + c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None] + + iters = tl.cdiv(K, BLOCK_K) + for i in range(0, iters): + if NO_K_MASK or i < iters - 1: + block = tl.load(src_blk_ptrs, mask=N_mask[:, None]) + if has_coeff: + block *= c + tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None]) + + else: + K_mask = (i * BLOCK_K + K_blk) < K + mask = N_mask[:, None] & K_mask[None, :] + block = tl.load(src_blk_ptrs, mask=mask) + if has_coeff: + block *= c + tl.store(tgt_blk_ptrs, block, mask=mask) + src_blk_ptrs += BLOCK_K * stride_sk + tgt_blk_ptrs += BLOCK_K * stride_ti diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py new file mode 100644 index 000000000..9f0270aa6 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/single.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/shawntan/scattermoe +# Copyright (c) Shawn Tan and ScatterMoE Contributors +# Licensed under the Apache License, Version 2.0 +# See https://github.com/shawntan/scattermoe/blob/main/LICENSE + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _single2scatter( + X_ptr, + stride_xm, + stride_xk, + W_ptr, + stride_we, + stride_wk, + stride_wn, + Y_ptr, + stride_ym, + stride_yn, + expert_idxs_ptr, + FAN_OUT: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + E: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, +): + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + + N_block_id = pid0 + if FAN_OUT == 1: + in_idx = pid1 + else: + in_idx = 0 + out_idx = pid1 + + K_block = tl.arange(0, BLOCK_K) + N_block = tl.max_contiguous( + tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), + BLOCK_N, + ) + E_idx = tl.load(expert_idxs_ptr + pid1) + X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk + W_blk_ptrs = ( + W_ptr + + E_idx * stride_we + + K_block[:, None] * stride_wk + + N_block[None, :] * stride_wn + ) + N_mask = N_block < N + acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE) + for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)): + K_mask = K_block < K + x = tl.load(X_blk_ptrs, mask=K_mask[:, None], other=0.0) + w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0) + acc += tl.sum(x * w, axis=0)[None, :] + X_blk_ptrs += BLOCK_K * stride_xk + W_blk_ptrs += BLOCK_K * stride_wk + K_block += BLOCK_K + Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn + tl.store(Y_blk_ptrs, acc, mask=N_mask[None, :]) + + +def single2scatter(X, W, expert_idxs): + E, xdim, ydim = W.size() + k = expert_idxs.size(1) + assert X.size(0) == k or X.size(0) == 1 + Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype) + BLOCK_N = 128 + BLOCK_K = 128 + grid = triton.cdiv(ydim, BLOCK_N), k + _single2scatter[grid]( + X, + X.stride(0), + X.stride(1), + W, + W.stride(0), + W.stride(1), + W.stride(2), + Y, + Y.stride(0), + Y.stride(1), + expert_idxs, + FAN_OUT=Y.size(0) // X.size(0), + K=xdim, + N=ydim, + E=E, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ACC_TYPE=tl.float32, + ) + return Y diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py new file mode 100644 index 000000000..a42577483 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -0,0 +1,439 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors +# Adapted from https://github.com/shawntan/scattermoe +# See https://github.com/shawntan/scattermoe/blob/main/LICENSE +# +# Modifications and LoRA adaptation Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +ScatterMoE layer replacements for HuggingFace MoE architectures. + +Provides drop-in forward replacements that use ScatterMoE kernels for +acceleration. When used via the HF ``kernels`` library +(``replace_kernel_forward_from_hub``), these classes replace the forward +method of the original MoE block. + +LoRA support +------------ +When peft wraps parameters via ``target_parameters``, the ``self.experts`` +submodule becomes a chain of ``ParamWrapper`` objects and the ``self.gate`` +router may also become a ``ParamWrapper``. The ``HFScatterMoEGatedMLP`` +forward detects this and automatically: + +1. Unwraps ``self.gate`` to the base router, applying gate LoRA delta +2. Unwraps ``self.experts`` to the base ``OlmoeExperts`` module +3. Extracts LoRA A/B weights and scaling from each wrapper +4. Converts B layout from peft rank-major to scattermoe expert-major +5. Routes to ``parallel_linear_lora`` for fused LoRA computation +6. Passes through ``self.shared_expert`` / ``self.shared_expert_gate`` + (peft wraps their linear layers with standard LoRA, no special handling) +""" + +import torch +from torch import nn +from torch.nn import functional as F + +from .parallel_experts import flatten_sort_count, parallel_linear +from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora + +# ============================================================================= +# LoRA layout conversion utilities (peft <-> scattermoe) +# ============================================================================= + + +def peft_lora_B_to_scattermoe(peft_B, num_experts, rank): + """Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe + expert-major ``[N, r*E]``. + + peft reshapes B to ``[out, r, E]`` (rank-major). + scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major). + """ + N = peft_B.shape[0] + return ( + peft_B.reshape(N, rank, num_experts) + .permute(0, 2, 1) + .contiguous() + .reshape(N, num_experts * rank) + ) + + +def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): + """Convert peft LoRA weights to scattermoe layout (with A<->B swap). + + peft operates on the parameter in its native storage layout ``[E, dim1, dim2]`` + where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the + parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with + ``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles + are swapped relative to scattermoe's convention. + + peft gives: + lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]`` + + scattermoe needs: + lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]`` + + This function swaps A<->B and converts B from rank-major to expert-major. + Uses vectorized tensor operations (no Python loop over experts). + + Works for **both** gate_up_proj and down_proj since the transposition + issue is the same for any parameter. + """ + peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank) + + dim1 = peft_A.shape[1] # peft in_features -> scattermoe N + dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K + + # smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2] + # [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2] + smoe_A = ( + peft_B_em.reshape(dim2, num_experts, rank) + .permute(1, 2, 0) + .contiguous() + .reshape(rank * num_experts, dim2) + ) + + # smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r] + # [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r] + smoe_B = ( + peft_A.reshape(num_experts, rank, dim1) + .permute(2, 0, 1) + .contiguous() + .reshape(dim1, num_experts * rank) + ) + + return smoe_A, smoe_B + + +def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank): + """Deprecated alias for :func:`peft_lora_to_scattermoe`.""" + return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank) + + +# ============================================================================= +# ParamWrapper unwrapping +# ============================================================================= + + +def _unwrap_gate_lora(gate_module): + """Unwrap peft ``ParamWrapper`` on the router gate. + + When peft targets ``gate.weight``, ``self.gate`` becomes:: + + ParamWrapper(weight) + -> base_layer: OlmoeTopKRouter (the real module) + + This function detects the wrapping and returns the base router, its + weight tensor, and an optional LoRA delta tensor. + + Returns: + (base_gate, gate_weight, gate_lora_delta_or_None) + + ``base_gate`` is the original router module (with ``.top_k``, + ``.num_experts``, ``.norm_topk_prob``). + ``gate_weight`` is the base router weight (may be a DTensor under FSDP). + ``gate_lora_delta_or_None`` is the LoRA delta tensor if LoRA is active, + else ``None``. Kept separate to avoid mixing DTensor + Tensor in an add. + """ + if hasattr(gate_module, "base_layer") and hasattr(gate_module, "lora_A"): + base_gate = gate_module.base_layer + lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module) + if lora_A is not None: + # gate weight: [num_experts, hidden_size] + # lora_A: [r, hidden_size], lora_B: [num_experts, r] + # delta = scaling * B @ A = [num_experts, hidden_size] + delta = scaling * (lora_B @ lora_A) + return base_gate, base_gate.weight, delta + else: + return base_gate, base_gate.weight, None + else: + # No wrapping — gate is the original module + return gate_module, gate_module.weight, None + + +def _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling): + """Convert peft LoRA weights to scattermoe layout.""" + smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank) + return (smoe_A, smoe_B, scaling) + + +def _unwrap_experts_lora(experts_module): + """Walk a peft ``ParamWrapper`` chain on ``self.experts``. + + When peft targets ``experts.gate_up_proj`` and ``experts.down_proj`` via + ``target_parameters``, ``self.experts`` becomes a nested chain:: + + ParamWrapper(down_proj) + -> base_layer: ParamWrapper(gate_up_proj) + -> base_layer: OlmoeExperts (the real module) + + This function walks the chain, collects LoRA params keyed by + ``parameter_name``, and returns the base experts module. + + Returns: + (base_experts, gup_lora, down_lora) + + Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``. + A/B are already in scattermoe layout. + """ + # Collect ParamWrapper layers by their parameter_name + wrappers = {} + module = experts_module + while hasattr(module, "base_layer") and hasattr(module, "lora_A"): + param_name = getattr(module, "parameter_name", None) + if param_name is not None: + wrappers[param_name] = module + module = module.base_layer + + base_experts = module + + if not wrappers: + return base_experts, None, None + + # Determine num_experts from base module + num_experts = getattr(base_experts, "num_experts", None) + if num_experts is None: + # Fallback: infer from parameter shape + gup = getattr(base_experts, "gate_up_proj", None) + if gup is not None: + num_experts = gup.shape[0] + + # Extract gate_up_proj LoRA (needs A<->B swap due to transposition) + gup_lora = None + gup_wrapper = wrappers.get("gate_up_proj") + if gup_wrapper is not None: + lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper) + if lora_A is not None: + rank = lora_A.shape[0] // num_experts + gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling) + + # Extract down_proj LoRA (needs A<->B swap due to transposition) + down_lora = None + down_wrapper = wrappers.get("down_proj") + if down_wrapper is not None: + lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_wrapper) + if lora_A is not None: + rank = lora_A.shape[0] // num_experts + down_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling) + + return base_experts, gup_lora, down_lora + + +# ============================================================================= +# Layer classes +# ============================================================================= + + +class ScatterMoEGatedMLP(nn.Module): + def forward(self, layer_input): + """ + Forward pass of the mixture of experts layer. + + Args: + layer_input (Tensor): + Input tensor. + + Returns: + Tensor: + Output tensor. + """ + bsz, length, emb_size = layer_input.size() + layer_input = layer_input.reshape(-1, emb_size) + # compute the top_k routing decision + router_logits = self.router.layer(layer_input) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.router.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(layer_input.dtype) + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count( + selected_experts, num_experts=self.router.num_experts + ) + + # compute experts + gates, h = parallel_linear( + layer_input, + self.input_linear.weight.transpose(2, 1), + self.router.top_k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=False, + grouped_out=True, + ).chunk(2, dim=-1) + h = self.activation(gates) * h + layer_output = parallel_linear( + h, + self.output_linear.weight.transpose(2, 1), + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=True, + grouped_out=False, + gates=routing_weights, + ) + layer_output = layer_output.view(bsz, length, emb_size) + return layer_output + + +class HFScatterMoEGatedMLP(nn.Module): + """ + ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE). + + Used as a kernel layer via the HF ``kernels`` library. The ``forward`` + method replaces the original ``OlmoeSparseMoeBlock.forward``. + + Supports both full-parameter training and LoRA fine-tuning: + + * **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel) + * **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts + adapter weights, and uses ``parallel_linear_lora`` (fused kernel) + """ + + @staticmethod + def forward(self: nn.Module, layer_input: torch.Tensor): + """ + Forward pass using ScatterMoE kernels. + + Args: + self: The MoeSparseMoeBlock module containing: + - self.gate: Router (or peft ParamWrapper wrapping it) + - self.experts: Experts module (or peft ParamWrapper chain) + - self.shared_expert: Optional shared expert (e.g. Qwen2MoE) + - self.shared_expert_gate: Optional shared expert gate + layer_input: Input tensor [batch_size, seq_len, hidden_size] + + Returns: + Tensor: [batch_size, seq_len, hidden_size] + """ + batch_size, sequence_length, hidden_dim = layer_input.shape + hidden_states_flat = layer_input.view(-1, hidden_dim) + + # ==================================================================== + # Shared Expert (if present, e.g. Qwen2MoE) + # ==================================================================== + # peft wraps individual linear layers inside shared_expert with + # standard LoRA — calling forward() handles this transparently. + if hasattr(self, "shared_expert") and self.shared_expert is not None: + shared_expert_output = self.shared_expert(hidden_states_flat) + # shared_expert_gate may also be peft-wrapped (standard LoRA + # on nn.Linear), its forward() applies LoRA automatically. + shared_expert_gate_output = F.sigmoid( + self.shared_expert_gate(hidden_states_flat) + ) + shared_expert_output = shared_expert_output * shared_expert_gate_output + else: + shared_expert_output = None + + # ==================================================================== + # Router Computation (with optional gate LoRA) + # ==================================================================== + base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate) + router_logits = F.linear(hidden_states_flat, gate_weight) + if gate_lora_delta is not None: + router_logits = router_logits + F.linear( + hidden_states_flat, gate_lora_delta + ) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + + top_k = base_gate.top_k + num_experts = base_gate.num_experts + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + if base_gate.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states_flat.dtype) + + sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count( + selected_experts, num_experts=num_experts + ) + + # ==================================================================== + # Detect LoRA (peft ParamWrapper) and extract adapter weights + # ==================================================================== + experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts) + + # ==================================================================== + # Gate + Up projection + # ==================================================================== + gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter] + + if gup_lora is not None: + gup_A, gup_B, gup_scaling = gup_lora + gup = parallel_linear_lora( + hidden_states_flat, + gate_up_W, + top_k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + lora_A=gup_A, + lora_B=gup_B, + scaling=gup_scaling, + grouped_in=False, + grouped_out=True, + use_fused_dX=True, + use_fused_gather=True, + ) + else: + gup = parallel_linear( + hidden_states_flat, + gate_up_W, + top_k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=False, + grouped_out=True, + ) + + gates, h = gup.chunk(2, dim=-1) + h = experts.act_fn(gates) * h + + # ==================================================================== + # Down projection + # ==================================================================== + down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden] + + if down_lora is not None: + down_A, down_B, down_scaling = down_lora + expert_output = parallel_linear_lora( + h, + down_W, + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + lora_A=down_A, + lora_B=down_B, + scaling=down_scaling, + gates=routing_weights, + grouped_in=True, + grouped_out=False, + use_fused_dX=True, + use_fused_gather=True, + ) + else: + expert_output = parallel_linear( + h, + down_W, + 1, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + grouped_in=True, + grouped_out=False, + gates=routing_weights, + ) + + # ==================================================================== + # Combine with shared expert and reshape + # ==================================================================== + if shared_expert_output is not None: + expert_output = expert_output + shared_expert_output + + expert_output = expert_output.view(batch_size, sequence_length, hidden_dim) + return expert_output diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py new file mode 100644 index 000000000..aec68311b --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/lora_ops.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +ParallelExperts module with LoRA support. + +Provides a drop-in replacement for ScatterMoE's ParallelExperts that +uses the fused LoRA kernel when adapter weights are attached. +""" + +from typing import Optional + +import torch +import torch.nn as nn + +from .parallel_linear_lora import parallel_linear_lora + + +class ParallelExperts(nn.Module): + """ + Parallel Experts with fused LoRA support. + + Drop-in replacement for the original ParallelExperts. When LoRA parameters + are attached via set_lora(), the forward pass uses a fused kernel: + Y = X @ W + scaling * (X @ A^T) @ B^T + """ + + def __init__( + self, + num_experts: int, + input_size: int, + output_size: int, + bias: bool = False, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + if bias: + self.bias = nn.Parameter(torch.empty(num_experts, output_size)) + else: + self.bias = None + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + self._lora_A: torch.Tensor | None = None + self._lora_B: torch.Tensor | None = None + self._lora_scaling: float | None = None + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.02) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def extra_repr(self) -> str: + return ( + f"num_experts={self.num_experts}, " + f"input_size={self.input_size}, " + f"output_size={self.output_size}" + ) + + def set_lora(self, lora_A: torch.Tensor, lora_B: torch.Tensor, scaling: float): + """Attach LoRA parameters for fused computation.""" + self._lora_A = lora_A + self._lora_B = lora_B + self._lora_scaling = scaling + + def clear_lora(self): + """Remove LoRA parameters.""" + self._lora_A = None + self._lora_B = None + self._lora_scaling = None + + def forward( + self, + inputs: torch.Tensor, + k: int, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + gates: Optional[torch.Tensor] = None, + grouped_in: bool = False, + grouped_out: bool = False, + ) -> torch.Tensor: + return parallel_linear_lora( + inputs, + self.weight.permute(0, 2, 1), # [E, input, output] + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + lora_A=self._lora_A, + lora_B=self._lora_B, + scaling=self._lora_scaling if self._lora_scaling is not None else 1.0, + expert_biases=self.bias, + gates=gates, + grouped_in=grouped_in, + grouped_out=grouped_out, + ) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py new file mode 100644 index 000000000..7a1eef472 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_experts.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/shawntan/scattermoe +# Copyright (c) Shawn Tan and ScatterMoE Contributors +# Licensed under the Apache License, Version 2.0 +# See https://github.com/shawntan/scattermoe/blob/main/LICENSE + +from typing import Optional + +import torch +import torch.nn as nn + +from . import kernels + + +@torch.library.custom_op("scattermoe::bincount", mutates_args={}) +def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: + return x.bincount(minlength=minlength) + + +@compileable_bincount.register_fake +def _(x: torch.Tensor, minlength: int) -> torch.Tensor: + return torch.empty(minlength, dtype=torch.long, device=x.device) + + +@torch.compile +def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int): + with torch.no_grad(): + flattened_expert_idxs = expert_idxs.flatten() + sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs) + expert_counts = compileable_bincount( + flattened_expert_idxs, minlength=num_experts + ) + expert_offsets = expert_counts.cumsum(-1) + return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets + + +class ParallelLinear(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + expert_weights: torch.Tensor, + k: int, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + expert_biases: Optional[torch.Tensor] = None, + gates: Optional[torch.Tensor] = None, + grouped_in: bool = False, + grouped_out: bool = False, + ): + with torch.device(x.device): + output = kernels.ops.scatter2scatter( + X=x, + W=expert_weights, + b=expert_biases, + k=k, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + x_grouped=grouped_in, + y_grouped=grouped_out, + ) + if gates is not None: + output_expanded = output.view( + gates.size(0), gates.size(1), output.size(-1) + ) + output = (gates.unsqueeze(1) @ output_expanded).squeeze(1) + else: + output_expanded = None + + ctx.save_for_backward( + x, + expert_weights, + expert_biases, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, + output_expanded, + ) + ctx.grouped_in = grouped_in + ctx.grouped_out = grouped_out + ctx.k = k + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + with torch.device(grad_out.device): + ( + x, + expert_weights, + expert_biases, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, + output_expanded, + ) = ctx.saved_tensors + k = ctx.k + grouped_in = ctx.grouped_in + grouped_out = ctx.grouped_out + + if gates is not None: + # calculate gates gradient + # d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1) + d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) + grouped_grad_out = output_expanded.flatten( + 0, 1 + ) # reuse expanded buffer later + else: + d_gates = None + gates_flat = None + gate_fan = 1 + grouped_grad_out = None + + if grouped_out: + grouped_grad_out = grad_out + else: + grouped_grad_out = kernels.ops.group( + grad_out, + sorted_scattered_idxs, + fan_out=gate_fan, + coeff=gates_flat, + out=grouped_grad_out, + ) + if grouped_in: + grouped_x = x + d_expanded_input = None + else: + grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k) + d_expanded_input = grouped_x + + d_weights, d_biases = kernels.ops.group_bwd_W( + DY=grouped_grad_out, + X=grouped_x, + expert_offsets=expert_offsets, + E=expert_weights.size(0), + has_bias=expert_biases is not None, + ) + + d_expanded_input = kernels.ops.scatter2scatter( + X=grouped_grad_out, + x_grouped=True, + W=expert_weights.permute(0, 2, 1), + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, + y_grouped=grouped_in, + out=d_expanded_input, # Reuse grouped_x buffer + ) + + if k == 1: + d_input = d_expanded_input + else: + d_input = d_expanded_input.view( + x.size(0), k, d_expanded_input.size(-1) + ).sum(-2) + return ( + # x, expert_weights, + d_input, + d_weights, + # k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets, + None, + None, + None, + None, + # bias, gates + d_biases, + d_gates, + # grouped_in, grouped_out, + None, + None, + ) + + +def parallel_linear( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + expert_biases=None, + gates=None, + grouped_in=False, + grouped_out=False, +): + results = ParallelLinear.apply( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + expert_biases, + gates, + grouped_in, + grouped_out, + ) + return results + + +class ParallelExperts(nn.Module): + def __init__(self, num_experts, input_size, output_size, bias=False) -> None: + super().__init__() + self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + + if bias: + self.bias = nn.Parameter(torch.empty(num_experts, output_size)) + else: + self.bias = None + + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + self.reset_parameters() + + def extra_repr(self): + return "num_experts={}, input_size={}, output_size={}".format( + self.num_experts, self.input_size, self.output_size + ) + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.02) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward( + self, + inputs, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates=None, + grouped_in=False, + grouped_out=False, + ): + results = parallel_linear( + inputs, + self.weight.permute(0, 2, 1), + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + expert_biases=self.bias, + gates=gates, + grouped_in=grouped_in, + grouped_out=grouped_out, + ) + return results diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py new file mode 100644 index 000000000..5d00e1230 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py @@ -0,0 +1,480 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +ScatterMoE + LoRA Autograd Function +==================================== + +Provides the autograd function and Python interface for fused ScatterMoE + LoRA. + +Key design for LoRA training: + - Expert weights W are FROZEN (no gradient computed for W). + - Only LoRA adapter weights (A, B) receive gradients. + - The input gradient dX is still computed (needed for upstream layers). + - This avoids the expensive group_bwd_W computation entirely. + +Forward: + Y = X @ W + scaling * (X @ A^T) @ B^T + +Backward (W frozen): + dX = dY @ W^T + scaling * (dY @ B) @ A (via scatter2scatter for base, separate for LoRA) + dA = scaling * (dY @ B)^T @ X (per-expert, on grouped data) + dB = scaling * dY^T @ (X @ A^T) (per-expert, on grouped data) +""" + +from typing import Optional + +import torch + +from .kernels import ops as base_ops +from .kernels.lora_ops import ( + group_bwd_lora, + group_bwd_lora_fused, + scatter2scatter_lora, + scatter2scatter_lora_dX, +) + + +class ScatterMoELoRA(torch.autograd.Function): + """ + Autograd function for fused ScatterMoE + LoRA with frozen expert weights. + + This function is optimized for the LoRA fine-tuning scenario where: + - Expert weights W are frozen (requires_grad=False) + - Only LoRA A and B matrices receive gradients + - Input gradients are computed for upstream layer backprop + """ + + @staticmethod + def forward( + ctx, + x: torch.Tensor, + expert_weights: torch.Tensor, + k: int, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + scaling: float, + expert_biases: Optional[torch.Tensor] = None, + gates: Optional[torch.Tensor] = None, + grouped_in: bool = False, + grouped_out: bool = False, + use_fused_dX: bool = False, + use_fused_gather: bool = False, + ): + with torch.device(x.device): + # Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T + output = scatter2scatter_lora( + X=x, + W=expert_weights, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=k, + lora_A=lora_A, + lora_B=lora_B, + scaling=scaling, + b=expert_biases, + x_grouped=grouped_in, + y_grouped=grouped_out, + ) + + # Handle gating (weighted combination of top-k expert outputs) + if gates is not None: + output_expanded = output.view( + gates.size(0), gates.size(1), output.size(-1) + ) + output = (gates.unsqueeze(1) @ output_expanded).squeeze(1) + else: + output_expanded = None + + ctx.save_for_backward( + x, + lora_A, + lora_B, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, + output_expanded, + ) + # Store frozen weights as plain Python attributes instead of + # save_for_backward. This avoids: + # 1. Version-check conflicts with FSDP unshard/reshard + # 2. Pinning all-gathered parameters via saved_tensors hooks + # 3. Interfering with activation offloading pack/unpack hooks + # Safe because expert_weights are frozen (requires_grad=False). + ctx.expert_weights = expert_weights + ctx.expert_biases = expert_biases + ctx.grouped_in = grouped_in + ctx.grouped_out = grouped_out + ctx.k = k + ctx.scaling = scaling + ctx.use_fused_dX = use_fused_dX + ctx.use_fused_gather = use_fused_gather + + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + with torch.device(grad_out.device): + ( + x, + lora_A, + lora_B, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + gates, + output_expanded, + ) = ctx.saved_tensors + expert_weights = ctx.expert_weights + + k = ctx.k + scaling = ctx.scaling + grouped_in = ctx.grouped_in + grouped_out = ctx.grouped_out + E = expert_weights.size(0) + + # ------------------------------------------------------------------ + # Gate gradients (if using top-k gating with routing weights) + # ------------------------------------------------------------------ + if gates is not None: + # d_gates[t, j] = output_expanded[t, j, :] . grad_out[t, :] + d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1) + gates_flat = gates.flatten() + gate_fan = gates.size(1) + # Reuse output_expanded buffer for grouped_grad_out + grouped_grad_out = output_expanded.flatten(0, 1) + else: + d_gates = None + gates_flat = None + gate_fan = 1 + grouped_grad_out = None + + # ------------------------------------------------------------------ + # LoRA gradients (dA, dB) and setup for dX + # ------------------------------------------------------------------ + # Fused gather uses sorted_scattered_idxs for indirect X access + # in the Triton kernel, avoiding the group(x) allocation. + # + # can_fuse_gather: X is ungrouped and not too large for scatter loads + # - When gates is None and grouped_out=False: both DY and X ungrouped + # - When grouped_out=True (gate_up_proj): DY already grouped, X ungrouped + # -> use dy_grouped=True in the fused kernel + M_total = sorted_scattered_idxs.size(0) + K_dim = x.size(-1) + N_dim = expert_weights.size(-1) + fuse_gather_workload = M_total * max(K_dim, N_dim) + _FUSE_GATHER_THRESHOLD = 2**24 # ~16M elements + + can_fuse_gather = ( + ctx.use_fused_gather + and not grouped_in # X must be ungrouped for scatter access + and gates is None # gate coeff requires multiplicative gather + and fuse_gather_workload < _FUSE_GATHER_THRESHOLD + ) + + if can_fuse_gather: + # ------------------------------------------------------------------ + # Fused path: skip group(x) entirely + # ------------------------------------------------------------------ + d_expanded_input = None + + d_lora_A, d_lora_B = group_bwd_lora_fused( + DY=grad_out, + X=x, + lora_A=lora_A, + lora_B=lora_B, + expert_offsets=expert_offsets, + sorted_scattered_idxs=sorted_scattered_idxs, + E=E, + k=k, + scaling=scaling, + dy_grouped=grouped_out, + ) + + # Prepare grouped_grad_out for the dX path (needed by both + # the fused dX kernel when grouped_out=True, and the non-fused path) + if grouped_out: + grouped_grad_out = grad_out + elif not ctx.use_fused_dX: + grouped_grad_out = base_ops.group( + grad_out, + sorted_scattered_idxs, + fan_out=gate_fan, + coeff=gates_flat, + out=grouped_grad_out, + ) + else: + # ------------------------------------------------------------------ + # Original path: explicit group() calls + # ------------------------------------------------------------------ + if grouped_out: + grouped_grad_out = grad_out + else: + grouped_grad_out = base_ops.group( + grad_out, + sorted_scattered_idxs, + fan_out=gate_fan, + coeff=gates_flat, + out=grouped_grad_out, + ) + + if grouped_in: + grouped_x = x + d_expanded_input = None + else: + grouped_x = base_ops.group(x, sorted_scattered_idxs, fan_out=k) + d_expanded_input = grouped_x # Will be overwritten; reuse buffer + + d_lora_A, d_lora_B = group_bwd_lora( + DY=grouped_grad_out, + X=grouped_x, + lora_A=lora_A, + lora_B=lora_B, + expert_offsets=expert_offsets, + E=E, + scaling=scaling, + ) + + # ------------------------------------------------------------------ + # Input gradient: dX = dY @ W^T + scaling * (dY @ B) @ A + # ------------------------------------------------------------------ + if ctx.use_fused_dX: + if can_fuse_gather and not grouped_out: + # Fully fused: read ungrouped DY via scatter pattern + d_expanded_input = scatter2scatter_lora_dX( + DY=grad_out, + W=expert_weights, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, + lora_A=lora_A, + lora_B=lora_B, + scaling=scaling, + dy_grouped=False, + dx_grouped=grouped_in, + out=d_expanded_input, + ) + else: + # Fused dX only: read from pre-grouped DY + d_expanded_input = scatter2scatter_lora_dX( + DY=grouped_grad_out, + W=expert_weights, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, + lora_A=lora_A, + lora_B=lora_B, + scaling=scaling, + dy_grouped=True, + dx_grouped=grouped_in, + out=d_expanded_input, + ) + else: + # Original path: separate base scatter2scatter + LoRA Python loop + d_expanded_input = base_ops.scatter2scatter( + X=grouped_grad_out, + x_grouped=True, + W=expert_weights.permute(0, 2, 1), # [E, N, K] + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, + y_grouped=grouped_in, + out=d_expanded_input, + ) + + # LoRA part: dX_lora = scaling * (dY @ B) @ A + if scaling != 0.0: + d_input_lora_grouped = _compute_lora_input_grad( + grouped_grad_out, + lora_A, + lora_B, + expert_offsets, + E, + scaling, + ) + if grouped_in: + d_expanded_input.add_(d_input_lora_grouped) + else: + # Scatter-add LoRA gradient directly into d_expanded_input. + # Avoids allocating a zeros_like + add result + d_expanded_input[sorted_scattered_idxs] += d_input_lora_grouped + + # Reduce over top-k if k > 1 + if k == 1: + d_input = d_expanded_input + else: + d_input = d_expanded_input.view( + x.size(0), k, d_expanded_input.size(-1) + ).sum(-2) + + # W is frozen during LoRA training -- skip weight gradient + d_weights = ( + torch.zeros_like(expert_weights) + if expert_weights.requires_grad + else None + ) + d_biases = None + + return ( + d_input, + d_weights, + None, + None, + None, + None, # k, sorted indices, offsets + d_lora_A, + d_lora_B, + None, # lora_A, lora_B, scaling + d_biases, + d_gates, + None, + None, # grouped_in, grouped_out + None, # use_fused_dX + None, # use_fused_gather + ) + + +def _compute_lora_input_grad( + grouped_grad_out: torch.Tensor, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + expert_offsets: torch.Tensor, + E: int, + scaling: float, +) -> torch.Tensor: + """ + Compute the LoRA contribution to the input gradient: + dX_lora = scaling * (dY @ B) @ A + + Uses PyTorch ops on expert-grouped data. + Each expert e: dX_e = scaling * (dY_e @ B_e) @ A_e + """ + R = lora_A.size(0) // E + K = lora_A.size(1) + M_total = grouped_grad_out.size(0) + + d_input_lora = torch.zeros( + (M_total, K), device=grouped_grad_out.device, dtype=grouped_grad_out.dtype + ) + + compute_dtype = grouped_grad_out.dtype + + prev_offset = 0 + for e in range(E): + curr_offset = expert_offsets[e].item() + if curr_offset > prev_offset: + dy_e = grouped_grad_out[prev_offset:curr_offset] # [M_e, N] + a_e = lora_A[e * R : (e + 1) * R, :].to(compute_dtype) # [r, K] + b_e = lora_B[:, e * R : (e + 1) * R].to(compute_dtype) # [N, r] + + # dX_e = scaling * (dY_e @ B_e) @ A_e + dy_b = dy_e @ b_e # [M_e, r] + dx_e = scaling * (dy_b @ a_e) # [M_e, K] + d_input_lora[prev_offset:curr_offset] = dx_e + + prev_offset = curr_offset + + return d_input_lora + + +# ============================================================================= +# Helper: Extract LoRA params from PEFT ParamWrapper +# ============================================================================= + + +def get_lora_params_from_wrapper(module) -> tuple: + """ + Extract LoRA parameters from a PEFT ParamWrapper. + + Returns: + (lora_A, lora_B, scaling) if LoRA is active, else (None, None, None) + """ + if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"): + return None, None, None + + active_adapters = getattr(module, "active_adapters", ["default"]) + if not active_adapters: + return None, None, None + + adapter_name = active_adapters[0] + + lora_A_dict = getattr(module, "lora_A", {}) + lora_B_dict = getattr(module, "lora_B", {}) + scaling_dict = getattr(module, "scaling", {}) + + if adapter_name not in lora_A_dict: + return None, None, None + + lora_A = lora_A_dict[adapter_name].weight + lora_B = lora_B_dict[adapter_name].weight + scaling = scaling_dict[adapter_name] + + return lora_A, lora_B, scaling + + +# ============================================================================= +# Drop-in replacement for parallel_linear +# ============================================================================= + + +def parallel_linear_lora( + inputs: torch.Tensor, + expert_weights: torch.Tensor, + k: int, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + lora_A: Optional[torch.Tensor] = None, + lora_B: Optional[torch.Tensor] = None, + scaling: float = 1.0, + expert_biases: Optional[torch.Tensor] = None, + gates: Optional[torch.Tensor] = None, + grouped_in: bool = False, + grouped_out: bool = False, + use_fused_dX: bool = False, + use_fused_gather: bool = False, +): + """ + Drop-in replacement for parallel_linear that supports LoRA. + + If lora_A and lora_B are provided, uses fused LoRA kernel. + Otherwise falls back to standard scatter2scatter. + """ + if lora_A is not None and lora_B is not None: + return ScatterMoELoRA.apply( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + lora_A, + lora_B, + scaling, + expert_biases, + gates, + grouped_in, + grouped_out, + use_fused_dX, + use_fused_gather, + ) + else: + from .parallel_experts import ParallelLinear + + return ParallelLinear.apply( + inputs, + expert_weights, + k, + sorted_expert_idxs, + sorted_scattered_idxs, + expert_offsets, + expert_biases, + gates, + grouped_in, + grouped_out, + ) diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index c7fb79ff6..56d0448d5 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -1,5 +1,7 @@ +from pathlib import Path + from kernels import ( - LayerRepository, + LocalLayerRepository, Mode, register_kernel_mapping, replace_kernel_forward_from_hub, @@ -19,16 +21,19 @@ class KernelsPlugin(BasePlugin): self._kernelize_model(cfg.model_config_type) def _register_kernels(self): + plugin_root = Path(__file__).parent register_kernel_mapping( { "HFScatterMoEParallelExperts": { "cuda": { - Mode.TRAINING: LayerRepository( - repo_id="axolotl-ai-co/scattermoe", + Mode.TRAINING: LocalLayerRepository( + repo_path=plugin_root / "libs" / "scattermoe_lora", + package_name="scattermoe_lora", layer_name="HFScatterMoEGatedMLP", ), - Mode.INFERENCE: LayerRepository( - repo_id="axolotl-ai-co/scattermoe", + Mode.INFERENCE: LocalLayerRepository( + repo_path=plugin_root / "libs" / "scattermoe_lora", + package_name="scattermoe_lora", layer_name="HFScatterMoEGatedMLP", ), }, diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 222260020..62dcbde7a 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -329,7 +329,7 @@ class PatchManager: else: has_remote_code = False - if has_remote_code and self.cfg.trust_remote_code is False: + if has_remote_code and self.cfg.trust_remote_code is not None: # If explicitly set in YAML, prefer that has_remote_code = self.cfg.trust_remote_code diff --git a/src/axolotl/utils/data/lock.py b/src/axolotl/utils/data/lock.py index afd1547af..9699f60e8 100644 --- a/src/axolotl/utils/data/lock.py +++ b/src/axolotl/utils/data/lock.py @@ -54,15 +54,19 @@ class FileLockLoader: def cleanup(self): """Clean up ready flag when last process is done.""" - with FileLock(str(self.lock_file_path)): - counter_content = self.counter_path.read_text().strip() - count = int(counter_content) if counter_content else 0 - count -= 1 + try: + with FileLock(str(self.lock_file_path)): + counter_content = self.counter_path.read_text().strip() + count = int(counter_content) if counter_content else 0 + count -= 1 - if count <= 0: - # Last process cleans everything up - self.ready_flag_path.unlink(missing_ok=True) - self.counter_path.unlink(missing_ok=True) - else: - # Still have active processes - self.counter_path.write_text(str(count)) + if count <= 0: + # Last process cleans everything up + self.ready_flag_path.unlink(missing_ok=True) + self.counter_path.unlink(missing_ok=True) + else: + # Still have active processes + self.counter_path.write_text(str(count)) + except FileNotFoundError: + # Lock file might have already been deleted by another process + pass diff --git a/tests/integrations/test_scattermoe_lora.py b/tests/integrations/test_scattermoe_lora.py new file mode 100644 index 000000000..859119c81 --- /dev/null +++ b/tests/integrations/test_scattermoe_lora.py @@ -0,0 +1,323 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Axolotl AI +# Licensed under the Apache License, Version 2.0 + +""" +Unit tests for scattermoe-lora code-review fixes. + +Tests cover: +- KernelsArgs validator: disable_mlp_kernel_scattermoe +- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward +- ParallelExperts: scaling=0.0 not treated as falsy +- single2scatter: non-aligned K/N dimensions +- group_compileable: coeff=None accepted +- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract +""" + +from unittest.mock import patch + +import pytest +import torch + +# ============================================================================ +# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator +# ============================================================================ + + +class TestKernelsArgsValidator: + """Test that disable_mlp_kernel_scattermoe sets both flags correctly. + + These tests call the validator classmethod directly on raw dicts, + since lora_mlp_kernel / mlp_kernel are not declared model fields. + """ + + def test_disables_lora_mlp_kernel_when_scattermoe(self): + """lora_mlp_kernel=True gets set to False when use_scattermoe=True.""" + from axolotl.integrations.kernels.args import KernelsArgs + + data = { + "use_kernels": True, + "use_scattermoe": True, + "lora_mlp_kernel": True, + } + result = KernelsArgs.disable_mlp_kernel_scattermoe(data) + assert result["lora_mlp_kernel"] is False + assert result["mlp_kernel"] is False + + def test_mlp_kernel_disabled_without_lora(self): + """Even without lora_mlp_kernel, mlp_kernel should be disabled.""" + from axolotl.integrations.kernels.args import KernelsArgs + + data = { + "use_kernels": True, + "use_scattermoe": True, + } + result = KernelsArgs.disable_mlp_kernel_scattermoe(data) + assert result["mlp_kernel"] is False + # lora_mlp_kernel was not in data, should not be added + assert "lora_mlp_kernel" not in result + + def test_lora_mlp_kernel_false_unchanged(self): + """lora_mlp_kernel=False should stay False (no warning, no change).""" + from axolotl.integrations.kernels.args import KernelsArgs + + data = { + "use_kernels": True, + "use_scattermoe": True, + "lora_mlp_kernel": False, + } + result = KernelsArgs.disable_mlp_kernel_scattermoe(data) + assert result["lora_mlp_kernel"] is False + + def test_no_change_when_scattermoe_disabled(self): + """When use_scattermoe is not True, nothing should be changed.""" + from axolotl.integrations.kernels.args import KernelsArgs + + data = { + "use_kernels": True, + "use_scattermoe": False, + "lora_mlp_kernel": True, + } + result = KernelsArgs.disable_mlp_kernel_scattermoe(data) + assert result["lora_mlp_kernel"] is True + + +class TestParallelExpertsScaling: + """Test that scaling=0.0 is preserved and not overridden to 1.0.""" + + def test_scaling_zero_preserved(self): + """scaling=0.0 should be passed as 0.0, not replaced with 1.0.""" + pytest.importorskip("triton") + from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import ( + ParallelExperts, + ) + + pe = ParallelExperts(num_experts=2, input_size=4, output_size=4) + pe.set_lora( + lora_A=torch.randn(4, 4), + lora_B=torch.randn(4, 4), + scaling=0.0, + ) + assert pe._lora_scaling == 0.0 + + # Patch parallel_linear_lora to capture the scaling arg + with patch( + "axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora" + ) as mock_pll: + mock_pll.return_value = torch.randn(4, 4) + # Create dummy routing tensors + pe.forward( + inputs=torch.randn(2, 4), + k=1, + sorted_expert_idxs=torch.tensor([0, 0, 1, 1]), + sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]), + expert_offsets=torch.tensor([2, 4]), + ) + # Check that scaling=0.0 was passed, not 1.0 + call_kwargs = mock_pll.call_args + assert ( + call_kwargs.kwargs.get("scaling") == 0.0 + or call_kwargs[1].get("scaling") == 0.0 + ), f"Expected scaling=0.0 but got {call_kwargs}" + + def test_scaling_none_defaults_to_one(self): + """scaling=None (no LoRA attached) should default to 1.0.""" + pytest.importorskip("triton") + from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import ( + ParallelExperts, + ) + + pe = ParallelExperts(num_experts=2, input_size=4, output_size=4) + # No set_lora called, so _lora_scaling is None + + with patch( + "axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora" + ) as mock_pll: + mock_pll.return_value = torch.randn(4, 4) + pe.forward( + inputs=torch.randn(2, 4), + k=1, + sorted_expert_idxs=torch.tensor([0, 0, 1, 1]), + sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]), + expert_offsets=torch.tensor([2, 4]), + ) + call_kwargs = mock_pll.call_args + scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get( + "scaling" + ) + assert scaling_val == 1.0, ( + f"Expected scaling=1.0 for None but got {scaling_val}" + ) + + def test_scaling_positive_preserved(self): + """Normal positive scaling should be preserved.""" + pytest.importorskip("triton") + from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import ( + ParallelExperts, + ) + + pe = ParallelExperts(num_experts=2, input_size=4, output_size=4) + pe.set_lora( + lora_A=torch.randn(4, 4), + lora_B=torch.randn(4, 4), + scaling=0.5, + ) + + with patch( + "axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora" + ) as mock_pll: + mock_pll.return_value = torch.randn(4, 4) + pe.forward( + inputs=torch.randn(2, 4), + k=1, + sorted_expert_idxs=torch.tensor([0, 0, 1, 1]), + sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]), + expert_offsets=torch.tensor([2, 4]), + ) + call_kwargs = mock_pll.call_args + scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get( + "scaling" + ) + assert scaling_val == 0.5 + + +# ============================================================================ +# 4. single2scatter: non-aligned K/N dimensions (GPU only) +# ============================================================================ + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestSingle2ScatterBounds: + """Test single2scatter with non-aligned dimensions.""" + + def test_non_aligned_k(self): + """K not a multiple of BLOCK_K should produce correct results.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import ( + single2scatter, + ) + + E, K, N = 2, 100, 128 # K=100 not a multiple of 128 + W = torch.randn(E, K, N, device="cuda", dtype=torch.float32) + X = torch.randn(1, K, device="cuda", dtype=torch.float32) + expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long) + + Y = single2scatter(X, W, expert_idxs) + assert Y.shape == (2, N) + + # Verify against manual computation + Y_ref_0 = X[0] @ W[0] + Y_ref_1 = X[0] @ W[1] + torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2) + + def test_non_aligned_n(self): + """N not a multiple of BLOCK_N should produce correct results.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import ( + single2scatter, + ) + + E, K, N = 2, 128, 100 # N=100 not a multiple of 128 + W = torch.randn(E, K, N, device="cuda", dtype=torch.float32) + X = torch.randn(1, K, device="cuda", dtype=torch.float32) + expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long) + + Y = single2scatter(X, W, expert_idxs) + assert Y.shape == (2, N) + + Y_ref_0 = X[0] @ W[0] + Y_ref_1 = X[0] @ W[1] + torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2) + + def test_non_aligned_both(self): + """Both K and N not aligned should produce correct results.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import ( + single2scatter, + ) + + E, K, N = 2, 100, 100 # Neither aligned to 128 + W = torch.randn(E, K, N, device="cuda", dtype=torch.float32) + X = torch.randn(1, K, device="cuda", dtype=torch.float32) + expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long) + + Y = single2scatter(X, W, expert_idxs) + assert Y.shape == (2, N) + + Y_ref_0 = X[0] @ W[0] + Y_ref_1 = X[0] @ W[1] + torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2) + + +# ============================================================================ +# 5. group_compileable: coeff=None accepted +# ============================================================================ + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestGroupCoeffNone: + """Test that group() works with coeff=None.""" + + def test_group_with_none_coeff(self): + """group() should accept coeff=None without errors.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group + + M, K = 4, 32 + A = torch.randn(M, K, device="cuda", dtype=torch.float32) + sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long) + + # This should not raise a TypeError + Y = group(A, sorted_expert_idxs, coeff=None, fan_out=1) + assert Y.shape == (M, K) + + def test_group_with_coeff(self): + """group() should also work with actual coeff values.""" + from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group + + M, K = 4, 32 + A = torch.randn(M, K, device="cuda", dtype=torch.float32) + sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long) + coeff = torch.ones(M, device="cuda", dtype=torch.float32) * 0.5 + + Y = group(A, sorted_expert_idxs, coeff=coeff, fan_out=1) + assert Y.shape == (M, K) + + +# ============================================================================ +# 6. Layer return value contracts +# ============================================================================ + + +class TestLayerReturnValues: + """Test that layer forward methods return the correct types.""" + + def test_hf_scatter_moe_returns_single_tensor(self): + """HFScatterMoEGatedMLP.forward should return a single tensor, not a tuple.""" + pytest.importorskip("triton") + # Verify the forward method signature and return annotation + import inspect + + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + HFScatterMoEGatedMLP, + ) + + sig = inspect.signature(HFScatterMoEGatedMLP.forward) + # It's a staticmethod taking (self, layer_input) + params = list(sig.parameters.keys()) + assert "self" in params + assert "layer_input" in params + + def test_scatter_moe_gated_mlp_docstring_no_router_logits(self): + """ScatterMoEGatedMLP.forward docstring should not mention router logits as return.""" + pytest.importorskip("triton") + from axolotl.integrations.kernels.libs.scattermoe_lora.layers import ( + ScatterMoEGatedMLP, + ) + + docstring = ScatterMoEGatedMLP.forward.__doc__ + assert docstring is not None + # The docstring should mention output tensor but NOT router logits + assert "Output tensor" in docstring or "output tensor" in docstring.lower() + assert "Router logits" not in docstring, ( + "Docstring should not mention 'Router logits' in Returns section" + )