From 2dcca15f6545d9ac80664ebdc51683ccfdfe00fa Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 18 Mar 2026 23:36:28 +0000 Subject: [PATCH] more scattermoe optims --- .../libs/scattermoe_lora/kernels/lora_ops.py | 34 ++++---- .../kernels/libs/scattermoe_lora/layers.py | 84 ++++++++++++++++--- 2 files changed, 92 insertions(+), 26 deletions(-) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py index 16af5fcdc..f858077c7 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py @@ -327,20 +327,21 @@ def _compute_expert_block_lora( def _scatter2scatter_lora_configs(): """Generate forward kernel autotune configs. - Search space includes smaller tile sizes and fewer pipeline stages to - support GPUs with limited shared memory (e.g. ~99KB on some GPUs). + Search space includes BLOCK_M to allow trading token-tile size for + larger BLOCK_K/BLOCK_N tiles. On GPUs with ~99KB SMEM, BLOCK_M=128 + forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128 + (4× fewer inner-loop iterations). Search space: + BLOCK_M: {32, 64, 128} BLOCK_N: {32, 64, 128, 256} BLOCK_K: {32, 64, 128} num_warps: {4, 8} num_stages: {3, 4, 5} - - BLOCK_M is fixed at 128 (module-level constant, not autotuned in the - scatter2scatter pattern). """ configs = [] - for block_n, block_k, warps, stages in product( + for block_m, block_n, block_k, warps, stages in product( + [32, 64, 128], # BLOCK_M [32, 64, 128, 256], # BLOCK_N [32, 64, 128], # BLOCK_K [4, 8], # num_warps @@ -348,7 +349,7 @@ def _scatter2scatter_lora_configs(): ): configs.append( triton.Config( - {"BLOCK_N": block_n, "BLOCK_K": block_k}, + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k}, num_stages=stages, num_warps=warps, ) @@ -373,10 +374,11 @@ def _prune_fwd_configs(configs, named_args, **kwargs): scored = [] for config in configs: + block_m = config.kwargs["BLOCK_M"] block_n = config.kwargs["BLOCK_N"] block_k = config.kwargs["BLOCK_K"] # Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N - smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k) + smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k) # A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop smem_lora_loop = config.num_stages * block_r * block_k * 2 # B tile [BLOCK_N, BLOCK_R] loaded once in epilogue @@ -626,7 +628,7 @@ def scatter2scatter_lora( N=N, E=E, ACTUAL_R=R, # True LoRA rank for weight indexing - BLOCK_M=BLOCK_M, + # BLOCK_M is autotuned (injected by triton.autotune from Config kwargs) BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16) ACC_TYPE=tl.float32, scaling=scaling, @@ -779,17 +781,18 @@ def _scatter2scatter_lora_dX_configs(): The inner loop is over N (not K as in forward). The output dimension is K. So BLOCK_K tiles the output and BLOCK_N tiles the reduction. - Search space includes smaller tile sizes and fewer pipeline stages to - support GPUs with limited shared memory (e.g. ~99KB on some GPUs). + BLOCK_M is now autotunable (was fixed at 128). Search space: + BLOCK_M: {32, 64, 128} (token tile) BLOCK_K: {32, 64, 128, 256} (output tile) BLOCK_N: {32, 64, 128, 256} (reduction tile) num_warps: {4, 8} num_stages: {3, 4, 5} """ configs = [] - for block_k, block_n, warps, stages in product( + for block_m, block_k, block_n, warps, stages in product( + [32, 64, 128], # BLOCK_M [32, 64, 128, 256], # BLOCK_K (output dimension) [32, 64, 128, 256], # BLOCK_N (reduction dimension) [4, 8], # num_warps @@ -797,7 +800,7 @@ def _scatter2scatter_lora_dX_configs(): ): configs.append( triton.Config( - {"BLOCK_K": block_k, "BLOCK_N": block_n}, + {"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n}, num_stages=stages, num_warps=warps, ) @@ -822,10 +825,11 @@ def _prune_dX_configs(configs, named_args, **kwargs): scored = [] for config in configs: + block_m = config.kwargs["BLOCK_M"] block_k = config.kwargs["BLOCK_K"] block_n = config.kwargs["BLOCK_N"] # Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K - smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n) + smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n) # B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop smem_lora_loop = config.num_stages * block_n * block_r * 2 # A tile [BLOCK_R, BLOCK_K] loaded once in epilogue @@ -1067,7 +1071,7 @@ def scatter2scatter_lora_dX( N=N, E=E, ACTUAL_R=R, - BLOCK_M=BLOCK_M, + # BLOCK_M is autotuned (injected by triton.autotune from Config kwargs) BLOCK_R=BLOCK_R, ACC_TYPE=tl.float32, scaling=scaling, diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index 5125e8801..453c8c318 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -489,20 +489,65 @@ class HFScatterMoEGatedMLP(nn.Module): # ==================================================================== experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts) + # ==================================================================== + # Selective expert weight dequantization + # ==================================================================== + # When experts are BnB-quantized (quantize_moe_experts), dequantize + # only the active experts instead of all E. This saves ~97% memory + # for the transient dequant buffer when few experts are active. + use_selective = ( + getattr(self, "_use_selective_dequant", False) + and hasattr(experts, "parametrizations") + and "gate_up_proj" in experts.parametrizations + ) + + if use_selective: + from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import ( + get_active_experts, + remap_expert_indices, + selective_expert_weights, + selective_lora_weights, + ) + + active_experts = get_active_experts(sorted_expert_idxs, num_experts) + remapped_expert_idxs, compact_offsets = remap_expert_indices( + sorted_expert_idxs, expert_offsets, active_experts, num_experts, + ) + num_active = len(active_experts) + + # Dequantize only active experts' weights + gate_up_W = selective_expert_weights( + experts, "gate_up_proj", active_experts, + ).transpose(2, 1) # [num_active, hidden, 2*inter] + + # Remap LoRA weights to match compact expert indices + if gup_lora is not None: + gup_A, gup_B, gup_scaling = gup_lora + gup_A, gup_B = selective_lora_weights( + gup_A, gup_B, active_experts, num_experts, + ) + gup_lora = (gup_A, gup_B, gup_scaling) + + # Use remapped indices for ScatterMoE kernels + sei_gup = remapped_expert_idxs + eo_gup = compact_offsets + else: + gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter] + sei_gup = sorted_expert_idxs + eo_gup = expert_offsets + # ==================================================================== # Gate + Up projection # ==================================================================== - gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter] - if gup_lora is not None: gup_A, gup_B, gup_scaling = gup_lora gup = parallel_linear_lora( hidden_states_flat, gate_up_W, top_k, - sorted_expert_idxs, + sei_gup, sorted_scattered_idxs, - expert_offsets, + eo_gup, lora_A=gup_A, lora_B=gup_B, scaling=gup_scaling, @@ -516,9 +561,9 @@ class HFScatterMoEGatedMLP(nn.Module): hidden_states_flat, gate_up_W, top_k, - sorted_expert_idxs, + sei_gup, sorted_scattered_idxs, - expert_offsets, + eo_gup, grouped_in=False, grouped_out=True, ) @@ -529,7 +574,24 @@ class HFScatterMoEGatedMLP(nn.Module): # ==================================================================== # Down projection # ==================================================================== - down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden] + if use_selective: + down_W = selective_expert_weights( + experts, "down_proj", active_experts, + ).transpose(2, 1) # [num_active, inter, hidden] + + if down_lora is not None: + down_A, down_B, down_scaling = down_lora + down_A, down_B = selective_lora_weights( + down_A, down_B, active_experts, num_experts, + ) + down_lora = (down_A, down_B, down_scaling) + + sei_down = remapped_expert_idxs + eo_down = compact_offsets + else: + down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden] + sei_down = sorted_expert_idxs + eo_down = expert_offsets if down_lora is not None: down_A, down_B, down_scaling = down_lora @@ -537,9 +599,9 @@ class HFScatterMoEGatedMLP(nn.Module): h, down_W, 1, - sorted_expert_idxs, + sei_down, sorted_scattered_idxs, - expert_offsets, + eo_down, lora_A=down_A, lora_B=down_B, scaling=down_scaling, @@ -554,9 +616,9 @@ class HFScatterMoEGatedMLP(nn.Module): h, down_W, 1, - sorted_expert_idxs, + sei_down, sorted_scattered_idxs, - expert_offsets, + eo_down, grouped_in=True, grouped_out=False, gates=routing_weights,