diff --git a/benchmarks/bench_scattermoe_lora.py b/benchmarks/bench_scattermoe_lora.py index 3b995c1ff..a1d80f598 100644 --- a/benchmarks/bench_scattermoe_lora.py +++ b/benchmarks/bench_scattermoe_lora.py @@ -134,7 +134,9 @@ def main(): _clean() x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R) - # Forward with LoRA + # Forward with LoRA (auto-dispatched: fused or split) + dispatch = "split" if (E <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS + and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD) else "fused" t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora( X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k, lora_A=lA, lora_B=lB, scaling=2.0, @@ -162,7 +164,8 @@ def main(): overhead = t_fwd / t_base - 1 if t_base > 0 else 0 print(f" R={R:>2} {proj:<8} " - f"fwd={t_fwd:>6.2f}ms base={t_base:>6.2f}ms " + f"fwd={t_fwd:>6.2f}ms [{dispatch}] " + f"base={t_base:>6.2f}ms " f"(+{overhead*100:.0f}%) " f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms " f"total={total:>6.2f}ms") @@ -184,7 +187,17 @@ def main(): lB_ag.grad = None t_full = _bench(_run_autograd) - print(f" full_fwd_bwd={t_full:>6.2f}ms") + + # Memory measurement + _clean() + torch.cuda.reset_peak_memory_stats() + mem_before = torch.cuda.memory_allocated() + _run_autograd() + torch.cuda.synchronize() + mem_peak = torch.cuda.max_memory_allocated() - mem_before + + print(f" full_fwd_bwd={t_full:>6.2f}ms " + f"peak_delta={mem_peak/1e6:>6.1f}MB") print() diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py index f858077c7..731b36645 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py @@ -533,6 +533,78 @@ def _scatter2scatter_lora( tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :]) +def _scatter2scatter_lora_split( + X: torch.Tensor, + W: torch.Tensor, + sorted_expert_idxs: torch.Tensor, + sorted_scattered_idxs: torch.Tensor, + k: int, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + scaling: float, + b: Optional[torch.Tensor] = None, + x_grouped: bool = False, + y_grouped: bool = False, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel. + + Faster for models with few large experts (e.g. Mixtral E=8, I=14336) + because the base kernel runs at full speed without LoRA SMEM overhead, + and the LoRA matmuls (R=16) are tiny separate passes. + + Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T) + """ + from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import ( + scatter2scatter, + ) + + E = W.size(0) + R = lora_A.size(0) // E + K = W.size(1) + N = W.size(2) + + # 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes) + output = scatter2scatter( + X=X, W=W, b=b, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=k, x_grouped=x_grouped, y_grouped=y_grouped, out=out, + ) + + # 2. XA = X @ A^T (tiny: output is [M*k, R]) + # Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter) + W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous() + XA = scatter2scatter( + X=X, W=W_A, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=k, x_grouped=x_grouped, y_grouped=True, + ) + + # 3. Y_lora = XA @ B^T (R is tiny, so this is very fast) + # Reshape B: [N, R*E] → [E, R, N] + W_B = lora_B.T.reshape(E, R, N).contiguous() + Y_lora = scatter2scatter( + X=XA, W=W_B, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + k=1, x_grouped=True, y_grouped=y_grouped, + ) + + # 4. Y = Y_base + scaling * Y_lora + output.add_(Y_lora, alpha=scaling) + return output + + +# Threshold for switching from fused to split LoRA forward. +# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile +# loads dominate in the fused kernel's inner loop). +# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE). +_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N +_SPLIT_LORA_FWD_MAX_EXPERTS = 32 + + def scatter2scatter_lora( X: torch.Tensor, W: torch.Tensor, @@ -548,7 +620,13 @@ def scatter2scatter_lora( out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e] + Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e] + + Automatically selects between: + - Fused kernel: single Triton kernel with LoRA in the inner loop. + Best for many small experts (E>=64, small K*N). + - Split dispatch: 3 separate scatter2scatter calls (base + XA + lora). + Best for few large experts (E<=32, large K*N like Mixtral). Args: X: Input [M, K] or [M*k, K] if x_grouped @@ -567,12 +645,23 @@ def scatter2scatter_lora( Returns: Y: Output [M*k, N] """ - assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) - assert sorted_scattered_idxs.size(0) == X.size(0) * k - E = W.size(0) K = W.size(1) N = W.size(2) + + # Dispatch: split for few large experts, fused for many small experts + if ( + E <= _SPLIT_LORA_FWD_MAX_EXPERTS + and K * N >= _SPLIT_LORA_FWD_THRESHOLD + ): + return _scatter2scatter_lora_split( + X, W, sorted_expert_idxs, sorted_scattered_idxs, k, + lora_A, lora_B, scaling, b, x_grouped, y_grouped, out, + ) + + assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) + assert sorted_scattered_idxs.size(0) == X.size(0) * k + R = lora_A.size(0) // E # Pad R to power of 2 for Triton tile size @@ -612,11 +701,9 @@ def scatter2scatter_lora( b_ptr, stride_be, stride_bn, - # A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride lora_A, lora_A.stride(0), lora_A.stride(1), - # B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride lora_B, lora_B.stride(0), lora_B.stride(1), @@ -627,9 +714,8 @@ def scatter2scatter_lora( K=K, N=N, E=E, - ACTUAL_R=R, # True LoRA rank for weight indexing - # BLOCK_M is autotuned (injected by triton.autotune from Config kwargs) - BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16) + ACTUAL_R=R, + BLOCK_R=BLOCK_R, ACC_TYPE=tl.float32, scaling=scaling, allow_tf32=ALLOW_TF32, diff --git a/tests/integrations/test_scattermoe_lora_kernels.py b/tests/integrations/test_scattermoe_lora_kernels.py index fa6dc72f5..708bf6e56 100644 --- a/tests/integrations/test_scattermoe_lora_kernels.py +++ b/tests/integrations/test_scattermoe_lora_kernels.py @@ -294,6 +294,34 @@ class TestScatterMoELoRAAutograd: assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero" + def test_split_matches_fused(self): + """Split dispatch (for few large experts) matches fused kernel.""" + # Use a shape where split would be dispatched (large K*N, few E) + E, K, N, T, k, R = 8, 512, 1024, 128, 2, 16 + x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) + + # Force fused path + orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD + lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18 + out_fused = lora_ops.scatter2scatter_lora( + X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, + k=k, lora_A=lA, lora_B=lB, scaling=SCALING, + ) + + # Force split path + lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0 + out_split = lora_ops.scatter2scatter_lora( + X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, + k=k, lora_A=lA, lora_B=lB, scaling=SCALING, + ) + lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig + + norm_err = ( + (out_fused.float() - out_split.float()).norm() + / (out_fused.float().norm() + 1e-6) + ).item() + assert norm_err < 0.01, f"split vs fused norm_err={norm_err}" + def test_scaling_zero_gives_base_only(self): """With scaling=0.0, LoRA contribution vanishes. Output = X@W.""" E, K, N, T, k, R = 16, 64, 32, 32, 2, 4