more scattermoe optims

This commit is contained in:
Wing Lian
2026-03-18 23:36:28 +00:00
parent c5db90aa3f
commit 2dcca15f65
2 changed files with 92 additions and 26 deletions

View File

@@ -327,20 +327,21 @@ def _compute_expert_block_lora(
def _scatter2scatter_lora_configs():
"""Generate forward kernel autotune configs.
Search space includes smaller tile sizes and fewer pipeline stages to
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
Search space includes BLOCK_M to allow trading token-tile size for
larger BLOCK_K/BLOCK_N tiles. On GPUs with ~99KB SMEM, BLOCK_M=128
forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128
(4× fewer inner-loop iterations).
Search space:
BLOCK_M: {32, 64, 128}
BLOCK_N: {32, 64, 128, 256}
BLOCK_K: {32, 64, 128}
num_warps: {4, 8}
num_stages: {3, 4, 5}
BLOCK_M is fixed at 128 (module-level constant, not autotuned in the
scatter2scatter pattern).
"""
configs = []
for block_n, block_k, warps, stages in product(
for block_m, block_n, block_k, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64, 128, 256], # BLOCK_N
[32, 64, 128], # BLOCK_K
[4, 8], # num_warps
@@ -348,7 +349,7 @@ def _scatter2scatter_lora_configs():
):
configs.append(
triton.Config(
{"BLOCK_N": block_n, "BLOCK_K": block_k},
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
num_stages=stages,
num_warps=warps,
)
@@ -373,10 +374,11 @@ def _prune_fwd_configs(configs, named_args, **kwargs):
scored = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_n = config.kwargs["BLOCK_N"]
block_k = config.kwargs["BLOCK_K"]
# Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k)
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k)
# A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop
smem_lora_loop = config.num_stages * block_r * block_k * 2
# B tile [BLOCK_N, BLOCK_R] loaded once in epilogue
@@ -626,7 +628,7 @@ def scatter2scatter_lora(
N=N,
E=E,
ACTUAL_R=R, # True LoRA rank for weight indexing
BLOCK_M=BLOCK_M,
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
ACC_TYPE=tl.float32,
scaling=scaling,
@@ -779,17 +781,18 @@ def _scatter2scatter_lora_dX_configs():
The inner loop is over N (not K as in forward). The output dimension is K.
So BLOCK_K tiles the output and BLOCK_N tiles the reduction.
Search space includes smaller tile sizes and fewer pipeline stages to
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
BLOCK_M is now autotunable (was fixed at 128).
Search space:
BLOCK_M: {32, 64, 128} (token tile)
BLOCK_K: {32, 64, 128, 256} (output tile)
BLOCK_N: {32, 64, 128, 256} (reduction tile)
num_warps: {4, 8}
num_stages: {3, 4, 5}
"""
configs = []
for block_k, block_n, warps, stages in product(
for block_m, block_k, block_n, warps, stages in product(
[32, 64, 128], # BLOCK_M
[32, 64, 128, 256], # BLOCK_K (output dimension)
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
[4, 8], # num_warps
@@ -797,7 +800,7 @@ def _scatter2scatter_lora_dX_configs():
):
configs.append(
triton.Config(
{"BLOCK_K": block_k, "BLOCK_N": block_n},
{"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n},
num_stages=stages,
num_warps=warps,
)
@@ -822,10 +825,11 @@ def _prune_dX_configs(configs, named_args, **kwargs):
scored = []
for config in configs:
block_m = config.kwargs["BLOCK_M"]
block_k = config.kwargs["BLOCK_K"]
block_n = config.kwargs["BLOCK_N"]
# Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n)
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n)
# B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop
smem_lora_loop = config.num_stages * block_n * block_r * 2
# A tile [BLOCK_R, BLOCK_K] loaded once in epilogue
@@ -1067,7 +1071,7 @@ def scatter2scatter_lora_dX(
N=N,
E=E,
ACTUAL_R=R,
BLOCK_M=BLOCK_M,
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
BLOCK_R=BLOCK_R,
ACC_TYPE=tl.float32,
scaling=scaling,

View File

@@ -489,20 +489,65 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Selective expert weight dequantization
# ====================================================================
# When experts are BnB-quantized (quantize_moe_experts), dequantize
# only the active experts instead of all E. This saves ~97% memory
# for the transient dequant buffer when few experts are active.
use_selective = (
getattr(self, "_use_selective_dequant", False)
and hasattr(experts, "parametrizations")
and "gate_up_proj" in experts.parametrizations
)
if use_selective:
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import (
get_active_experts,
remap_expert_indices,
selective_expert_weights,
selective_lora_weights,
)
active_experts = get_active_experts(sorted_expert_idxs, num_experts)
remapped_expert_idxs, compact_offsets = remap_expert_indices(
sorted_expert_idxs, expert_offsets, active_experts, num_experts,
)
num_active = len(active_experts)
# Dequantize only active experts' weights
gate_up_W = selective_expert_weights(
experts, "gate_up_proj", active_experts,
).transpose(2, 1) # [num_active, hidden, 2*inter]
# Remap LoRA weights to match compact expert indices
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup_A, gup_B = selective_lora_weights(
gup_A, gup_B, active_experts, num_experts,
)
gup_lora = (gup_A, gup_B, gup_scaling)
# Use remapped indices for ScatterMoE kernels
sei_gup = remapped_expert_idxs
eo_gup = compact_offsets
else:
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
sei_gup = sorted_expert_idxs
eo_gup = expert_offsets
# ====================================================================
# Gate + Up projection
# ====================================================================
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sei_gup,
sorted_scattered_idxs,
expert_offsets,
eo_gup,
lora_A=gup_A,
lora_B=gup_B,
scaling=gup_scaling,
@@ -516,9 +561,9 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sei_gup,
sorted_scattered_idxs,
expert_offsets,
eo_gup,
grouped_in=False,
grouped_out=True,
)
@@ -529,7 +574,24 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
# Down projection
# ====================================================================
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
if use_selective:
down_W = selective_expert_weights(
experts, "down_proj", active_experts,
).transpose(2, 1) # [num_active, inter, hidden]
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
down_A, down_B = selective_lora_weights(
down_A, down_B, active_experts, num_experts,
)
down_lora = (down_A, down_B, down_scaling)
sei_down = remapped_expert_idxs
eo_down = compact_offsets
else:
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
sei_down = sorted_expert_idxs
eo_down = expert_offsets
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
@@ -537,9 +599,9 @@ class HFScatterMoEGatedMLP(nn.Module):
h,
down_W,
1,
sorted_expert_idxs,
sei_down,
sorted_scattered_idxs,
expert_offsets,
eo_down,
lora_A=down_A,
lora_B=down_B,
scaling=down_scaling,
@@ -554,9 +616,9 @@ class HFScatterMoEGatedMLP(nn.Module):
h,
down_W,
1,
sorted_expert_idxs,
sei_down,
sorted_scattered_idxs,
expert_offsets,
eo_down,
grouped_in=True,
grouped_out=False,
gates=routing_weights,