From b3289fd190c99cf08a35d2d71efffe92e0e7c440 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 22 Mar 2026 13:53:19 -0400 Subject: [PATCH] feat: LoRA kernel support for bias, dropout, dora, embeddings (#3528) [skip ci] * feat: LoRA kernel support for bias, dropout, dora, embeddings * chore: lint * chore: lint * address PR feedback, add regression tests, add fsdp2 tests for lora kernels * update tests for new sigs * update tests now that bias and dropout are supported --- src/axolotl/kernels/dora.py | 147 ++ src/axolotl/kernels/lora.py | 1570 +++++++++++++---- src/axolotl/kernels/quantize.py | 4 + src/axolotl/monkeypatch/lora_kernels.py | 75 +- src/axolotl/utils/schemas/config.py | 21 +- src/axolotl/utils/schemas/validation.py | 10 +- tests/core/test_async_grpo.py | 4 +- tests/e2e/kernels/test_lora.py | 6 +- tests/e2e/kernels/test_lora_features.py | 1245 +++++++++++++ tests/e2e/multigpu/test_fsdp2_lora_kernels.py | 120 ++ .../lora_kernels/test_lora_kernel_patching.py | 38 +- .../utils/lora/test_config_validation_lora.py | 30 +- tests/utils/lora/test_freeze_lora.py | 25 +- 13 files changed, 2847 insertions(+), 448 deletions(-) create mode 100644 src/axolotl/kernels/dora.py create mode 100644 tests/e2e/kernels/test_lora_features.py create mode 100644 tests/e2e/multigpu/test_fsdp2_lora_kernels.py diff --git a/src/axolotl/kernels/dora.py b/src/axolotl/kernels/dora.py new file mode 100644 index 000000000..3ed35cf74 --- /dev/null +++ b/src/axolotl/kernels/dora.py @@ -0,0 +1,147 @@ +""" +Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation). + +Fuses the weight norm computation and magnitude scaling to avoid +materializing the full [out_features, in_features] combined weight matrix. +The B@A product is computed row-by-row inside the kernel. +""" + +import torch +import triton +import triton.language as tl + +from .quantize import dequantize + + +@triton.jit +def _dora_fused_norm_kernel( + # Pointers + W_ptr, # base weight [out, in] (dequantized, row-major) + B_ptr, # LoRA B [out, rank] (row-major) + A_ptr, # LoRA A [rank, in] (row-major) + mag_ptr, # magnitude vector [out] + out_ptr, # output mag_norm_scale [out] + # Shapes + out_features, + in_features, + rank, + # Scaling + lora_scale, # float scaling factor + # Block sizes + BLOCK_IN: tl.constexpr, + BLOCK_R: tl.constexpr, # >= rank, power of 2 +): + """Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s * (B[i,:] @ A)[:] ||_2 + + Each program handles one output row. B[row,:] is loaded once (small), + then we tile over in_features computing the dot product with A[:,tile] + and accumulating the squared norm. + + This avoids materializing the full [out, in] B@A matrix. + """ + row = tl.program_id(0) + if row >= out_features: + return + + # Accumulate squared norm across tiles of in_features + norm_sq_acc = tl.zeros([BLOCK_IN], dtype=tl.float32) + + for start in range(0, in_features, BLOCK_IN): + cols = start + tl.arange(0, BLOCK_IN) + col_mask = cols < in_features + + # Load W[row, cols] + w_vals = tl.load( + W_ptr + row * in_features + cols, + mask=col_mask, + other=0.0, + ).to(tl.float32) + + # Compute (B[row,:] @ A[:, cols]) for this tile + # Load B[row, r] as scalar and A[r, cols] as vector for each r + ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32) + for r in tl.static_range(BLOCK_R): + # Load scalar B[row, r] + b_val = tl.load( + B_ptr + row * rank + r, + mask=(r < rank), + other=0.0, + ).to(tl.float32) + # Load vector A[r, cols] + a_vals = tl.load( + A_ptr + r * in_features + cols, + mask=(col_mask & (r < rank)), + other=0.0, + ).to(tl.float32) + ba_vals += b_val * a_vals + + # Combined: W + s * (B @ A) + combined = w_vals + lora_scale * ba_vals + + # Accumulate squared values + norm_sq_acc += tl.where(col_mask, combined * combined, 0.0) + + # Reduce to scalar norm + norm_sq = tl.sum(norm_sq_acc, axis=0) + norm = tl.sqrt(norm_sq + 1e-12) # epsilon for numerical stability + + # Load magnitude and compute scale + mag = tl.load(mag_ptr + row).to(tl.float32) + scale = mag / norm + + tl.store(out_ptr + row, scale) + + +def triton_dora_scale( + W: torch.Tensor, + W_quant, + A: torch.Tensor, + B: torch.Tensor, + s: float, + magnitude: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + """Compute DoRA mag_norm_scale using fused Triton kernel. + + Computes B@A row-by-row inside the kernel, avoiding the full + [out_features, in_features] materialization. + + Args: + W: base weight [out, in] (possibly quantized) + W_quant: quantization state + A: LoRA A [rank, in] + B: LoRA B [out, rank] + s: LoRA scaling factor + magnitude: learned magnitude [out] + dtype: compute dtype + + Returns: + mag_norm_scale: [out] tensor = magnitude / ||W + s * B @ A||_2 + """ + # Dequantize W to [out, in] + W_full = dequantize(W.t(), W_quant).t().contiguous().to(dtype) + + out_features, in_features = W_full.shape + rank = A.shape[0] + + out = torch.empty(out_features, dtype=dtype, device=W.device) + + # Block sizes + BLOCK_IN = triton.next_power_of_2(min(in_features, 2048)) + BLOCK_R = triton.next_power_of_2(rank) + + _dora_fused_norm_kernel[(out_features,)]( + W_full, + B.contiguous().to(dtype), + A.contiguous().to(dtype), + magnitude.contiguous(), + out, + out_features=out_features, + in_features=in_features, + rank=rank, + lora_scale=s, + BLOCK_IN=BLOCK_IN, + BLOCK_R=BLOCK_R, + ) + + return out.detach() diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py index 9dc66a918..1576a10cd 100644 --- a/src/axolotl/kernels/lora.py +++ b/src/axolotl/kernels/lora.py @@ -4,6 +4,9 @@ Module for definition of Low-Rank Adaptation (LoRA) Triton kernels. See "LoRA: Low-Rank Adaptation of Large Language Models" (https://arxiv.org/abs/2106.09685). +Also supports DoRA (Weight-Decomposed Low-Rank Adaptation): +See "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). + Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. """ @@ -29,6 +32,9 @@ def get_lora_parameters( torch.Tensor | None, torch.Tensor | None, float | None, + torch.Tensor | None, + nn.Module | None, + torch.Tensor | None, ]: """ Gets LoRA parameters from a projection module. @@ -37,9 +43,16 @@ def get_lora_parameters( proj: The projection module to extract parameters from. Returns: - A tuple containing the base weights, quantization state, LoRA A and B weights, - scaling factor, and base layer bias. Quant state, weights, and bias may be - `None` if not available. + A tuple containing: + - W: base weight tensor + - b: base layer bias (or None) + - quant_state: quantization state (or None) + - A: LoRA A weight (or None) + - B: LoRA B weight (or None) + - s: LoRA scaling factor (or None) + - lora_bias: LoRA B bias (or None) + - dropout: dropout module (or None) + - magnitude: DoRA magnitude vector (or None) """ # For DPO or disabled adapters base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj @@ -50,7 +63,7 @@ def get_lora_parameters( quant_state = getattr(W, "quant_state", None) if quant_state is None and W.dtype == torch.float8_e4m3fn: quant_state = getattr(base_layer, "weight_scale_inv", None) - return W, b, quant_state, None, None, None + return W, b, quant_state, None, None, None, None, None, None quant_state = getattr(W, "quant_state", None) if quant_state is None and W.dtype == torch.float8_e4m3fn: @@ -78,7 +91,136 @@ def get_lora_parameters( B = linear_B.weight s = proj.scaling[active_adapter] - return W, b, quant_state, A, B, s + # LoRA bias from lora_B (when bias="lora_only" or bias="all") + lora_bias = linear_B.bias # None if bias=False + + # Dropout module + dropout = None + if hasattr(proj, "lora_dropout") and active_adapter in proj.lora_dropout: + dropout = proj.lora_dropout[active_adapter] + + # DoRA magnitude vector + magnitude = None + if ( + hasattr(proj, "lora_magnitude_vector") + and proj.lora_magnitude_vector + and active_adapter in proj.lora_magnitude_vector + ): + mag_layer = proj.lora_magnitude_vector[active_adapter] + magnitude = mag_layer.weight + # FSDP2 DTensor unshard for magnitude vector + if isinstance(magnitude, DTensor): + magnitude = magnitude.full_tensor() + + return W, b, quant_state, A, B, s, lora_bias, dropout, magnitude + + +def _apply_dropout( + dropout: nn.Module | None, X: torch.Tensor, training: bool +) -> torch.Tensor | None: + """Apply dropout to X if dropout module exists and is active. + + Returns X_drop (different tensor) or None if no dropout needed. + """ + if dropout is None or isinstance(dropout, nn.Identity) or not training: + return None + return dropout(X) + + +_USE_TRITON_DORA: bool | None = None + + +def _should_use_triton_dora() -> bool: + """Check if Triton DoRA kernel is available.""" + global _USE_TRITON_DORA + if _USE_TRITON_DORA is None: + try: + from .dora import triton_dora_scale # noqa: F401 + + _USE_TRITON_DORA = True + except (ImportError, RuntimeError): + _USE_TRITON_DORA = False + return _USE_TRITON_DORA + + +def _compute_dora_scale( + W: torch.Tensor, + W_quant: QuantState | torch.Tensor | None, + A: torch.Tensor, + B: torch.Tensor, + s: float, + magnitude: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + """Compute DoRA magnitude/norm scaling factor with optional caching. + + Uses Triton kernel when available for better performance. + Caches weight_norm on the magnitude tensor. Cache invalidated when + LoRA A/B data changes (after optimizer step). + + Returns: + mag_norm_scale: [out_features] tensor = magnitude / ||W + s * B @ A||_2 + """ + # Check cache on magnitude tensor (avoids expensive norm recomputation) + # Use tensor._version which increments on any in-place modification + # (data_ptr doesn't change when optimizers update params in-place) + cache = getattr(magnitude, "_dora_cache", None) + if cache is not None: + cached_a_ver, cached_b_ver, cached_norm = cache + if cached_a_ver == A._version and cached_b_ver == B._version: + return magnitude.to(dtype) / cached_norm + + # Full recomputation - try Triton first + if _should_use_triton_dora() and W.is_cuda: + from .dora import triton_dora_scale + + result = triton_dora_scale(W, W_quant, A, B, s, magnitude, dtype) + weight_norm = (magnitude.to(dtype) / result).detach() + magnitude._dora_cache = (A._version, B._version, weight_norm) + return result + + # PyTorch fallback + W_full = dequantize(W.t(), W_quant).t().to(dtype) # [out, in] + lora_weight = B.to(dtype) @ A.to(dtype) + combined = W_full + s * lora_weight + weight_norm = torch.linalg.norm(combined, dim=1).to(dtype) + weight_norm = weight_norm.detach() + + magnitude._dora_cache = (A._version, B._version, weight_norm) + + return magnitude.to(dtype) / weight_norm + + +def _compute_dora_scale_cached( + proj: nn.Module, + W: torch.Tensor, + W_quant: QuantState | torch.Tensor | None, + A: torch.Tensor, + B: torch.Tensor, + s: float, + magnitude: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + """Compute DoRA scale with caching. Recomputes only when LoRA params change. + + Caches the weight norm on the projection module. The cache is invalidated + when LoRA A/B data pointers change (indicating an optimizer step occurred). + """ + cache = getattr(proj, "_dora_norm_cache", None) + if cache is not None: + cached_a_ver, cached_b_ver, cached_norm = cache + if cached_a_ver == A._version and cached_b_ver == B._version: + return magnitude.to(dtype) / cached_norm + + # Cache miss - full recomputation + W_full = dequantize(W.t(), W_quant).t().to(dtype) + lora_weight = B.to(dtype) @ A.to(dtype) + combined = W_full + s * lora_weight + weight_norm = torch.linalg.norm(combined, dim=1).to(dtype).detach() + + proj._dora_norm_cache = (A._version, B._version, weight_norm) + + return magnitude.to(dtype) / weight_norm def matmul_lora( @@ -90,6 +232,8 @@ def matmul_lora( B: torch.Tensor | None, s: float | None, out: torch.Tensor | None = None, + X_drop: torch.Tensor | None = None, + lora_bias: torch.Tensor | None = None, ) -> torch.Tensor: """ Efficient fused matmul + LoRA computation. @@ -102,9 +246,11 @@ def matmul_lora( B: LoRA B matrix [out_features, rank] s: LoRA scaling factor out: Optional output tensor for inplace operations + X_drop: Optional dropout-applied input for LoRA path (if None, uses X) + lora_bias: Optional LoRA B layer bias [out_features] Returns: - Result of X @ W + X @ A @ B + Result of X @ W + s * X_drop @ A @ B + b + s * lora_bias """ dtype = X.dtype W = dequantize(W.t(), W_quant) @@ -113,6 +259,8 @@ def matmul_lora( if X.dim() == 3: batch, seq_len, _ = X.shape X = X.view(-1, X.shape[-1]) + if X_drop is not None: + X_drop = X_drop.view(-1, X_drop.shape[-1]) reshape = True out = torch.matmul(X, W, out=out) @@ -120,8 +268,11 @@ def matmul_lora( del W if A is not None: + X_lora = X_drop if X_drop is not None else X A, B = A.t().to(dtype), B.t().to(dtype) # type: ignore[union-attr] - out += s * X @ A @ B + out += s * X_lora @ A @ B + if lora_bias is not None: + out += s * lora_bias if b is not None: out += b @@ -130,87 +281,200 @@ def matmul_lora( class LoRA_MLP(torch.autograd.Function): - """Optimized LoRA MLP implementation.""" + """Optimized LoRA MLP implementation. + + Supports bias, dropout, and DoRA. Dropout is applied to the input for + gate/up projections. The down projection uses hidden states (post-activation) + as input, so dropout is not applied there. + """ @staticmethod @torch_amp_custom_fwd def forward( ctx, X: torch.Tensor, + X_drop: torch.Tensor | None, + # Gate params gate_weight: torch.Tensor, gate_bias: torch.Tensor | None, gate_quant: QuantState | None, gate_A: torch.Tensor | None, gate_B: torch.Tensor | None, gate_scale: float, + gate_lora_bias: torch.Tensor | None, + gate_magnitude: torch.Tensor | None, + # Up params up_weight: torch.Tensor, up_bias: torch.Tensor | None, up_quant: QuantState | None, up_A: torch.Tensor | None, up_B: torch.Tensor | None, up_scale: float, + up_lora_bias: torch.Tensor | None, + up_magnitude: torch.Tensor | None, + # Down params down_weight: torch.Tensor, down_bias: torch.Tensor | None, down_quant: QuantState | None, down_A: torch.Tensor | None, down_B: torch.Tensor | None, down_scale: float, + down_lora_bias: torch.Tensor | None, + down_magnitude: torch.Tensor | None, + # Activation and flags activation_fn: Callable, activation_fn_backward: Callable, inplace: bool | None = True, ) -> torch.Tensor: - """ - Forward pass for LoRA MLP. + has_dropout = X_drop is not None + has_dora = gate_magnitude is not None + dtype = X.dtype + X_lora = X_drop if has_dropout else X - Args: - ctx: Autograd context - X: Input features - gate_weight: Gate projection weight - gate_bias: Gate projection bias - gate_quant: Gate quantization state - gate_A: Gate LoRA A matrix - gate_B: Gate LoRA B matrix - gate_scale: Gate LoRA scale - up_weight: Up projection weight - up_quant: Up projection quantization state - up_A: Up projection LoRA A matrix - up_B: Up projection LoRA B matrix - up_scale: Up projection LoRA scale - down_weight: Down projection weight - down_bias: Down projection bias - down_quant: Down projection quantization state - down_A: Down projection LoRA A matrix - down_B: Down projection LoRA B matrix - down_scale: Down projection LoRA scale - activation_fn: Forward activation function - activation_fn_backward: Backward activation function - inplace: Whether to perform operations in-place + if has_dora: + # Gate with DoRA + gate_base = matmul_lora(X, gate_weight, None, gate_quant, None, None, None) + gate_lora = _lora_only( + X_lora, gate_A, gate_B, gate_scale, gate_lora_bias, dtype + ) + gate_mag_scale = _compute_dora_scale( + gate_weight, + gate_quant, + gate_A, + gate_B, + gate_scale, + gate_magnitude, + dtype, + ) + gate = gate_mag_scale.unsqueeze(0) * (gate_base + gate_lora) + if gate_bias is not None: + gate = gate + gate_bias - Returns: - Output transformed by multi-layer perceptron and activation function - """ - # Compute projections - gate = matmul_lora( - X, gate_weight, gate_bias, gate_quant, gate_A, gate_B, gate_scale - ) - up = matmul_lora(X, up_weight, up_bias, up_quant, up_A, up_B, up_scale) + # Up with DoRA + up_base = matmul_lora(X, up_weight, None, up_quant, None, None, None) + up_lora = _lora_only(X_lora, up_A, up_B, up_scale, up_lora_bias, dtype) + up_mag_scale = _compute_dora_scale( + up_weight, up_quant, up_A, up_B, up_scale, up_magnitude, dtype + ) + up = up_mag_scale.unsqueeze(0) * (up_base + up_lora) + if up_bias is not None: + up = up + up_bias + + gate_combined = gate_base + gate_lora + up_combined = up_base + up_lora + else: + gate = matmul_lora( + X, + gate_weight, + gate_bias, + gate_quant, + gate_A, + gate_B, + gate_scale, + X_drop=X_drop, + lora_bias=gate_lora_bias, + ) + up = matmul_lora( + X, + up_weight, + up_bias, + up_quant, + up_A, + up_B, + up_scale, + X_drop=X_drop, + lora_bias=up_lora_bias, + ) # Activation hidden = activation_fn(gate, up) - # Down projection - output = matmul_lora( - hidden, down_weight, down_bias, down_quant, down_A, down_B, down_scale - ) + # Down projection (no dropout on hidden - it's an intermediate) + if has_dora: + down_base = matmul_lora( + hidden, down_weight, None, down_quant, None, None, None + ) + down_lora = _lora_only( + hidden, down_A, down_B, down_scale, down_lora_bias, dtype + ) + down_mag_scale = _compute_dora_scale( + down_weight, + down_quant, + down_A, + down_B, + down_scale, + down_magnitude, + dtype, + ) + down_combined = down_base + down_lora + output = down_mag_scale.unsqueeze(0) * down_combined + if down_bias is not None: + output = output + down_bias + else: + output = matmul_lora( + hidden, + down_weight, + down_bias, + down_quant, + down_A, + down_B, + down_scale, + lora_bias=down_lora_bias, + ) # Save for backward - ctx.save_for_backward(X, gate, up, gate_A, gate_B, up_A, up_B, down_A, down_B) + if has_dora: + ctx.save_for_backward( + X, + X_drop if has_dropout else X, + gate, + up, + gate_A.to(dtype) if gate_A is not None else gate_A, + gate_B.to(dtype) if gate_B is not None else gate_B, + up_A.to(dtype) if up_A is not None else up_A, + up_B.to(dtype) if up_B is not None else up_B, + down_A.to(dtype) if down_A is not None else down_A, + down_B.to(dtype) if down_B is not None else down_B, + gate_magnitude, + up_magnitude, + down_magnitude, + gate_mag_scale, + up_mag_scale, + down_mag_scale, + gate_combined, + up_combined, + down_combined, + gate_lora_bias, + up_lora_bias, + down_lora_bias, + ) + else: + # Pre-convert LoRA matrices to compute dtype for backward + dtype = X.dtype + ctx.save_for_backward( + X, + X_drop if has_dropout else X, + gate, + up, + gate_A.to(dtype) if gate_A is not None else gate_A, + gate_B.to(dtype) if gate_B is not None else gate_B, + up_A.to(dtype) if up_A is not None else up_A, + up_B.to(dtype) if up_B is not None else up_B, + down_A.to(dtype) if down_A is not None else down_A, + down_B.to(dtype) if down_B is not None else down_B, + gate_lora_bias, + up_lora_bias, + down_lora_bias, + ) + ctx.scales = (gate_scale, up_scale, down_scale) ctx.quants = (gate_quant, up_quant, down_quant) ctx.weights = (gate_weight, up_weight, down_weight) ctx.activation_fn = activation_fn ctx.activation_fn_backward = activation_fn_backward ctx.inplace = inplace + ctx.has_dropout = has_dropout + ctx.has_dora = has_dora return output @@ -219,171 +483,225 @@ class LoRA_MLP(torch.autograd.Function): def backward( ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor, - ) -> tuple[ - torch.Tensor | None, - None, - None, - None, - torch.Tensor | None, - torch.Tensor | None, - None, - None, - None, - None, - torch.Tensor | None, - torch.Tensor | None, - None, - None, - None, - None, - torch.Tensor | None, - torch.Tensor | None, - None, - None, - None, - None, - None, - ]: - """ - Performs backward pass computation for LoRA MLP. - - Args: - ctx: Context object storing tensors saved during forward pass - grad_output: Gradient of loss with respect to layer output - - Returns: - Tuple containing gradients for all inputs from forward pass: - - Input gradient tensor (or `None`) - - `None` for weights/biases/quantization states - - LoRA A/B matrix gradients (or `None`) - - `None` for scaling factors - - `None` for activation functions and flags - """ - ( - X, - gate, - up, - gate_A, - gate_B, - up_A, - up_B, - down_A, - down_B, - ) = ctx.saved_tensors + ): gate_scale, up_scale, down_scale = ctx.scales gate_quant, up_quant, down_quant = ctx.quants gate_weight, up_weight, down_weight = ctx.weights + has_dropout = ctx.has_dropout + has_dora = ctx.has_dora + + if has_dora: + ( + X, + X_lora, + gate, + up, + gate_A, + gate_B, + up_A, + up_B, + down_A, + down_B, + gate_magnitude, + up_magnitude, + down_magnitude, + gate_mag_scale, + up_mag_scale, + down_mag_scale, + gate_combined, + up_combined, + down_combined, + gate_lora_bias, + up_lora_bias, + down_lora_bias, + ) = ctx.saved_tensors + else: + ( + X, + X_lora, + gate, + up, + gate_A, + gate_B, + up_A, + up_B, + down_A, + down_B, + gate_lora_bias, + up_lora_bias, + down_lora_bias, + ) = ctx.saved_tensors + gate_magnitude = up_magnitude = down_magnitude = None + gate_mag_scale = up_mag_scale = down_mag_scale = None + gate_combined = up_combined = down_combined = None # Transpose all LoRA matrices - gate_A, gate_B = ( - gate_A.t() if gate_A is not None else None, - gate_B.t() if gate_B is not None else None, - ) - up_A, up_B = ( - up_A.t() if up_A is not None else None, - up_B.t() if up_B is not None else None, - ) - down_A, down_B = ( - down_A.t() if down_A is not None else None, - down_B.t() if down_B is not None else None, - ) + gate_A_t = gate_A.t() if gate_A is not None else None + gate_B_t = gate_B.t() if gate_B is not None else None + up_A_t = up_A.t() if up_A is not None else None + up_B_t = up_B.t() if up_B is not None else None + down_A_t = down_A.t() if down_A is not None else None + down_B_t = down_B.t() if down_B is not None else None # Reshape inputs batch, seq_len, hd = X.shape grad_output = grad_output.view(-1, grad_output.shape[-1]) X = X.view(-1, X.shape[-1]) + X_lora = X_lora.view(-1, X_lora.shape[-1]) gate = gate.view(-1, gate.shape[-1]) up = up.view(-1, up.shape[-1]) - dtype = X.dtype - # Down projection + # DoRA magnitude gradients for down projection + d_gate_mag = d_up_mag = d_down_mag = None + d_gate_lora_bias = d_up_lora_bias = d_down_lora_bias = None + + if has_dora: + down_combined_flat = down_combined.view(-1, down_combined.shape[-1]) + d_down_mag = ( + (grad_output * down_combined_flat).sum(dim=0) + * down_mag_scale + / down_magnitude + ) + grad_output = grad_output * down_mag_scale.unsqueeze(0) + + # Down lora bias gradient + if down_lora_bias is not None: + d_down_lora_bias = down_scale * grad_output.sum(dim=0) + + # Down projection backward grad_down = matmul_lora( grad_output, down_weight.t(), None, down_quant, - down_B, - down_A, + down_B_t, + down_A_t, down_scale, ) # Activation backward h, grad_gate, grad_up = ctx.activation_fn_backward(grad_down, gate, up) - # Initialize and compute LoRA gradients + # DoRA magnitude gradients for gate and up + if has_dora: + gate_combined_flat = gate_combined.view(-1, gate_combined.shape[-1]) + up_combined_flat = up_combined.view(-1, up_combined.shape[-1]) + d_gate_mag = ( + (grad_gate * gate_combined_flat).sum(dim=0) + * gate_mag_scale + / gate_magnitude + ) + d_up_mag = ( + (grad_up * up_combined_flat).sum(dim=0) * up_mag_scale / up_magnitude + ) + grad_gate = grad_gate * gate_mag_scale.unsqueeze(0) + grad_up = grad_up * up_mag_scale.unsqueeze(0) + + # LoRA bias gradients for gate and up + if gate_lora_bias is not None: + d_gate_lora_bias = gate_scale * grad_gate.sum(dim=0) + if up_lora_bias is not None: + d_up_lora_bias = up_scale * grad_up.sum(dim=0) + + # LoRA parameter gradients (already in compute dtype from forward) + # Compute grad @ B once per projection, reuse for dA and dX_lora + # Note: _t suffix means transposed from saved shape (A_t = A.t(), etc.) d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None + grad_B_up = grad_B_gate = None - if down_A is not None and down_B is not None: - d_down_A = h.t() @ (grad_output @ down_B.t()) - d_down_B = (down_A.t() @ h.t()) @ grad_output - d_down_A *= down_scale - d_down_B *= down_scale + if down_A_t is not None and down_B_t is not None: + grad_B_down = grad_output @ down_B_t.t() # reused in matmul_lora above too + d_down_A = torch.empty_like(down_A_t) + d_down_B = torch.empty_like(down_B_t) + d_down_A.addmm_(h.t(), grad_B_down, alpha=down_scale, beta=0) + d_down_B.addmm_(down_A_t.t() @ h.t(), grad_output, alpha=down_scale, beta=0) - if up_A is not None and up_B is not None: - d_up_A = X.t() @ (grad_up @ up_B.t()) - d_up_B = (up_A.t() @ X.t()) @ grad_up - d_up_A *= up_scale - d_up_B *= up_scale + if up_A_t is not None and up_B_t is not None: + grad_B_up = grad_up @ up_B_t.t() # [T, rank] — reuse for dX + d_up_A = torch.empty_like(up_A_t) + d_up_B = torch.empty_like(up_B_t) + d_up_A.addmm_(X_lora.t(), grad_B_up, alpha=up_scale, beta=0) + d_up_B.addmm_(up_A_t.t() @ X_lora.t(), grad_up, alpha=up_scale, beta=0) - if gate_A is not None and gate_B is not None: - d_gate_A = X.t() @ (grad_gate @ gate_B.t()) - d_gate_B = (gate_A.t() @ X.t()) @ grad_gate - d_gate_A *= gate_scale - d_gate_B *= gate_scale + if gate_A_t is not None and gate_B_t is not None: + grad_B_gate = grad_gate @ gate_B_t.t() # [T, rank] — reuse for dX + d_gate_A = torch.empty_like(gate_A_t) + d_gate_B = torch.empty_like(gate_B_t) + d_gate_A.addmm_(X_lora.t(), grad_B_gate, alpha=gate_scale, beta=0) + d_gate_B.addmm_( + gate_A_t.t() @ X_lora.t(), grad_gate, alpha=gate_scale, beta=0 + ) # Compute input gradients - dX = torch.zeros_like(X) if ctx.needs_input_grad[0] else None + dX = None + dX_drop = None - if dX is not None: - # Up projection gradients - up_weight = dequantize(up_weight.t(), up_quant) + if ctx.needs_input_grad[0]: + # Base path gradients through gate and up + up_weight_deq = dequantize(up_weight.t(), up_quant) if ctx.inplace: - dX = torch.matmul(grad_up, up_weight.t(), out=X) + dX = torch.matmul(grad_up, up_weight_deq.t(), out=X) else: - dX = torch.matmul(grad_up, up_weight.t()) - del up_weight + dX = torch.matmul(grad_up, up_weight_deq.t()) + del up_weight_deq - # Note the .to(dtype) only where mixing LoRA with base weights - if up_A is not None and up_B is not None: - dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t()) + gate_weight_deq = dequantize(gate_weight, gate_quant) + dX += grad_gate @ gate_weight_deq + del gate_weight_deq - # Gate projection gradients - gate_weight = dequantize(gate_weight, gate_quant) - dX += grad_gate @ gate_weight - del gate_weight + # LoRA path: reuse grad_B_up and grad_B_gate from above + if has_dropout: + dX_drop = torch.zeros_like(X_lora) + if grad_B_up is not None: + dX_drop.addmm_(grad_B_up, up_A_t.t(), alpha=up_scale) # type: ignore[union-attr] + if grad_B_gate is not None: + dX_drop.addmm_(grad_B_gate, gate_A_t.t(), alpha=gate_scale) # type: ignore[union-attr] + dX_drop = dX_drop.view(batch, seq_len, hd) + else: + if grad_B_up is not None: + dX.addmm_(grad_B_up, up_A_t.t(), alpha=up_scale) # type: ignore[union-attr] + if grad_B_gate is not None: + dX.addmm_(grad_B_gate, gate_A_t.t(), alpha=gate_scale) # type: ignore[union-attr] - if gate_A is not None and gate_B is not None: - dX += ( - grad_gate - @ gate_B.to(dtype).t() - @ (gate_scale * gate_A.to(dtype).t()) - ) - - # Reshape back dX = dX.view(batch, seq_len, hd) - # Return gradients in correct order matching forward inputs + # Return gradients matching forward input order: + # X, X_drop, + # gate: weight, bias, quant, A, B, scale, lora_bias, magnitude + # up: weight, bias, quant, A, B, scale, lora_bias, magnitude + # down: weight, bias, quant, A, B, scale, lora_bias, magnitude + # activation_fn, activation_fn_backward, inplace return ( dX, + dX_drop, + # Gate None, None, None, d_gate_A.t() if d_gate_A is not None else None, d_gate_B.t() if d_gate_B is not None else None, None, + d_gate_lora_bias, + d_gate_mag, + # Up None, None, None, d_up_A.t() if d_up_A is not None else None, d_up_B.t() if d_up_B is not None else None, None, + d_up_lora_bias, + d_up_mag, + # Down None, None, None, d_down_A.t() if d_down_A is not None else None, d_down_B.t() if d_down_B is not None else None, None, - None, + d_down_lora_bias, + d_down_mag, + # Activation fns and flags None, None, None, @@ -391,40 +709,54 @@ class LoRA_MLP(torch.autograd.Function): def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor: - """ - Applies LoRA to MLP layer with SwiGLU activation. + """Applies LoRA to MLP layer with SwiGLU activation. - Args: - X: Input tensor for the MLP layer - inplace: Whether to perform operations in-place to save memory - - Returns: - Output tensor after applying LoRA-adapted MLP with SwiGLU activation + Supports bias, dropout, and DoRA. """ - gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) - upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) - downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) + gateW, gateb, gateW_quant, gateA, gateB, gateS, gateLB, gateDrop, gateMag = ( + get_lora_parameters(self.gate_proj) + ) + upW, upb, upW_quant, upA, upB, upS, upLB, upDrop, upMag = get_lora_parameters( + self.up_proj + ) + downW, downb, downW_quant, downA, downB, downS, downLB, downDrop, downMag = ( + get_lora_parameters(self.down_proj) + ) + + # Shared dropout mask for gate and up (same input) + X_drop = _apply_dropout(gateDrop, X, self.training) out = LoRA_MLP.apply( X, + X_drop, + # Gate gateW, gateb, gateW_quant, gateA, gateB, gateS, + gateLB, + gateMag, + # Up upW, upb, upW_quant, upA, upB, upS, + upLB, + upMag, + # Down downW, downb, downW_quant, downA, downB, downS, + downLB, + downMag, + # Activation and flags swiglu_forward, swiglu_backward, inplace, @@ -434,39 +766,53 @@ def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch. def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor: - """ - Applies LoRA to MLP layer with GEGLU activation. + """Applies LoRA to MLP layer with GEGLU activation. - Args: - X: Input tensor for the MLP layer - inplace: Whether to perform operations in-place to save memory - - Returns: - Output tensor after applying LoRA-adapted MLP with GEGLU activation + Supports bias, dropout, and DoRA. """ - gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) - upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) - downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) + gateW, gateb, gateW_quant, gateA, gateB, gateS, gateLB, gateDrop, gateMag = ( + get_lora_parameters(self.gate_proj) + ) + upW, upb, upW_quant, upA, upB, upS, upLB, upDrop, upMag = get_lora_parameters( + self.up_proj + ) + downW, downb, downW_quant, downA, downB, downS, downLB, downDrop, downMag = ( + get_lora_parameters(self.down_proj) + ) + + X_drop = _apply_dropout(gateDrop, X, self.training) + out = LoRA_MLP.apply( X, + X_drop, + # Gate gateW, gateb, gateW_quant, gateA, gateB, gateS, + gateLB, + gateMag, + # Up upW, upb, upW_quant, upA, upB, upS, + upLB, + upMag, + # Down downW, downb, downW_quant, downA, downB, downS, + downLB, + downMag, + # Activation and flags geglu_forward, geglu_backward, inplace, @@ -479,8 +825,8 @@ class LoRA_QKV(torch.autograd.Function): """ Optimized LoRA QKV implementation with quantization support. - Implements efficient computation of query, key, value projections with LoRA, - supporting quantization and memory optimization. + Supports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation). + Dropout is applied outside this Function so autograd handles its backward. """ @staticmethod @@ -488,65 +834,160 @@ class LoRA_QKV(torch.autograd.Function): def forward( ctx: torch.autograd.function.FunctionCtx, X: torch.Tensor, + X_drop: torch.Tensor | None, + # Q params q_weight: torch.Tensor, q_bias: torch.Tensor | None, q_quant: QuantState | None, q_A: torch.Tensor | None, q_B: torch.Tensor | None, q_scale: float, + q_lora_bias: torch.Tensor | None, + q_magnitude: torch.Tensor | None, + # K params k_weight: torch.Tensor, k_bias: torch.Tensor | None, k_quant: QuantState | None, k_A: torch.Tensor | None, k_B: torch.Tensor | None, k_scale: float, + k_lora_bias: torch.Tensor | None, + k_magnitude: torch.Tensor | None, + # V params v_weight: torch.Tensor, v_bias: torch.Tensor | None, v_quant: QuantState | None, v_A: torch.Tensor | None, v_B: torch.Tensor | None, v_scale: float, + v_lora_bias: torch.Tensor | None, + v_magnitude: torch.Tensor | None, + # Flags inplace: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Forward pass computing Q, K, V projections with LoRA. + has_dropout = X_drop is not None + has_dora = q_magnitude is not None - Args: - ctx: Autograd context - X: Input tensor - q_weight: Query projection weight - q_bias: Query projection bias - q_quant: Query quantization state - q_A: Query LoRA A matrix - q_B: Query LoRA B matrix - q_scale: Query LoRA scale - k_weight: Key projection weight - k_bias: Key projection bias - k_quant: Key quantization state - k_A: Key LoRA A matrix - k_B: Key LoRA B matrix - k_scale: Key LoRA scale - v_weight: Value projection weight - v_bias: Value projection bias - v_quant: Value quantization state - v_A: Value LoRA A matrix - v_B: Value LoRA B matrix - v_scale: Value LoRA scale - inplace: Whether to perform operations in-place + if has_dora: + dtype = X.dtype + X_lora = X_drop if has_dropout else X - Returns: - Tuple of (Query, Key, Value) projection tensors - """ - Q = matmul_lora(X, q_weight, q_bias, q_quant, q_A, q_B, q_scale) - K = matmul_lora(X, k_weight, k_bias, k_quant, k_A, k_B, k_scale) - V = matmul_lora(X, v_weight, v_bias, v_quant, v_A, v_B, v_scale) + # Compute Q with DoRA + Q_base = matmul_lora(X, q_weight, None, q_quant, None, None, None) + Q_lora = _lora_only(X_lora, q_A, q_B, q_scale, q_lora_bias, dtype) + q_mag_scale = _compute_dora_scale( + q_weight, q_quant, q_A, q_B, q_scale, q_magnitude, dtype + ) + Q = q_mag_scale.unsqueeze(0) * (Q_base + Q_lora) + if q_bias is not None: + Q = Q + q_bias + + # Compute K with DoRA + K_base = matmul_lora(X, k_weight, None, k_quant, None, None, None) + K_lora = _lora_only(X_lora, k_A, k_B, k_scale, k_lora_bias, dtype) + k_mag_scale = _compute_dora_scale( + k_weight, k_quant, k_A, k_B, k_scale, k_magnitude, dtype + ) + K = k_mag_scale.unsqueeze(0) * (K_base + K_lora) + if k_bias is not None: + K = K + k_bias + + # Compute V with DoRA + V_base = matmul_lora(X, v_weight, None, v_quant, None, None, None) + V_lora = _lora_only(X_lora, v_A, v_B, v_scale, v_lora_bias, dtype) + v_mag_scale = _compute_dora_scale( + v_weight, v_quant, v_A, v_B, v_scale, v_magnitude, dtype + ) + V = v_mag_scale.unsqueeze(0) * (V_base + V_lora) + if v_bias is not None: + V = V + v_bias + + # Save for backward: need combined (base+lora) and mag_scale for DoRA grads + Q_combined = Q_base + Q_lora + K_combined = K_base + K_lora + V_combined = V_base + V_lora + + ctx.save_for_backward( + X, + X_drop if has_dropout else X, + q_A.to(dtype) if q_A is not None else q_A, + q_B.to(dtype) if q_B is not None else q_B, + k_A.to(dtype) if k_A is not None else k_A, + k_B.to(dtype) if k_B is not None else k_B, + v_A.to(dtype) if v_A is not None else v_A, + v_B.to(dtype) if v_B is not None else v_B, + q_magnitude, + k_magnitude, + v_magnitude, + q_mag_scale, + k_mag_scale, + v_mag_scale, + Q_combined, + K_combined, + V_combined, + q_lora_bias, + k_lora_bias, + v_lora_bias, + ) + else: + # Standard LoRA (with optional dropout and bias) + Q = matmul_lora( + X, + q_weight, + q_bias, + q_quant, + q_A, + q_B, + q_scale, + X_drop=X_drop, + lora_bias=q_lora_bias, + ) + K = matmul_lora( + X, + k_weight, + k_bias, + k_quant, + k_A, + k_B, + k_scale, + X_drop=X_drop, + lora_bias=k_lora_bias, + ) + V = matmul_lora( + X, + v_weight, + v_bias, + v_quant, + v_A, + v_B, + v_scale, + X_drop=X_drop, + lora_bias=v_lora_bias, + ) + + # Pre-convert LoRA matrices to compute dtype to avoid + # redundant fp32→bf16 conversion in backward + dtype = X.dtype + ctx.save_for_backward( + X, + X_drop if has_dropout else X, + q_A.to(dtype) if q_A is not None else q_A, + q_B.to(dtype) if q_B is not None else q_B, + k_A.to(dtype) if k_A is not None else k_A, + k_B.to(dtype) if k_B is not None else k_B, + v_A.to(dtype) if v_A is not None else v_A, + v_B.to(dtype) if v_B is not None else v_B, + q_lora_bias, + k_lora_bias, + v_lora_bias, + ) - ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B) ctx.scales = (q_scale, k_scale, v_scale) ctx.quants = (q_quant, k_quant, v_quant) ctx.weights = (q_weight, k_weight, v_weight) - ctx.biases = (q_bias, k_bias, v_bias) ctx.inplace = inplace + ctx.has_dropout = has_dropout + ctx.has_dora = has_dora return Q, K, V @@ -557,110 +998,169 @@ class LoRA_QKV(torch.autograd.Function): q_grad: torch.Tensor, k_grad: torch.Tensor, v_grad: torch.Tensor, - ) -> tuple[ - torch.Tensor, - None, - None, - None, - torch.Tensor | None, - torch.Tensor | None, - None, - None, - None, - None, - torch.Tensor | None, - torch.Tensor | None, - None, - None, - None, - None, - torch.Tensor | None, - torch.Tensor | None, - None, - None, - ]: - """ - Backward pass computing gradients for LoRA QKV. - - Args: - ctx: Autograd context - q_grad: Gradient for query projection - k_grad: Gradient for key projection - v_grad: Gradient for value projection - - Returns: - Tuple containing gradients for all forward inputs - """ - X, A_q, B_q, A_k, B_k, A_v, B_v = ctx.saved_tensors + ): q_weight, k_weight, v_weight = ctx.weights q_quant, k_quant, v_quant = ctx.quants q_scale, k_scale, v_scale = ctx.scales - dtype = X.dtype + has_dropout = ctx.has_dropout + has_dora = ctx.has_dora + + if has_dora: + ( + X, + X_lora, + A_q, + B_q, + A_k, + B_k, + A_v, + B_v, + q_magnitude, + k_magnitude, + v_magnitude, + q_mag_scale, + k_mag_scale, + v_mag_scale, + Q_combined, + K_combined, + V_combined, + q_lora_bias, + k_lora_bias, + v_lora_bias, + ) = ctx.saved_tensors + else: + ( + X, + X_lora, + A_q, + B_q, + A_k, + B_k, + A_v, + B_v, + q_lora_bias, + k_lora_bias, + v_lora_bias, + ) = ctx.saved_tensors + q_magnitude = k_magnitude = v_magnitude = None + q_mag_scale = k_mag_scale = v_mag_scale = None + Q_combined = K_combined = V_combined = None - # Reshape gradients batch, seq_len = X.shape[:2] q_grad = q_grad.view(-1, q_grad.shape[-1]) k_grad = k_grad.reshape(-1, k_grad.shape[-1]) v_grad = v_grad.view(-1, v_grad.shape[-1]) X = X.view(-1, X.shape[-1]) + X_lora = X_lora.view(-1, X_lora.shape[-1]) - # Pre-transpose X once - X_t = X.t() + # DoRA: scale gradients through mag_norm_scale + d_q_mag = d_k_mag = d_v_mag = None + d_q_lora_bias = d_k_lora_bias = d_v_lora_bias = None + + if has_dora: + Q_combined = Q_combined.view(-1, Q_combined.shape[-1]) + K_combined = K_combined.view(-1, K_combined.shape[-1]) + V_combined = V_combined.view(-1, V_combined.shape[-1]) + + # Magnitude gradients: d_mag = sum_t(grad * combined) / weight_norm + # Since mag_scale = magnitude / weight_norm, and weight_norm is detached: + # d_magnitude = sum_t(grad * combined) * (1 / weight_norm) + # But we have mag_scale = magnitude / weight_norm + # d_mag_scale_j = grad_j * combined_j (per element) + # d_magnitude_j = d_mag_scale_j / weight_norm_j = sum_t(grad * combined) / weight_norm + # Simpler: d_magnitude = sum(grad * combined, dim=0) * mag_scale / magnitude + # Actually: mag_scale = m/wn, d_output/d_m = combined/wn = combined * mag_scale/m + # d_m = sum(grad * combined * mag_scale / m, dim=0) ... no. + # Let's be precise: output = mag_scale * combined, mag_scale = m / wn (wn detached) + # d_loss/d_m = d_loss/d_output * d_output/d_m = sum_t(grad_t * combined_t / wn) + # = sum_t(grad_t * combined_t) * (1/wn) = sum_t(grad_t * combined_t) * mag_scale / m + # Or just: d_m = sum(grad * combined, dim=0) / weight_norm + # Since we don't have weight_norm saved, use mag_scale/magnitude: + # 1/wn = mag_scale/magnitude + d_q_mag = (q_grad * Q_combined).sum(dim=0) * q_mag_scale / q_magnitude + d_k_mag = (k_grad * K_combined).sum(dim=0) * k_mag_scale / k_magnitude + d_v_mag = (v_grad * V_combined).sum(dim=0) * v_mag_scale / v_magnitude + + # Chain rule: grad through combined = grad * mag_scale + q_grad = q_grad * q_mag_scale.unsqueeze(0) + k_grad = k_grad * k_mag_scale.unsqueeze(0) + v_grad = v_grad * v_mag_scale.unsqueeze(0) + + # LoRA bias gradients + if q_lora_bias is not None: + d_q_lora_bias = q_scale * q_grad.sum(dim=0) + if k_lora_bias is not None: + d_k_lora_bias = k_scale * k_grad.sum(dim=0) + if v_lora_bias is not None: + d_v_lora_bias = v_scale * v_grad.sum(dim=0) + + # Pre-transpose X_lora for LoRA gradients + X_lora_t = X_lora.t() # Initialize LoRA gradients as None d_A_q = d_B_q = d_A_k = d_B_k = d_A_v = d_B_v = None - # Compute q path LoRA gradients if adapters exist + # Compute LoRA gradients using X_lora (before any inplace ops on X) + # A_q, B_q etc. are already in compute dtype (converted in forward) + # Key optimization: compute grad @ B once, reuse for both dA and dX_lora + # A has shape [rank, in], B has shape [out, rank] + grad_B_q = grad_B_k = grad_B_v = None + if A_q is not None and B_q is not None: - A_q_scaled = (q_scale * A_q).to(dtype) - B_q_scaled = B_q.to(dtype) - d_A_q = torch.mm(X_t, torch.mm(q_grad, B_q_scaled)) - d_B_q = torch.mm(torch.mm(A_q_scaled, X_t), q_grad) + grad_B_q = q_grad @ B_q # [T, rank] — reused for dA and dX + d_A_q = torch.empty_like(A_q.t()) + d_B_q = torch.empty_like(B_q.t()) + d_A_q.addmm_(X_lora_t, grad_B_q, alpha=q_scale, beta=0) + d_B_q.addmm_(A_q @ X_lora_t, q_grad, alpha=q_scale, beta=0) - # Compute k path LoRA gradients if adapters exist if A_k is not None and B_k is not None: - A_k_scaled = (k_scale * A_k).to(dtype) - B_k_scaled = B_k.to(dtype) - d_A_k = torch.mm(X_t, torch.mm(k_grad, B_k_scaled)) - d_B_k = torch.mm(torch.mm(A_k_scaled, X_t), k_grad) + grad_B_k = k_grad @ B_k + d_A_k = torch.empty_like(A_k.t()) + d_B_k = torch.empty_like(B_k.t()) + d_A_k.addmm_(X_lora_t, grad_B_k, alpha=k_scale, beta=0) + d_B_k.addmm_(A_k @ X_lora_t, k_grad, alpha=k_scale, beta=0) - # Compute v path LoRA gradients if adapters exist if A_v is not None and B_v is not None: - A_v_scaled = (v_scale * A_v).to(dtype) - B_v_scaled = B_v.to(dtype) - d_A_v = torch.mm(X_t, torch.mm(v_grad, B_v_scaled)) - d_B_v = torch.mm(torch.mm(A_v_scaled, X_t), v_grad) + grad_B_v = v_grad @ B_v + d_A_v = torch.empty_like(A_v.t()) + d_B_v = torch.empty_like(B_v.t()) + d_A_v.addmm_(X_lora_t, grad_B_v, alpha=v_scale, beta=0) + d_B_v.addmm_(A_v @ X_lora_t, v_grad, alpha=v_scale, beta=0) - # Compute input gradient, reusing X memory if possible + # Base path input gradient (can use inplace on X since X_lora refs are done) out_buffer = X if ctx.inplace else None - # Q path q_weight_t = dequantize(q_weight, q_quant) grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer) - del q_weight del q_weight_t - if A_q is not None and B_q is not None: - # Stay decomposed: dQ @ B^T gives [T, R], then [T, R] @ (s*A) gives [T, in] - # This is 65x fewer FLOPs than materializing B@A into [out, in] - grad_X.addmm_(torch.mm(q_grad, B_q_scaled), A_q_scaled) - # K path k_weight_t = dequantize(k_weight, k_quant) grad_X.addmm_(k_grad, k_weight_t) - del k_weight del k_weight_t - if A_k is not None and B_k is not None: - grad_X.addmm_(torch.mm(k_grad, B_k_scaled), A_k_scaled) - # V path v_weight_t = dequantize(v_weight, v_quant) grad_X.addmm_(v_grad, v_weight_t) - del v_weight del v_weight_t - if A_v is not None and B_v is not None: - grad_X.addmm_(torch.mm(v_grad, B_v_scaled), A_v_scaled) - # Transpose gradients if needed + # LoRA path input gradient: s * grad @ B @ A (reuses grad_B_* from above) + if has_dropout: + grad_X_drop = torch.zeros_like(X_lora) + if grad_B_q is not None: + grad_X_drop.addmm_(grad_B_q, A_q, alpha=q_scale) + if grad_B_k is not None: + grad_X_drop.addmm_(grad_B_k, A_k, alpha=k_scale) + if grad_B_v is not None: + grad_X_drop.addmm_(grad_B_v, A_v, alpha=v_scale) + else: + grad_X_drop = None + if grad_B_q is not None: + grad_X.addmm_(grad_B_q, A_q, alpha=q_scale) + if grad_B_k is not None: + grad_X.addmm_(grad_B_k, A_k, alpha=k_scale) + if grad_B_v is not None: + grad_X.addmm_(grad_B_v, A_v, alpha=v_scale) + + # Transpose LoRA gradients if d_A_q is not None: d_A_q = d_A_q.t() d_B_q = d_B_q.t() # type: ignore[union-attr] @@ -671,66 +1171,126 @@ class LoRA_QKV(torch.autograd.Function): d_A_v = d_A_v.t() d_B_v = d_B_v.t() # type: ignore[union-attr] + grad_X = grad_X.view(batch, seq_len, -1) + if grad_X_drop is not None: + grad_X_drop = grad_X_drop.view(batch, seq_len, -1) + + # Return gradients for all forward inputs: + # X, X_drop, + # q: weight, bias, quant, A, B, scale, lora_bias, magnitude + # k: weight, bias, quant, A, B, scale, lora_bias, magnitude + # v: weight, bias, quant, A, B, scale, lora_bias, magnitude + # inplace return ( - grad_X.view(batch, seq_len, -1), + grad_X, + grad_X_drop, + # Q None, None, None, d_A_q, d_B_q, None, + d_q_lora_bias, + d_q_mag, + # K None, None, None, d_A_k, d_B_k, None, + d_k_lora_bias, + d_k_mag, + # V None, None, None, d_A_v, d_B_v, None, + d_v_lora_bias, + d_v_mag, + # inplace None, ) +def _lora_only( + X: torch.Tensor, + A: torch.Tensor | None, + B: torch.Tensor | None, + s: float | None, + lora_bias: torch.Tensor | None, + dtype: torch.dtype, +) -> torch.Tensor: + """Compute only the LoRA contribution: s * X @ A^T @ B^T + s * lora_bias.""" + if A is None: + return torch.zeros( + X.shape[:-1] + (B.shape[0] if B is not None else 1,), + device=X.device, + dtype=dtype, + ) + reshape = False + if X.dim() == 3: + batch, seq_len, _ = X.shape + X = X.view(-1, X.shape[-1]) + reshape = True + At, Bt = A.t().to(dtype), B.t().to(dtype) # type: ignore[union-attr] + out = s * X @ At @ Bt + if lora_bias is not None: + out = out + s * lora_bias + return out.view(batch, seq_len, -1) if reshape else out + + def apply_lora_qkv( self, X: torch.Tensor, inplace: bool = True ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Applies LoRA to compute Query, Key, Value projections. - Args: - X: Input tensor - inplace: Whether to perform operations in-place - - Returns: - Tuple of (Query, Key, Value) projection tensors + Supports bias, dropout, and DoRA. Dropout is applied outside the autograd + Function so PyTorch handles its backward automatically. A single shared + dropout mask is used across Q, K, V projections for memory efficiency. """ - QW, Qb, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj) - KW, Kb, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj) - VW, Vb, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj) + QW, Qb, QW_quant, QA, QB, QS, Qlb, Qdrop, Qmag = get_lora_parameters(self.q_proj) + KW, Kb, KW_quant, KA, KB, KS, Klb, Kdrop, Kmag = get_lora_parameters(self.k_proj) + VW, Vb, VW_quant, VA, VB, VS, Vlb, Vdrop, Vmag = get_lora_parameters(self.v_proj) + + # Apply dropout outside autograd.Function (shared mask for Q, K, V) + X_drop = _apply_dropout(Qdrop, X, self.training) + Q, K, V = LoRA_QKV.apply( X, + X_drop, + # Q QW, Qb, QW_quant, QA, QB, QS, + Qlb, + Qmag, + # K KW, Kb, KW_quant, KA, KB, KS, + Klb, + Kmag, + # V VW, Vb, VW_quant, VA, VB, VS, + Vlb, + Vmag, + # Flags inplace, ) @@ -738,43 +1298,75 @@ def apply_lora_qkv( class LoRA_O(torch.autograd.Function): - """Optimized LoRA implementation for output projection.""" + """Optimized LoRA implementation for output projection. + + Supports bias, dropout, and DoRA. + """ @staticmethod @torch_amp_custom_fwd def forward( ctx: torch.autograd.function.FunctionCtx, X: torch.Tensor, + X_drop: torch.Tensor | None, W: torch.Tensor, - b: torch.Tensor, + b: torch.Tensor | None, W_quant: QuantState | None, - A: torch.Tensor, - B: torch.Tensor, + A: torch.Tensor | None, + B: torch.Tensor | None, s: float, + lora_bias: torch.Tensor | None, + magnitude: torch.Tensor | None, ) -> torch.Tensor: - """ - Forward pass for output projection with LoRA. + has_dropout = X_drop is not None + has_dora = magnitude is not None + dtype = X.dtype - Args: - ctx: Autograd context - X: Input tensor - W: Output projection weight - b: Output projection bias - W_quant: Weight quantization state - A: LoRA A matrix - B: LoRA B matrix - s: LoRA scaling factor + if has_dora: + X_lora = X_drop if has_dropout else X + base_out = matmul_lora(X, W, None, W_quant, None, None, None) + lora_out = _lora_only(X_lora, A, B, s, lora_bias, dtype) + mag_scale = _compute_dora_scale(W, W_quant, A, B, s, magnitude, dtype) + combined = base_out + lora_out + XW = mag_scale.unsqueeze(0) * combined + if b is not None: + XW = XW + b - Returns: - Output projection result - """ - XW = matmul_lora(X, W, b, W_quant, A, B, s) - ctx.custom_saved_tensors = ( - W, - W_quant, - s, - ) - ctx.save_for_backward(A, B, X) + ctx.save_for_backward( + A.to(dtype) if A is not None else A, + B.to(dtype) if B is not None else B, + X, + X_drop if has_dropout else X, + magnitude, + mag_scale, + combined, + lora_bias, + ) + else: + XW = matmul_lora( + X, + W, + b, + W_quant, + A, + B, + s, + X_drop=X_drop, + lora_bias=lora_bias, + ) + # Pre-convert LoRA matrices to compute dtype for backward + dtype = X.dtype + ctx.save_for_backward( + A.to(dtype) if A is not None else A, + B.to(dtype) if B is not None else B, + X, + X_drop if has_dropout else X, + lora_bias, + ) + + ctx.custom_saved_tensors = (W, W_quant, s) + ctx.has_dropout = has_dropout + ctx.has_dora = has_dora return XW @@ -783,62 +1375,330 @@ class LoRA_O(torch.autograd.Function): def backward( ctx: torch.autograd.function.FunctionCtx, dY: torch.Tensor, - ) -> tuple[ - torch.Tensor, - None, - None, - None, - torch.Tensor, - torch.Tensor, - None, - ]: - """ - Backward pass computing gradients for LoRA output projection. - - Args: - ctx: Autograd context - dY: Gradient of loss with respect to output - - Returns: - Tuple containing gradients for all forward inputs - """ + ): W, W_quant, s = ctx.custom_saved_tensors - A, B, X = ctx.saved_tensors + has_dropout = ctx.has_dropout + has_dora = ctx.has_dora + + if has_dora: + A, B, X, X_lora, magnitude, mag_scale, combined, lora_bias = ( + ctx.saved_tensors + ) + else: + A, B, X, X_lora, lora_bias = ctx.saved_tensors + magnitude = mag_scale = combined = None batch, seq_len, hd = X.shape dY = dY.reshape(-1, dY.shape[-1]) X = X.reshape(-1, X.shape[-1]) - dtype = X.dtype + X_lora = X_lora.reshape(-1, X_lora.shape[-1]) - # Weight projection - dY_X = X.t() @ dY - d_A = s * dY_X @ B - d_B = s * A @ dY_X + d_mag = d_lora_bias = None - # Get derivative for dX - W = dequantize(W.t(), W_quant) - dX = dY @ W.t() - del W + if has_dora: + combined = combined.view(-1, combined.shape[-1]) + d_mag = (dY * combined).sum(dim=0) * mag_scale / magnitude + dY = dY * mag_scale.unsqueeze(0) - A, B = A.to(dtype), B.to(dtype) - # Stay decomposed: dY @ B gives [T, R], then [T, R] @ A gives [T, in] - dX.addmm_(torch.mm(dY, B), A, alpha=s) + # LoRA bias gradient + if lora_bias is not None: + d_lora_bias = s * dY.sum(dim=0) - # W, b, W_quant, A, B, s - return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None + # LoRA parameter gradients (A, B already in compute dtype from forward) + # Compute dY @ B once, reuse for both dA and dX_lora + d_A = d_B = None + grad_B = None + if A is not None: + grad_B = dY @ B # [T, rank] — reused below + X_lora_t = X_lora.t() + d_A = torch.empty_like(A.t()) + d_B = torch.empty_like(B.t()) + d_A.addmm_(X_lora_t, grad_B, alpha=s, beta=0) + d_B.addmm_(A @ X_lora_t, dY, alpha=s, beta=0) + + # Base path input gradient + W_deq = dequantize(W.t(), W_quant) + dX = dY @ W_deq.t() + del W_deq + + if has_dropout: + dX_drop = None + if grad_B is not None: + dX_drop = (grad_B @ A * s).view(batch, seq_len, hd) + else: + dX_drop = None + if grad_B is not None: + dX.addmm_(grad_B, A, alpha=s) + + # X, X_drop, W, b, W_quant, A, B, s, lora_bias, magnitude + return ( + dX.view(batch, seq_len, hd), + dX_drop, + None, + None, + None, + d_A.t() if d_A is not None else None, + d_B.t() if d_B is not None else None, + None, + d_lora_bias, + d_mag, + ) def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor: """ Applies LoRA to output projection layer. - Args: - X: Input tensor - - Returns: - Transformed output tensor + Supports bias, dropout, and DoRA. """ - OW, Ob, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj) - output = LoRA_O.apply(X, OW, Ob, OW_quant, OA, OB, OS) + OW, Ob, OW_quant, OA, OB, OS, Olb, Odrop, Omag = get_lora_parameters(self.o_proj) + X_drop = _apply_dropout(Odrop, X, self.training) + output = LoRA_O.apply(X, X_drop, OW, Ob, OW_quant, OA, OB, OS, Olb, Omag) return output + + +# ============================================================ +# Embedding LoRA kernel +# ============================================================ + + +def get_embedding_lora_parameters( + embed: nn.Module, +) -> tuple[ + torch.Tensor, # W (base embedding weight) + torch.Tensor | None, # A (lora_embedding_A) + torch.Tensor | None, # B (lora_embedding_B) + float | None, # scaling + nn.Module | None, # dropout + torch.Tensor | None, # magnitude (DoRA) + nn.Module, # base_layer +]: + """Extract LoRA parameters from a PEFT Embedding module.""" + base_layer = embed.base_layer if hasattr(embed, "base_layer") else embed + W = base_layer.weight + + if not hasattr(embed, "disable_adapters") or embed.disable_adapters or embed.merged: + return W, None, None, None, None, None, base_layer + + active_adapter = ( + embed.active_adapters[0] + if hasattr(embed, "active_adapters") + else embed.active_adapter + ) + + A = embed.lora_embedding_A[active_adapter] # nn.Parameter [rank, vocab] + B = embed.lora_embedding_B[active_adapter] # nn.Parameter [hidden_dim, rank] + s = embed.scaling[active_adapter] + + # FSDP2 DTensor unshard (mirrors linear path logic) + if isinstance(A, DTensor): + A = A.full_tensor() + if isinstance(B, DTensor): + B = B.full_tensor() + + dropout = None + if hasattr(embed, "lora_dropout") and active_adapter in embed.lora_dropout: + dropout = embed.lora_dropout[active_adapter] + + magnitude = None + if ( + hasattr(embed, "lora_magnitude_vector") + and embed.lora_magnitude_vector + and active_adapter in embed.lora_magnitude_vector + ): + mag_layer = embed.lora_magnitude_vector[active_adapter] + magnitude = mag_layer.weight + if isinstance(magnitude, DTensor): + magnitude = magnitude.full_tensor() + + return W, A, B, s, dropout, magnitude, base_layer + + +class LoRA_Embedding(torch.autograd.Function): + """Fused LoRA embedding: F.embedding(x, W) + s * F.embedding(x, A^T) @ B^T. + + Supports dropout and DoRA. + """ + + @staticmethod + @torch_amp_custom_fwd + def forward( + ctx, + x: torch.Tensor, + W: torch.Tensor, + A: torch.Tensor | None, + B: torch.Tensor | None, + s: float | None, + magnitude: torch.Tensor | None, + padding_idx: int | None, + # base_layer fields for F.embedding + max_norm: float | None, + norm_type: float, + scale_grad_by_freq: bool, + sparse: bool, + ) -> torch.Tensor: + import torch.nn.functional as F + + has_dora = magnitude is not None + dtype = W.dtype + + # Base embedding lookup + result = F.embedding( + x, + W, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + + if A is not None: + # LoRA: F.embedding(x, A^T) @ B^T * s + A_T = A.t() # type: ignore[union-attr] # [vocab, rank] + B_T = B.t() # type: ignore[union-attr] # [rank, hidden_dim] + after_A = F.embedding( + x, + A_T, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) # [batch, seq, rank] + + lora_result = after_A @ B_T # [batch, seq, hidden_dim] + + if has_dora: + mag_scale = _compute_dora_scale(W.t(), None, A, B, s, magnitude, dtype) # type: ignore[arg-type] + # DoRA: mag_scale * (base + s * lora) + # base embedding has no bias + pre_scaled = result + s * lora_result # unscaled combined + result = mag_scale.unsqueeze(0) * pre_scaled + ctx.save_for_backward( + x, + A.to(dtype), + B.to(dtype), # type: ignore[union-attr] + after_A, + magnitude, + mag_scale, + pre_scaled, # save unscaled for correct d_mag + ) + else: + result = result + s * lora_result + ctx.save_for_backward(x, A.to(dtype), B.to(dtype), after_A) # type: ignore[union-attr] + else: + ctx.save_for_backward( + x, + ) + + ctx.s = s + ctx.has_dora = has_dora + ctx.has_lora = A is not None + ctx.padding_idx = padding_idx + ctx.scale_grad_by_freq = scale_grad_by_freq + + return result + + @staticmethod + @torch_amp_custom_bwd + def backward(ctx, grad_output): + s = ctx.s + has_dora = ctx.has_dora + has_lora = ctx.has_lora + + d_A = d_B = d_mag = None + + if not has_lora: + (x,) = ctx.saved_tensors + elif has_dora: + x, A, B, after_A, magnitude, mag_scale, combined = ctx.saved_tensors + # DoRA magnitude gradient + combined_flat = combined.view(-1, combined.shape[-1]) + grad_flat = grad_output.view(-1, grad_output.shape[-1]) + d_mag = (grad_flat * combined_flat).sum(dim=0) * mag_scale / magnitude + # Chain rule through mag_scale + grad_output = grad_output * mag_scale.unsqueeze(0).unsqueeze(0) + else: + x, A, B, after_A = ctx.saved_tensors + + if has_lora: + # Use float32 for gradient computation (LoRA params are fp32) + compute_dtype = torch.float32 + + after_A_flat = after_A.view(-1, after_A.shape[-1]).to(compute_dtype) + grad_flat = grad_output.view(-1, grad_output.shape[-1]).to(compute_dtype) + B_f = B.to(compute_dtype) + + # B is [hidden_dim, rank], B_T = B.t() = [rank, hidden_dim] + # lora_result = after_A @ B_T → d/d(B_T) = s * after_A^T @ grad + B_T = B_f.t() # [rank, hidden_dim] + d_B_T = torch.empty_like(B_T) + d_B_T.addmm_(after_A_flat.t(), grad_flat, alpha=s, beta=0) + d_B = d_B_T.t() # [hidden_dim, rank] + + # d_A: gradient flows through F.embedding lookup + # d_after_A = s * grad @ B = [T, hidden] @ [hidden, rank] = [T, rank] + d_after_A = s * grad_flat @ B_f + + # F.embedding backward: scatter d_after_A into A^T gradient + x_flat = x.view(-1) + + # Zero out padding_idx contributions (matches F.embedding behavior) + if ctx.padding_idx is not None: + pad_mask = x_flat != ctx.padding_idx + d_after_A = d_after_A * pad_mask.unsqueeze(1).to(d_after_A.dtype) + + # scale_grad_by_freq: divide each contribution by token frequency + if ctx.scale_grad_by_freq: + counts = torch.bincount(x_flat, minlength=A.shape[1]).clamp(min=1) + freq_scale = 1.0 / counts[x_flat].unsqueeze(1).to(d_after_A.dtype) + d_after_A = d_after_A * freq_scale + + A_f = A.to(compute_dtype) + d_A_T = torch.zeros_like(A_f.t()) # [vocab, rank] + d_A_T.index_add_(0, x_flat, d_after_A) + d_A = d_A_T.t() # [rank, vocab] + + # x, W, A, B, s, magnitude, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse + return ( + None, # x + None, # W (base embedding weight grad handled by PyTorch) + d_A, # A + d_B, # B + None, # s + d_mag, # magnitude + None, + None, + None, + None, + None, # padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse + ) + + +def apply_lora_embedding(self, x: torch.Tensor) -> torch.Tensor: + """Applies LoRA to embedding layer.""" + W, A, B, s, dropout, magnitude, base_layer = get_embedding_lora_parameters(self) + + # Capture base output dtype (bf16 for bf16 models) to cast back at end + output_dtype = W.dtype + + # Note: PEFT's Embedding forward does not apply dropout for embeddings + # (integer indices can't be dropped; PEFT silently ignores lora_dropout here) + result = LoRA_Embedding.apply( + x, + W, + A, + B, + s, + magnitude, + base_layer.padding_idx, + base_layer.max_norm, + base_layer.norm_type, + base_layer.scale_grad_by_freq, + base_layer.sparse, + ) + + # Cast to model dtype (LoRA ops may upcast to float32) + return result.to(output_dtype) diff --git a/src/axolotl/kernels/quantize.py b/src/axolotl/kernels/quantize.py index c9c0f59bd..ff564fecc 100644 --- a/src/axolotl/kernels/quantize.py +++ b/src/axolotl/kernels/quantize.py @@ -105,6 +105,10 @@ def dequantize( # Extract quantization state if not isinstance(quant_state, list): # New style quant_state class + # Non-double-quantized models have offset=None and state2=None + if quant_state.offset is None or quant_state.state2 is None: + # Fall back to bitsandbytes standard dequantize + return bnb.functional.dequantize_4bit(W, quant_state, quant_type="nf4") absmax = quant_state.absmax.to(target_device) shape = quant_state.shape dtype = quant_state.dtype diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 44be5267d..5bb3a32eb 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -12,6 +12,7 @@ from torch import nn from transformers import AutoConfig from axolotl.kernels.lora import ( + apply_lora_embedding, apply_lora_mlp_geglu, apply_lora_mlp_swiglu, apply_lora_o, @@ -370,13 +371,13 @@ def apply_lora_kernel_patches( active_adapter = model.active_adapter lora_config = model.model.peft_config[active_adapter] - # Only patch if conditions are met - can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none" - - if not can_patch: - LOG.warning("Cannot patch layers - requires no dropout and no bias") - LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file") - return model + # Log what features are active + if lora_config.lora_dropout > 0: + LOG.info(f"LoRA kernels: dropout={lora_config.lora_dropout} enabled") + if lora_config.bias != "none": + LOG.info(f"LoRA kernels: bias={lora_config.bias} enabled") + if lora_config.use_dora: + LOG.info("LoRA kernels: DoRA enabled") # This needs to be reset after patching original_level = LOG.getEffectiveLevel() @@ -419,44 +420,33 @@ def apply_lora_kernel_patches( for linear_proj in ["q_proj", "k_proj", "v_proj"] ] can_patch_qkv = all( - hasattr(module, "lora_A") - and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 - for module in layer_modules + hasattr(module, "lora_A") for module in layer_modules ) if can_patch_qkv: - # Add optimized implementation self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn) else: LOG.warning_once( - "Cannot patch some attention QKV projections - requires LoRA " - "adapters and no lora_magnitude_vector (DoRA)" + "Cannot patch some attention QKV projections - requires LoRA adapters" ) if cfg.lora_o_kernel: # Output patching layer_modules = [ getattr(self_attn, linear_proj) for linear_proj in ["o_proj"] ] - can_patch_o = all( - hasattr(module, "lora_A") - and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 - for module in layer_modules - ) + can_patch_o = all(hasattr(module, "lora_A") for module in layer_modules) if can_patch_o: self_attn.apply_o = types.MethodType(apply_lora_o, self_attn) else: LOG.warning_once( - "Cannot patch some attention output projection - requires LoRA " - "adapters and no lora_magnitude_vector (DoRA)" + "Cannot patch some attention output projection - requires LoRA adapters" ) for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer): if cfg.lora_mlp_kernel: # MLP patching can_patch_mlp = all( - hasattr(proj, "lora_A") - and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 - for proj in (gate_proj, up_proj, down_proj) + hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj) ) if can_patch_mlp: @@ -464,15 +454,50 @@ def apply_lora_kernel_patches( layer.mlp.forward = types.MethodType(apply_fn, mlp) else: LOG.warning_once( - "Cannot patch some MLP layers - requires LoRA adapters and no " - "lora_magnitude_vector (DoRA)" + "Cannot patch some MLP layers - requires LoRA adapters" ) + # Patch embedding layers (model-level, not per-layer) + if cfg.lora_embedding_kernel: + _patch_embedding_layers(model, cfg) + LOG.setLevel(original_level) return model +def _patch_embedding_layers(model: PeftModelForCausalLM, cfg: DictDefault): + """Patch embedding layers with fused LoRA kernel. + + Handles both embed_tokens (nn.Embedding with lora_embedding_A/B) and + lm_head (nn.Linear with lora_A/B, used when tied embeddings are untied by PEFT). + """ + pretrained_model = model.model + patched = 0 + + # Find embedding modules - check common locations + for attr_path in [ + ("model", "embed_tokens"), + ("model", "language_model", "embed_tokens"), + ]: + parent = pretrained_model + for attr in attr_path: + parent = getattr(parent, attr, None) + if parent is None: + break + if parent is not None and hasattr(parent, "lora_embedding_A"): + LOG.info(f"Patching embedding layer: {'.'.join(attr_path)}") + parent.forward = types.MethodType(apply_lora_embedding, parent) + patched += 1 + + # lm_head with LoRA is a Linear layer - already handled by LoRA_O/LoRA_W kernels + # when included in target_modules. No special embedding handling needed since + # PEFT wraps it as a Linear (not Embedding) even for tied models. + + if not patched: + LOG.debug("No embedding layers with LoRA found to patch") + + class FakeMLP(nn.Module): """ placeholder MLP for triton patching diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 67dea4958..2f269b78e 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -703,6 +703,12 @@ class AxolotlInputConfig( "description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html" }, ) + lora_embedding_kernel: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html" + }, + ) chunked_cross_entropy: bool | None = Field( default=None, @@ -1313,6 +1319,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel") + or data.get("lora_embedding_kernel") ): capabilities = data.get("capabilities") is_fsdp = data.get("fsdp_config") is not None @@ -1360,7 +1367,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): if data.get("adapter") in ["lora", "qlora"]: # Skip if already set, using unsloth optimizations, or using 8-bit unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] - kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] + kernel_fields = [ + "lora_mlp_kernel", + "lora_qkv_kernel", + "lora_o_kernel", + "lora_embedding_kernel", + ] if ( any(data.get(k) is not None for k in kernel_fields) or any(data.get(k) for k in unsloth_fields) @@ -1373,10 +1385,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): if data.get("trust_remote_code"): return data - # Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks - if data.get("lora_dropout") != 0: - return data - # Check multi-GPU compatibility capabilities = data.get("capabilities") is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1 @@ -1398,6 +1406,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): if data.get("lora_o_kernel") is None: data["lora_o_kernel"] = True + if data.get("lora_embedding_kernel") is None: + data["lora_embedding_kernel"] = True + LOG.warning( "Auto-enabling LoRA kernel optimizations for faster training. " + "Please explicitly set `lora_*_kernel` config values to `false` to disable. " diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 8ff61b370..c902d8703 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -681,15 +681,7 @@ class LoRAValidationMixin: @model_validator(mode="before") @classmethod def check_lora_kernels_dora(cls, data): - if ( - data.get("lora_mlp_kernel") - or data.get("lora_qkv_kernel") - or data.get("lora_o_kernel") - ) and data.get("peft_use_dora"): - raise ValueError( - "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not " - "compatible with DoRA at the moment." - ) + # DoRA is now supported by lora kernels return data @model_validator(mode="before") diff --git a/tests/core/test_async_grpo.py b/tests/core/test_async_grpo.py index eb83be1b6..14c38df29 100644 --- a/tests/core/test_async_grpo.py +++ b/tests/core/test_async_grpo.py @@ -153,7 +153,7 @@ class TestLoraFP8Guard(unittest.TestCase): proj.base_layer = base_layer - W, b, quant_state, A, B, s = get_lora_parameters(proj) + W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj) # quant_state should be None since weight is bf16, not FP8 self.assertIsNone(quant_state) @@ -174,7 +174,7 @@ class TestLoraFP8Guard(unittest.TestCase): scale_inv = torch.ones(1) base_layer.weight_scale_inv = scale_inv - W, b, quant_state, A, B, s = get_lora_parameters(proj) + W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj) self.assertIs(quant_state, scale_inv) diff --git a/tests/e2e/kernels/test_lora.py b/tests/e2e/kernels/test_lora.py index 9baceb668..568524557 100644 --- a/tests/e2e/kernels/test_lora.py +++ b/tests/e2e/kernels/test_lora.py @@ -102,7 +102,7 @@ def mock_proj(): def test_get_lora_parameters(mock_proj): """Tests get_lora_parameters function""" # Test with LoRA enabled - W, b, _, A, B, s = get_lora_parameters(mock_proj) + W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj) assert isinstance(W, torch.Tensor) assert W.shape == (128, 64) @@ -113,13 +113,13 @@ def test_get_lora_parameters(mock_proj): # Test with LoRA disabled mock_proj.disable_adapters = True - W, b, _, A, B, s = get_lora_parameters(mock_proj) + W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj) assert A is None and B is None and s is None # Test with merged state mock_proj.disable_adapters = False mock_proj.merged = True - W, b, _, A, B, s = get_lora_parameters(mock_proj) + W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj) assert A is None and B is None and s is None diff --git a/tests/e2e/kernels/test_lora_features.py b/tests/e2e/kernels/test_lora_features.py new file mode 100644 index 000000000..80495c68d --- /dev/null +++ b/tests/e2e/kernels/test_lora_features.py @@ -0,0 +1,1245 @@ +""" +Tests for LoRA kernel correctness with bias, dropout, and DoRA support. + +Compares fused kernel outputs and gradients against PEFT's reference implementation. +""" + +import pytest +import torch +from peft import LoraConfig, get_peft_model +from torch import nn +from transformers import AutoConfig, AutoModelForCausalLM + +from axolotl.kernels.lora import ( + _compute_dora_scale, + apply_lora_mlp_swiglu, + apply_lora_o, + apply_lora_qkv, + get_lora_parameters, + matmul_lora, +) +from axolotl.monkeypatch.lora_kernels import ( + apply_lora_kernel_patches, + patch_self_attn_lora, +) +from axolotl.utils.dict import DictDefault + +MODEL_NAME = "Qwen/Qwen3-0.6B" +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +@pytest.fixture(scope="module") +def model_config(): + return AutoConfig.from_pretrained(MODEL_NAME) + + +def _make_peft_model( + lora_dropout=0.0, + bias="none", + use_dora=False, + target_modules=None, +): + """Create a PEFT model with given config.""" + if target_modules is None: + target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=DTYPE, + attn_implementation="eager", + ).to(DEVICE) + lora_config = LoraConfig( + r=8, + lora_alpha=16, + lora_dropout=lora_dropout, + bias=bias, + use_dora=use_dora, + target_modules=target_modules, + ) + peft_model = get_peft_model(model, lora_config) + return peft_model + + +def _get_layer(peft_model, layer_idx=0): + """Get a specific transformer layer from the model.""" + return peft_model.model.model.layers[layer_idx] + + +def _make_input(batch=2, seq_len=16, hidden_size=1024): + """Create random input tensor.""" + return torch.randn( + batch, seq_len, hidden_size, dtype=DTYPE, device=DEVICE, requires_grad=True + ) + + +def _compare_tensors(a, b, name="", atol=1e-2, rtol=1e-2): + """Compare two tensors with informative error messages.""" + if a is None and b is None: + return + assert a is not None and b is not None, f"{name}: one is None, other is not" + assert a.shape == b.shape, f"{name}: shape mismatch {a.shape} vs {b.shape}" + diff = (a - b).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + assert torch.allclose(a, b, atol=atol, rtol=rtol), ( + f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}" + ) + + +class TestGetLoraParameters: + """Test the extended get_lora_parameters function.""" + + def test_returns_9_values(self): + model = _make_peft_model() + layer = _get_layer(model) + params = get_lora_parameters(layer.self_attn.q_proj) + assert len(params) == 9 + W, b, quant, A, B, s, lora_bias, dropout, magnitude = params + assert W is not None + assert A is not None + assert B is not None + assert s is not None + assert lora_bias is None # bias="none" + assert dropout is not None # should be nn.Identity + assert magnitude is None # no DoRA + del model + + def test_with_bias(self): + """Qwen3 has no base bias, so PEFT doesn't add lora_bias even with bias='lora_only'. + This test verifies get_lora_parameters handles this correctly.""" + model = _make_peft_model(bias="lora_only") + layer = _get_layer(model) + params = get_lora_parameters(layer.self_attn.q_proj) + _, _, _, _, _, _, lora_bias, _, _ = params + # Qwen3 q_proj has no base bias, so PEFT sets lora_bias=False + assert lora_bias is None + del model + + def test_with_bias_on_biased_layer(self): + """Test with manually added bias to verify lora_bias extraction.""" + model = _make_peft_model(bias="lora_only") + layer = _get_layer(model) + q_proj = layer.self_attn.q_proj + adapter = q_proj.active_adapters[0] + # Manually add bias to lora_B to test extraction + old_B = q_proj.lora_B[adapter] + q_proj.lora_B[adapter] = torch.nn.Linear( + old_B.in_features, old_B.out_features, bias=True, device=DEVICE, dtype=DTYPE + ) + params = get_lora_parameters(q_proj) + _, _, _, _, _, _, lora_bias, _, _ = params + assert lora_bias is not None + assert lora_bias.shape[0] == old_B.out_features + del model + + def test_with_dropout(self): + model = _make_peft_model(lora_dropout=0.1) + layer = _get_layer(model) + params = get_lora_parameters(layer.self_attn.q_proj) + _, _, _, _, _, _, _, dropout, _ = params + assert dropout is not None + assert isinstance(dropout, nn.Dropout) + del model + + def test_with_dora(self): + model = _make_peft_model(use_dora=True) + layer = _get_layer(model) + params = get_lora_parameters(layer.self_attn.q_proj) + _, _, _, _, _, _, _, _, magnitude = params + assert magnitude is not None + del model + + +class TestMatmulLora: + """Test matmul_lora with new lora_bias and X_drop parameters.""" + + def test_basic(self): + X = torch.randn(4, 8, dtype=DTYPE, device=DEVICE) + W = torch.randn(16, 8, dtype=DTYPE, device=DEVICE) + A = torch.randn(4, 8, dtype=DTYPE, device=DEVICE) # [rank, in] + B = torch.randn(16, 4, dtype=DTYPE, device=DEVICE) # [out, rank] + s = 2.0 + + result = matmul_lora(X, W, None, None, A, B, s) + expected = X @ W.t() + s * X @ A.t() @ B.t() + _compare_tensors(result, expected, "basic matmul_lora") + + def test_with_lora_bias(self): + X = torch.randn(4, 8, dtype=DTYPE, device=DEVICE) + W = torch.randn(16, 8, dtype=DTYPE, device=DEVICE) + A = torch.randn(4, 8, dtype=DTYPE, device=DEVICE) + B = torch.randn(16, 4, dtype=DTYPE, device=DEVICE) + lora_bias = torch.randn(16, dtype=DTYPE, device=DEVICE) + s = 2.0 + + result = matmul_lora(X, W, None, None, A, B, s, lora_bias=lora_bias) + expected = X @ W.t() + s * X @ A.t() @ B.t() + s * lora_bias + _compare_tensors(result, expected, "matmul_lora with lora_bias") + + def test_with_x_drop(self): + X = torch.randn(4, 8, dtype=DTYPE, device=DEVICE) + X_drop = X * 0.5 # simulated dropout + W = torch.randn(16, 8, dtype=DTYPE, device=DEVICE) + A = torch.randn(4, 8, dtype=DTYPE, device=DEVICE) + B = torch.randn(16, 4, dtype=DTYPE, device=DEVICE) + s = 2.0 + + result = matmul_lora(X, W, None, None, A, B, s, X_drop=X_drop) + expected = X @ W.t() + s * X_drop @ A.t() @ B.t() + _compare_tensors(result, expected, "matmul_lora with X_drop") + + +class TestDoraScale: + """Test DoRA magnitude/norm scaling computation.""" + + def test_basic(self): + W = torch.randn(16, 8, dtype=DTYPE, device=DEVICE) + A = torch.randn(4, 8, dtype=DTYPE, device=DEVICE) + B = torch.randn(16, 4, dtype=DTYPE, device=DEVICE) + magnitude = torch.randn(16, dtype=DTYPE, device=DEVICE).abs() + 0.1 + s = 2.0 + + scale = _compute_dora_scale(W, None, A, B, s, magnitude, DTYPE) + + # Manual computation + combined = W + s * B @ A + weight_norm = torch.linalg.norm(combined, dim=1) + expected = magnitude / weight_norm + + _compare_tensors(scale, expected, "dora_scale") + + +# ============================================================ +# Integration tests: compare kernel outputs against PEFT reference +# ============================================================ + + +def _run_peft_qkv(layer, X): + """Run Q, K, V projections through PEFT's standard forward.""" + Q = layer.self_attn.q_proj(X) + K = layer.self_attn.k_proj(X) + V = layer.self_attn.v_proj(X) + return Q, K, V + + +def _run_kernel_qkv(layer, X): + """Run Q, K, V projections through our fused kernel.""" + return apply_lora_qkv(layer.self_attn, X, inplace=False) + + +def _run_peft_o(layer, X): + """Run O projection through PEFT's standard forward.""" + return layer.self_attn.o_proj(X) + + +def _run_kernel_o(layer, X): + """Run O projection through our fused kernel.""" + return apply_lora_o(layer.self_attn, X) + + +def _run_peft_mlp(layer, X): + """Run MLP through PEFT's standard forward.""" + return layer.mlp(X) + + +def _run_kernel_mlp(layer, X): + """Run MLP through our fused kernel.""" + return apply_lora_mlp_swiglu(layer.mlp, X, inplace=False) + + +class TestQKVKernel: + """Test LoRA_QKV kernel against PEFT reference.""" + + @pytest.mark.parametrize("bias", ["none", "lora_only"]) + def test_forward_bias(self, bias): + model = _make_peft_model(bias=bias) + model.eval() + layer = _get_layer(model) + X = _make_input(hidden_size=model.config.hidden_size) + + with torch.no_grad(): + peft_Q, peft_K, peft_V = _run_peft_qkv(layer, X) + kern_Q, kern_K, kern_V = _run_kernel_qkv(layer, X) + + _compare_tensors(kern_Q, peft_Q, f"QKV Q (bias={bias})") + _compare_tensors(kern_K, peft_K, f"QKV K (bias={bias})") + _compare_tensors(kern_V, peft_V, f"QKV V (bias={bias})") + del model + + def test_forward_dropout_eval(self): + """Dropout disabled in eval - should match exactly.""" + model = _make_peft_model(lora_dropout=0.1) + model.eval() + layer = _get_layer(model) + X = _make_input(hidden_size=model.config.hidden_size) + + with torch.no_grad(): + peft_Q, peft_K, peft_V = _run_peft_qkv(layer, X) + kern_Q, kern_K, kern_V = _run_kernel_qkv(layer, X) + + _compare_tensors(kern_Q, peft_Q, "QKV Q (dropout eval)") + _compare_tensors(kern_K, peft_K, "QKV K (dropout eval)") + _compare_tensors(kern_V, peft_V, "QKV V (dropout eval)") + del model + + def test_forward_dora(self): + model = _make_peft_model(use_dora=True) + model.eval() + layer = _get_layer(model) + X = _make_input(hidden_size=model.config.hidden_size) + + with torch.no_grad(): + peft_Q, peft_K, peft_V = _run_peft_qkv(layer, X) + kern_Q, kern_K, kern_V = _run_kernel_qkv(layer, X) + + _compare_tensors(kern_Q, peft_Q, "QKV Q (DoRA)") + _compare_tensors(kern_K, peft_K, "QKV K (DoRA)") + _compare_tensors(kern_V, peft_V, "QKV V (DoRA)") + del model + + def test_forward_dora_bias(self): + model = _make_peft_model(use_dora=True, bias="lora_only") + model.eval() + layer = _get_layer(model) + X = _make_input(hidden_size=model.config.hidden_size) + + with torch.no_grad(): + peft_Q, peft_K, peft_V = _run_peft_qkv(layer, X) + kern_Q, kern_K, kern_V = _run_kernel_qkv(layer, X) + + _compare_tensors(kern_Q, peft_Q, "QKV Q (DoRA+bias)") + _compare_tensors(kern_K, peft_K, "QKV K (DoRA+bias)") + _compare_tensors(kern_V, peft_V, "QKV V (DoRA+bias)") + del model + + @pytest.mark.parametrize("bias", ["none", "lora_only"]) + def test_backward_bias(self, bias): + """Test that gradients match between kernel and PEFT.""" + model = _make_peft_model(bias=bias) + model.train() + layer = _get_layer(model) + + # PEFT reference + X1 = _make_input(hidden_size=model.config.hidden_size) + pQ, pK, pV = _run_peft_qkv(layer, X1) + loss_peft = pQ.sum() + pK.sum() + pV.sum() + loss_peft.backward() + + peft_grads = {} + for name, param in layer.self_attn.named_parameters(): + if param.grad is not None: + peft_grads[name] = param.grad.clone() + layer.self_attn.zero_grad() + + # Kernel + X2 = X1.detach().clone().requires_grad_(True) + kQ, kK, kV = _run_kernel_qkv(layer, X2) + loss_kern = kQ.sum() + kK.sum() + kV.sum() + loss_kern.backward() + + kern_grads = {} + for name, param in layer.self_attn.named_parameters(): + if param.grad is not None: + kern_grads[name] = param.grad.clone() + layer.self_attn.zero_grad() + + # Compare LoRA parameter gradients + for name in peft_grads: + if "lora_" in name: + _compare_tensors( + kern_grads.get(name), + peft_grads[name], + f"grad {name} (bias={bias})", + atol=5e-2, + rtol=5e-2, + ) + del model + + def test_backward_dora(self): + """Test DoRA backward pass gradients.""" + model = _make_peft_model(use_dora=True) + model.train() + layer = _get_layer(model) + + X1 = _make_input(hidden_size=model.config.hidden_size) + pQ, pK, pV = _run_peft_qkv(layer, X1) + loss_peft = pQ.sum() + pK.sum() + pV.sum() + loss_peft.backward() + + peft_grads = {} + for name, param in layer.self_attn.named_parameters(): + if param.grad is not None: + peft_grads[name] = param.grad.clone() + layer.self_attn.zero_grad() + + X2 = X1.detach().clone().requires_grad_(True) + kQ, kK, kV = _run_kernel_qkv(layer, X2) + loss_kern = kQ.sum() + kK.sum() + kV.sum() + loss_kern.backward() + + kern_grads = {} + for name, param in layer.self_attn.named_parameters(): + if param.grad is not None: + kern_grads[name] = param.grad.clone() + layer.self_attn.zero_grad() + + for name in peft_grads: + if "lora_" in name or "magnitude" in name: + _compare_tensors( + kern_grads.get(name), + peft_grads[name], + f"grad {name} (DoRA)", + atol=5e-2, + rtol=5e-2, + ) + del model + + +class TestOKernel: + """Test LoRA_O kernel against PEFT reference.""" + + @staticmethod + def _o_input_dim(model): + """o_proj input is num_heads * head_dim (may differ from hidden_size with GQA).""" + cfg = model.config + text_cfg = cfg.get_text_config() if hasattr(cfg, "get_text_config") else cfg + return text_cfg.num_attention_heads * text_cfg.head_dim + + @pytest.mark.parametrize("bias", ["none", "lora_only"]) + def test_forward_bias(self, bias): + model = _make_peft_model(bias=bias) + model.eval() + layer = _get_layer(model) + X = _make_input(hidden_size=self._o_input_dim(model)) + + with torch.no_grad(): + peft_out = _run_peft_o(layer, X) + kern_out = _run_kernel_o(layer, X) + + _compare_tensors(kern_out, peft_out, f"O (bias={bias})") + del model + + def test_forward_dora(self): + model = _make_peft_model(use_dora=True) + model.eval() + layer = _get_layer(model) + X = _make_input(hidden_size=self._o_input_dim(model)) + + with torch.no_grad(): + peft_out = _run_peft_o(layer, X) + kern_out = _run_kernel_o(layer, X) + + _compare_tensors(kern_out, peft_out, "O (DoRA)") + del model + + @pytest.mark.parametrize("bias", ["none", "lora_only"]) + def test_backward_bias(self, bias): + model = _make_peft_model(bias=bias) + model.train() + layer = _get_layer(model) + + X1 = _make_input(hidden_size=self._o_input_dim(model)) + peft_out = _run_peft_o(layer, X1) + peft_out.sum().backward() + peft_grads = { + n: p.grad.clone() + for n, p in layer.self_attn.o_proj.named_parameters() + if p.grad is not None + } + layer.self_attn.o_proj.zero_grad() + + X2 = X1.detach().clone().requires_grad_(True) + kern_out = _run_kernel_o(layer, X2) + kern_out.sum().backward() + kern_grads = { + n: p.grad.clone() + for n, p in layer.self_attn.o_proj.named_parameters() + if p.grad is not None + } + layer.self_attn.o_proj.zero_grad() + + for name in peft_grads: + if "lora_" in name: + _compare_tensors( + kern_grads.get(name), + peft_grads[name], + f"O grad {name} (bias={bias})", + atol=5e-2, + rtol=5e-2, + ) + del model + + +class TestMLPKernel: + """Test LoRA_MLP kernel against PEFT reference.""" + + @pytest.mark.parametrize("bias", ["none", "lora_only"]) + def test_forward_bias(self, bias): + model = _make_peft_model(bias=bias) + model.eval() + layer = _get_layer(model) + X = _make_input(hidden_size=model.config.hidden_size) + + with torch.no_grad(): + peft_out = _run_peft_mlp(layer, X) + kern_out = _run_kernel_mlp(layer, X) + + _compare_tensors(kern_out, peft_out, f"MLP (bias={bias})") + del model + + def test_forward_dropout_eval(self): + model = _make_peft_model(lora_dropout=0.1) + model.eval() + layer = _get_layer(model) + X = _make_input(hidden_size=model.config.hidden_size) + + with torch.no_grad(): + peft_out = _run_peft_mlp(layer, X) + kern_out = _run_kernel_mlp(layer, X) + + _compare_tensors(kern_out, peft_out, "MLP (dropout eval)") + del model + + def test_forward_dora(self): + model = _make_peft_model(use_dora=True) + model.eval() + layer = _get_layer(model) + X = _make_input(hidden_size=model.config.hidden_size) + + with torch.no_grad(): + peft_out = _run_peft_mlp(layer, X) + kern_out = _run_kernel_mlp(layer, X) + + # Relaxed tolerance for MLP DoRA: 3 projections + activation + DoRA + # causes bf16 accumulation differences + _compare_tensors(kern_out, peft_out, "MLP (DoRA)", atol=0.3, rtol=0.05) + del model + + def test_forward_dora_bias(self): + model = _make_peft_model(use_dora=True, bias="lora_only") + model.eval() + layer = _get_layer(model) + X = _make_input(hidden_size=model.config.hidden_size) + + with torch.no_grad(): + peft_out = _run_peft_mlp(layer, X) + kern_out = _run_kernel_mlp(layer, X) + + _compare_tensors(kern_out, peft_out, "MLP (DoRA+bias)", atol=0.3, rtol=0.05) + del model + + @pytest.mark.parametrize("bias", ["none", "lora_only"]) + def test_backward_bias(self, bias): + model = _make_peft_model(bias=bias) + model.train() + layer = _get_layer(model) + hidden_size = model.config.hidden_size + + X1 = _make_input(hidden_size=hidden_size) + peft_out = _run_peft_mlp(layer, X1) + peft_out.sum().backward() + peft_grads = { + n: p.grad.clone() + for n, p in layer.mlp.named_parameters() + if p.grad is not None + } + layer.mlp.zero_grad() + + X2 = X1.detach().clone().requires_grad_(True) + kern_out = _run_kernel_mlp(layer, X2) + kern_out.sum().backward() + kern_grads = { + n: p.grad.clone() + for n, p in layer.mlp.named_parameters() + if p.grad is not None + } + layer.mlp.zero_grad() + + # MLP backward has longer chain (3 projections + activation) = more bf16 accumulation error + for name in peft_grads: + if "lora_" in name: + _compare_tensors( + kern_grads.get(name), + peft_grads[name], + f"MLP grad {name} (bias={bias})", + atol=0.5, + rtol=0.1, + ) + del model + + def test_backward_dora(self): + model = _make_peft_model(use_dora=True) + model.train() + layer = _get_layer(model) + + X1 = _make_input(hidden_size=model.config.hidden_size) + peft_out = _run_peft_mlp(layer, X1) + peft_out.sum().backward() + peft_grads = { + n: p.grad.clone() + for n, p in layer.mlp.named_parameters() + if p.grad is not None + } + layer.mlp.zero_grad() + + X2 = X1.detach().clone().requires_grad_(True) + kern_out = _run_kernel_mlp(layer, X2) + kern_out.sum().backward() + kern_grads = { + n: p.grad.clone() + for n, p in layer.mlp.named_parameters() + if p.grad is not None + } + layer.mlp.zero_grad() + + for name in peft_grads: + if "lora_" in name or "magnitude" in name: + _compare_tensors( + kern_grads.get(name), + peft_grads[name], + f"MLP grad {name} (DoRA)", + atol=0.5, + rtol=0.1, + ) + del model + + +class TestFullModelPatch: + """Test applying kernel patches to a full model.""" + + def test_patched_forward_basic(self): + """Test that patched model forward matches unpatched PEFT model (bias=none, no DoRA).""" + from peft import PeftModelForCausalLM + + base_model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=DTYPE, + attn_implementation="eager", + ).to(DEVICE) + lora_config = LoraConfig( + r=8, + lora_alpha=16, + bias="none", + use_dora=False, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + ) + model = PeftModelForCausalLM(base_model, lora_config) + model.eval() + + # Get PEFT reference output + input_ids = torch.randint(0, 1000, (1, 32), device=DEVICE) + with torch.no_grad(): + peft_out = model(input_ids).logits + + # Apply kernel patches + cfg = DictDefault( + { + "base_model": MODEL_NAME, + "lora_qkv_kernel": True, + "lora_o_kernel": True, + "lora_mlp_kernel": True, + } + ) + patch_self_attn_lora(cfg) + apply_lora_kernel_patches(model, cfg) + + # Get kernel output + with torch.no_grad(): + kern_out = model(input_ids).logits + + _compare_tensors(kern_out, peft_out, "Full model (basic)", atol=5e-1, rtol=1e-1) + del model + + +class TestEmbeddingKernel: + """Test LoRA embedding kernel against PEFT reference.""" + + def _make_embedding_model(self, use_dora=False): + from peft import PeftModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=DTYPE, + attn_implementation="eager", + ).to(DEVICE) + lora_config = LoraConfig( + r=8, + lora_alpha=16, + use_dora=use_dora, + target_modules=["embed_tokens"], + ) + return PeftModelForCausalLM(model, lora_config) + + def test_forward_basic(self): + from axolotl.kernels.lora import apply_lora_embedding + + model = self._make_embedding_model() + model.eval() + + embed = model.model.model.embed_tokens + input_ids = torch.randint(0, 1000, (2, 16), device=DEVICE) + + with torch.no_grad(): + peft_out = embed(input_ids) + kern_out = apply_lora_embedding(embed, input_ids) + + # Cast to same dtype for comparison (PEFT may return float32) + _compare_tensors(kern_out.to(peft_out.dtype), peft_out, "Embedding basic") + del model + + def test_forward_dora(self): + from axolotl.kernels.lora import apply_lora_embedding + + model = self._make_embedding_model(use_dora=True) + model.eval() + + embed = model.model.model.embed_tokens + input_ids = torch.randint(0, 1000, (2, 16), device=DEVICE) + + with torch.no_grad(): + peft_out = embed(input_ids) + kern_out = apply_lora_embedding(embed, input_ids) + + _compare_tensors( + kern_out.to(peft_out.dtype), peft_out, "Embedding DoRA", atol=0.3, rtol=0.05 + ) + del model + + def test_backward(self): + from axolotl.kernels.lora import apply_lora_embedding + + model = self._make_embedding_model() + model.train() + + embed = model.model.model.embed_tokens + input_ids = torch.randint(0, 1000, (2, 16), device=DEVICE) + + # PEFT reference + peft_out = embed(input_ids) + peft_out.sum().backward() + peft_grads = {} + for n, p in embed.named_parameters(): + if p.grad is not None and "lora" in n: + peft_grads[n] = p.grad.clone() + embed.zero_grad() + + # Kernel + kern_out = apply_lora_embedding(embed, input_ids) + kern_out.sum().backward() + kern_grads = {} + for n, p in embed.named_parameters(): + if p.grad is not None and "lora" in n: + kern_grads[n] = p.grad.clone() + embed.zero_grad() + + for name in peft_grads: + _compare_tensors( + kern_grads.get(name), + peft_grads[name], + f"Embedding grad {name}", + atol=5e-2, + rtol=5e-2, + ) + del model + + +class TestTiedEmbeddings: + """Test that tied embeddings work correctly with kernel patching.""" + + def test_tied_embed_and_lm_head(self): + """When both embed_tokens and lm_head have LoRA, PEFT unties them. + Verify patched model produces valid output (no crashes, finite values).""" + from peft import PeftModelForCausalLM + + base = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=DTYPE, + attn_implementation="eager", + ).to(DEVICE) + lora_config = LoraConfig( + r=8, + lora_alpha=16, + target_modules=[ + "embed_tokens", + "lm_head", + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + ) + model = PeftModelForCausalLM(base, lora_config) + model.eval() + + cfg = DictDefault( + { + "base_model": MODEL_NAME, + "lora_qkv_kernel": True, + "lora_o_kernel": True, + "lora_mlp_kernel": True, + "lora_embedding_kernel": True, + } + ) + + # Apply all kernel patches (class + instance level) + patch_self_attn_lora(cfg) + apply_lora_kernel_patches(model, cfg) + + input_ids = torch.randint(0, 1000, (1, 32), device=DEVICE) + with torch.no_grad(): + out = model(input_ids).logits + + # Verify output is valid + assert out.shape == (1, 32, model.config.vocab_size) + assert torch.isfinite(out).all(), "Output contains non-finite values" + assert out.abs().max() > 0, "Output is all zeros" + + # Verify backward works + model.train() + out = model(input_ids).logits + out.sum().backward() + # Check that LoRA params got gradients + embed = model.model.model.embed_tokens + has_embed_grad = any( + p.grad is not None and p.grad.abs().sum() > 0 + for n, p in embed.named_parameters() + if "lora" in n + ) + assert has_embed_grad, "Embedding LoRA params got no gradients" + del model + + +class TestQuantizedModels: + """Test kernels with quantized base weights.""" + + def test_nf4_qlora_forward_backward(self): + """NF4 QLoRA with kernel patches.""" + from peft import PeftModelForCausalLM + from transformers import BitsAndBytesConfig + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=DTYPE, + bnb_4bit_use_double_quant=True, + ) + base = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + quantization_config=bnb_config, + attn_implementation="eager", + ) + lora_config = LoraConfig( + r=8, + lora_alpha=16, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + ) + model = PeftModelForCausalLM(base, lora_config) + + cfg = DictDefault( + { + "base_model": MODEL_NAME, + "lora_qkv_kernel": True, + "lora_o_kernel": True, + "lora_mlp_kernel": True, + } + ) + patch_self_attn_lora(cfg) + apply_lora_kernel_patches(model, cfg) + model.train() + + ids = torch.randint(0, 1000, (1, 32), device=DEVICE) + out = model(ids).logits + assert torch.isfinite(out).all() + out.sum().backward() + has_grads = sum( + 1 for n, p in model.named_parameters() if p.grad is not None and "lora" in n + ) + assert has_grads > 0, "No LoRA gradients" + del model + + def test_nf4_single_quant(self): + """NF4 without double quantization.""" + from peft import PeftModelForCausalLM + from transformers import BitsAndBytesConfig + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=DTYPE, + ) + base = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + quantization_config=bnb_config, + attn_implementation="eager", + ) + lora_config = LoraConfig( + r=8, + lora_alpha=16, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + ) + model = PeftModelForCausalLM(base, lora_config) + + cfg = DictDefault( + { + "base_model": MODEL_NAME, + "lora_qkv_kernel": True, + "lora_o_kernel": True, + "lora_mlp_kernel": True, + } + ) + patch_self_attn_lora(cfg) + apply_lora_kernel_patches(model, cfg) + model.train() + + ids = torch.randint(0, 1000, (1, 32), device=DEVICE) + out = model(ids).logits + assert torch.isfinite(out).all() + out.sum().backward() + has_grads = sum( + 1 for n, p in model.named_parameters() if p.grad is not None and "lora" in n + ) + assert has_grads > 0 + del model + + +class TestTritonDoRA: + """Test Triton DoRA kernel against reference implementation.""" + + def test_triton_dora_scale(self): + from axolotl.kernels.dora import triton_dora_scale + from axolotl.kernels.lora import _compute_dora_scale + + # Random weights matching Qwen3-1.7B dimensions + out_feat, in_feat, rank = 1024, 1024, 8 + W = torch.randn(out_feat, in_feat, dtype=DTYPE, device=DEVICE) + A = torch.randn(rank, in_feat, dtype=DTYPE, device=DEVICE) + B = torch.randn(out_feat, rank, dtype=DTYPE, device=DEVICE) + magnitude = torch.randn(out_feat, dtype=DTYPE, device=DEVICE).abs() + 0.1 + s = 2.0 + + # Clear cache to force recomputation + if hasattr(magnitude, "_dora_cache"): + del magnitude._dora_cache + + ref = _compute_dora_scale(W, None, A, B, s, magnitude, DTYPE) + tri = triton_dora_scale(W, None, A, B, s, magnitude, DTYPE) + + _compare_tensors(tri, ref, "Triton DoRA scale", atol=1e-2, rtol=1e-2) + + def test_triton_dora_scale_small(self): + """Test with K/V projection dimensions (smaller out_features).""" + from axolotl.kernels.dora import triton_dora_scale + from axolotl.kernels.lora import _compute_dora_scale + + out_feat, in_feat, rank = 128, 1024, 8 + W = torch.randn(out_feat, in_feat, dtype=DTYPE, device=DEVICE) + A = torch.randn(rank, in_feat, dtype=DTYPE, device=DEVICE) + B = torch.randn(out_feat, rank, dtype=DTYPE, device=DEVICE) + magnitude = torch.randn(out_feat, dtype=DTYPE, device=DEVICE).abs() + 0.1 + s = 2.0 + + if hasattr(magnitude, "_dora_cache"): + del magnitude._dora_cache + + ref = _compute_dora_scale(W, None, A, B, s, magnitude, DTYPE) + tri = triton_dora_scale(W, None, A, B, s, magnitude, DTYPE) + + _compare_tensors(tri, ref, "Triton DoRA scale (small)", atol=1e-2, rtol=1e-2) + + +# ============================================================ +# Regression tests for review fixes +# ============================================================ + + +class TestDoRAEmbeddingNoDoubleScale: + """Regression: DoRA embedding forward must save the pre-scaled combined + tensor, not the already-scaled result, so backward computes d_mag correctly.""" + + def test_dora_magnitude_gradient_magnitude(self): + """d_mag should be O(1) relative to the gradient, not O(mag_scale^2).""" + from peft import PeftModelForCausalLM + + from axolotl.kernels.lora import apply_lora_embedding + + base = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=DTYPE, + attn_implementation="eager", + ).to(DEVICE) + lora_config = LoraConfig( + r=8, + lora_alpha=16, + use_dora=True, + target_modules=["embed_tokens"], + ) + model = PeftModelForCausalLM(base, lora_config) + model.train() + + embed = model.model.model.embed_tokens + ids = torch.randint(0, 1000, (2, 16), device=DEVICE) + + # Run PEFT reference to get reference d_mag + peft_out = embed(ids) + peft_out.sum().backward() + peft_mag_grad = None + for n, p in embed.named_parameters(): + if "magnitude" in n and p.grad is not None: + peft_mag_grad = p.grad.clone() + embed.zero_grad() + + # Run kernel + kern_out = apply_lora_embedding(embed, ids) + kern_out.to(peft_out.dtype).sum().backward() + kern_mag_grad = None + for n, p in embed.named_parameters(): + if "magnitude" in n and p.grad is not None: + kern_mag_grad = p.grad.clone() + embed.zero_grad() + + assert peft_mag_grad is not None, "PEFT should produce magnitude gradients" + assert kern_mag_grad is not None, "Kernel should produce magnitude gradients" + + # Key check: gradients should be same order of magnitude + # Double-scaling would make kern_mag_grad ~mag_scale times too large + ratio = kern_mag_grad.abs().mean() / peft_mag_grad.abs().mean() + assert 0.5 < ratio < 2.0, ( + f"Magnitude gradient ratio kernel/peft = {ratio:.3f}, " + f"expected ~1.0 (double-scaling would give >> 1)" + ) + del model + + +class TestDoraCacheInvalidation: + """Regression: DoRA weight norm cache must invalidate after in-place + param updates (optimizer steps), not just pointer changes.""" + + def test_cache_invalidates_on_inplace_update(self): + W = torch.randn(64, 64, dtype=DTYPE, device=DEVICE) + A = torch.randn(8, 64, dtype=DTYPE, device=DEVICE) + B = torch.randn(64, 8, dtype=DTYPE, device=DEVICE) + magnitude = torch.randn(64, dtype=DTYPE, device=DEVICE).abs() + 0.1 + s = 2.0 + + # Clear any existing cache + if hasattr(magnitude, "_dora_cache"): + del magnitude._dora_cache + + # First call populates cache + result1 = _compute_dora_scale(W, None, A, B, s, magnitude, DTYPE) + + # Simulate optimizer in-place update (pointer stays same, content changes) + old_ptr = A.data_ptr() + A.data.add_(torch.randn_like(A) * 0.1) + assert A.data_ptr() == old_ptr, "Pointer should not change for in-place ops" + + # Second call must detect the change and recompute + result2 = _compute_dora_scale(W, None, A, B, s, magnitude, DTYPE) + + # Results should differ since A changed + assert not torch.allclose(result1, result2, atol=1e-4), ( + "DoRA scale should change after in-place param update — cache not invalidated!" + ) + + +class TestEmbeddingPaddingIdxGrad: + """Regression: custom embedding backward must zero out gradients at + padding_idx positions, matching F.embedding behavior.""" + + def test_padding_idx_gradient_is_zero(self): + from axolotl.kernels.lora import LoRA_Embedding + + vocab, hidden, rank = 100, 32, 4 + W = torch.randn( + vocab, hidden, dtype=torch.float32, device=DEVICE, requires_grad=False + ) + A = torch.randn( + rank, vocab, dtype=torch.float32, device=DEVICE, requires_grad=True + ) + B = torch.randn( + hidden, rank, dtype=torch.float32, device=DEVICE, requires_grad=True + ) + s = 2.0 + padding_idx = 0 + + # Input containing the padding token + x = torch.tensor([[padding_idx, 1, 2, padding_idx, 3]], device=DEVICE) + + out = LoRA_Embedding.apply( + x, + W, + A, + B, + s, + None, + padding_idx, + None, + 2.0, + False, + False, # max_norm, norm_type, scale_grad_by_freq, sparse + ) + out.sum().backward() + + # The gradient for A at the padding_idx column should be zero + # A is [rank, vocab], so A.grad[:, padding_idx] should be zero + assert A.grad is not None + pad_grad = A.grad[:, padding_idx] + assert torch.all(pad_grad == 0), ( + f"Gradient at padding_idx={padding_idx} should be zero, got {pad_grad}" + ) + + # Non-padding positions should have non-zero gradients + non_pad_grad = A.grad[:, 1] + assert non_pad_grad.abs().sum() > 0, "Non-padding gradients should be non-zero" + + +class TestEmbeddingScaleGradByFreq: + """Regression: custom embedding backward must scale gradients by + inverse frequency when scale_grad_by_freq=True.""" + + def test_repeated_tokens_get_scaled_gradients(self): + from axolotl.kernels.lora import LoRA_Embedding + + vocab, hidden, rank = 100, 32, 4 + W = torch.randn( + vocab, hidden, dtype=torch.float32, device=DEVICE, requires_grad=False + ) + + # Run WITHOUT scale_grad_by_freq + A1 = torch.randn( + rank, vocab, dtype=torch.float32, device=DEVICE, requires_grad=True + ) + B1 = torch.randn( + hidden, rank, dtype=torch.float32, device=DEVICE, requires_grad=True + ) + # Token 5 appears 3 times + x = torch.tensor([[5, 5, 5, 10, 20]], device=DEVICE) + + out1 = LoRA_Embedding.apply( + x, + W, + A1, + B1, + 2.0, + None, + None, + None, + 2.0, + False, + False, + ) + out1.sum().backward() + grad_no_scale = A1.grad[:, 5].clone() + + # Run WITH scale_grad_by_freq + A2 = A1.data.clone().requires_grad_(True) + B2 = B1.data.clone().requires_grad_(True) + out2 = LoRA_Embedding.apply( + x, + W, + A2, + B2, + 2.0, + None, + None, + None, + 2.0, + True, + False, + ) + out2.sum().backward() + grad_with_scale = A2.grad[:, 5].clone() + + # With scale_grad_by_freq, token 5 (count=3) should have grad / 3 + expected_ratio = 1.0 / 3.0 + actual_ratio = grad_with_scale.abs().mean() / grad_no_scale.abs().mean() + assert abs(actual_ratio - expected_ratio) < 0.01, ( + f"scale_grad_by_freq ratio for count=3 token: expected {expected_ratio:.3f}, " + f"got {actual_ratio:.3f}" + ) + + +class TestEmbeddingDropoutNotAppliedToBase: + """Regression: embedding dropout must NOT be applied to the base embedding + output — PEFT's Embedding.forward does not use lora_dropout.""" + + def test_kernel_matches_peft_with_dropout_config(self): + """Even with lora_dropout>0, embedding output should match PEFT exactly.""" + from peft import PeftModelForCausalLM + + from axolotl.kernels.lora import apply_lora_embedding + + base = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=DTYPE, + attn_implementation="eager", + ).to(DEVICE) + lora_config = LoraConfig( + r=8, + lora_alpha=16, + lora_dropout=0.5, # high dropout + target_modules=["embed_tokens"], + ) + model = PeftModelForCausalLM(base, lora_config) + model.train() # training mode — dropout would be active if applied + + embed = model.model.model.embed_tokens + ids = torch.randint(0, 1000, (2, 16), device=DEVICE) + + # Run both multiple times — if dropout were applied, results would vary + with torch.no_grad(): + peft_out = embed(ids) + kern1 = apply_lora_embedding(embed, ids) + kern2 = apply_lora_embedding(embed, ids) + + # Kernel should be deterministic (no dropout) + _compare_tensors( + kern1.to(peft_out.dtype), + kern2.to(peft_out.dtype), + "Embedding deterministic (no dropout)", + atol=0, + rtol=0, + ) + + # And should match PEFT + _compare_tensors( + kern1.to(peft_out.dtype), + peft_out, + "Embedding matches PEFT with dropout config", + ) + del model diff --git a/tests/e2e/multigpu/test_fsdp2_lora_kernels.py b/tests/e2e/multigpu/test_fsdp2_lora_kernels.py new file mode 100644 index 000000000..27ad2b8e9 --- /dev/null +++ b/tests/e2e/multigpu/test_fsdp2_lora_kernels.py @@ -0,0 +1,120 @@ +"""Test LoRA kernels under FSDP2 multi-GPU training. + +Verifies that lora_qkv_kernel, lora_o_kernel, lora_mlp_kernel, and +lora_embedding_kernel work correctly with FSDP2 sharding, including +with bias, dropout, and DoRA enabled. +""" + +from pathlib import Path + +import yaml +from accelerate.test_utils import execute_subprocess_async +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import require_torch_2_7_0 + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + + +def _run_training(temp_dir, cfg): + """Write config and launch multi-GPU training.""" + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + +def _base_lora_fsdp2_config(temp_dir, **overrides): + """Base config for LoRA + FSDP2 + kernel tests.""" + cfg = { + "base_model": "Qwen/Qwen3-0.6B", + "sequence_len": 512, + "val_set_size": 0.0, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:1%]", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_target_linear": True, + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 1e-4, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "bf16": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen3DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + # Enable all LoRA kernels + "lora_mlp_kernel": True, + "lora_qkv_kernel": True, + "lora_o_kernel": True, + "lora_embedding_kernel": True, + "save_safetensors": True, + } + cfg.update(overrides) + return DictDefault(cfg) + + +class TestFSDP2LoRAKernels: + """Test LoRA kernels under FSDP2.""" + + @require_torch_2_7_0 + def test_lora_kernels_basic(self, temp_dir): + """Basic LoRA + kernels + FSDP2: no dropout, no bias, no DoRA.""" + cfg = _base_lora_fsdp2_config(temp_dir) + _run_training(temp_dir, cfg) + assert (Path(temp_dir) / "adapter_model.safetensors").exists() + + @require_torch_2_7_0 + def test_lora_kernels_with_dropout(self, temp_dir): + """LoRA kernels + dropout + FSDP2.""" + cfg = _base_lora_fsdp2_config(temp_dir, lora_dropout=0.1) + _run_training(temp_dir, cfg) + assert (Path(temp_dir) / "adapter_model.safetensors").exists() + + @require_torch_2_7_0 + def test_lora_kernels_with_dora(self, temp_dir): + """LoRA kernels + DoRA + FSDP2.""" + cfg = _base_lora_fsdp2_config(temp_dir, peft_use_dora=True) + _run_training(temp_dir, cfg) + assert (Path(temp_dir) / "adapter_model.safetensors").exists() + + @require_torch_2_7_0 + def test_lora_kernels_with_dora_and_dropout(self, temp_dir): + """LoRA kernels + DoRA + dropout + FSDP2.""" + cfg = _base_lora_fsdp2_config( + temp_dir, + peft_use_dora=True, + lora_dropout=0.05, + ) + _run_training(temp_dir, cfg) + assert (Path(temp_dir) / "adapter_model.safetensors").exists() diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index 73f883858..2865a80f9 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -222,9 +222,9 @@ def test_model_specific_activation(model_name, expected_activation): def test_kernel_patch_conditions(): - """Test various conditions that should prevent kernel patching.""" + """Test that kernels ARE patched even with dropout and bias (now supported).""" test_configs = [ - # Dropout prevents patching + # Dropout — kernels now support this { "peft_type": "LORA", "task_type": "CAUSAL_LM", @@ -234,7 +234,7 @@ def test_kernel_patch_conditions(): "lora_dropout": 0.1, "bias": "none", }, - # Bias prevents patching + # Bias — kernels now support this { "peft_type": "LORA", "task_type": "CAUSAL_LM", @@ -252,13 +252,14 @@ def test_kernel_patch_conditions(): model = PeftModelForCausalLM(model, peft_config) cfg = DictDefault({"lora_mlp_kernel": True}) - # Should not patch patched_model = apply_lora_kernel_patches(model, cfg) layer = patched_model.model.model.layers[0].mlp - # Verify no patches applied - assert layer.forward.__func__ is not apply_lora_mlp_swiglu - assert layer.forward.__func__ is not apply_lora_mlp_geglu + # Verify patches ARE applied (dropout and bias are now supported) + assert ( + layer.forward.__func__ is apply_lora_mlp_swiglu + or layer.forward.__func__ is apply_lora_mlp_geglu + ) def test_kernel_config_options(): @@ -511,7 +512,7 @@ def test_kernel_training_integration_auto_enable(temp_dir): def test_kernel_training_integration_dropout_non_zero(temp_dir): - """Test model loading with dropout non-zero should not patch.""" + """Test model loading with dropout non-zero DOES patch (now supported).""" from axolotl.cli.utils import load_model_and_tokenizer @@ -546,31 +547,18 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir): # Load config cfg = load_cfg(str(path)) - # Get original attention class - attention_cls = get_attention_cls_from_config(cfg) - - # Store original state before patching - original_forward_method = attention_cls.forward - # Load model model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg) - # We call modelloader as that's where the patches are applied - # despite the fact that we're not using it to load the model model_loader = ModelLoader(cfg, tokenizer) - # Apply patch + # Apply patches — should succeed even with dropout > 0 model_loader.patch_manager._apply_self_attention_lora_patch() - - # Verify patch was not applied - assert attention_cls.forward == original_forward_method - - # Apply apply_lora_kernel_patches model_loader.patch_manager._apply_lora_kernel_patch(model) - # Verify patch was not applied + # Verify patches WERE applied (dropout is now supported by kernels) layers = get_layers(model) for layer in layers: for self_attn in find_self_attn_in_layer(layer): - assert not hasattr(self_attn, "apply_qkv") - assert not hasattr(self_attn, "apply_o") + assert hasattr(self_attn, "apply_qkv") + assert hasattr(self_attn, "apply_o") diff --git a/tests/utils/lora/test_config_validation_lora.py b/tests/utils/lora/test_config_validation_lora.py index 9d97288b6..45a848c65 100644 --- a/tests/utils/lora/test_config_validation_lora.py +++ b/tests/utils/lora/test_config_validation_lora.py @@ -28,20 +28,22 @@ class TestLoRAConfigValidation: result = validate_config(valid_config) assert result["adapter"] == "lora" - with pytest.raises(ValueError, match="not compatible with DoRA"): - invalid_config = DictDefault( - { - "adapter": "lora", - "lora_mlp_kernel": True, - "peft_use_dora": True, - "datasets": [{"path": "dummy_dataset", "type": "alpaca"}], - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - "learning_rate": 1e-5, - "base_model": "dummy_model", - } - ) - validate_config(invalid_config) + # DoRA is now compatible with lora kernels + dora_kernel_config = DictDefault( + { + "adapter": "lora", + "lora_mlp_kernel": True, + "peft_use_dora": True, + "datasets": [{"path": "dummy_dataset", "type": "alpaca"}], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-5, + "base_model": "dummy_model", + } + ) + result = validate_config(dora_kernel_config) + assert result["lora_mlp_kernel"] is True + assert result["peft_use_dora"] is True def test_qlora_4bit_validation(self): """Test QLoRA 4-bit configuration validation""" diff --git a/tests/utils/lora/test_freeze_lora.py b/tests/utils/lora/test_freeze_lora.py index da90c1826..7c5ec8fb3 100644 --- a/tests/utils/lora/test_freeze_lora.py +++ b/tests/utils/lora/test_freeze_lora.py @@ -38,6 +38,11 @@ class TestLoRAParameterFreezing: mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype) mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype) + mock_layer.lora_B["default"].bias = None + + # Required by get_lora_parameters for dropout/DoRA extraction + mock_layer.lora_dropout = {} + mock_layer.lora_magnitude_vector = None else: mock_layer.weight = base_layer.weight mock_layer.bias = base_layer.bias @@ -48,7 +53,7 @@ class TestLoRAParameterFreezing: """Test that LoRA parameters are None when adapters are disabled.""" layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True) - W, b, quant_state, A, B, s = get_lora_parameters(layer) + W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer) # Base parameters should be returned assert W is not None @@ -62,7 +67,7 @@ class TestLoRAParameterFreezing: """Test that LoRA parameters are None when adapters are merged.""" layer = self.create_mock_lora_layer(has_adapters=True, merged=True) - W, b, quant_state, A, B, s = get_lora_parameters(layer) + W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer) # Base parameters should be returned assert W is not None @@ -77,7 +82,7 @@ class TestLoRAParameterFreezing: """Test parameter behavior when no adapters are present.""" layer = self.create_mock_lora_layer(has_adapters=False) - W, b, quant_state, A, B, s = get_lora_parameters(layer) + W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer) # Base parameters should be returned assert W is not None @@ -94,7 +99,7 @@ class TestLoRAParameterFreezing: has_adapters=True, adapters_disabled=False, merged=False ) - W, b, quant_state, A, B, s = get_lora_parameters(layer) + W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer) # All parameters should be returned assert W is not None @@ -110,7 +115,7 @@ class TestLoRAParameterFreezing: has_adapters=True, adapters_disabled=False, merged=False ) - W, b, quant_state, A, B, s = get_lora_parameters(layer) + W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer) # Check shape consistency assert W.shape == (512, 256) @@ -124,7 +129,7 @@ class TestLoRAParameterFreezing: has_adapters=True, adapters_disabled=False, merged=False ) - W, b, quant_state, A, B, s = get_lora_parameters(layer) + W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer) assert W.dtype == self.dtype assert b.dtype == self.dtype @@ -138,7 +143,7 @@ class TestLoRAParameterFreezing: quant_state_mock = Mock() layer.base_layer.weight.quant_state = quant_state_mock - W, b, quant_state, A, B, s = get_lora_parameters(layer) + W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer) assert quant_state == quant_state_mock @@ -157,7 +162,7 @@ class TestLoRAParameterFreezing: layer.active_adapters = ["adapter2"] - W, b, quant_state, A, B, s = get_lora_parameters(layer) + W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer) assert s == 0.2 assert torch.equal(A, layer.lora_A["adapter2"].weight) @@ -192,13 +197,13 @@ class TestLoRAParameterFreezingIntegration: model = get_peft_model(base_model, lora_config) lora_layer = model.base_model.model.linear # Test with adapters enabled - W, b, quant_state, A, B, s = get_lora_parameters(lora_layer) + W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer) assert A is not None assert B is not None assert s is not None # Test with adapters disabled model.disable_adapter_layers() - W, b, quant_state, A, B, s = get_lora_parameters(lora_layer) + W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer) assert A is None assert B is None assert s is None