diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py new file mode 100644 index 000000000..97910bcc3 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py @@ -0,0 +1,284 @@ +""" +Selective Expert Dequantization +=============================== + +Instead of dequantizing all E expert weight matrices at once (which creates +a ~1 GB transient buffer for 256 experts), only dequantize the experts that +are actually routed to by the current batch's top-k selection. + +For Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512): + - Full dequant: [256, 2048, 1024] = 1,074 MB per projection + - Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection + - Savings: ~97% memory reduction per layer + +This module provides format-agnostic selective weight extraction: + - BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert + - bf16/fp32: direct indexing (no dequant needed) + - FP8: slice + cast + +The ScatterMoE kernel itself doesn't change — we remap expert indices +from global (0..E-1) to compact (0..num_active-1) and pass the smaller +weight tensor. +""" + +from typing import Optional + +import torch +import torch.nn as nn + + +def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor: + """Get sorted unique expert indices from the routing output. + + Args: + sorted_expert_idxs: Expert assignments sorted by expert id [T*k] + E: Total number of experts + + Returns: + active: Sorted unique expert indices [num_active] + """ + return torch.unique(sorted_expert_idxs) + + +def remap_expert_indices( + sorted_expert_idxs: torch.Tensor, + expert_offsets: torch.Tensor, + active_experts: torch.Tensor, + E: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Remap global expert indices to compact indices. + + Maps expert ids from [0..E-1] to [0..num_active-1], preserving the + sort order. Also compacts expert_offsets to only active experts. + + Args: + sorted_expert_idxs: [T*k] expert ids in sorted order + expert_offsets: [E] cumulative token counts (original) + active_experts: [num_active] sorted unique expert ids + E: Total number of experts + + Returns: + remapped_idxs: [T*k] expert ids in [0..num_active-1] + compact_offsets: [num_active] cumulative token counts + """ + # Build remap table: global_id -> compact_id + remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device) + remap[active_experts] = torch.arange( + len(active_experts), device=sorted_expert_idxs.device + ) + + remapped_idxs = remap[sorted_expert_idxs] + + # Compact the expert_offsets: only keep active experts' cumulative counts + compact_offsets = expert_offsets[active_experts] + + return remapped_idxs, compact_offsets + + +def _selective_dequant_bnb4( + raw_param: torch.Tensor, + quant_state, + active_experts: torch.Tensor, + expert_shape: tuple[int, ...], +) -> torch.Tensor: + """Dequantize only selected experts from BnB 4-bit packed data. + + The raw parameter is a flattened 4-bit packed tensor. Each expert's + data is contiguous (stored in expert-major order), so we can gather + the packed data and absmax blocks for active experts, then dequantize + as one contiguous block. + + Args: + raw_param: Flattened uint8 tensor of packed 4-bit weights + quant_state: BnB QuantState with absmax, blocksize, code, etc. + active_experts: [num_active] expert indices to dequantize + expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048)) + + Returns: + Dequantized weights [num_active, dim1, dim2] in original dtype + """ + import bitsandbytes.functional as F # noqa: N812 + from bitsandbytes.functional import QuantState + + expert_numel = expert_shape[0] * expert_shape[1] + packed_per_expert = expert_numel // 2 # 4-bit = 2 values per byte + blocks_per_expert = expert_numel // quant_state.blocksize + num_active = len(active_experts) + + if blocks_per_expert == 0: + # Expert is smaller than one quantization block — blocks span across + # expert boundaries, so per-expert slicing isn't possible. + # Fallback: full dequantize + index. + full = F.dequantize_4bit(raw_param, quant_state) + E_total = full.numel() // expert_numel + return full.reshape(E_total, *expert_shape)[active_experts] + + # Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass) + if quant_state.quant_type == "nf4" and raw_param.dtype == torch.uint8: + from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import ( + selective_dequant_nf4_triton, + ) + + # Handle nested (double) quantization: dequantize absmax first + # BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset + if quant_state.nested: + absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + else: + absmax = quant_state.absmax + + return selective_dequant_nf4_triton( + packed_data=raw_param, + absmax=absmax, + active_experts=active_experts, + expert_shape=expert_shape, + blocksize=quant_state.blocksize, + dtype=quant_state.dtype, + codebook=quant_state.code, + ) + + # Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats) + raw_flat = raw_param.reshape(-1) + + offsets_qt = ( + active_experts.long()[:, None] * packed_per_expert + + torch.arange(packed_per_expert, device=raw_param.device)[None, :] + ).reshape(-1) + qt_gathered = raw_flat[offsets_qt] + + offsets_abs = ( + active_experts.long()[:, None] * blocks_per_expert + + torch.arange(blocks_per_expert, device=raw_param.device)[None, :] + ).reshape(-1) + + if quant_state.nested: + full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2) + full_absmax += quant_state.offset + if full_absmax.dtype != torch.float32: + full_absmax = full_absmax.float() + absmax_gathered = full_absmax[offsets_abs] + else: + absmax_gathered = quant_state.absmax[offsets_abs] + + qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered + + gathered_qs = QuantState( + absmax=absmax_gathered, + shape=torch.Size([num_active * expert_numel]), + blocksize=quant_state.blocksize, + quant_type=quant_state.quant_type, + code=quant_state.code, + dtype=quant_state.dtype, + ) + + deq = F.dequantize_4bit(qt_gathered, gathered_qs) + return deq.reshape(num_active, *expert_shape) + + +def _selective_index_dense( + param: torch.Tensor, + active_experts: torch.Tensor, +) -> torch.Tensor: + """Select experts from a dense (bf16/fp32) weight tensor. + + Simple indexing — no dequantization needed. + """ + return param[active_experts] + + +def selective_expert_weights( + experts_module: nn.Module, + param_name: str, + active_experts: torch.Tensor, +) -> torch.Tensor: + """Extract and dequantize only the active experts' weights. + + Format-agnostic: dispatches based on whether the parameter is + BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32. + + Args: + experts_module: The base experts module (e.g. Qwen3_5MoeExperts) + param_name: "gate_up_proj" or "down_proj" + active_experts: [num_active] sorted unique expert indices + + Returns: + Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE + """ + # Check if the parameter is BnB-quantized via parametrize + if ( + hasattr(experts_module, "parametrizations") + and param_name in experts_module.parametrizations + ): + param_list = experts_module.parametrizations[param_name] + parametrization = param_list[0] + + # BnB 4-bit parametrization + if hasattr(parametrization, "quant_state"): + # The raw quantized data is on the ParametrizationList, not the + # individual Bnb4bitParametrization module + raw_param = param_list.original + qs = parametrization.quant_state + # qs.shape is the original tensor shape before flattening. + # For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D). + orig_shape = qs.shape + if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3: + expert_shape = (orig_shape[1], orig_shape[2]) + elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1: + # Flattened — need to infer from module attributes + E_total = getattr(experts_module, "num_experts", None) + if E_total is None: + E_total = int(active_experts.max().item()) + 1 + expert_numel = orig_shape[0] // E_total + d2 = getattr(experts_module, "hidden_dim", None) or getattr(experts_module, "intermediate_dim", None) + if d2 and expert_numel % d2 == 0: + expert_shape = (expert_numel // d2, d2) + else: + full = getattr(experts_module, param_name) + return full[active_experts] + else: + full = getattr(experts_module, param_name) + return full[active_experts] + + return _selective_dequant_bnb4( + raw_param, qs, active_experts, expert_shape + ) + + # Dense parameter (bf16/fp32) — direct indexing + param = getattr(experts_module, param_name) + if param.dim() == 3: + return param[active_experts] + + # Fallback: full access + return param + + +def selective_lora_weights( + lora_A: torch.Tensor, + lora_B: torch.Tensor, + active_experts: torch.Tensor, + E: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Select LoRA A and B weights for only the active experts. + + LoRA layout (scattermoe format): + A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r] + B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r] + + Returns compact: + A: [r*num_active, K] + B: [N, r*num_active] + """ + R = lora_A.size(0) // E + + # Vectorized gather: active_experts[:, None] * R + arange(R)[None, :] + row_idx = ( + active_experts.long()[:, None] * R + + torch.arange(R, device=lora_A.device)[None, :] + ).reshape(-1) + + compact_A = lora_A[row_idx] # [r*num_active, K] + compact_B = lora_B[:, row_idx] # [N, r*num_active] + + return compact_A, compact_B diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant_kernel.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant_kernel.py new file mode 100644 index 000000000..e29799c05 --- /dev/null +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant_kernel.py @@ -0,0 +1,167 @@ +""" +Triton kernel for fused selective expert gather + NF4 dequantization. + +Instead of: + 1. Gather packed uint8 data for active experts (memory copy) + 2. Gather absmax for active experts (memory copy) + 3. Call BnB dequantize_4bit CUDA kernel + +This kernel does all three in one pass: + - Reads packed NF4 bytes from expert-strided positions + - Looks up the NF4 codebook + - Multiplies by the per-block absmax + - Writes bf16 output directly + +This eliminates the intermediate gather buffer entirely. +""" + +import torch +import triton +import triton.language as tl + + +# NF4 codebook (16 values, precomputed by BnB) +# These are the normalized float4 reconstruction values +NF4_CODEBOOK = [ + -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, + -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, + 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, + 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, + 0.7229568362236023, 1.0, +] + + +@triton.jit +def _selective_dequant_nf4_kernel( + # Input: packed NF4 data (flattened, expert-major order) + packed_ptr, + # Input: absmax values (flattened, expert-major order) + absmax_ptr, + # Input: active expert indices + active_experts_ptr, + # Input: NF4 codebook (16 float values) + codebook_ptr, + # Output: dequantized bf16 weights [num_active, expert_numel] + out_ptr, + stride_out_e, # stride for expert dim in output + # Dimensions + num_active, + packed_per_expert, # expert_numel // 2 + blocks_per_expert, # expert_numel // blocksize + blocksize: tl.constexpr, + # Tile size + BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2) +): + """ + Each program processes BLOCK_SIZE elements from one expert. + + Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE)) + + For each output element: + 1. Compute which byte in packed data contains this element + 2. Extract the 4-bit nibble (high or low) + 3. Look up in NF4 codebook + 4. Scale by absmax for this block + """ + expert_local_idx = tl.program_id(0) # which active expert (0..num_active-1) + block_id = tl.program_id(1) # which element block + + # Load the global expert index + expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64) + + expert_numel = packed_per_expert * 2 # 2 elements per packed byte + elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = elem_offset < expert_numel + + # Each element is packed as: byte[i//2], low nibble for even i, high for odd i + byte_idx = elem_offset // 2 + is_high = (elem_offset % 2) == 1 + + # Read packed bytes from the global expert's region + packed_global_offset = expert_global * packed_per_expert + byte_idx + packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(tl.int32) + + # Extract 4-bit nibble + # BnB packing: high nibble = even element, low nibble = odd element + nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF) + + # NF4 codebook lookup + # Load all 16 codebook values (small, fits in registers) + # Use gather from codebook pointer + code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0) + + # Load absmax for this element's quantization block + block_idx = elem_offset // blocksize + absmax_global_offset = expert_global * blocks_per_expert + block_idx + absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0) + + # Dequantize: value = codebook[nibble] * absmax + result = code_val * absmax_val + + # Store to output + out_offset = expert_local_idx * stride_out_e + elem_offset + tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask) + + +def selective_dequant_nf4_triton( + packed_data: torch.Tensor, + absmax: torch.Tensor, + active_experts: torch.Tensor, + expert_shape: tuple[int, int], + blocksize: int, + dtype: torch.dtype = torch.bfloat16, + codebook: torch.Tensor | None = None, +) -> torch.Tensor: + """Fused selective gather + NF4 dequantization via Triton kernel. + + Args: + packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1] + absmax: Per-block scaling factors [total_blocks] + active_experts: Sorted indices of experts to dequantize [num_active] + expert_shape: (dim1, dim2) per expert + blocksize: Quantization block size + dtype: Output dtype (default bf16) + codebook: NF4 lookup table [16] (uses default NF4 codebook if None) + + Returns: + Dequantized weights [num_active, dim1, dim2] + """ + num_active = active_experts.shape[0] + expert_numel = expert_shape[0] * expert_shape[1] + packed_per_expert = expert_numel // 2 + blocks_per_expert = expert_numel // blocksize + + # Prepare codebook on device + if codebook is None: + codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32, + device=packed_data.device) + else: + codebook = codebook.to(device=packed_data.device, dtype=torch.float32) + + # Flatten inputs + packed_flat = packed_data.reshape(-1) + absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32 + + # Output buffer + out = torch.empty(num_active, expert_numel, dtype=dtype, + device=packed_data.device) + + BLOCK_SIZE = 1024 # Process 1024 elements per thread block + + grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE)) + + _selective_dequant_nf4_kernel[grid]( + packed_flat, + absmax_flat, + active_experts, + codebook, + out, + out.stride(0), + num_active=num_active, + packed_per_expert=packed_per_expert, + blocks_per_expert=blocks_per_expert, + blocksize=blocksize, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return out.reshape(num_active, *expert_shape)