handle base+lora split kernel for older moe models
This commit is contained in:
@@ -533,6 +533,78 @@ def _scatter2scatter_lora(
|
||||
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
||||
|
||||
|
||||
def _scatter2scatter_lora_split(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
sorted_scattered_idxs: torch.Tensor,
|
||||
k: int,
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
scaling: float,
|
||||
b: Optional[torch.Tensor] = None,
|
||||
x_grouped: bool = False,
|
||||
y_grouped: bool = False,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
|
||||
|
||||
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
|
||||
because the base kernel runs at full speed without LoRA SMEM overhead,
|
||||
and the LoRA matmuls (R=16) are tiny separate passes.
|
||||
|
||||
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
|
||||
"""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
|
||||
scatter2scatter,
|
||||
)
|
||||
|
||||
E = W.size(0)
|
||||
R = lora_A.size(0) // E
|
||||
K = W.size(1)
|
||||
N = W.size(2)
|
||||
|
||||
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
|
||||
output = scatter2scatter(
|
||||
X=X, W=W, b=b,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k, x_grouped=x_grouped, y_grouped=y_grouped, out=out,
|
||||
)
|
||||
|
||||
# 2. XA = X @ A^T (tiny: output is [M*k, R])
|
||||
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
|
||||
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
|
||||
XA = scatter2scatter(
|
||||
X=X, W=W_A,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k, x_grouped=x_grouped, y_grouped=True,
|
||||
)
|
||||
|
||||
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
|
||||
# Reshape B: [N, R*E] → [E, R, N]
|
||||
W_B = lora_B.T.reshape(E, R, N).contiguous()
|
||||
Y_lora = scatter2scatter(
|
||||
X=XA, W=W_B,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=1, x_grouped=True, y_grouped=y_grouped,
|
||||
)
|
||||
|
||||
# 4. Y = Y_base + scaling * Y_lora
|
||||
output.add_(Y_lora, alpha=scaling)
|
||||
return output
|
||||
|
||||
|
||||
# Threshold for switching from fused to split LoRA forward.
|
||||
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
|
||||
# loads dominate in the fused kernel's inner loop).
|
||||
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
|
||||
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
|
||||
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
|
||||
|
||||
|
||||
def scatter2scatter_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
@@ -548,7 +620,13 @@ def scatter2scatter_lora(
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
||||
Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
||||
|
||||
Automatically selects between:
|
||||
- Fused kernel: single Triton kernel with LoRA in the inner loop.
|
||||
Best for many small experts (E>=64, small K*N).
|
||||
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
|
||||
Best for few large experts (E<=32, large K*N like Mixtral).
|
||||
|
||||
Args:
|
||||
X: Input [M, K] or [M*k, K] if x_grouped
|
||||
@@ -567,12 +645,23 @@ def scatter2scatter_lora(
|
||||
Returns:
|
||||
Y: Output [M*k, N]
|
||||
"""
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
|
||||
E = W.size(0)
|
||||
K = W.size(1)
|
||||
N = W.size(2)
|
||||
|
||||
# Dispatch: split for few large experts, fused for many small experts
|
||||
if (
|
||||
E <= _SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= _SPLIT_LORA_FWD_THRESHOLD
|
||||
):
|
||||
return _scatter2scatter_lora_split(
|
||||
X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
|
||||
lora_A, lora_B, scaling, b, x_grouped, y_grouped, out,
|
||||
)
|
||||
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
|
||||
R = lora_A.size(0) // E
|
||||
|
||||
# Pad R to power of 2 for Triton tile size
|
||||
@@ -612,11 +701,9 @@ def scatter2scatter_lora(
|
||||
b_ptr,
|
||||
stride_be,
|
||||
stride_bn,
|
||||
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
|
||||
lora_A,
|
||||
lora_A.stride(0),
|
||||
lora_A.stride(1),
|
||||
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
|
||||
lora_B,
|
||||
lora_B.stride(0),
|
||||
lora_B.stride(1),
|
||||
@@ -627,9 +714,8 @@ def scatter2scatter_lora(
|
||||
K=K,
|
||||
N=N,
|
||||
E=E,
|
||||
ACTUAL_R=R, # True LoRA rank for weight indexing
|
||||
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
|
||||
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
ACC_TYPE=tl.float32,
|
||||
scaling=scaling,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
|
||||
Reference in New Issue
Block a user