handle base+lora split kernel for older moe models
This commit is contained in:
@@ -134,7 +134,9 @@ def main():
|
||||
_clean()
|
||||
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R)
|
||||
|
||||
# Forward with LoRA
|
||||
# Forward with LoRA (auto-dispatched: fused or split)
|
||||
dispatch = "split" if (E <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD) else "fused"
|
||||
t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=k, lora_A=lA, lora_B=lB, scaling=2.0,
|
||||
@@ -162,7 +164,8 @@ def main():
|
||||
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
||||
|
||||
print(f" R={R:>2} {proj:<8} "
|
||||
f"fwd={t_fwd:>6.2f}ms base={t_base:>6.2f}ms "
|
||||
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
||||
f"base={t_base:>6.2f}ms "
|
||||
f"(+{overhead*100:.0f}%) "
|
||||
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||
f"total={total:>6.2f}ms")
|
||||
@@ -184,7 +187,17 @@ def main():
|
||||
lB_ag.grad = None
|
||||
|
||||
t_full = _bench(_run_autograd)
|
||||
print(f" full_fwd_bwd={t_full:>6.2f}ms")
|
||||
|
||||
# Memory measurement
|
||||
_clean()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_before = torch.cuda.memory_allocated()
|
||||
_run_autograd()
|
||||
torch.cuda.synchronize()
|
||||
mem_peak = torch.cuda.max_memory_allocated() - mem_before
|
||||
|
||||
print(f" full_fwd_bwd={t_full:>6.2f}ms "
|
||||
f"peak_delta={mem_peak/1e6:>6.1f}MB")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
@@ -533,6 +533,78 @@ def _scatter2scatter_lora(
|
||||
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
||||
|
||||
|
||||
def _scatter2scatter_lora_split(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
sorted_expert_idxs: torch.Tensor,
|
||||
sorted_scattered_idxs: torch.Tensor,
|
||||
k: int,
|
||||
lora_A: torch.Tensor,
|
||||
lora_B: torch.Tensor,
|
||||
scaling: float,
|
||||
b: Optional[torch.Tensor] = None,
|
||||
x_grouped: bool = False,
|
||||
y_grouped: bool = False,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
|
||||
|
||||
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
|
||||
because the base kernel runs at full speed without LoRA SMEM overhead,
|
||||
and the LoRA matmuls (R=16) are tiny separate passes.
|
||||
|
||||
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
|
||||
"""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
|
||||
scatter2scatter,
|
||||
)
|
||||
|
||||
E = W.size(0)
|
||||
R = lora_A.size(0) // E
|
||||
K = W.size(1)
|
||||
N = W.size(2)
|
||||
|
||||
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
|
||||
output = scatter2scatter(
|
||||
X=X, W=W, b=b,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k, x_grouped=x_grouped, y_grouped=y_grouped, out=out,
|
||||
)
|
||||
|
||||
# 2. XA = X @ A^T (tiny: output is [M*k, R])
|
||||
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
|
||||
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
|
||||
XA = scatter2scatter(
|
||||
X=X, W=W_A,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k, x_grouped=x_grouped, y_grouped=True,
|
||||
)
|
||||
|
||||
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
|
||||
# Reshape B: [N, R*E] → [E, R, N]
|
||||
W_B = lora_B.T.reshape(E, R, N).contiguous()
|
||||
Y_lora = scatter2scatter(
|
||||
X=XA, W=W_B,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=1, x_grouped=True, y_grouped=y_grouped,
|
||||
)
|
||||
|
||||
# 4. Y = Y_base + scaling * Y_lora
|
||||
output.add_(Y_lora, alpha=scaling)
|
||||
return output
|
||||
|
||||
|
||||
# Threshold for switching from fused to split LoRA forward.
|
||||
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
|
||||
# loads dominate in the fused kernel's inner loop).
|
||||
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
|
||||
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
|
||||
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
|
||||
|
||||
|
||||
def scatter2scatter_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
@@ -548,7 +620,13 @@ def scatter2scatter_lora(
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
||||
Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
||||
|
||||
Automatically selects between:
|
||||
- Fused kernel: single Triton kernel with LoRA in the inner loop.
|
||||
Best for many small experts (E>=64, small K*N).
|
||||
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
|
||||
Best for few large experts (E<=32, large K*N like Mixtral).
|
||||
|
||||
Args:
|
||||
X: Input [M, K] or [M*k, K] if x_grouped
|
||||
@@ -567,12 +645,23 @@ def scatter2scatter_lora(
|
||||
Returns:
|
||||
Y: Output [M*k, N]
|
||||
"""
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
|
||||
E = W.size(0)
|
||||
K = W.size(1)
|
||||
N = W.size(2)
|
||||
|
||||
# Dispatch: split for few large experts, fused for many small experts
|
||||
if (
|
||||
E <= _SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= _SPLIT_LORA_FWD_THRESHOLD
|
||||
):
|
||||
return _scatter2scatter_lora_split(
|
||||
X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
|
||||
lora_A, lora_B, scaling, b, x_grouped, y_grouped, out,
|
||||
)
|
||||
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||
|
||||
R = lora_A.size(0) // E
|
||||
|
||||
# Pad R to power of 2 for Triton tile size
|
||||
@@ -612,11 +701,9 @@ def scatter2scatter_lora(
|
||||
b_ptr,
|
||||
stride_be,
|
||||
stride_bn,
|
||||
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
|
||||
lora_A,
|
||||
lora_A.stride(0),
|
||||
lora_A.stride(1),
|
||||
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
|
||||
lora_B,
|
||||
lora_B.stride(0),
|
||||
lora_B.stride(1),
|
||||
@@ -627,9 +714,8 @@ def scatter2scatter_lora(
|
||||
K=K,
|
||||
N=N,
|
||||
E=E,
|
||||
ACTUAL_R=R, # True LoRA rank for weight indexing
|
||||
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
|
||||
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
ACC_TYPE=tl.float32,
|
||||
scaling=scaling,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
|
||||
@@ -294,6 +294,34 @@ class TestScatterMoELoRAAutograd:
|
||||
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
|
||||
|
||||
|
||||
def test_split_matches_fused(self):
|
||||
"""Split dispatch (for few large experts) matches fused kernel."""
|
||||
# Use a shape where split would be dispatched (large K*N, few E)
|
||||
E, K, N, T, k, R = 8, 512, 1024, 128, 2, 16
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
# Force fused path
|
||||
orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18
|
||||
out_fused = lora_ops.scatter2scatter_lora(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||
)
|
||||
|
||||
# Force split path
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0
|
||||
out_split = lora_ops.scatter2scatter_lora(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||
)
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig
|
||||
|
||||
norm_err = (
|
||||
(out_fused.float() - out_split.float()).norm()
|
||||
/ (out_fused.float().norm() + 1e-6)
|
||||
).item()
|
||||
assert norm_err < 0.01, f"split vs fused norm_err={norm_err}"
|
||||
|
||||
def test_scaling_zero_gives_base_only(self):
|
||||
"""With scaling=0.0, LoRA contribution vanishes. Output = X@W."""
|
||||
E, K, N, T, k, R = 16, 64, 32, 32, 2, 4
|
||||
|
||||
Reference in New Issue
Block a user