handle base+lora split kernel for older moe models
This commit is contained in:
@@ -134,7 +134,9 @@ def main():
|
|||||||
_clean()
|
_clean()
|
||||||
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R)
|
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R)
|
||||||
|
|
||||||
# Forward with LoRA
|
# Forward with LoRA (auto-dispatched: fused or split)
|
||||||
|
dispatch = "split" if (E <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
||||||
|
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD) else "fused"
|
||||||
t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora(
|
t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora(
|
||||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||||
k=k, lora_A=lA, lora_B=lB, scaling=2.0,
|
k=k, lora_A=lA, lora_B=lB, scaling=2.0,
|
||||||
@@ -162,7 +164,8 @@ def main():
|
|||||||
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
||||||
|
|
||||||
print(f" R={R:>2} {proj:<8} "
|
print(f" R={R:>2} {proj:<8} "
|
||||||
f"fwd={t_fwd:>6.2f}ms base={t_base:>6.2f}ms "
|
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
||||||
|
f"base={t_base:>6.2f}ms "
|
||||||
f"(+{overhead*100:.0f}%) "
|
f"(+{overhead*100:.0f}%) "
|
||||||
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||||
f"total={total:>6.2f}ms")
|
f"total={total:>6.2f}ms")
|
||||||
@@ -184,7 +187,17 @@ def main():
|
|||||||
lB_ag.grad = None
|
lB_ag.grad = None
|
||||||
|
|
||||||
t_full = _bench(_run_autograd)
|
t_full = _bench(_run_autograd)
|
||||||
print(f" full_fwd_bwd={t_full:>6.2f}ms")
|
|
||||||
|
# Memory measurement
|
||||||
|
_clean()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
mem_before = torch.cuda.memory_allocated()
|
||||||
|
_run_autograd()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
mem_peak = torch.cuda.max_memory_allocated() - mem_before
|
||||||
|
|
||||||
|
print(f" full_fwd_bwd={t_full:>6.2f}ms "
|
||||||
|
f"peak_delta={mem_peak/1e6:>6.1f}MB")
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|||||||
@@ -533,6 +533,78 @@ def _scatter2scatter_lora(
|
|||||||
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
||||||
|
|
||||||
|
|
||||||
|
def _scatter2scatter_lora_split(
|
||||||
|
X: torch.Tensor,
|
||||||
|
W: torch.Tensor,
|
||||||
|
sorted_expert_idxs: torch.Tensor,
|
||||||
|
sorted_scattered_idxs: torch.Tensor,
|
||||||
|
k: int,
|
||||||
|
lora_A: torch.Tensor,
|
||||||
|
lora_B: torch.Tensor,
|
||||||
|
scaling: float,
|
||||||
|
b: Optional[torch.Tensor] = None,
|
||||||
|
x_grouped: bool = False,
|
||||||
|
y_grouped: bool = False,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
|
||||||
|
|
||||||
|
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
|
||||||
|
because the base kernel runs at full speed without LoRA SMEM overhead,
|
||||||
|
and the LoRA matmuls (R=16) are tiny separate passes.
|
||||||
|
|
||||||
|
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
|
||||||
|
"""
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
|
||||||
|
scatter2scatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
E = W.size(0)
|
||||||
|
R = lora_A.size(0) // E
|
||||||
|
K = W.size(1)
|
||||||
|
N = W.size(2)
|
||||||
|
|
||||||
|
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
|
||||||
|
output = scatter2scatter(
|
||||||
|
X=X, W=W, b=b,
|
||||||
|
sorted_expert_idxs=sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||||
|
k=k, x_grouped=x_grouped, y_grouped=y_grouped, out=out,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. XA = X @ A^T (tiny: output is [M*k, R])
|
||||||
|
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
|
||||||
|
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
|
||||||
|
XA = scatter2scatter(
|
||||||
|
X=X, W=W_A,
|
||||||
|
sorted_expert_idxs=sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||||
|
k=k, x_grouped=x_grouped, y_grouped=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
|
||||||
|
# Reshape B: [N, R*E] → [E, R, N]
|
||||||
|
W_B = lora_B.T.reshape(E, R, N).contiguous()
|
||||||
|
Y_lora = scatter2scatter(
|
||||||
|
X=XA, W=W_B,
|
||||||
|
sorted_expert_idxs=sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||||
|
k=1, x_grouped=True, y_grouped=y_grouped,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Y = Y_base + scaling * Y_lora
|
||||||
|
output.add_(Y_lora, alpha=scaling)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# Threshold for switching from fused to split LoRA forward.
|
||||||
|
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
|
||||||
|
# loads dominate in the fused kernel's inner loop).
|
||||||
|
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
|
||||||
|
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
|
||||||
|
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
|
||||||
|
|
||||||
|
|
||||||
def scatter2scatter_lora(
|
def scatter2scatter_lora(
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
W: torch.Tensor,
|
W: torch.Tensor,
|
||||||
@@ -548,7 +620,13 @@ def scatter2scatter_lora(
|
|||||||
out: Optional[torch.Tensor] = None,
|
out: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
||||||
|
|
||||||
|
Automatically selects between:
|
||||||
|
- Fused kernel: single Triton kernel with LoRA in the inner loop.
|
||||||
|
Best for many small experts (E>=64, small K*N).
|
||||||
|
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
|
||||||
|
Best for few large experts (E<=32, large K*N like Mixtral).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
X: Input [M, K] or [M*k, K] if x_grouped
|
X: Input [M, K] or [M*k, K] if x_grouped
|
||||||
@@ -567,12 +645,23 @@ def scatter2scatter_lora(
|
|||||||
Returns:
|
Returns:
|
||||||
Y: Output [M*k, N]
|
Y: Output [M*k, N]
|
||||||
"""
|
"""
|
||||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
|
||||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
|
||||||
|
|
||||||
E = W.size(0)
|
E = W.size(0)
|
||||||
K = W.size(1)
|
K = W.size(1)
|
||||||
N = W.size(2)
|
N = W.size(2)
|
||||||
|
|
||||||
|
# Dispatch: split for few large experts, fused for many small experts
|
||||||
|
if (
|
||||||
|
E <= _SPLIT_LORA_FWD_MAX_EXPERTS
|
||||||
|
and K * N >= _SPLIT_LORA_FWD_THRESHOLD
|
||||||
|
):
|
||||||
|
return _scatter2scatter_lora_split(
|
||||||
|
X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
|
||||||
|
lora_A, lora_B, scaling, b, x_grouped, y_grouped, out,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||||
|
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||||
|
|
||||||
R = lora_A.size(0) // E
|
R = lora_A.size(0) // E
|
||||||
|
|
||||||
# Pad R to power of 2 for Triton tile size
|
# Pad R to power of 2 for Triton tile size
|
||||||
@@ -612,11 +701,9 @@ def scatter2scatter_lora(
|
|||||||
b_ptr,
|
b_ptr,
|
||||||
stride_be,
|
stride_be,
|
||||||
stride_bn,
|
stride_bn,
|
||||||
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
|
|
||||||
lora_A,
|
lora_A,
|
||||||
lora_A.stride(0),
|
lora_A.stride(0),
|
||||||
lora_A.stride(1),
|
lora_A.stride(1),
|
||||||
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
|
|
||||||
lora_B,
|
lora_B,
|
||||||
lora_B.stride(0),
|
lora_B.stride(0),
|
||||||
lora_B.stride(1),
|
lora_B.stride(1),
|
||||||
@@ -627,9 +714,8 @@ def scatter2scatter_lora(
|
|||||||
K=K,
|
K=K,
|
||||||
N=N,
|
N=N,
|
||||||
E=E,
|
E=E,
|
||||||
ACTUAL_R=R, # True LoRA rank for weight indexing
|
ACTUAL_R=R,
|
||||||
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
|
BLOCK_R=BLOCK_R,
|
||||||
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
|
|
||||||
ACC_TYPE=tl.float32,
|
ACC_TYPE=tl.float32,
|
||||||
scaling=scaling,
|
scaling=scaling,
|
||||||
allow_tf32=ALLOW_TF32,
|
allow_tf32=ALLOW_TF32,
|
||||||
|
|||||||
@@ -294,6 +294,34 @@ class TestScatterMoELoRAAutograd:
|
|||||||
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
|
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_matches_fused(self):
|
||||||
|
"""Split dispatch (for few large experts) matches fused kernel."""
|
||||||
|
# Use a shape where split would be dispatched (large K*N, few E)
|
||||||
|
E, K, N, T, k, R = 8, 512, 1024, 128, 2, 16
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
|
||||||
|
# Force fused path
|
||||||
|
orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||||
|
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18
|
||||||
|
out_fused = lora_ops.scatter2scatter_lora(
|
||||||
|
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||||
|
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force split path
|
||||||
|
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0
|
||||||
|
out_split = lora_ops.scatter2scatter_lora(
|
||||||
|
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||||
|
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||||
|
)
|
||||||
|
lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig
|
||||||
|
|
||||||
|
norm_err = (
|
||||||
|
(out_fused.float() - out_split.float()).norm()
|
||||||
|
/ (out_fused.float().norm() + 1e-6)
|
||||||
|
).item()
|
||||||
|
assert norm_err < 0.01, f"split vs fused norm_err={norm_err}"
|
||||||
|
|
||||||
def test_scaling_zero_gives_base_only(self):
|
def test_scaling_zero_gives_base_only(self):
|
||||||
"""With scaling=0.0, LoRA contribution vanishes. Output = X@W."""
|
"""With scaling=0.0, LoRA contribution vanishes. Output = X@W."""
|
||||||
E, K, N, T, k, R = 16, 64, 32, 32, 2, 4
|
E, K, N, T, k, R = 16, 64, 32, 32, 2, 4
|
||||||
|
|||||||
Reference in New Issue
Block a user