Lora kernels bias support (#3025)

* lora kernels bias support

* revert rename

* nit

* lint, tests

* satisfying the rabbit
This commit is contained in:
Dan Saunders
2025-08-06 20:20:08 -04:00
committed by GitHub
parent e442ff22aa
commit d09290f2f4
3 changed files with 156 additions and 90 deletions

View File

@@ -26,6 +26,7 @@ def get_lora_parameters(
proj: nn.Module, proj: nn.Module,
) -> tuple[ ) -> tuple[
torch.Tensor, torch.Tensor,
torch.Tensor | None,
QuantState | None, QuantState | None,
torch.Tensor | None, torch.Tensor | None,
torch.Tensor | None, torch.Tensor | None,
@@ -38,17 +39,20 @@ def get_lora_parameters(
proj: The projection module to extract parameters from. proj: The projection module to extract parameters from.
Returns: Returns:
A tuple containing the base weight matrix, quantization state, LoRA A matrix, A tuple containing the base weights, quantization state, LoRA A and B weights,
LoRA B matrix, and scaling factor. States and matrices may be None if not scaling factor, and base layer bias. Quant state, weights, and bias may be
available. `None` if not available.
""" """
# For DPO or disabled adapters # For DPO or disabled adapters
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
W = base_layer.weight W = base_layer.weight
b = base_layer.bias
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None) quant_state = getattr(W, "quant_state", None)
return W, quant_state, None, None, None return W, b, quant_state, None, None, None
quant_state = getattr(W, "quant_state", None)
active_adapter = ( active_adapter = (
proj.active_adapters[0] proj.active_adapters[0]
@@ -72,18 +76,17 @@ def get_lora_parameters(
B = linear_B.weight B = linear_B.weight
s = proj.scaling[active_adapter] s = proj.scaling[active_adapter]
quant_state = getattr(W, "quant_state", None) return W, b, quant_state, A, B, s
return W, quant_state, A, B, s
def matmul_lora( def matmul_lora(
X: torch.Tensor, X: torch.Tensor,
W: torch.Tensor, W: torch.Tensor,
W_quant: QuantState, b: torch.Tensor | None,
A: torch.Tensor, W_quant: QuantState | None,
B: torch.Tensor, A: torch.Tensor | None,
s: float, B: torch.Tensor | None,
s: float | None,
out: torch.Tensor | None = None, out: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@@ -104,21 +107,23 @@ def matmul_lora(
dtype = X.dtype dtype = X.dtype
W = dequantize(W.t(), W_quant) W = dequantize(W.t(), W_quant)
reshape = False
if X.dim() == 3: if X.dim() == 3:
batch, seq_len, _ = X.shape batch, seq_len, _ = X.shape
X = X.view(-1, X.shape[-1]) X = X.view(-1, X.shape[-1])
reshape = True reshape = True
else:
reshape = False
out = torch.matmul(X, W, out=out) out = torch.matmul(X, W, out=out)
if W_quant is not None: if W_quant is not None:
del W del W
if A is not None: if A is not None:
A, B = A.t().to(dtype), B.t().to(dtype) A, B = A.t().to(dtype), B.t().to(dtype) # type: ignore[union-attr]
out += s * X @ A @ B out += s * X @ A @ B
if b is not None:
out += b
return out.view(batch, seq_len, -1) if reshape else out return out.view(batch, seq_len, -1) if reshape else out
@@ -131,17 +136,20 @@ class LoRA_MLP(torch.autograd.Function):
ctx, ctx,
X: torch.Tensor, X: torch.Tensor,
gate_weight: torch.Tensor, gate_weight: torch.Tensor,
gate_quant: object | None, gate_bias: torch.Tensor | None,
gate_quant: QuantState | None,
gate_A: torch.Tensor | None, gate_A: torch.Tensor | None,
gate_B: torch.Tensor | None, gate_B: torch.Tensor | None,
gate_scale: float, gate_scale: float,
up_weight: torch.Tensor, up_weight: torch.Tensor,
up_quant: object | None, up_bias: torch.Tensor | None,
up_quant: QuantState | None,
up_A: torch.Tensor | None, up_A: torch.Tensor | None,
up_B: torch.Tensor | None, up_B: torch.Tensor | None,
up_scale: float, up_scale: float,
down_weight: torch.Tensor, down_weight: torch.Tensor,
down_quant: object | None, down_bias: torch.Tensor | None,
down_quant: QuantState | None,
down_A: torch.Tensor | None, down_A: torch.Tensor | None,
down_B: torch.Tensor | None, down_B: torch.Tensor | None,
down_scale: float, down_scale: float,
@@ -156,20 +164,22 @@ class LoRA_MLP(torch.autograd.Function):
ctx: Autograd context ctx: Autograd context
X: Input features X: Input features
gate_weight: Gate projection weight gate_weight: Gate projection weight
gate_bias: Gate projection bias
gate_quant: Gate quantization state gate_quant: Gate quantization state
gate_A: Gate LoRA A matrix gate_A: Gate LoRA A matrix
gate_B: Gate LoRA B matrix gate_B: Gate LoRA B matrix
gate_scale: Gate LoRA scale gate_scale: Gate LoRA scale
up_weight: Up-projection weight up_weight: Up projection weight
up_quant: Up-projection quantization state up_quant: Up projection quantization state
up_A: Up-projection LoRA A matrix up_A: Up projection LoRA A matrix
up_B: Up-projection LoRA B matrix up_B: Up projection LoRA B matrix
up_scale: Up-projection LoRA scale up_scale: Up projection LoRA scale
down_weight: Down-projection weight down_weight: Down projection weight
down_quant: Down-projection quantization state down_bias: Down projection bias
down_A: Down-projection LoRA A matrix down_quant: Down projection quantization state
down_B: Down-projection LoRA B matrix down_A: Down projection LoRA A matrix
down_scale: Down-projection LoRA scale down_B: Down projection LoRA B matrix
down_scale: Down projection LoRA scale
activation_fn: Forward activation function activation_fn: Forward activation function
activation_fn_backward: Backward activation function activation_fn_backward: Backward activation function
inplace: Whether to perform operations in-place inplace: Whether to perform operations in-place
@@ -178,15 +188,17 @@ class LoRA_MLP(torch.autograd.Function):
Output transformed by multi-layer perceptron and activation function Output transformed by multi-layer perceptron and activation function
""" """
# Compute projections # Compute projections
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale) gate = matmul_lora(
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale) X, gate_weight, gate_bias, gate_quant, gate_A, gate_B, gate_scale
)
up = matmul_lora(X, up_weight, up_bias, up_quant, up_A, up_B, up_scale)
# Activation # Activation
hidden = activation_fn(gate, up) hidden = activation_fn(gate, up)
# Down projection # Down projection
output = matmul_lora( output = matmul_lora(
hidden, down_weight, down_quant, down_A, down_B, down_scale hidden, down_weight, down_bias, down_quant, down_A, down_B, down_scale
) )
# Save for backward # Save for backward
@@ -209,22 +221,26 @@ class LoRA_MLP(torch.autograd.Function):
torch.Tensor | None, torch.Tensor | None,
None, None,
None, None,
None,
torch.Tensor | None, torch.Tensor | None,
torch.Tensor | None, torch.Tensor | None,
None, None,
None, None,
None, None,
None,
torch.Tensor | None, torch.Tensor | None,
torch.Tensor | None, torch.Tensor | None,
None, None,
None, None,
None, None,
None,
torch.Tensor | None, torch.Tensor | None,
torch.Tensor | None, torch.Tensor | None,
None, None,
None, None,
None, None,
None, None,
None,
]: ]:
""" """
Performs backward pass computation for LoRA MLP. Performs backward pass computation for LoRA MLP.
@@ -236,7 +252,7 @@ class LoRA_MLP(torch.autograd.Function):
Returns: Returns:
Tuple containing gradients for all inputs from forward pass: Tuple containing gradients for all inputs from forward pass:
- Input gradient tensor (or `None`) - Input gradient tensor (or `None`)
- `None` for weights/quantization states - `None` for weights/biases/quantization states
- LoRA A/B matrix gradients (or `None`) - LoRA A/B matrix gradients (or `None`)
- `None` for scaling factors - `None` for scaling factors
- `None` for activation functions and flags - `None` for activation functions and flags
@@ -279,9 +295,10 @@ class LoRA_MLP(torch.autograd.Function):
dtype = X.dtype dtype = X.dtype
# Down projection # Down projection
DW = matmul_lora( grad_down = matmul_lora(
grad_output, grad_output,
down_weight.t(), down_weight.t(),
None,
down_quant, down_quant,
down_B, down_B,
down_A, down_A,
@@ -289,7 +306,7 @@ class LoRA_MLP(torch.autograd.Function):
) )
# Activation backward # Activation backward
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up) h, grad_gate, grad_up = ctx.activation_fn_backward(grad_down, gate, up)
# Initialize and compute LoRA gradients # Initialize and compute LoRA gradients
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
@@ -329,8 +346,8 @@ class LoRA_MLP(torch.autograd.Function):
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t()) dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients # Gate projection gradients
gate_weight = dequantize(gate_weight.t(), gate_quant) gate_weight = dequantize(gate_weight, gate_quant)
dX += grad_gate @ gate_weight.t() dX += grad_gate @ gate_weight
del gate_weight del gate_weight
if gate_A is not None and gate_B is not None: if gate_A is not None and gate_B is not None:
@@ -348,22 +365,26 @@ class LoRA_MLP(torch.autograd.Function):
dX, dX,
None, None,
None, None,
None,
d_gate_A.t() if d_gate_A is not None else None, d_gate_A.t() if d_gate_A is not None else None,
d_gate_B.t() if d_gate_B is not None else None, d_gate_B.t() if d_gate_B is not None else None,
None, None,
None, None,
None, None,
None,
d_up_A.t() if d_up_A is not None else None, d_up_A.t() if d_up_A is not None else None,
d_up_B.t() if d_up_B is not None else None, d_up_B.t() if d_up_B is not None else None,
None, None,
None, None,
None, None,
None,
d_down_A.t() if d_down_A is not None else None, d_down_A.t() if d_down_A is not None else None,
d_down_B.t() if d_down_B is not None else None, d_down_B.t() if d_down_B is not None else None,
None, None,
None, None,
None, None,
None, None,
None,
) )
@@ -378,23 +399,26 @@ def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.
Returns: Returns:
Output tensor after applying LoRA-adapted MLP with SwiGLU activation Output tensor after applying LoRA-adapted MLP with SwiGLU activation
""" """
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply( out = LoRA_MLP.apply(
X, X,
gateW, gateW,
gateb,
gateW_quant, gateW_quant,
gateA, gateA,
gateB, gateB,
gateS, gateS,
upW, upW,
upb,
upW_quant, upW_quant,
upA, upA,
upB, upB,
upS, upS,
downW, downW,
downb,
downW_quant, downW_quant,
downA, downA,
downB, downB,
@@ -418,22 +442,25 @@ def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.T
Returns: Returns:
Output tensor after applying LoRA-adapted MLP with GEGLU activation Output tensor after applying LoRA-adapted MLP with GEGLU activation
""" """
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply( out = LoRA_MLP.apply(
X, X,
gateW, gateW,
gateb,
gateW_quant, gateW_quant,
gateA, gateA,
gateB, gateB,
gateS, gateS,
upW, upW,
upb,
upW_quant, upW_quant,
upA, upA,
upB, upB,
upS, upS,
downW, downW,
downb,
downW_quant, downW_quant,
downA, downA,
downB, downB,
@@ -460,16 +487,19 @@ class LoRA_QKV(torch.autograd.Function):
ctx: torch.autograd.function.FunctionCtx, ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor, X: torch.Tensor,
q_weight: torch.Tensor, q_weight: torch.Tensor,
q_bias: torch.Tensor | None,
q_quant: QuantState | None, q_quant: QuantState | None,
q_A: torch.Tensor | None, q_A: torch.Tensor | None,
q_B: torch.Tensor | None, q_B: torch.Tensor | None,
q_scale: float, q_scale: float,
k_weight: torch.Tensor, k_weight: torch.Tensor,
k_bias: torch.Tensor | None,
k_quant: QuantState | None, k_quant: QuantState | None,
k_A: torch.Tensor | None, k_A: torch.Tensor | None,
k_B: torch.Tensor | None, k_B: torch.Tensor | None,
k_scale: float, k_scale: float,
v_weight: torch.Tensor, v_weight: torch.Tensor,
v_bias: torch.Tensor | None,
v_quant: QuantState | None, v_quant: QuantState | None,
v_A: torch.Tensor | None, v_A: torch.Tensor | None,
v_B: torch.Tensor | None, v_B: torch.Tensor | None,
@@ -483,16 +513,19 @@ class LoRA_QKV(torch.autograd.Function):
ctx: Autograd context ctx: Autograd context
X: Input tensor X: Input tensor
q_weight: Query projection weight q_weight: Query projection weight
q_bias: Query projection bias
q_quant: Query quantization state q_quant: Query quantization state
q_A: Query LoRA A matrix q_A: Query LoRA A matrix
q_B: Query LoRA B matrix q_B: Query LoRA B matrix
q_scale: Query LoRA scale q_scale: Query LoRA scale
k_weight: Key projection weight k_weight: Key projection weight
k_bias: Key projection bias
k_quant: Key quantization state k_quant: Key quantization state
k_A: Key LoRA A matrix k_A: Key LoRA A matrix
k_B: Key LoRA B matrix k_B: Key LoRA B matrix
k_scale: Key LoRA scale k_scale: Key LoRA scale
v_weight: Value projection weight v_weight: Value projection weight
v_bias: Value projection bias
v_quant: Value quantization state v_quant: Value quantization state
v_A: Value LoRA A matrix v_A: Value LoRA A matrix
v_B: Value LoRA B matrix v_B: Value LoRA B matrix
@@ -502,20 +535,21 @@ class LoRA_QKV(torch.autograd.Function):
Returns: Returns:
Tuple of (Query, Key, Value) projection tensors Tuple of (Query, Key, Value) projection tensors
""" """
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale) Q = matmul_lora(X, q_weight, q_bias, q_quant, q_A, q_B, q_scale)
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale) K = matmul_lora(X, k_weight, k_bias, k_quant, k_A, k_B, k_scale)
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale) V = matmul_lora(X, v_weight, v_bias, v_quant, v_A, v_B, v_scale)
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B) ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
ctx.scales = (q_scale, k_scale, v_scale) ctx.scales = (q_scale, k_scale, v_scale)
ctx.quants = (q_quant, k_quant, v_quant) ctx.quants = (q_quant, k_quant, v_quant)
ctx.weights = (q_weight, k_weight, v_weight) ctx.weights = (q_weight, k_weight, v_weight)
ctx.biases = (q_bias, k_bias, v_bias)
ctx.inplace = inplace ctx.inplace = inplace
return Q, K, V return Q, K, V
@staticmethod @staticmethod
@torch_amp_custom_fwd @torch_amp_custom_bwd
def backward( def backward(
ctx: torch.autograd.function.FunctionCtx, ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor, q_grad: torch.Tensor,
@@ -525,16 +559,19 @@ class LoRA_QKV(torch.autograd.Function):
torch.Tensor, torch.Tensor,
None, None,
None, None,
None,
torch.Tensor | None, torch.Tensor | None,
torch.Tensor | None, torch.Tensor | None,
None, None,
None, None,
None, None,
None,
torch.Tensor | None, torch.Tensor | None,
torch.Tensor | None, torch.Tensor | None,
None, None,
None, None,
None, None,
None,
torch.Tensor | None, torch.Tensor | None,
torch.Tensor | None, torch.Tensor | None,
None, None,
@@ -622,31 +659,31 @@ class LoRA_QKV(torch.autograd.Function):
# Transpose gradients if needed # Transpose gradients if needed
if d_A_q is not None: if d_A_q is not None:
d_A_q = d_A_q.t() d_A_q = d_A_q.t()
if d_B_q is not None: d_B_q = d_B_q.t() # type: ignore[union-attr]
d_B_q = d_B_q.t()
if d_A_k is not None: if d_A_k is not None:
d_A_k = d_A_k.t() d_A_k = d_A_k.t()
if d_B_k is not None: d_B_k = d_B_k.t() # type: ignore[union-attr]
d_B_k = d_B_k.t()
if d_A_v is not None: if d_A_v is not None:
d_A_v = d_A_v.t() d_A_v = d_A_v.t()
if d_B_v is not None: d_B_v = d_B_v.t() # type: ignore[union-attr]
d_B_v = d_B_v.t()
return ( return (
grad_X.view(batch, seq_len, -1), grad_X.view(batch, seq_len, -1),
None, None,
None, None,
None,
d_A_q, d_A_q,
d_B_q, d_B_q,
None, None,
None, None,
None, None,
None,
d_A_k, d_A_k,
d_B_k, d_B_k,
None, None,
None, None,
None, None,
None,
d_A_v, d_A_v,
d_B_v, d_B_v,
None, None,
@@ -667,22 +704,25 @@ def apply_lora_qkv(
Returns: Returns:
Tuple of (Query, Key, Value) projection tensors Tuple of (Query, Key, Value) projection tensors
""" """
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj) QW, Qb, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj) KW, Kb, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj) VW, Vb, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply( Q, K, V = LoRA_QKV.apply(
X, X,
QW, QW,
Qb,
QW_quant, QW_quant,
QA, QA,
QB, QB,
QS, QS,
KW, KW,
Kb,
KW_quant, KW_quant,
KA, KA,
KB, KB,
KS, KS,
VW, VW,
Vb,
VW_quant, VW_quant,
VA, VA,
VB, VB,
@@ -702,10 +742,11 @@ class LoRA_O(torch.autograd.Function):
ctx: torch.autograd.function.FunctionCtx, ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor, X: torch.Tensor,
W: torch.Tensor, W: torch.Tensor,
b: torch.Tensor,
W_quant: QuantState | None, W_quant: QuantState | None,
A: torch.Tensor | None, A: torch.Tensor,
B: torch.Tensor | None, B: torch.Tensor,
S: float, s: float,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Forward pass for output projection with LoRA. Forward pass for output projection with LoRA.
@@ -714,19 +755,20 @@ class LoRA_O(torch.autograd.Function):
ctx: Autograd context ctx: Autograd context
X: Input tensor X: Input tensor
W: Output projection weight W: Output projection weight
b: Output projection bias
W_quant: Weight quantization state W_quant: Weight quantization state
A: LoRA A matrix A: LoRA A matrix
B: LoRA B matrix B: LoRA B matrix
S: LoRA scaling factor s: LoRA scaling factor
Returns: Returns:
Output projection tensor Output projection result
""" """
XW = matmul_lora(X, W, W_quant, A, B, S) XW = matmul_lora(X, W, b, W_quant, A, B, s)
ctx.custom_saved_tensors = ( ctx.custom_saved_tensors = (
W, W,
W_quant, W_quant,
S, s,
) )
ctx.save_for_backward(A, B, X) ctx.save_for_backward(A, B, X)
@@ -741,8 +783,9 @@ class LoRA_O(torch.autograd.Function):
torch.Tensor, torch.Tensor,
None, None,
None, None,
torch.Tensor | None, None,
torch.Tensor | None, torch.Tensor,
torch.Tensor,
None, None,
]: ]:
""" """
@@ -755,7 +798,7 @@ class LoRA_O(torch.autograd.Function):
Returns: Returns:
Tuple containing gradients for all forward inputs Tuple containing gradients for all forward inputs
""" """
W, W_quant, S = ctx.custom_saved_tensors W, W_quant, s = ctx.custom_saved_tensors
A, B, X = ctx.saved_tensors A, B, X = ctx.saved_tensors
batch, seq_len, hd = X.shape batch, seq_len, hd = X.shape
@@ -765,17 +808,19 @@ class LoRA_O(torch.autograd.Function):
# Weight projection # Weight projection
dY_X = X.t() @ dY dY_X = X.t() @ dY
d_A = S * dY_X @ B d_A = s * dY_X @ B
d_B = S * A @ dY_X d_B = s * A @ dY_X
# Get derivative for dX # Get derivative for dX
W = dequantize(W.t(), W_quant) W = dequantize(W.t(), W_quant)
dX = dY @ W.t() dX = dY @ W.t()
del W del W
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
# W, W_quant, A, B, S A, B = A.to(dtype), B.to(dtype)
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None dX += s * dY @ B @ A
# W, b, W_quant, A, B, s
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor: def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
@@ -788,7 +833,7 @@ def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj) OW, Ob, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS) output = LoRA_O.apply(X, OW, Ob, OW_quant, OA, OB, OS)
return output return output

View File

@@ -390,7 +390,6 @@ def apply_lora_kernel_patches(
] ]
can_patch_qkv = all( can_patch_qkv = all(
hasattr(module, "lora_A") hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules for module in layer_modules
) )
@@ -400,7 +399,8 @@ def apply_lora_kernel_patches(
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn) self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else: else:
LOG.warning_once( LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias" "Cannot patch some attention QKV projections - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
) )
if cfg.lora_o_kernel: if cfg.lora_o_kernel:
# Output patching # Output patching
@@ -409,7 +409,6 @@ def apply_lora_kernel_patches(
] ]
can_patch_o = all( can_patch_o = all(
hasattr(module, "lora_A") hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules for module in layer_modules
) )
@@ -418,14 +417,14 @@ def apply_lora_kernel_patches(
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn) self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
else: else:
LOG.warning_once( LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters with no bias" "Cannot patch some attention output projection - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
) )
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer): for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel: if cfg.lora_mlp_kernel:
# MLP patching # MLP patching
can_patch_mlp = all( can_patch_mlp = all(
hasattr(proj, "lora_A") hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj) for proj in (gate_proj, up_proj, down_proj)
) )
@@ -435,7 +434,8 @@ def apply_lora_kernel_patches(
layer.mlp.forward = types.MethodType(apply_fn, mlp) layer.mlp.forward = types.MethodType(apply_fn, mlp)
else: else:
LOG.warning_once( LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters with no bias" "Cannot patch some MLP layers - requires LoRA adapters and no "
"lora_magnitude_vector (DoRA)"
) )
LOG.setLevel(original_level) LOG.setLevel(original_level)

View File

@@ -64,6 +64,7 @@ def sample_tensors():
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16 batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
), ),
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16), "W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
"b": torch.randn(out_dim, device="cuda", dtype=torch.float16),
"scale": 0.5, "scale": 0.5,
"shapes": { "shapes": {
"batch": batch_size, "batch": batch_size,
@@ -103,23 +104,24 @@ def mock_proj():
def test_get_lora_parameters(mock_proj): def test_get_lora_parameters(mock_proj):
"""Tests get_lora_parameters function""" """Tests get_lora_parameters function"""
# Test with LoRA enabled # Test with LoRA enabled
W, _, A, B, s = get_lora_parameters(mock_proj) W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert isinstance(W, torch.Tensor) assert isinstance(W, torch.Tensor)
assert W.shape == (128, 64) assert W.shape == (128, 64)
assert b.shape == (128,)
assert A.shape == (8, 64) assert A.shape == (8, 64)
assert B.shape == (128, 8) assert B.shape == (128, 8)
assert s == 0.5 assert s == 0.5
# Test with LoRA disabled # Test with LoRA disabled
mock_proj.disable_adapters = True mock_proj.disable_adapters = True
W, _, A, B, s = get_lora_parameters(mock_proj) W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None assert A is None and B is None and s is None
# Test with merged state # Test with merged state
mock_proj.disable_adapters = False mock_proj.disable_adapters = False
mock_proj.merged = True mock_proj.merged = True
W, _, A, B, s = get_lora_parameters(mock_proj) W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None assert A is None and B is None and s is None
@@ -127,6 +129,7 @@ def test_matmul_lora(sample_tensors):
"""Tests matmul_lora function""" """Tests matmul_lora function"""
X = sample_tensors["X"] X = sample_tensors["X"]
W = sample_tensors["W"] W = sample_tensors["W"]
b = sample_tensors["b"]
scale = sample_tensors["scale"] scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"] shapes = sample_tensors["shapes"]
@@ -138,19 +141,20 @@ def test_matmul_lora(sample_tensors):
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16) B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test base matmul # Test base matmul
out1 = matmul_lora(X, W, None, None, None, None) out1 = matmul_lora(X, W, b, None, None, None, None)
expected1 = torch.matmul(X, W.t()) matmul = torch.matmul(X, W.t())
expected1 = matmul + b
assert torch.allclose(out1, expected1, rtol=1e-3) assert torch.allclose(out1, expected1, rtol=1e-3)
# Test with LoRA # Test with LoRA
out2 = matmul_lora(X, W, None, A, B, scale) out2 = matmul_lora(X, W, b, None, A, B, scale)
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t()) lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
expected2 = expected1 + lora_term expected2 = matmul + lora_term + b
assert torch.allclose(out2, expected2, rtol=1e-3) assert torch.allclose(out2, expected2, rtol=1e-3)
# Test 3D input reshaping # Test 3D input reshaping
X_3d = X.clone() X_3d = X.clone()
out3 = matmul_lora(X_3d, W, None, A, B, scale) out3 = matmul_lora(X_3d, W, b, None, A, B, scale)
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0]) assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
@@ -175,16 +179,19 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
output = LoRA_MLP.apply( output = LoRA_MLP.apply(
X, X,
gate_proj.weight, gate_proj.weight,
gate_proj.bias,
None, # gate_quant None, # gate_quant
None, # gate_A None, # gate_A
None, # gate_B None, # gate_B
None, # gate_scale None, # gate_scale
up_proj.weight, up_proj.weight,
up_proj.bias,
None, # up_quant None, # up_quant
None, # up_A None, # up_A
None, # up_B None, # up_B
None, # up_scale None, # up_scale
down_proj.weight, down_proj.weight,
down_proj.bias,
None, # down_quant None, # down_quant
None, # down_A None, # down_A
None, # down_B None, # down_B
@@ -243,16 +250,19 @@ def test_lora_mlp_with_adapters(
output = LoRA_MLP.apply( output = LoRA_MLP.apply(
X, X,
gate_proj.weight, gate_proj.weight,
gate_proj.bias,
None, None,
gate_A, gate_A,
gate_B, gate_B,
scale, scale,
up_proj.weight, up_proj.weight,
up_proj.bias,
None, None,
up_A, up_A,
up_B, up_B,
scale, scale,
down_proj.weight, down_proj.weight,
down_proj.bias,
None, None,
down_A, down_A,
down_B, down_B,
@@ -323,6 +333,7 @@ def test_lora_qkv(sample_tensors):
X.requires_grad = True X.requires_grad = True
# Test without LoRA adapters # Test without LoRA adapters
# pylint: disable=duplicate-code
Q1, K1, V1 = LoRA_QKV.apply( Q1, K1, V1 = LoRA_QKV.apply(
X, X,
q_weight, q_weight,
@@ -330,16 +341,19 @@ def test_lora_qkv(sample_tensors):
None, None,
None, None,
None, None,
None,
k_weight, k_weight,
None, None,
None, None,
None, None,
None, None,
None,
v_weight, v_weight,
None, None,
None, None,
None, None,
None, None,
None,
True, True,
) )
@@ -356,16 +370,19 @@ def test_lora_qkv(sample_tensors):
X, X,
q_weight, q_weight,
None, None,
None,
q_A, q_A,
q_B, q_B,
scale, scale,
k_weight, k_weight,
None, None,
None,
k_A, k_A,
k_B, k_B,
scale, scale,
v_weight, v_weight,
None, None,
None,
v_A, v_A,
v_B, v_B,
scale, scale,
@@ -399,6 +416,7 @@ def test_lora_o(sample_tensors):
"""Tests LoRA output projection""" """Tests LoRA output projection"""
X = sample_tensors["X"] X = sample_tensors["X"]
W = sample_tensors["W"] W = sample_tensors["W"]
b = sample_tensors["b"]
scale = sample_tensors["scale"] scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"] shapes = sample_tensors["shapes"]
@@ -411,7 +429,7 @@ def test_lora_o(sample_tensors):
# Test forward pass # Test forward pass
X.requires_grad = True X.requires_grad = True
output = LoRA_O.apply(X, W, None, A, B, scale) output = LoRA_O.apply(X, W, b, None, A, B, scale)
assert output.shape == (X.shape[0], X.shape[1], W.shape[0]) assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
@@ -425,6 +443,7 @@ def test_with_quantization(sample_tensors, mock_quantstate):
"""Tests LoRA with quantized weights""" """Tests LoRA with quantized weights"""
X = sample_tensors["X"] # [batch, seq, hidden] X = sample_tensors["X"] # [batch, seq, hidden]
W = sample_tensors["W"] # [out, hidden] W = sample_tensors["W"] # [out, hidden]
b = sample_tensors["b"] # [out]
scale = 0.5 scale = 0.5
shapes = sample_tensors["shapes"] shapes = sample_tensors["shapes"]
@@ -436,13 +455,13 @@ def test_with_quantization(sample_tensors, mock_quantstate):
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16) B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test matmul with quantization # Test matmul with quantization
out = matmul_lora(X, W, mock_quantstate, A, B, scale) out = matmul_lora(X, W, b, mock_quantstate, A, B, scale)
assert out.shape == (X.shape[0], X.shape[1], W.shape[0]) assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
assert not torch.isnan(out).any() assert not torch.isnan(out).any()
# Test with different batch sizes # Test with different batch sizes
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16) X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale) out2 = matmul_lora(X2, W, b, mock_quantstate, A, B, scale)
assert out2.shape == (4, 6, W.shape[0]) assert out2.shape == (4, 6, W.shape[0])
assert not torch.isnan(out2).any() assert not torch.isnan(out2).any()
@@ -459,11 +478,12 @@ def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
"""Tests various input shapes and dimensions""" """Tests various input shapes and dimensions"""
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16) X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16) W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
b = torch.randn(out, device="cuda", dtype=torch.float16)
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16) A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
B = torch.randn(out, rank, device="cuda", dtype=torch.float16) B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
scale = 0.5 scale = 0.5
result = matmul_lora(X, W, None, A, B, scale) result = matmul_lora(X, W, b, None, A, B, scale)
assert result.shape == (batch, seq, out) assert result.shape == (batch, seq, out)
@@ -471,6 +491,7 @@ def test_gradient_flow(sample_tensors):
"""Tests gradient flow through LoRA layers""" """Tests gradient flow through LoRA layers"""
X = sample_tensors["X"].clone() X = sample_tensors["X"].clone()
W = sample_tensors["W"].clone() W = sample_tensors["W"].clone()
b = sample_tensors["b"].clone()
scale = sample_tensors["scale"] scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"] shapes = sample_tensors["shapes"]
@@ -486,7 +507,7 @@ def test_gradient_flow(sample_tensors):
B.requires_grad = True B.requires_grad = True
# Forward pass # Forward pass
out = matmul_lora(X, W, None, A, B, scale) out = matmul_lora(X, W, b, None, A, B, scale)
loss = out.sum() loss = out.sum()
# Backward pass # Backward pass