From c5db90aa3f8e2daa9ced22194524224eddc05a59 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 18 Mar 2026 21:42:13 +0000 Subject: [PATCH] optimize moe + lora --- .../libs/scattermoe_lora/kernels/lora_ops.py | 288 +++++++++++++++--- src/axolotl/integrations/kernels/plugin.py | 6 + src/axolotl/loaders/model.py | 14 + src/axolotl/utils/callbacks/profiler.py | 37 ++- 4 files changed, 309 insertions(+), 36 deletions(-) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py index 5d47c2040..16af5fcdc 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py @@ -1330,6 +1330,224 @@ def _group_bwd_lora( ) +def _group_bwd_split_configs(): + """Autotune configs for split dA/dB kernels.""" + configs = [] + for block_m, block_dim, warps, stages in product( + [32, 64, 128], # BLOCK_M (token tile) + [32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile) + [4, 8], # num_warps + [3, 4, 5], # num_stages + ): + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_DIM": block_dim}, + num_stages=stages, + num_warps=warps, + ) + ) + return configs + + +def _prune_split_configs(configs, named_args, **kwargs): + """Prune split kernel configs based on SMEM capacity.""" + smem_cap = _get_smem_capacity() + block_r = named_args.get("BLOCK_R", 64) + inner_dim = named_args.get("INNER_DIM", 2048) + + # Fixed inner tile for reduction dimension + BLOCK_INNER = 64 + + pruned = [] + for config in configs: + block_m = config.kwargs["BLOCK_M"] + block_dim = config.kwargs["BLOCK_DIM"] + # Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM] + smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2 + # LoRA weights held in registers: [INNER, R] or [R, DIM] + smem += (block_r * max(block_dim, BLOCK_INNER)) * 2 + if smem <= smem_cap - _SMEM_SLACK: + pruned.append(config) + + if pruned: + return pruned + configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"]) + return [configs[0]] + + +@triton.autotune( + configs=_group_bwd_split_configs(), + key=["M", "K", "N"], + prune_configs_by={"early_config_prune": _prune_split_configs}, +) +@triton.heuristics({ + "NO_DIM_MASK": lambda args: ( + (args["K"] % args["BLOCK_DIM"]) == 0 + if args["COMPUTE_DA"] + else (args["N"] % args["BLOCK_DIM"]) == 0 + ), +}) +@triton.jit +def _group_bwd_lora_split( + # Data tensors (DY and X are always present) + DY_ptr, stride_dym, stride_dyn, + X_ptr, stride_xm, stride_xk, + # LoRA weight for the inner reduction (B for dA, A for dB) + LW_ptr, stride_lw0, stride_lw1, + # Output gradient tensor (dA or dB) + OUT_ptr, stride_out0, stride_out1, + # Expert offsets + expert_offsets_ptr, + # Dimensions + M, K: tl.constexpr, N: tl.constexpr, + ACTUAL_R: tl.constexpr, BLOCK_R: tl.constexpr, + INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB) + scaling, + # Mode flag + COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB + # Tile sizes + BLOCK_M: tl.constexpr, BLOCK_DIM: tl.constexpr, + ACC_TYPE: tl.constexpr, + allow_tf32: tl.constexpr, + NO_DIM_MASK: tl.constexpr, +): + """ + Unified split kernel for LoRA gradient computation. + + When COMPUTE_DA=True: + dA[e] = scaling * (dY @ B[e])^T @ X → [R, K] + Grid: (E, cdiv(K, BLOCK_DIM)) + - outer_ptr/stride = X (read [M, K_block]) + - inner reduction over N using DY and B + - output shape [BLOCK_R, BLOCK_DIM] + + When COMPUTE_DA=False: + dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R] + Grid: (E, cdiv(N, BLOCK_DIM)) + - outer_ptr/stride = DY (read [M, N_block]) + - inner reduction over K using X and A + - output shape [BLOCK_DIM, BLOCK_R] + + No atomic adds — each (E, dim_block) pair is written by exactly one block. + """ + E_idx = tl.program_id(0) + dim_block_id = tl.program_id(1) + + if E_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32) + end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32) + num_tokens = end_idx - start_idx + + # Output dimension tile (K for dA, N for dB) + if COMPUTE_DA: + OUT_DIM: tl.constexpr = K + else: + OUT_DIM: tl.constexpr = N + dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM) + dim_mask = dim_block < OUT_DIM + R_block = tl.arange(0, BLOCK_R) + R_mask = R_block < ACTUAL_R + lora_offset = E_idx * ACTUAL_R + + # Output pointers — layout differs: dA is [R, K], dB is [N, R] + if COMPUTE_DA: + out_blk_ptrs = ( + OUT_ptr + + (lora_offset + R_block)[:, None] * stride_out0 + + dim_block[None, :] * stride_out1 + ) + out_mask = R_mask[:, None] & dim_mask[None, :] + else: + out_blk_ptrs = ( + OUT_ptr + + dim_block[:, None] * stride_out0 + + (lora_offset + R_block)[None, :] * stride_out1 + ) + out_mask = dim_mask[:, None] & R_mask[None, :] + + if num_tokens > 0: + M_block = tl.arange(0, BLOCK_M) + INPUT_DTYPE = X_ptr.dtype.element_ty + BLOCK_INNER: tl.constexpr = 64 + inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER) + + if COMPUTE_DA: + acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE) + else: + acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE) + + M_iters = tl.cdiv(num_tokens, BLOCK_M) + for i in range(M_iters): + M_idx = start_idx + i * BLOCK_M + M_block + M_mask = M_idx < end_idx + + if COMPUTE_DA: + # Load X[M, K_block] (the "outer" tensor for dA) + outer = tl.load( + X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk, + mask=M_mask[:, None] & dim_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + + # Reduce DY[M, :] @ B[e][:, R] over N → [M, R] + reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE) + inner_range = tl.arange(0, BLOCK_INNER) + for j in range(inner_iters): + inn_off = j * BLOCK_INNER + inner_range + inn_mask = inn_off < N + + dy_tile = tl.load( + DY_ptr + M_idx[:, None] * stride_dym + inn_off[None, :] * stride_dyn, + mask=M_mask[:, None] & inn_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + # B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride + lw_tile = tl.load( + LW_ptr + inn_off[:, None] * stride_lw0 + (lora_offset + R_block)[None, :] * stride_lw1, + mask=inn_mask[:, None] & R_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32) + + # dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block] + acc += tl.dot(tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32) + else: + # Load DY[M, N_block] (the "outer" tensor for dB) + outer = tl.load( + DY_ptr + M_idx[:, None] * stride_dym + dim_block[None, :] * stride_dyn, + mask=M_mask[:, None] & dim_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + + # Reduce X[M, :] @ A[e][:, :].T over K → [M, R] + reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE) + inner_range = tl.arange(0, BLOCK_INNER) + for j in range(inner_iters): + inn_off = j * BLOCK_INNER + inner_range + inn_mask = inn_off < K + + x_tile = tl.load( + X_ptr + M_idx[:, None] * stride_xm + inn_off[None, :] * stride_xk, + mask=M_mask[:, None] & inn_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + # A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride + # We want A[e]^T: [K, R], so load as [K_inner, R] + lw_tile = tl.load( + LW_ptr + (lora_offset + R_block)[None, :] * stride_lw0 + inn_off[:, None] * stride_lw1, + mask=inn_mask[:, None] & R_mask[None, :], other=0.0 + ).to(INPUT_DTYPE) + reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32) + + # dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R] + acc += tl.dot(tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32) + + tl.store(out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask) + else: + # Zero out this expert's slice — needed because output uses empty_like + if COMPUTE_DA: + tl.store(out_blk_ptrs, tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty), mask=out_mask) + else: + tl.store(out_blk_ptrs, tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty), mask=out_mask) + + def group_bwd_lora( DY: torch.Tensor, X: torch.Tensor, @@ -1344,6 +1562,9 @@ def group_bwd_lora( """ Compute LoRA gradients for A and B on expert-grouped data. + Uses split dA/dB kernels that eliminate atomic adds by giving each + (expert, output_block) pair its own thread block. + Args: DY: Gradient w.r.t. output [M_total, N] (grouped by expert) X: Input [M_total, K] (grouped by expert) @@ -1361,46 +1582,45 @@ def group_bwd_lora( K = X.size(1) N = DY.size(1) - # Zero-init for atomic accumulation - dA = torch.zeros_like(lora_A) - dB = torch.zeros_like(lora_B) + # No zero-init needed: the split kernels write zeros for experts with + # zero routed tokens directly in the kernel (else branch). + dA = torch.empty_like(lora_A) + dB = torch.empty_like(lora_B) BLOCK_R = _block_r_for_rank(R) - def grid(META): - return ( - E * triton.cdiv(K, META["BLOCK_K"]), - triton.cdiv(N, META["BLOCK_N"]), - ) + def grid_dA(META): + return (E, triton.cdiv(K, META["BLOCK_DIM"])) - _group_bwd_lora[grid]( - DY, - DY.stride(0), - DY.stride(1), - X, - X.stride(0), - X.stride(1), - lora_A, - lora_A.stride(0), - lora_A.stride(1), - lora_B, - lora_B.stride(0), - lora_B.stride(1), - dA, - dA.stride(0), - dA.stride(1), - dB, - dB.stride(0), - dB.stride(1), + _group_bwd_lora_split[grid_dA]( + DY, DY.stride(0), DY.stride(1), + X, X.stride(0), X.stride(1), + lora_B, lora_B.stride(0), lora_B.stride(1), + dA, dA.stride(0), dA.stride(1), expert_offsets, - M=DY.size(0), - K=K, - N=N, - ACTUAL_R=R, # True LoRA rank - BLOCK_R=BLOCK_R, # Padded tile size + M=DY.size(0), K=K, N=N, + ACTUAL_R=R, BLOCK_R=BLOCK_R, + INNER_DIM=N, scaling=scaling, - ACC_TYPE=tl.float32, - allow_tf32=ALLOW_TF32, + COMPUTE_DA=True, + ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32, + ) + + def grid_dB(META): + return (E, triton.cdiv(N, META["BLOCK_DIM"])) + + _group_bwd_lora_split[grid_dB]( + DY, DY.stride(0), DY.stride(1), + X, X.stride(0), X.stride(1), + lora_A, lora_A.stride(0), lora_A.stride(1), + dB, dB.stride(0), dB.stride(1), + expert_offsets, + M=DY.size(0), K=K, N=N, + ACTUAL_R=R, BLOCK_R=BLOCK_R, + INNER_DIM=K, + scaling=scaling, + COMPUTE_DA=False, + ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32, ) return dA, dB diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index 351db5ef2..c3b0360ac 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -61,7 +61,13 @@ class KernelsPlugin(BasePlugin): return "axolotl.integrations.kernels.KernelsArgs" def pre_model_load(self, cfg): + from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK + + # Prefer text backbone type for VLMs, but fall back to base type + # when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text) moe_model_type = cfg.model_config_type_text or cfg.model_config_type + if moe_model_type not in SPARSE_MOE_BLOCK and cfg.model_config_type in SPARSE_MOE_BLOCK: + moe_model_type = cfg.model_config_type if cfg.use_scattermoe: self._register_kernels() diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 37c112337..dd3f4ddfa 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -505,6 +505,20 @@ class ModelLoader: elif not is_ds_zero3: self.model_kwargs["device_map"] = device_map + # quantize_moe_experts quantizes expert weights on-the-fly during loading, + # so the actual VRAM usage is much less than bf16 estimates. + # When device_map is "auto", accelerate's infer_auto_device_map computes + # the device map at bf16 size (before quantization), causing it to offload + # layers to CPU, which BnB then rejects. Force single-GPU placement to + # prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single). + if getattr(self.cfg, "quantize_moe_experts", False) and device_map in ( + "auto", + None, + ): + self.model_kwargs["device_map"] = { + "": int(os.environ.get("LOCAL_RANK", 0)) + } + cur_device = get_device_type() if "mps" in str(cur_device): self.model_kwargs["device_map"] = "mps:0" diff --git a/src/axolotl/utils/callbacks/profiler.py b/src/axolotl/utils/callbacks/profiler.py index 2cf5e0f4f..2598eac94 100644 --- a/src/axolotl/utils/callbacks/profiler.py +++ b/src/axolotl/utils/callbacks/profiler.py @@ -17,6 +17,8 @@ from transformers import ( class PytorchProfilerCallback(TrainerCallback): """ PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps. + + Also runs torch.profiler to produce a Chrome trace for timing analysis. """ def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0): @@ -26,9 +28,12 @@ class PytorchProfilerCallback(TrainerCallback): if profiler_steps_start == 0: # start recording memory allocations before everything is allocated, because if we start # at the beginning of step 0, we won't have any memory allocations in the traces - torch.cuda.memory._record_memory_history(enabled="all") + torch.cuda.memory._record_memory_history( + enabled="all", stacks="all" + ) profiler_steps_start = -1 self.profiler_steps_start = profiler_steps_start + self._profiler = None def on_step_begin( self, @@ -38,7 +43,22 @@ class PytorchProfilerCallback(TrainerCallback): **kwargs, ): if state.global_step == self.profiler_steps_start: - torch.cuda.memory._record_memory_history(enabled="all") + torch.cuda.memory._record_memory_history( + enabled="all", stacks="all" + ) + + # Start torch.profiler on the first profiled step + if state.global_step == max(self.profiler_steps_start, 0): + self._profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + profile_memory=True, + with_stack=True, + ) + self._profiler.__enter__() def on_step_end( self, @@ -55,6 +75,13 @@ class PytorchProfilerCallback(TrainerCallback): # tell CUDA to stop recording memory allocations now torch.cuda.memory._record_memory_history(enabled=None) + # Stop and export torch.profiler trace + if self._profiler is not None: + self._profiler.__exit__(None, None, None) + trace_path = Path(args.output_dir) / "profiler_trace.json" + self._profiler.export_chrome_trace(str(trace_path)) + self._profiler = None + def on_train_end( self, args: TrainingArguments, @@ -73,3 +100,9 @@ class PytorchProfilerCallback(TrainerCallback): # tell CUDA to stop recording memory allocations now torch.cuda.memory._record_memory_history(enabled=None) + + if self._profiler is not None: + self._profiler.__exit__(None, None, None) + trace_path = Path(args.output_dir) / "profiler_trace.json" + self._profiler.export_chrome_trace(str(trace_path)) + self._profiler = None