diff --git a/benchmarks/bench_scattermoe_lora.py b/benchmarks/bench_scattermoe_lora.py index a1d80f598..0fb3ba68c 100644 --- a/benchmarks/bench_scattermoe_lora.py +++ b/benchmarks/bench_scattermoe_lora.py @@ -12,14 +12,14 @@ Usage: import argparse import gc -import statistics import time +from functools import partial import torch from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import ( - ops as base_ops, lora_ops, + ops as base_ops, ) from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import ( flatten_sort_count, @@ -36,7 +36,7 @@ ITERS = 20 # ─── Model configs ────────────────────────────────────────────────────────── BUILTIN_CONFIGS = { - "Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k + "Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k "Qwen3-30B-A3B": (128, 2048, 768, 8), "OLMoE-1B-7B": (64, 2048, 1024, 8), "Mixtral-8x7B": (8, 4096, 14336, 2), @@ -50,26 +50,32 @@ def _resolve_config(spec): if key in name.lower() or name.lower() in key: return name, cfg - # Try HuggingFace AutoConfig from transformers import AutoConfig + hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True) if callable(getattr(hf_cfg, "get_text_config", None)): tc = hf_cfg.get_text_config() if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type: hf_cfg = tc - H = hf_cfg.hidden_size - I = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size - E = (getattr(hf_cfg, "num_experts", None) - or getattr(hf_cfg, "num_local_experts", None) - or getattr(hf_cfg, "n_routed_experts", None)) - k = (getattr(hf_cfg, "num_experts_per_tok", None) - or getattr(hf_cfg, "num_experts_per_token", None) or 2) + hidden = hf_cfg.hidden_size + inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size + experts = ( + getattr(hf_cfg, "num_experts", None) + or getattr(hf_cfg, "num_local_experts", None) + or getattr(hf_cfg, "n_routed_experts", None) + ) + top_k = ( + getattr(hf_cfg, "num_experts_per_tok", None) + or getattr(hf_cfg, "num_experts_per_token", None) + or 2 + ) name = spec.split("/")[-1] - return name, (E, H, I, k) + return name, (experts, hidden, inter, top_k) # ─── Benchmark helpers ────────────────────────────────────────────────────── + def _clean(): gc.collect() torch.cuda.empty_cache() @@ -87,29 +93,88 @@ def _bench(fn, warmup=WARMUP, iters=ITERS): fn() torch.cuda.synchronize() times.append((time.perf_counter() - t0) * 1000) - return statistics.median(times) + times.sort() + return times[len(times) // 2] -def _setup(E, K, N, T, top_k, R): +def _setup(num_experts, K, N, T, top_k, R): torch.manual_seed(42) x = torch.randn(T, K, device=DEVICE, dtype=DTYPE) - W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02 - lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01 - lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01 - logits = torch.randn(T, E, device=DEVICE) + W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02 + lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01 + lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01 + logits = torch.randn(T, num_experts, device=DEVICE) _, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1) - sei, ssi, eo = flatten_sort_count(top_idx, E) + sei, ssi, eo = flatten_sort_count(top_idx, num_experts) gx = base_ops.group(x, ssi, fan_out=top_k) dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE) return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy +# ─── Kernel wrappers (avoid B023 loop-variable capture) ────────────────────── + + +def _call_fwd(x, W, sei, ssi, top_k, lA, lB): + return lora_ops.scatter2scatter_lora( + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=top_k, + lora_A=lA, + lora_B=lB, + scaling=2.0, + ) + + +def _call_base(x, W, sei, ssi, top_k): + return base_ops.scatter2scatter( + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=top_k, + ) + + +def _call_dx(dy, W, sei, ssi, lA, lB): + return lora_ops.scatter2scatter_lora_dX( + DY=dy, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=1, + lora_A=lA, + lora_B=lB, + scaling=2.0, + dy_grouped=True, + dx_grouped=False, + ) + + +def _call_bwd(dy, gx, lA, lB, eo, num_experts): + return lora_ops.group_bwd_lora( + DY=dy, + X=gx, + lora_A=lA, + lora_B=lB, + expert_offsets=eo, + E=num_experts, + scaling=2.0, + ) + + # ─── Main ──────────────────────────────────────────────────────────────────── + def main(): parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark") - parser.add_argument("--models", "-m", nargs="+", - help="Model names or HF IDs (default: all builtins)") + parser.add_argument( + "--models", + "-m", + nargs="+", + help="Model names or HF IDs (default: all builtins)", + ) parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64]) parser.add_argument("--seq-len", "-T", type=int, default=2048) args = parser.parse_args() @@ -122,73 +187,84 @@ def main(): configs = [_resolve_config(m) for m in args.models] else: configs = list(BUILTIN_CONFIGS.items()) - configs = [(n, c) for n, c in configs] - for model_name, (E, H, I, k) in configs: + for model_name, (num_experts, hidden, inter, top_k) in configs: print(f"{'=' * 70}") - print(f" {model_name}: E={E}, H={H}, I={I}, k={k}") + print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}") print(f"{'=' * 70}") for R in args.ranks: - for proj, K, N in [("gate_up", H, 2 * I), ("down", I, H)]: + for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]: _clean() - x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R) + x, W, lA, lB, sei, ssi, eo, gx, dy = _setup( + num_experts, K, N, T, top_k, R + ) # Forward with LoRA (auto-dispatched: fused or split) - dispatch = "split" if (E <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS - and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD) else "fused" - t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora( - X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, - k=k, lora_A=lA, lora_B=lB, scaling=2.0, - )) - - # Forward without LoRA (base) - t_base = _bench(lambda: base_ops.scatter2scatter( - X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k, - )) - - # Backward dX - t_dx = _bench(lambda: lora_ops.scatter2scatter_lora_dX( - DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, - k=1, lora_A=lA, lora_B=lB, scaling=2.0, - dy_grouped=True, dx_grouped=False, - )) - - # Backward dA/dB - t_bwd = _bench(lambda: lora_ops.group_bwd_lora( - DY=dy, X=gx, lora_A=lA, lora_B=lB, - expert_offsets=eo, E=E, scaling=2.0, - )) + dispatch = ( + "split" + if ( + num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS + and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD + ) + else "fused" + ) + t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB)) + t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k)) + t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB)) + t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts)) total = t_fwd + t_dx + t_bwd overhead = t_fwd / t_base - 1 if t_base > 0 else 0 - print(f" R={R:>2} {proj:<8} " - f"fwd={t_fwd:>6.2f}ms [{dispatch}] " - f"base={t_base:>6.2f}ms " - f"(+{overhead*100:.0f}%) " - f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms " - f"total={total:>6.2f}ms") + print( + f" R={R:>2} {proj:<8} " + f"fwd={t_fwd:>6.2f}ms [{dispatch}] " + f"base={t_base:>6.2f}ms " + f"(+{overhead * 100:.0f}%) " + f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms " + f"total={total:>6.2f}ms" + ) - # Full autograd fwd+bwd + # Full autograd fwd+bwd with memory measurement x_ag = x.clone().requires_grad_(True) lA_ag = lA.clone().requires_grad_(True) lB_ag = lB.clone().requires_grad_(True) - def _run_autograd(): + def _run_autograd( + _x=x_ag, + _W=W, + _k=top_k, + _sei=sei, + _ssi=ssi, + _eo=eo, + _lA=lA_ag, + _lB=lB_ag, + ): out = ScatterMoELoRA.apply( - x_ag, W, k, sei, ssi, eo, - lA_ag, lB_ag, 2.0, - None, None, False, False, True, False, + _x, + _W, + _k, + _sei, + _ssi, + _eo, + _lA, + _lB, + 2.0, + None, + None, + False, + False, + True, + False, ) out.sum().backward() - x_ag.grad = None - lA_ag.grad = None - lB_ag.grad = None + _x.grad = None + _lA.grad = None + _lB.grad = None t_full = _bench(_run_autograd) - # Memory measurement _clean() torch.cuda.reset_peak_memory_stats() mem_before = torch.cuda.memory_allocated() @@ -196,8 +272,10 @@ def main(): torch.cuda.synchronize() mem_peak = torch.cuda.max_memory_allocated() - mem_before - print(f" full_fwd_bwd={t_full:>6.2f}ms " - f"peak_delta={mem_peak/1e6:>6.1f}MB") + print( + f" full_fwd_bwd={t_full:>6.2f}ms " + f"peak_delta={mem_peak / 1e6:>6.1f}MB" + ) print() diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py index 731b36645..d605b652d 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py @@ -566,30 +566,41 @@ def _scatter2scatter_lora_split( # 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes) output = scatter2scatter( - X=X, W=W, b=b, + X=X, + W=W, + b=b, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, - k=k, x_grouped=x_grouped, y_grouped=y_grouped, out=out, + k=k, + x_grouped=x_grouped, + y_grouped=y_grouped, + out=out, ) # 2. XA = X @ A^T (tiny: output is [M*k, R]) # Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter) W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous() XA = scatter2scatter( - X=X, W=W_A, + X=X, + W=W_A, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, - k=k, x_grouped=x_grouped, y_grouped=True, + k=k, + x_grouped=x_grouped, + y_grouped=True, ) # 3. Y_lora = XA @ B^T (R is tiny, so this is very fast) # Reshape B: [N, R*E] → [E, R, N] W_B = lora_B.T.reshape(E, R, N).contiguous() Y_lora = scatter2scatter( - X=XA, W=W_B, + X=XA, + W=W_B, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, - k=1, x_grouped=True, y_grouped=y_grouped, + k=1, + x_grouped=True, + y_grouped=y_grouped, ) # 4. Y = Y_base + scaling * Y_lora @@ -650,13 +661,20 @@ def scatter2scatter_lora( N = W.size(2) # Dispatch: split for few large experts, fused for many small experts - if ( - E <= _SPLIT_LORA_FWD_MAX_EXPERTS - and K * N >= _SPLIT_LORA_FWD_THRESHOLD - ): + if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD: return _scatter2scatter_lora_split( - X, W, sorted_expert_idxs, sorted_scattered_idxs, k, - lora_A, lora_B, scaling, b, x_grouped, y_grouped, out, + X, + W, + sorted_expert_idxs, + sorted_scattered_idxs, + k, + lora_A, + lora_B, + scaling, + b, + x_grouped, + y_grouped, + out, ) assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0) @@ -1443,7 +1461,6 @@ def _prune_split_configs(configs, named_args, **kwargs): """Prune split kernel configs based on SMEM capacity.""" smem_cap = _get_smem_capacity() block_r = named_args.get("BLOCK_R", 64) - inner_dim = named_args.get("INNER_DIM", 2048) # Fixed inner tile for reduction dimension BLOCK_INNER = 64 @@ -1470,33 +1487,47 @@ def _prune_split_configs(configs, named_args, **kwargs): key=["M", "K", "N"], prune_configs_by={"early_config_prune": _prune_split_configs}, ) -@triton.heuristics({ - "NO_DIM_MASK": lambda args: ( - (args["K"] % args["BLOCK_DIM"]) == 0 - if args["COMPUTE_DA"] - else (args["N"] % args["BLOCK_DIM"]) == 0 - ), -}) +@triton.heuristics( + { + "NO_DIM_MASK": lambda args: ( + (args["K"] % args["BLOCK_DIM"]) == 0 + if args["COMPUTE_DA"] + else (args["N"] % args["BLOCK_DIM"]) == 0 + ), + } +) @triton.jit def _group_bwd_lora_split( # Data tensors (DY and X are always present) - DY_ptr, stride_dym, stride_dyn, - X_ptr, stride_xm, stride_xk, + DY_ptr, + stride_dym, + stride_dyn, + X_ptr, + stride_xm, + stride_xk, # LoRA weight for the inner reduction (B for dA, A for dB) - LW_ptr, stride_lw0, stride_lw1, + LW_ptr, + stride_lw0, + stride_lw1, # Output gradient tensor (dA or dB) - OUT_ptr, stride_out0, stride_out1, + OUT_ptr, + stride_out0, + stride_out1, # Expert offsets expert_offsets_ptr, # Dimensions - M, K: tl.constexpr, N: tl.constexpr, - ACTUAL_R: tl.constexpr, BLOCK_R: tl.constexpr, + M, + K: tl.constexpr, + N: tl.constexpr, + ACTUAL_R: tl.constexpr, + BLOCK_R: tl.constexpr, INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB) scaling, # Mode flag COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB # Tile sizes - BLOCK_M: tl.constexpr, BLOCK_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DIM: tl.constexpr, ACC_TYPE: tl.constexpr, allow_tf32: tl.constexpr, NO_DIM_MASK: tl.constexpr, @@ -1532,9 +1563,9 @@ def _group_bwd_lora_split( # Output dimension tile (K for dA, N for dB) if COMPUTE_DA: - OUT_DIM: tl.constexpr = K + OUT_DIM: tl.constexpr = K # type: ignore[no-redef] else: - OUT_DIM: tl.constexpr = N + OUT_DIM: tl.constexpr = N # type: ignore[no-redef] dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM) dim_mask = dim_block < OUT_DIM R_block = tl.arange(0, BLOCK_R) @@ -1577,7 +1608,8 @@ def _group_bwd_lora_split( # Load X[M, K_block] (the "outer" tensor for dA) outer = tl.load( X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk, - mask=M_mask[:, None] & dim_mask[None, :], other=0.0 + mask=M_mask[:, None] & dim_mask[None, :], + other=0.0, ).to(INPUT_DTYPE) # Reduce DY[M, :] @ B[e][:, R] over N → [M, R] @@ -1588,23 +1620,34 @@ def _group_bwd_lora_split( inn_mask = inn_off < N dy_tile = tl.load( - DY_ptr + M_idx[:, None] * stride_dym + inn_off[None, :] * stride_dyn, - mask=M_mask[:, None] & inn_mask[None, :], other=0.0 + DY_ptr + + M_idx[:, None] * stride_dym + + inn_off[None, :] * stride_dyn, + mask=M_mask[:, None] & inn_mask[None, :], + other=0.0, ).to(INPUT_DTYPE) # B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride lw_tile = tl.load( - LW_ptr + inn_off[:, None] * stride_lw0 + (lora_offset + R_block)[None, :] * stride_lw1, - mask=inn_mask[:, None] & R_mask[None, :], other=0.0 + LW_ptr + + inn_off[:, None] * stride_lw0 + + (lora_offset + R_block)[None, :] * stride_lw1, + mask=inn_mask[:, None] & R_mask[None, :], + other=0.0, ).to(INPUT_DTYPE) reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32) # dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block] - acc += tl.dot(tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32) + acc += tl.dot( + tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32 + ) else: # Load DY[M, N_block] (the "outer" tensor for dB) outer = tl.load( - DY_ptr + M_idx[:, None] * stride_dym + dim_block[None, :] * stride_dyn, - mask=M_mask[:, None] & dim_mask[None, :], other=0.0 + DY_ptr + + M_idx[:, None] * stride_dym + + dim_block[None, :] * stride_dyn, + mask=M_mask[:, None] & dim_mask[None, :], + other=0.0, ).to(INPUT_DTYPE) # Reduce X[M, :] @ A[e][:, :].T over K → [M, R] @@ -1615,27 +1658,45 @@ def _group_bwd_lora_split( inn_mask = inn_off < K x_tile = tl.load( - X_ptr + M_idx[:, None] * stride_xm + inn_off[None, :] * stride_xk, - mask=M_mask[:, None] & inn_mask[None, :], other=0.0 + X_ptr + + M_idx[:, None] * stride_xm + + inn_off[None, :] * stride_xk, + mask=M_mask[:, None] & inn_mask[None, :], + other=0.0, ).to(INPUT_DTYPE) # A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride # We want A[e]^T: [K, R], so load as [K_inner, R] lw_tile = tl.load( - LW_ptr + (lora_offset + R_block)[None, :] * stride_lw0 + inn_off[:, None] * stride_lw1, - mask=inn_mask[:, None] & R_mask[None, :], other=0.0 + LW_ptr + + (lora_offset + R_block)[None, :] * stride_lw0 + + inn_off[:, None] * stride_lw1, + mask=inn_mask[:, None] & R_mask[None, :], + other=0.0, ).to(INPUT_DTYPE) reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32) # dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R] - acc += tl.dot(tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32) + acc += tl.dot( + tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32 + ) - tl.store(out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask) + tl.store( + out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask + ) else: # Zero out this expert's slice — needed because output uses empty_like if COMPUTE_DA: - tl.store(out_blk_ptrs, tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty), mask=out_mask) + tl.store( + out_blk_ptrs, + tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty), + mask=out_mask, + ) else: - tl.store(out_blk_ptrs, tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty), mask=out_mask) + tl.store( + out_blk_ptrs, + tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty), + mask=out_mask, + ) def group_bwd_lora( @@ -1683,34 +1744,58 @@ def group_bwd_lora( return (E, triton.cdiv(K, META["BLOCK_DIM"])) _group_bwd_lora_split[grid_dA]( - DY, DY.stride(0), DY.stride(1), - X, X.stride(0), X.stride(1), - lora_B, lora_B.stride(0), lora_B.stride(1), - dA, dA.stride(0), dA.stride(1), + DY, + DY.stride(0), + DY.stride(1), + X, + X.stride(0), + X.stride(1), + lora_B, + lora_B.stride(0), + lora_B.stride(1), + dA, + dA.stride(0), + dA.stride(1), expert_offsets, - M=DY.size(0), K=K, N=N, - ACTUAL_R=R, BLOCK_R=BLOCK_R, + M=DY.size(0), + K=K, + N=N, + ACTUAL_R=R, + BLOCK_R=BLOCK_R, INNER_DIM=N, scaling=scaling, COMPUTE_DA=True, - ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32, ) def grid_dB(META): return (E, triton.cdiv(N, META["BLOCK_DIM"])) _group_bwd_lora_split[grid_dB]( - DY, DY.stride(0), DY.stride(1), - X, X.stride(0), X.stride(1), - lora_A, lora_A.stride(0), lora_A.stride(1), - dB, dB.stride(0), dB.stride(1), + DY, + DY.stride(0), + DY.stride(1), + X, + X.stride(0), + X.stride(1), + lora_A, + lora_A.stride(0), + lora_A.stride(1), + dB, + dB.stride(0), + dB.stride(1), expert_offsets, - M=DY.size(0), K=K, N=N, - ACTUAL_R=R, BLOCK_R=BLOCK_R, + M=DY.size(0), + K=K, + N=N, + ACTUAL_R=R, + BLOCK_R=BLOCK_R, INNER_DIM=K, scaling=scaling, COMPUTE_DA=False, - ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32, + ACC_TYPE=tl.float32, + allow_tf32=ALLOW_TF32, ) return dA, dB diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py index 453c8c318..c6c01e255 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py @@ -511,20 +511,26 @@ class HFScatterMoEGatedMLP(nn.Module): active_experts = get_active_experts(sorted_expert_idxs, num_experts) remapped_expert_idxs, compact_offsets = remap_expert_indices( - sorted_expert_idxs, expert_offsets, active_experts, num_experts, + sorted_expert_idxs, + expert_offsets, + active_experts, + num_experts, ) - num_active = len(active_experts) - # Dequantize only active experts' weights gate_up_W = selective_expert_weights( - experts, "gate_up_proj", active_experts, + experts, + "gate_up_proj", + active_experts, ).transpose(2, 1) # [num_active, hidden, 2*inter] # Remap LoRA weights to match compact expert indices if gup_lora is not None: gup_A, gup_B, gup_scaling = gup_lora gup_A, gup_B = selective_lora_weights( - gup_A, gup_B, active_experts, num_experts, + gup_A, + gup_B, + active_experts, + num_experts, ) gup_lora = (gup_A, gup_B, gup_scaling) @@ -576,13 +582,18 @@ class HFScatterMoEGatedMLP(nn.Module): # ==================================================================== if use_selective: down_W = selective_expert_weights( - experts, "down_proj", active_experts, + experts, + "down_proj", + active_experts, ).transpose(2, 1) # [num_active, inter, hidden] if down_lora is not None: down_A, down_B, down_scaling = down_lora down_A, down_B = selective_lora_weights( - down_A, down_B, active_experts, num_experts, + down_A, + down_B, + active_experts, + num_experts, ) down_lora = (down_A, down_B, down_scaling) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py index 97910bcc3..1df8b2f68 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py @@ -21,8 +21,6 @@ from global (0..E-1) to compact (0..num_active-1) and pass the smaller weight tensor. """ -from typing import Optional - import torch import torch.nn as nn @@ -79,7 +77,7 @@ def _selective_dequant_bnb4( raw_param: torch.Tensor, quant_state, active_experts: torch.Tensor, - expert_shape: tuple[int, ...], + expert_shape: tuple[int, int], ) -> torch.Tensor: """Dequantize only selected experts from BnB 4-bit packed data. @@ -231,7 +229,9 @@ def selective_expert_weights( if E_total is None: E_total = int(active_experts.max().item()) + 1 expert_numel = orig_shape[0] // E_total - d2 = getattr(experts_module, "hidden_dim", None) or getattr(experts_module, "intermediate_dim", None) + d2 = getattr(experts_module, "hidden_dim", None) or getattr( + experts_module, "intermediate_dim", None + ) if d2 and expert_numel % d2 == 0: expert_shape = (expert_numel // d2, d2) else: @@ -241,9 +241,7 @@ def selective_expert_weights( full = getattr(experts_module, param_name) return full[active_experts] - return _selective_dequant_bnb4( - raw_param, qs, active_experts, expert_shape - ) + return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape) # Dense parameter (bf16/fp32) — direct indexing param = getattr(experts_module, param_name) diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant_kernel.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant_kernel.py index e29799c05..aa9f0278a 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant_kernel.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant_kernel.py @@ -19,15 +19,25 @@ import torch import triton import triton.language as tl - # NF4 codebook (16 values, precomputed by BnB) # These are the normalized float4 reconstruction values NF4_CODEBOOK = [ - -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, - -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, - 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, - 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, - 0.7229568362236023, 1.0, + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, ] @@ -46,8 +56,8 @@ def _selective_dequant_nf4_kernel( stride_out_e, # stride for expert dim in output # Dimensions num_active, - packed_per_expert, # expert_numel // 2 - blocks_per_expert, # expert_numel // blocksize + packed_per_expert, # expert_numel // 2 + blocks_per_expert, # expert_numel // blocksize blocksize: tl.constexpr, # Tile size BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2) @@ -79,7 +89,9 @@ def _selective_dequant_nf4_kernel( # Read packed bytes from the global expert's region packed_global_offset = expert_global * packed_per_expert + byte_idx - packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(tl.int32) + packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to( + tl.int32 + ) # Extract 4-bit nibble # BnB packing: high nibble = even element, low nibble = odd element @@ -133,8 +145,9 @@ def selective_dequant_nf4_triton( # Prepare codebook on device if codebook is None: - codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32, - device=packed_data.device) + codebook = torch.tensor( + NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device + ) else: codebook = codebook.to(device=packed_data.device, dtype=torch.float32) @@ -143,8 +156,7 @@ def selective_dequant_nf4_triton( absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32 # Output buffer - out = torch.empty(num_active, expert_numel, dtype=dtype, - device=packed_data.device) + out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device) BLOCK_SIZE = 1024 # Process 1024 elements per thread block diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py index c3b0360ac..939bdb790 100644 --- a/src/axolotl/integrations/kernels/plugin.py +++ b/src/axolotl/integrations/kernels/plugin.py @@ -66,7 +66,10 @@ class KernelsPlugin(BasePlugin): # Prefer text backbone type for VLMs, but fall back to base type # when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text) moe_model_type = cfg.model_config_type_text or cfg.model_config_type - if moe_model_type not in SPARSE_MOE_BLOCK and cfg.model_config_type in SPARSE_MOE_BLOCK: + if ( + moe_model_type not in SPARSE_MOE_BLOCK + and cfg.model_config_type in SPARSE_MOE_BLOCK + ): moe_model_type = cfg.model_config_type if cfg.use_scattermoe: diff --git a/src/axolotl/utils/callbacks/profiler.py b/src/axolotl/utils/callbacks/profiler.py index 2598eac94..49e8c5388 100644 --- a/src/axolotl/utils/callbacks/profiler.py +++ b/src/axolotl/utils/callbacks/profiler.py @@ -28,9 +28,7 @@ class PytorchProfilerCallback(TrainerCallback): if profiler_steps_start == 0: # start recording memory allocations before everything is allocated, because if we start # at the beginning of step 0, we won't have any memory allocations in the traces - torch.cuda.memory._record_memory_history( - enabled="all", stacks="all" - ) + torch.cuda.memory._record_memory_history(enabled="all", stacks="all") profiler_steps_start = -1 self.profiler_steps_start = profiler_steps_start self._profiler = None @@ -43,13 +41,11 @@ class PytorchProfilerCallback(TrainerCallback): **kwargs, ): if state.global_step == self.profiler_steps_start: - torch.cuda.memory._record_memory_history( - enabled="all", stacks="all" - ) + torch.cuda.memory._record_memory_history(enabled="all", stacks="all") # Start torch.profiler on the first profiled step if state.global_step == max(self.profiler_steps_start, 0): - self._profiler = torch.profiler.profile( + profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, @@ -58,7 +54,8 @@ class PytorchProfilerCallback(TrainerCallback): profile_memory=True, with_stack=True, ) - self._profiler.__enter__() + profiler.__enter__() + self._profiler = profiler def on_step_end( self, diff --git a/tests/integrations/test_scattermoe_lora_kernels.py b/tests/integrations/test_scattermoe_lora_kernels.py index 708bf6e56..fc783fa1d 100644 --- a/tests/integrations/test_scattermoe_lora_kernels.py +++ b/tests/integrations/test_scattermoe_lora_kernels.py @@ -19,8 +19,8 @@ import pytest import torch from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import ( - ops as base_ops, lora_ops, + ops as base_ops, ) from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import ( flatten_sort_count, @@ -151,8 +151,14 @@ class TestScatter2ScatterLoRAForward: x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) kernel_out = lora_ops.scatter2scatter_lora( - X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, - k=k, lora_A=lA, lora_B=lB, scaling=SCALING, + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=k, + lora_A=lA, + lora_B=lB, + scaling=SCALING, ) ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E) @@ -164,8 +170,14 @@ class TestScatter2ScatterLoRAForward: x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) out = lora_ops.scatter2scatter_lora( - X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, - k=k, lora_A=lA, lora_B=lB, scaling=SCALING, + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=k, + lora_A=lA, + lora_B=lB, + scaling=SCALING, ) assert out.shape == (T * k, N) assert out.dtype == DTYPE @@ -188,9 +200,16 @@ class TestScatter2ScatterLoRADX: dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE) kernel_dx = lora_ops.scatter2scatter_lora_dX( - DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, - k=1, lora_A=lA, lora_B=lB, scaling=SCALING, - dy_grouped=True, dx_grouped=False, + DY=dy, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=1, + lora_A=lA, + lora_B=lB, + scaling=SCALING, + dy_grouped=True, + dx_grouped=False, ) ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E) @@ -215,8 +234,13 @@ class TestGroupBwdLoRA: dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE) kern_dA, kern_dB = lora_ops.group_bwd_lora( - DY=dy, X=gx, lora_A=lA, lora_B=lB, - expert_offsets=eo, E=E, scaling=SCALING, + DY=dy, + X=gx, + lora_A=lA, + lora_B=lB, + expert_offsets=eo, + E=E, + scaling=SCALING, ) ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING) @@ -225,12 +249,10 @@ class TestGroupBwdLoRA: # fp32 loop), so max absolute error can be large on individual elements # while the overall tensor is correct. dA_norm_err = ( - (kern_dA.float() - ref_dA.float()).norm() - / (ref_dA.float().norm() + 1e-6) + (kern_dA.float() - ref_dA.float()).norm() / (ref_dA.float().norm() + 1e-6) ).item() dB_norm_err = ( - (kern_dB.float() - ref_dB.float()).norm() - / (ref_dB.float().norm() + 1e-6) + (kern_dB.float() - ref_dB.float()).norm() / (ref_dB.float().norm() + 1e-6) ).item() assert dA_norm_err < 0.01, f"[{desc}] dA norm_rel_err={dA_norm_err}" assert dB_norm_err < 0.01, f"[{desc}] dB norm_rel_err={dB_norm_err}" @@ -249,14 +271,21 @@ class TestGroupBwdLoRA: lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) dA, dB = lora_ops.group_bwd_lora( - DY=dy, X=gx, lora_A=lA, lora_B=lB, - expert_offsets=eo, E=E, scaling=2.0, + DY=dy, + X=gx, + lora_A=lA, + lora_B=lB, + expert_offsets=eo, + E=E, + scaling=2.0, ) # Experts 1..7 should have zero gradients for e in range(1, E): assert dA[e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dA not zero" - assert dB[:, e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dB not zero" + assert dB[:, e * R : (e + 1) * R].abs().max() == 0, ( + f"Expert {e} dB not zero" + ) # ─── Full autograd tests ──────────────────────────────────────────────────── @@ -278,9 +307,21 @@ class TestScatterMoELoRAAutograd: lB = lB.requires_grad_(True) out = ScatterMoELoRA.apply( - x, W, k, sei, ssi, eo, - lA, lB, SCALING, - None, None, False, False, True, False, + x, + W, + k, + sei, + ssi, + eo, + lA, + lB, + SCALING, + None, + None, + False, + False, + True, + False, ) out.sum().backward() @@ -293,7 +334,6 @@ class TestScatterMoELoRAAutograd: assert x.grad.abs().sum() > 0, f"[{desc}] x.grad all zero" assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero" - def test_split_matches_fused(self): """Split dispatch (for few large experts) matches fused kernel.""" # Use a shape where split would be dispatched (large K*N, few E) @@ -304,15 +344,27 @@ class TestScatterMoELoRAAutograd: orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18 out_fused = lora_ops.scatter2scatter_lora( - X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, - k=k, lora_A=lA, lora_B=lB, scaling=SCALING, + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=k, + lora_A=lA, + lora_B=lB, + scaling=SCALING, ) # Force split path lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0 out_split = lora_ops.scatter2scatter_lora( - X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, - k=k, lora_A=lA, lora_B=lB, scaling=SCALING, + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=k, + lora_A=lA, + lora_B=lB, + scaling=SCALING, ) lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig @@ -328,12 +380,28 @@ class TestScatterMoELoRAAutograd: x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R) out_lora = ScatterMoELoRA.apply( - x, W, k, sei, ssi, eo, - lA, lB, 0.0, - None, None, False, False, True, False, + x, + W, + k, + sei, + ssi, + eo, + lA, + lB, + 0.0, + None, + None, + False, + False, + True, + False, ) out_base = base_ops.scatter2scatter( - X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k, + X=x, + W=W, + sorted_expert_idxs=sei, + sorted_scattered_idxs=ssi, + k=k, ) err = (out_lora.float() - out_base.float()).abs().max().item() assert err < 0.01, f"scaling=0 should match base: err={err}"