Files
axolotl/src/axolotl/kernels/lora.py
Dan Saunders 3d8425fa91 Activation function Triton kernels, LoRA custom autograd functions (#2324)
* LoRA + activation fn Triton kernels: initial commit

* implementing optims

* finalizing MLP LoRA kernels and progress on QKV / W kernels

* updates

* O projection optim

* adding monkey patching logic

* doc strings, typing, pre-commit fixes

* updates

* adding lora 8b kernels example

* working on fsdp support

* tests and fixes

* small fixes, getting tests to pass, adding doc strings

* integration tests for LoRA patching

* config.qmd

* remove unneeded pytest fixture

* fix

* review comments first pass

* improving tests, attention class agnostic patching

* adding support for more archs

* wip SiLU / GELU impls

* improved testing, small updates, etc.

* slightly updating docs

* rebase

* fixing test_attention_patching_integration

* additional review comments, fixing test in CI (hopefully)

* isolating problematic patching test

* relaxing allclose threshold to reduce flakiness

* fixing accidental change

* adding model arch agnostic attention class fetching

* removing unused activations
2025-02-17 14:23:15 -05:00

780 lines
22 KiB
Python

"""
Module for definition of Low-Rank Adaptation (LoRA) Triton kernels.
See "LoRA: Low-Rank Adaptation of Large Language Models"
(https://arxiv.org/abs/2106.09685).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name
from typing import Callable
import torch
from bitsandbytes.functional import QuantState
from torch import nn
from .geglu import geglu_backward, geglu_forward
from .quantize import dequantize
from .swiglu import swiglu_backward, swiglu_forward
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
def get_lora_parameters(
proj: nn.Module,
) -> tuple[
torch.Tensor,
QuantState | None,
torch.Tensor | None,
torch.Tensor | None,
float | None,
]:
"""
Gets LoRA parameters from a projection module.
Args:
proj: The projection module to extract parameters from.
Returns:
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
LoRA B matrix, and scaling factor. States and matrices may be None if not
available.
"""
# For DPO or disabled adapters
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
W = base_layer.weight
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
return W, quant_state, None, None, None
active_adapter = (
proj.active_adapters[0]
if hasattr(proj, "active_adapters")
else proj.active_adapter
)
A = proj.lora_A[active_adapter].weight
B = proj.lora_B[active_adapter].weight
s = proj.scaling[active_adapter]
quant_state = getattr(W, "quant_state", None)
return W, quant_state, A, B, s
def matmul_lora(
X: torch.Tensor,
W: torch.Tensor,
W_quant: QuantState,
A: torch.Tensor,
B: torch.Tensor,
s: float,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Efficient fused matmul + LoRA computation.
Args:
X: Input tensor [*, in_features]
W: Base weight matrix [out_features, in_features]
W_quant: Quantization state for W
A: LoRA A matrix [rank, in_features]
B: LoRA B matrix [out_features, rank]
s: LoRA scaling factor
out: Optional output tensor for inplace operations
Returns:
Result of X @ W + X @ A @ B
"""
dtype = X.dtype
W = dequantize(W.t(), W_quant)
if X.dim() == 3:
batch, seq_len, _ = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
out = torch.matmul(X, W, out=out)
if W_quant is not None:
del W
if A is not None:
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
return out.view(batch, seq_len, -1) if reshape else out
class LoRA_MLP(torch.autograd.Function):
"""Optimized LoRA MLP implementation."""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx,
X: torch.Tensor,
gate_weight: torch.Tensor,
gate_quant: object | None,
gate_A: torch.Tensor | None,
gate_B: torch.Tensor | None,
gate_scale: float,
up_weight: torch.Tensor,
up_quant: object | None,
up_A: torch.Tensor | None,
up_B: torch.Tensor | None,
up_scale: float,
down_weight: torch.Tensor,
down_quant: object | None,
down_A: torch.Tensor | None,
down_B: torch.Tensor | None,
down_scale: float,
activation_fn: Callable,
activation_fn_backward: Callable,
inplace: bool | None = True,
) -> torch.Tensor:
"""
Forward pass for LoRA MLP.
Args:
ctx: Autograd context
X: Input features
gate_weight: Gate projection weight
gate_quant: Gate quantization state
gate_A: Gate LoRA A matrix
gate_B: Gate LoRA B matrix
gate_scale: Gate LoRA scale
up_weight: Up-projection weight
up_quant: Up-projection quantization state
up_A: Up-projection LoRA A matrix
up_B: Up-projection LoRA B matrix
up_scale: Up-projection LoRA scale
down_weight: Down-projection weight
down_quant: Down-projection quantization state
down_A: Down-projection LoRA A matrix
down_B: Down-projection LoRA B matrix
down_scale: Down-projection LoRA scale
activation_fn: Forward activation function
activation_fn_backward: Backward activation function
inplace: Whether to perform operations in-place
Returns:
Output transformed by multi-layer perceptron and activation function
"""
# Compute projections
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
# Activation
hidden = activation_fn(gate, up)
# Down projection
output = matmul_lora(
hidden, down_weight, down_quant, down_A, down_B, down_scale
)
# Save for backward
ctx.save_for_backward(X, gate, up, gate_A, gate_B, up_A, up_B, down_A, down_B)
ctx.scales = (gate_scale, up_scale, down_scale)
ctx.quants = (gate_quant, up_quant, down_quant)
ctx.weights = (gate_weight, up_weight, down_weight)
ctx.activation_fn = activation_fn
ctx.activation_fn_backward = activation_fn_backward
ctx.inplace = inplace
return output
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_output: torch.Tensor,
) -> tuple[
torch.Tensor | None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
]:
"""
Performs backward pass computation for LoRA MLP.
Args:
ctx: Context object storing tensors saved during forward pass
grad_output: Gradient of loss with respect to layer output
Returns:
Tuple containing gradients for all inputs from forward pass:
- Input gradient tensor (or `None`)
- `None` for weights/quantization states
- LoRA A/B matrix gradients (or `None`)
- `None` for scaling factors
- `None` for activation functions and flags
"""
(
X,
gate,
up,
gate_A,
gate_B,
up_A,
up_B,
down_A,
down_B,
) = ctx.saved_tensors
gate_scale, up_scale, down_scale = ctx.scales
gate_quant, up_quant, down_quant = ctx.quants
gate_weight, up_weight, down_weight = ctx.weights
# Transpose all LoRA matrices
gate_A, gate_B = (
gate_A.t() if gate_A is not None else None,
gate_B.t() if gate_B is not None else None,
)
up_A, up_B = (
up_A.t() if up_A is not None else None,
up_B.t() if up_B is not None else None,
)
down_A, down_B = (
down_A.t() if down_A is not None else None,
down_B.t() if down_B is not None else None,
)
# Reshape inputs
batch, seq_len, hd = X.shape
grad_output = grad_output.view(-1, grad_output.shape[-1])
X = X.view(-1, X.shape[-1])
gate = gate.view(-1, gate.shape[-1])
up = up.view(-1, up.shape[-1])
dtype = X.dtype
# Down projection
DW = matmul_lora(
grad_output,
down_weight.t(),
down_quant,
down_B,
down_A,
down_scale,
)
# Activation backward
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
# Initialize and compute LoRA gradients
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
if down_A is not None:
d_down_A = h.t() @ (grad_output @ down_B.t())
d_down_B = (down_A.t() @ h.t()) @ grad_output
d_down_A *= down_scale
d_down_B *= down_scale
if up_A is not None:
d_up_A = X.t() @ (grad_up @ up_B.t())
d_up_B = (up_A.t() @ X.t()) @ grad_up
d_up_A *= up_scale
d_up_B *= up_scale
if gate_A is not None:
d_gate_A = X.t() @ (grad_gate @ gate_B.t())
d_gate_B = (gate_A.t() @ X.t()) @ grad_gate
d_gate_A *= gate_scale
d_gate_B *= gate_scale
# Compute input gradients
dX = torch.zeros_like(X) if ctx.needs_input_grad[0] else None
if dX is not None:
# Up projection gradients
up_weight = dequantize(up_weight.t(), up_quant)
if ctx.inplace:
dX = torch.matmul(grad_up, up_weight.t(), out=X)
else:
dX = torch.matmul(grad_up, up_weight.t())
del up_weight
# Note the .to(dtype) only where mixing LoRA with base weights
if up_A is not None:
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients
gate_weight = dequantize(gate_weight.t(), gate_quant)
dX += grad_gate @ gate_weight.t()
del gate_weight
if gate_A is not None:
dX += (
grad_gate
@ gate_B.to(dtype).t()
@ (gate_scale * gate_A.to(dtype).t())
)
# Reshape back
dX = dX.view(batch, seq_len, hd)
# Return gradients in correct order matching forward inputs
return (
dX,
None,
None,
d_gate_A.t() if d_gate_A is not None else None,
d_gate_B.t() if d_gate_B is not None else None,
None,
None,
None,
d_up_A.t() if d_up_A is not None else None,
d_up_B.t() if d_up_B is not None else None,
None,
None,
None,
d_down_A.t() if d_down_A is not None else None,
d_down_B.t() if d_down_B is not None else None,
None,
None,
None,
None,
)
def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
"""
Applies LoRA to MLP layer with SwiGLU activation.
Args:
X: Input tensor for the MLP layer
inplace: Whether to perform operations in-place to save memory
Returns:
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
"""
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upW_quant,
upA,
upB,
upS,
downW,
downW_quant,
downA,
downB,
downS,
swiglu_forward,
swiglu_backward,
inplace,
)
return out
def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
"""
Applies LoRA to MLP layer with GEGLU activation.
Args:
X: Input tensor for the MLP layer
inplace: Whether to perform operations in-place to save memory
Returns:
Output tensor after applying LoRA-adapted MLP with GEGLU activation
"""
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upW_quant,
upA,
upB,
upS,
downW,
downW_quant,
downA,
downB,
downS,
geglu_forward,
geglu_backward,
inplace,
)
return out
class LoRA_QKV(torch.autograd.Function):
"""
Optimized LoRA QKV implementation with quantization support.
Implements efficient computation of query, key, value projections with LoRA,
supporting quantization and memory optimization.
"""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
q_weight: torch.Tensor,
q_quant: QuantState | None,
q_A: torch.Tensor | None,
q_B: torch.Tensor | None,
q_scale: float,
k_weight: torch.Tensor,
k_quant: QuantState | None,
k_A: torch.Tensor | None,
k_B: torch.Tensor | None,
k_scale: float,
v_weight: torch.Tensor,
v_quant: QuantState | None,
v_A: torch.Tensor | None,
v_B: torch.Tensor | None,
v_scale: float,
inplace: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass computing Q, K, V projections with LoRA.
Args:
ctx: Autograd context
X: Input tensor
q_weight: Query projection weight
q_quant: Query quantization state
q_A: Query LoRA A matrix
q_B: Query LoRA B matrix
q_scale: Query LoRA scale
k_weight: Key projection weight
k_quant: Key quantization state
k_A: Key LoRA A matrix
k_B: Key LoRA B matrix
k_scale: Key LoRA scale
v_weight: Value projection weight
v_quant: Value quantization state
v_A: Value LoRA A matrix
v_B: Value LoRA B matrix
v_scale: Value LoRA scale
inplace: Whether to perform operations in-place
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
ctx.scales = (q_scale, k_scale, v_scale)
ctx.quants = (q_quant, k_quant, v_quant)
ctx.weights = (q_weight, k_weight, v_weight)
ctx.inplace = inplace
return Q, K, V
@staticmethod
@torch_amp_custom_fwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor,
k_grad: torch.Tensor,
v_grad: torch.Tensor,
) -> tuple[
torch.Tensor,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
]:
"""
Backward pass computing gradients for LoRA QKV.
Args:
ctx: Autograd context
q_grad: Gradient for query projection
k_grad: Gradient for key projection
v_grad: Gradient for value projection
Returns:
Tuple containing gradients for all forward inputs
"""
X, A_q, B_q, A_k, B_k, A_v, B_v = ctx.saved_tensors
q_weight, k_weight, v_weight = ctx.weights
q_quant, k_quant, v_quant = ctx.quants
q_scale, k_scale, v_scale = ctx.scales
dtype = X.dtype
# Reshape gradients
batch, seq_len = X.shape[:2]
q_grad = q_grad.view(-1, q_grad.shape[-1])
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
v_grad = v_grad.view(-1, v_grad.shape[-1])
X = X.view(-1, X.shape[-1])
# Pre-transpose X once
X_t = X.t()
# Initialize LoRA gradients as None
d_A_q = d_B_q = d_A_k = d_B_k = d_A_v = d_B_v = None
# Compute q path LoRA gradients if adapters exist
if A_q is not None and B_q is not None:
A_q_scaled = (q_scale * A_q).to(dtype)
B_q_scaled = B_q.to(dtype)
d_A_q = torch.mm(X_t, torch.mm(q_grad, B_q_scaled))
d_B_q = torch.mm(torch.mm(A_q_scaled, X_t), q_grad)
# Compute k path LoRA gradients if adapters exist
if A_k is not None and B_k is not None:
A_k_scaled = (k_scale * A_k).to(dtype)
B_k_scaled = B_k.to(dtype)
d_A_k = torch.mm(X_t, torch.mm(k_grad, B_k_scaled))
d_B_k = torch.mm(torch.mm(A_k_scaled, X_t), k_grad)
# Compute v path LoRA gradients if adapters exist
if A_v is not None and B_v is not None:
A_v_scaled = (v_scale * A_v).to(dtype)
B_v_scaled = B_v.to(dtype)
d_A_v = torch.mm(X_t, torch.mm(v_grad, B_v_scaled))
d_B_v = torch.mm(torch.mm(A_v_scaled, X_t), v_grad)
# Compute input gradient, reusing X memory if possible
out_buffer = X if ctx.inplace else None
# Q path
q_weight_t = dequantize(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight
del q_weight_t
if A_q is not None and B_q is not None:
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# K path
k_weight_t = dequantize(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight
del k_weight_t
if A_k is not None and B_k is not None:
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
# V path
v_weight_t = dequantize(v_weight, v_quant)
grad_X.addmm_(v_grad, v_weight_t)
del v_weight
del v_weight_t
if A_v is not None and B_v is not None:
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
# Transpose gradients if needed
if d_A_q is not None:
d_A_q = d_A_q.t()
if d_B_q is not None:
d_B_q = d_B_q.t()
if d_A_k is not None:
d_A_k = d_A_k.t()
if d_B_k is not None:
d_B_k = d_B_k.t()
if d_A_v is not None:
d_A_v = d_A_v.t()
if d_B_v is not None:
d_B_v = d_B_v.t()
return (
grad_X.view(batch, seq_len, -1),
None,
None,
d_A_q,
d_B_q,
None,
None,
None,
d_A_k,
d_B_k,
None,
None,
None,
d_A_v,
d_B_v,
None,
None,
)
def apply_lora_qkv(
self, X: torch.Tensor, inplace: bool = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies LoRA to compute Query, Key, Value projections.
Args:
X: Input tensor
inplace: Whether to perform operations in-place
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(
X,
QW,
QW_quant,
QA,
QB,
QS,
KW,
KW_quant,
KA,
KB,
KS,
VW,
VW_quant,
VA,
VB,
VS,
inplace,
)
return Q, K, V
class LoRA_O(torch.autograd.Function):
"""Optimized LoRA implementation for output projection."""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
W: torch.Tensor,
W_quant: QuantState | None,
A: torch.Tensor | None,
B: torch.Tensor | None,
S: float,
) -> torch.Tensor:
"""
Forward pass for output projection with LoRA.
Args:
ctx: Autograd context
X: Input tensor
W: Output projection weight
W_quant: Weight quantization state
A: LoRA A matrix
B: LoRA B matrix
S: LoRA scaling factor
Returns:
Output projection tensor
"""
XW = matmul_lora(X, W, W_quant, A, B, S)
ctx.custom_saved_tensors = (
W,
W_quant,
S,
)
ctx.save_for_backward(A, B, X)
return XW
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
dY: torch.Tensor,
) -> tuple[
torch.Tensor,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
]:
"""
Backward pass computing gradients for LoRA output projection.
Args:
ctx: Autograd context
dY: Gradient of loss with respect to output
Returns:
Tuple containing gradients for all forward inputs
"""
W, W_quant, S = ctx.custom_saved_tensors
A, B, X = ctx.saved_tensors
batch, seq_len, hd = X.shape
dY = dY.reshape(-1, dY.shape[-1])
X = X.reshape(-1, X.shape[-1])
dtype = X.dtype
# Weight projection
dY_X = X.t() @ dY
d_A = S * dY_X @ B
d_B = S * A @ dY_X
# Get derivative for dX
W = dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
"""
Applies LoRA to output projection layer.
Args:
X: Input tensor
Returns:
Transformed output tensor
"""
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
return output