selective dequant
This commit is contained in:
@@ -0,0 +1,284 @@
|
|||||||
|
"""
|
||||||
|
Selective Expert Dequantization
|
||||||
|
===============================
|
||||||
|
|
||||||
|
Instead of dequantizing all E expert weight matrices at once (which creates
|
||||||
|
a ~1 GB transient buffer for 256 experts), only dequantize the experts that
|
||||||
|
are actually routed to by the current batch's top-k selection.
|
||||||
|
|
||||||
|
For Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512):
|
||||||
|
- Full dequant: [256, 2048, 1024] = 1,074 MB per projection
|
||||||
|
- Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection
|
||||||
|
- Savings: ~97% memory reduction per layer
|
||||||
|
|
||||||
|
This module provides format-agnostic selective weight extraction:
|
||||||
|
- BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert
|
||||||
|
- bf16/fp32: direct indexing (no dequant needed)
|
||||||
|
- FP8: slice + cast
|
||||||
|
|
||||||
|
The ScatterMoE kernel itself doesn't change — we remap expert indices
|
||||||
|
from global (0..E-1) to compact (0..num_active-1) and pass the smaller
|
||||||
|
weight tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor:
|
||||||
|
"""Get sorted unique expert indices from the routing output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sorted_expert_idxs: Expert assignments sorted by expert id [T*k]
|
||||||
|
E: Total number of experts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
active: Sorted unique expert indices [num_active]
|
||||||
|
"""
|
||||||
|
return torch.unique(sorted_expert_idxs)
|
||||||
|
|
||||||
|
|
||||||
|
def remap_expert_indices(
|
||||||
|
sorted_expert_idxs: torch.Tensor,
|
||||||
|
expert_offsets: torch.Tensor,
|
||||||
|
active_experts: torch.Tensor,
|
||||||
|
E: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Remap global expert indices to compact indices.
|
||||||
|
|
||||||
|
Maps expert ids from [0..E-1] to [0..num_active-1], preserving the
|
||||||
|
sort order. Also compacts expert_offsets to only active experts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sorted_expert_idxs: [T*k] expert ids in sorted order
|
||||||
|
expert_offsets: [E] cumulative token counts (original)
|
||||||
|
active_experts: [num_active] sorted unique expert ids
|
||||||
|
E: Total number of experts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
remapped_idxs: [T*k] expert ids in [0..num_active-1]
|
||||||
|
compact_offsets: [num_active] cumulative token counts
|
||||||
|
"""
|
||||||
|
# Build remap table: global_id -> compact_id
|
||||||
|
remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device)
|
||||||
|
remap[active_experts] = torch.arange(
|
||||||
|
len(active_experts), device=sorted_expert_idxs.device
|
||||||
|
)
|
||||||
|
|
||||||
|
remapped_idxs = remap[sorted_expert_idxs]
|
||||||
|
|
||||||
|
# Compact the expert_offsets: only keep active experts' cumulative counts
|
||||||
|
compact_offsets = expert_offsets[active_experts]
|
||||||
|
|
||||||
|
return remapped_idxs, compact_offsets
|
||||||
|
|
||||||
|
|
||||||
|
def _selective_dequant_bnb4(
|
||||||
|
raw_param: torch.Tensor,
|
||||||
|
quant_state,
|
||||||
|
active_experts: torch.Tensor,
|
||||||
|
expert_shape: tuple[int, ...],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Dequantize only selected experts from BnB 4-bit packed data.
|
||||||
|
|
||||||
|
The raw parameter is a flattened 4-bit packed tensor. Each expert's
|
||||||
|
data is contiguous (stored in expert-major order), so we can gather
|
||||||
|
the packed data and absmax blocks for active experts, then dequantize
|
||||||
|
as one contiguous block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_param: Flattened uint8 tensor of packed 4-bit weights
|
||||||
|
quant_state: BnB QuantState with absmax, blocksize, code, etc.
|
||||||
|
active_experts: [num_active] expert indices to dequantize
|
||||||
|
expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048))
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dequantized weights [num_active, dim1, dim2] in original dtype
|
||||||
|
"""
|
||||||
|
import bitsandbytes.functional as F # noqa: N812
|
||||||
|
from bitsandbytes.functional import QuantState
|
||||||
|
|
||||||
|
expert_numel = expert_shape[0] * expert_shape[1]
|
||||||
|
packed_per_expert = expert_numel // 2 # 4-bit = 2 values per byte
|
||||||
|
blocks_per_expert = expert_numel // quant_state.blocksize
|
||||||
|
num_active = len(active_experts)
|
||||||
|
|
||||||
|
if blocks_per_expert == 0:
|
||||||
|
# Expert is smaller than one quantization block — blocks span across
|
||||||
|
# expert boundaries, so per-expert slicing isn't possible.
|
||||||
|
# Fallback: full dequantize + index.
|
||||||
|
full = F.dequantize_4bit(raw_param, quant_state)
|
||||||
|
E_total = full.numel() // expert_numel
|
||||||
|
return full.reshape(E_total, *expert_shape)[active_experts]
|
||||||
|
|
||||||
|
# Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass)
|
||||||
|
if quant_state.quant_type == "nf4" and raw_param.dtype == torch.uint8:
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import (
|
||||||
|
selective_dequant_nf4_triton,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle nested (double) quantization: dequantize absmax first
|
||||||
|
# BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset
|
||||||
|
if quant_state.nested:
|
||||||
|
absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||||
|
absmax += quant_state.offset
|
||||||
|
if absmax.dtype != torch.float32:
|
||||||
|
absmax = absmax.float()
|
||||||
|
else:
|
||||||
|
absmax = quant_state.absmax
|
||||||
|
|
||||||
|
return selective_dequant_nf4_triton(
|
||||||
|
packed_data=raw_param,
|
||||||
|
absmax=absmax,
|
||||||
|
active_experts=active_experts,
|
||||||
|
expert_shape=expert_shape,
|
||||||
|
blocksize=quant_state.blocksize,
|
||||||
|
dtype=quant_state.dtype,
|
||||||
|
codebook=quant_state.code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats)
|
||||||
|
raw_flat = raw_param.reshape(-1)
|
||||||
|
|
||||||
|
offsets_qt = (
|
||||||
|
active_experts.long()[:, None] * packed_per_expert
|
||||||
|
+ torch.arange(packed_per_expert, device=raw_param.device)[None, :]
|
||||||
|
).reshape(-1)
|
||||||
|
qt_gathered = raw_flat[offsets_qt]
|
||||||
|
|
||||||
|
offsets_abs = (
|
||||||
|
active_experts.long()[:, None] * blocks_per_expert
|
||||||
|
+ torch.arange(blocks_per_expert, device=raw_param.device)[None, :]
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
|
if quant_state.nested:
|
||||||
|
full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||||
|
full_absmax += quant_state.offset
|
||||||
|
if full_absmax.dtype != torch.float32:
|
||||||
|
full_absmax = full_absmax.float()
|
||||||
|
absmax_gathered = full_absmax[offsets_abs]
|
||||||
|
else:
|
||||||
|
absmax_gathered = quant_state.absmax[offsets_abs]
|
||||||
|
|
||||||
|
qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered
|
||||||
|
|
||||||
|
gathered_qs = QuantState(
|
||||||
|
absmax=absmax_gathered,
|
||||||
|
shape=torch.Size([num_active * expert_numel]),
|
||||||
|
blocksize=quant_state.blocksize,
|
||||||
|
quant_type=quant_state.quant_type,
|
||||||
|
code=quant_state.code,
|
||||||
|
dtype=quant_state.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
deq = F.dequantize_4bit(qt_gathered, gathered_qs)
|
||||||
|
return deq.reshape(num_active, *expert_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _selective_index_dense(
|
||||||
|
param: torch.Tensor,
|
||||||
|
active_experts: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Select experts from a dense (bf16/fp32) weight tensor.
|
||||||
|
|
||||||
|
Simple indexing — no dequantization needed.
|
||||||
|
"""
|
||||||
|
return param[active_experts]
|
||||||
|
|
||||||
|
|
||||||
|
def selective_expert_weights(
|
||||||
|
experts_module: nn.Module,
|
||||||
|
param_name: str,
|
||||||
|
active_experts: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Extract and dequantize only the active experts' weights.
|
||||||
|
|
||||||
|
Format-agnostic: dispatches based on whether the parameter is
|
||||||
|
BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experts_module: The base experts module (e.g. Qwen3_5MoeExperts)
|
||||||
|
param_name: "gate_up_proj" or "down_proj"
|
||||||
|
active_experts: [num_active] sorted unique expert indices
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE
|
||||||
|
"""
|
||||||
|
# Check if the parameter is BnB-quantized via parametrize
|
||||||
|
if (
|
||||||
|
hasattr(experts_module, "parametrizations")
|
||||||
|
and param_name in experts_module.parametrizations
|
||||||
|
):
|
||||||
|
param_list = experts_module.parametrizations[param_name]
|
||||||
|
parametrization = param_list[0]
|
||||||
|
|
||||||
|
# BnB 4-bit parametrization
|
||||||
|
if hasattr(parametrization, "quant_state"):
|
||||||
|
# The raw quantized data is on the ParametrizationList, not the
|
||||||
|
# individual Bnb4bitParametrization module
|
||||||
|
raw_param = param_list.original
|
||||||
|
qs = parametrization.quant_state
|
||||||
|
# qs.shape is the original tensor shape before flattening.
|
||||||
|
# For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D).
|
||||||
|
orig_shape = qs.shape
|
||||||
|
if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3:
|
||||||
|
expert_shape = (orig_shape[1], orig_shape[2])
|
||||||
|
elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1:
|
||||||
|
# Flattened — need to infer from module attributes
|
||||||
|
E_total = getattr(experts_module, "num_experts", None)
|
||||||
|
if E_total is None:
|
||||||
|
E_total = int(active_experts.max().item()) + 1
|
||||||
|
expert_numel = orig_shape[0] // E_total
|
||||||
|
d2 = getattr(experts_module, "hidden_dim", None) or getattr(experts_module, "intermediate_dim", None)
|
||||||
|
if d2 and expert_numel % d2 == 0:
|
||||||
|
expert_shape = (expert_numel // d2, d2)
|
||||||
|
else:
|
||||||
|
full = getattr(experts_module, param_name)
|
||||||
|
return full[active_experts]
|
||||||
|
else:
|
||||||
|
full = getattr(experts_module, param_name)
|
||||||
|
return full[active_experts]
|
||||||
|
|
||||||
|
return _selective_dequant_bnb4(
|
||||||
|
raw_param, qs, active_experts, expert_shape
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dense parameter (bf16/fp32) — direct indexing
|
||||||
|
param = getattr(experts_module, param_name)
|
||||||
|
if param.dim() == 3:
|
||||||
|
return param[active_experts]
|
||||||
|
|
||||||
|
# Fallback: full access
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
|
def selective_lora_weights(
|
||||||
|
lora_A: torch.Tensor,
|
||||||
|
lora_B: torch.Tensor,
|
||||||
|
active_experts: torch.Tensor,
|
||||||
|
E: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Select LoRA A and B weights for only the active experts.
|
||||||
|
|
||||||
|
LoRA layout (scattermoe format):
|
||||||
|
A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r]
|
||||||
|
B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r]
|
||||||
|
|
||||||
|
Returns compact:
|
||||||
|
A: [r*num_active, K]
|
||||||
|
B: [N, r*num_active]
|
||||||
|
"""
|
||||||
|
R = lora_A.size(0) // E
|
||||||
|
|
||||||
|
# Vectorized gather: active_experts[:, None] * R + arange(R)[None, :]
|
||||||
|
row_idx = (
|
||||||
|
active_experts.long()[:, None] * R
|
||||||
|
+ torch.arange(R, device=lora_A.device)[None, :]
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
|
compact_A = lora_A[row_idx] # [r*num_active, K]
|
||||||
|
compact_B = lora_B[:, row_idx] # [N, r*num_active]
|
||||||
|
|
||||||
|
return compact_A, compact_B
|
||||||
@@ -0,0 +1,167 @@
|
|||||||
|
"""
|
||||||
|
Triton kernel for fused selective expert gather + NF4 dequantization.
|
||||||
|
|
||||||
|
Instead of:
|
||||||
|
1. Gather packed uint8 data for active experts (memory copy)
|
||||||
|
2. Gather absmax for active experts (memory copy)
|
||||||
|
3. Call BnB dequantize_4bit CUDA kernel
|
||||||
|
|
||||||
|
This kernel does all three in one pass:
|
||||||
|
- Reads packed NF4 bytes from expert-strided positions
|
||||||
|
- Looks up the NF4 codebook
|
||||||
|
- Multiplies by the per-block absmax
|
||||||
|
- Writes bf16 output directly
|
||||||
|
|
||||||
|
This eliminates the intermediate gather buffer entirely.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
# NF4 codebook (16 values, precomputed by BnB)
|
||||||
|
# These are the normalized float4 reconstruction values
|
||||||
|
NF4_CODEBOOK = [
|
||||||
|
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
|
||||||
|
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
|
||||||
|
0.07958029955625534, 0.16093020141124725, 0.24611230194568634,
|
||||||
|
0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
|
||||||
|
0.7229568362236023, 1.0,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _selective_dequant_nf4_kernel(
|
||||||
|
# Input: packed NF4 data (flattened, expert-major order)
|
||||||
|
packed_ptr,
|
||||||
|
# Input: absmax values (flattened, expert-major order)
|
||||||
|
absmax_ptr,
|
||||||
|
# Input: active expert indices
|
||||||
|
active_experts_ptr,
|
||||||
|
# Input: NF4 codebook (16 float values)
|
||||||
|
codebook_ptr,
|
||||||
|
# Output: dequantized bf16 weights [num_active, expert_numel]
|
||||||
|
out_ptr,
|
||||||
|
stride_out_e, # stride for expert dim in output
|
||||||
|
# Dimensions
|
||||||
|
num_active,
|
||||||
|
packed_per_expert, # expert_numel // 2
|
||||||
|
blocks_per_expert, # expert_numel // blocksize
|
||||||
|
blocksize: tl.constexpr,
|
||||||
|
# Tile size
|
||||||
|
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Each program processes BLOCK_SIZE elements from one expert.
|
||||||
|
|
||||||
|
Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE))
|
||||||
|
|
||||||
|
For each output element:
|
||||||
|
1. Compute which byte in packed data contains this element
|
||||||
|
2. Extract the 4-bit nibble (high or low)
|
||||||
|
3. Look up in NF4 codebook
|
||||||
|
4. Scale by absmax for this block
|
||||||
|
"""
|
||||||
|
expert_local_idx = tl.program_id(0) # which active expert (0..num_active-1)
|
||||||
|
block_id = tl.program_id(1) # which element block
|
||||||
|
|
||||||
|
# Load the global expert index
|
||||||
|
expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64)
|
||||||
|
|
||||||
|
expert_numel = packed_per_expert * 2 # 2 elements per packed byte
|
||||||
|
elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = elem_offset < expert_numel
|
||||||
|
|
||||||
|
# Each element is packed as: byte[i//2], low nibble for even i, high for odd i
|
||||||
|
byte_idx = elem_offset // 2
|
||||||
|
is_high = (elem_offset % 2) == 1
|
||||||
|
|
||||||
|
# Read packed bytes from the global expert's region
|
||||||
|
packed_global_offset = expert_global * packed_per_expert + byte_idx
|
||||||
|
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(tl.int32)
|
||||||
|
|
||||||
|
# Extract 4-bit nibble
|
||||||
|
# BnB packing: high nibble = even element, low nibble = odd element
|
||||||
|
nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF)
|
||||||
|
|
||||||
|
# NF4 codebook lookup
|
||||||
|
# Load all 16 codebook values (small, fits in registers)
|
||||||
|
# Use gather from codebook pointer
|
||||||
|
code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0)
|
||||||
|
|
||||||
|
# Load absmax for this element's quantization block
|
||||||
|
block_idx = elem_offset // blocksize
|
||||||
|
absmax_global_offset = expert_global * blocks_per_expert + block_idx
|
||||||
|
absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0)
|
||||||
|
|
||||||
|
# Dequantize: value = codebook[nibble] * absmax
|
||||||
|
result = code_val * absmax_val
|
||||||
|
|
||||||
|
# Store to output
|
||||||
|
out_offset = expert_local_idx * stride_out_e + elem_offset
|
||||||
|
tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def selective_dequant_nf4_triton(
|
||||||
|
packed_data: torch.Tensor,
|
||||||
|
absmax: torch.Tensor,
|
||||||
|
active_experts: torch.Tensor,
|
||||||
|
expert_shape: tuple[int, int],
|
||||||
|
blocksize: int,
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
codebook: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Fused selective gather + NF4 dequantization via Triton kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1]
|
||||||
|
absmax: Per-block scaling factors [total_blocks]
|
||||||
|
active_experts: Sorted indices of experts to dequantize [num_active]
|
||||||
|
expert_shape: (dim1, dim2) per expert
|
||||||
|
blocksize: Quantization block size
|
||||||
|
dtype: Output dtype (default bf16)
|
||||||
|
codebook: NF4 lookup table [16] (uses default NF4 codebook if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dequantized weights [num_active, dim1, dim2]
|
||||||
|
"""
|
||||||
|
num_active = active_experts.shape[0]
|
||||||
|
expert_numel = expert_shape[0] * expert_shape[1]
|
||||||
|
packed_per_expert = expert_numel // 2
|
||||||
|
blocks_per_expert = expert_numel // blocksize
|
||||||
|
|
||||||
|
# Prepare codebook on device
|
||||||
|
if codebook is None:
|
||||||
|
codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32,
|
||||||
|
device=packed_data.device)
|
||||||
|
else:
|
||||||
|
codebook = codebook.to(device=packed_data.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Flatten inputs
|
||||||
|
packed_flat = packed_data.reshape(-1)
|
||||||
|
absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32
|
||||||
|
|
||||||
|
# Output buffer
|
||||||
|
out = torch.empty(num_active, expert_numel, dtype=dtype,
|
||||||
|
device=packed_data.device)
|
||||||
|
|
||||||
|
BLOCK_SIZE = 1024 # Process 1024 elements per thread block
|
||||||
|
|
||||||
|
grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE))
|
||||||
|
|
||||||
|
_selective_dequant_nf4_kernel[grid](
|
||||||
|
packed_flat,
|
||||||
|
absmax_flat,
|
||||||
|
active_experts,
|
||||||
|
codebook,
|
||||||
|
out,
|
||||||
|
out.stride(0),
|
||||||
|
num_active=num_active,
|
||||||
|
packed_per_expert=packed_per_expert,
|
||||||
|
blocks_per_expert=blocks_per_expert,
|
||||||
|
blocksize=blocksize,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out.reshape(num_active, *expert_shape)
|
||||||
Reference in New Issue
Block a user