diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py index 2a221c13c..53aa861ea 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py @@ -195,6 +195,30 @@ def _estimate_smem_usage( _SMEM_SLACK = 10_000 +def _estimate_register_pressure( + num_warps: int, + *tile_sizes: tuple[int, int], +) -> float: + """Estimate per-thread register count from live tile sizes. + + Each tile of shape (rows, cols) requires rows*cols elements distributed + across 32 threads per warp, but each thread in the warp holds a fragment. + For Triton GEMM-style kernels, the register footprint per thread is + approximately sum(rows * cols) / 32 for each live tile, plus ~40 for + scalar overhead (loop counters, pointers, masks, etc.). + + Returns estimated registers per thread. + """ + # Each thread in a warp holds 1/32 of the tile elements + tile_regs = sum(r * c for r, c in tile_sizes) / 32 + scalar_overhead = 40 + return tile_regs + scalar_overhead + + +# Maximum registers per thread on NVIDIA GPUs +_MAX_REGS_PER_THREAD = 255 + + # ============================================================================= # Forward Kernel: scatter2scatter with fused LoRA # ============================================================================= @@ -357,7 +381,7 @@ def _scatter2scatter_lora_configs(): def _prune_fwd_configs(configs, named_args, **kwargs): - """Prune forward configs based on SMEM capacity. + """Prune forward configs based on SMEM capacity and register pressure. The forward kernel inner loop loads three tiles per pipeline stage: X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K]. @@ -383,14 +407,39 @@ def _prune_fwd_configs(configs, named_args, **kwargs): # B tile [BLOCK_N, BLOCK_R] loaded once in epilogue smem_lora_epilogue = block_n * block_r * 2 smem = smem_base + smem_lora_loop + smem_lora_epilogue + + # Register pressure: live tiles are acc[M,N], xa_acc[M,R], + # x[M,K], w[K,N], a[R,K], plus epilogue b[N,R] + est_regs = _estimate_register_pressure( + config.num_warps, + (block_m, block_n), # acc + (block_m, block_r), # xa_acc + (block_m, block_k), # x tile + (block_k, block_n), # w tile + (block_r, block_k), # a tile + (block_n, block_r), # b tile (epilogue) + ) + if est_regs > _MAX_REGS_PER_THREAD: + continue + scored.append((smem, config)) pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK] if pruned: return pruned - # All configs exceed SMEM — return the one with smallest estimated usage - scored.sort(key=lambda x: x[0]) - return [scored[0][1]] + if scored: + # All surviving configs exceed SMEM — return the one with smallest usage + scored.sort(key=lambda x: x[0]) + return [scored[0][1]] + # All configs pruned by register pressure — fall back to smallest tiles + return [ + min( + configs, + key=lambda c: ( + c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_N"] * c.kwargs["BLOCK_K"] + ), + ) + ] @triton.autotune( @@ -912,7 +961,7 @@ def _scatter2scatter_lora_dX_configs(): def _prune_dX_configs(configs, named_args, **kwargs): - """Prune backward dX configs based on SMEM capacity. + """Prune backward dX configs based on SMEM capacity and register pressure. The dX kernel inner loop loads three tiles per pipeline stage: DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R]. @@ -938,14 +987,39 @@ def _prune_dX_configs(configs, named_args, **kwargs): # A tile [BLOCK_R, BLOCK_K] loaded once in epilogue smem_lora_epilogue = block_r * block_k * 2 smem = smem_base + smem_lora_loop + smem_lora_epilogue + + # Register pressure: live tiles are acc[M,K], dy_b_acc[M,R], + # dy[M,N], wt[N,K], b[N,R], plus epilogue a[R,K] + est_regs = _estimate_register_pressure( + config.num_warps, + (block_m, block_k), # acc + (block_m, block_r), # dy_b_acc + (block_m, block_n), # dy tile + (block_n, block_k), # wt tile + (block_n, block_r), # b tile + (block_r, block_k), # a tile (epilogue) + ) + if est_regs > _MAX_REGS_PER_THREAD: + continue + scored.append((smem, config)) pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK] if pruned: return pruned - # All configs exceed SMEM — return the one with smallest estimated usage - scored.sort(key=lambda x: x[0]) - return [scored[0][1]] + if scored: + # All surviving configs exceed SMEM — return the one with smallest usage + scored.sort(key=lambda x: x[0]) + return [scored[0][1]] + # All configs pruned by register pressure — fall back to smallest tiles + return [ + min( + configs, + key=lambda c: ( + c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"] + ), + ) + ] @triton.autotune( @@ -1226,7 +1300,7 @@ def _group_bwd_lora_configs(): def _prune_bwd_lora_configs(configs, named_args, **kwargs): - """Prune backward configs based on SMEM capacity. + """Prune backward configs based on SMEM capacity and register pressure. The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N] in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] @@ -1245,14 +1319,40 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs): # A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert smem_lora = (block_r * block_k + block_n * block_r) * 2 smem = smem_base + smem_lora + + # Register pressure: dA_acc[R,K], dB_acc[N,R], x[M,K], dy[M,N], + # a[R,K], b[N,R], xa[M,R], dy_b[M,R] + est_regs = _estimate_register_pressure( + config.num_warps, + (block_r, block_k), # dA_acc + (block_n, block_r), # dB_acc + (block_m, block_k), # x tile + (block_m, block_n), # dy tile + (block_r, block_k), # a tile + (block_n, block_r), # b tile + (block_m, block_r), # xa intermediate + ) + if est_regs > _MAX_REGS_PER_THREAD: + continue + scored.append((smem, config)) pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK] if pruned: return pruned - # All configs exceed SMEM — return the one with smallest estimated usage - scored.sort(key=lambda x: x[0]) - return [scored[0][1]] + if scored: + # All surviving configs exceed SMEM — return the one with smallest usage + scored.sort(key=lambda x: x[0]) + return [scored[0][1]] + # All configs pruned by register pressure — fall back to smallest tiles + return [ + min( + configs, + key=lambda c: ( + c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"] + ), + ) + ] @triton.autotune( @@ -1457,7 +1557,7 @@ def _group_bwd_split_configs(): def _prune_split_configs(configs, named_args, **kwargs): - """Prune split kernel configs based on SMEM capacity.""" + """Prune split kernel configs based on SMEM capacity and register pressure.""" smem_cap = _get_smem_capacity() block_r = named_args.get("BLOCK_R", 64) @@ -1472,6 +1572,18 @@ def _prune_split_configs(configs, named_args, **kwargs): smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2 # LoRA weights held in registers: [INNER, R] or [R, DIM] smem += (block_r * max(block_dim, BLOCK_INNER)) * 2 + + # Register pressure check + est_regs = _estimate_register_pressure( + config.num_warps, + (block_r, block_dim), # acc + (block_m, BLOCK_INNER), # input tile + (block_m, block_dim), # other tile + (block_r, BLOCK_INNER), # lora weight + ) + if est_regs > _MAX_REGS_PER_THREAD: + continue + if smem <= smem_cap - _SMEM_SLACK: pruned.append(config)