handle base+lora split kernel for older moe models

This commit is contained in:
Wing Lian
2026-03-19 07:11:30 +00:00
parent 66fea258c7
commit 31d8d068bb
3 changed files with 139 additions and 12 deletions

View File

@@ -134,7 +134,9 @@ def main():
_clean() _clean()
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R) x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R)
# Forward with LoRA # Forward with LoRA (auto-dispatched: fused or split)
dispatch = "split" if (E <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD) else "fused"
t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora( t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=2.0, k=k, lora_A=lA, lora_B=lB, scaling=2.0,
@@ -162,7 +164,8 @@ def main():
overhead = t_fwd / t_base - 1 if t_base > 0 else 0 overhead = t_fwd / t_base - 1 if t_base > 0 else 0
print(f" R={R:>2} {proj:<8} " print(f" R={R:>2} {proj:<8} "
f"fwd={t_fwd:>6.2f}ms base={t_base:>6.2f}ms " f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
f"base={t_base:>6.2f}ms "
f"(+{overhead*100:.0f}%) " f"(+{overhead*100:.0f}%) "
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms " f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
f"total={total:>6.2f}ms") f"total={total:>6.2f}ms")
@@ -184,7 +187,17 @@ def main():
lB_ag.grad = None lB_ag.grad = None
t_full = _bench(_run_autograd) t_full = _bench(_run_autograd)
print(f" full_fwd_bwd={t_full:>6.2f}ms")
# Memory measurement
_clean()
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
_run_autograd()
torch.cuda.synchronize()
mem_peak = torch.cuda.max_memory_allocated() - mem_before
print(f" full_fwd_bwd={t_full:>6.2f}ms "
f"peak_delta={mem_peak/1e6:>6.1f}MB")
print() print()

View File

@@ -533,6 +533,78 @@ def _scatter2scatter_lora(
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :]) tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
def _scatter2scatter_lora_split(
X: torch.Tensor,
W: torch.Tensor,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
k: int,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
scaling: float,
b: Optional[torch.Tensor] = None,
x_grouped: bool = False,
y_grouped: bool = False,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
because the base kernel runs at full speed without LoRA SMEM overhead,
and the LoRA matmuls (R=16) are tiny separate passes.
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
"""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
scatter2scatter,
)
E = W.size(0)
R = lora_A.size(0) // E
K = W.size(1)
N = W.size(2)
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
output = scatter2scatter(
X=X, W=W, b=b,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k, x_grouped=x_grouped, y_grouped=y_grouped, out=out,
)
# 2. XA = X @ A^T (tiny: output is [M*k, R])
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
XA = scatter2scatter(
X=X, W=W_A,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k, x_grouped=x_grouped, y_grouped=True,
)
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
# Reshape B: [N, R*E] → [E, R, N]
W_B = lora_B.T.reshape(E, R, N).contiguous()
Y_lora = scatter2scatter(
X=XA, W=W_B,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1, x_grouped=True, y_grouped=y_grouped,
)
# 4. Y = Y_base + scaling * Y_lora
output.add_(Y_lora, alpha=scaling)
return output
# Threshold for switching from fused to split LoRA forward.
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
# loads dominate in the fused kernel's inner loop).
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
def scatter2scatter_lora( def scatter2scatter_lora(
X: torch.Tensor, X: torch.Tensor,
W: torch.Tensor, W: torch.Tensor,
@@ -548,7 +620,13 @@ def scatter2scatter_lora(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e] Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
Automatically selects between:
- Fused kernel: single Triton kernel with LoRA in the inner loop.
Best for many small experts (E>=64, small K*N).
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
Best for few large experts (E<=32, large K*N like Mixtral).
Args: Args:
X: Input [M, K] or [M*k, K] if x_grouped X: Input [M, K] or [M*k, K] if x_grouped
@@ -567,12 +645,23 @@ def scatter2scatter_lora(
Returns: Returns:
Y: Output [M*k, N] Y: Output [M*k, N]
""" """
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
E = W.size(0) E = W.size(0)
K = W.size(1) K = W.size(1)
N = W.size(2) N = W.size(2)
# Dispatch: split for few large experts, fused for many small experts
if (
E <= _SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= _SPLIT_LORA_FWD_THRESHOLD
):
return _scatter2scatter_lora_split(
X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
lora_A, lora_B, scaling, b, x_grouped, y_grouped, out,
)
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
R = lora_A.size(0) // E R = lora_A.size(0) // E
# Pad R to power of 2 for Triton tile size # Pad R to power of 2 for Triton tile size
@@ -612,11 +701,9 @@ def scatter2scatter_lora(
b_ptr, b_ptr,
stride_be, stride_be,
stride_bn, stride_bn,
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
lora_A, lora_A,
lora_A.stride(0), lora_A.stride(0),
lora_A.stride(1), lora_A.stride(1),
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
lora_B, lora_B,
lora_B.stride(0), lora_B.stride(0),
lora_B.stride(1), lora_B.stride(1),
@@ -627,9 +714,8 @@ def scatter2scatter_lora(
K=K, K=K,
N=N, N=N,
E=E, E=E,
ACTUAL_R=R, # True LoRA rank for weight indexing ACTUAL_R=R,
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs) BLOCK_R=BLOCK_R,
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
ACC_TYPE=tl.float32, ACC_TYPE=tl.float32,
scaling=scaling, scaling=scaling,
allow_tf32=ALLOW_TF32, allow_tf32=ALLOW_TF32,

View File

@@ -294,6 +294,34 @@ class TestScatterMoELoRAAutograd:
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero" assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
def test_split_matches_fused(self):
"""Split dispatch (for few large experts) matches fused kernel."""
# Use a shape where split would be dispatched (large K*N, few E)
E, K, N, T, k, R = 8, 512, 1024, 128, 2, 16
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
# Force fused path
orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18
out_fused = lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
)
# Force split path
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0
out_split = lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
)
lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig
norm_err = (
(out_fused.float() - out_split.float()).norm()
/ (out_fused.float().norm() + 1e-6)
).item()
assert norm_err < 0.01, f"split vs fused norm_err={norm_err}"
def test_scaling_zero_gives_base_only(self): def test_scaling_zero_gives_base_only(self):
"""With scaling=0.0, LoRA contribution vanishes. Output = X@W.""" """With scaling=0.0, LoRA contribution vanishes. Output = X@W."""
E, K, N, T, k, R = 16, 64, 32, 32, 2, 4 E, K, N, T, k, R = 16, 64, 32, 32, 2, 4