Lora kernels bias support (#3025)

* lora kernels bias support

* revert rename

* nit

* lint, tests

* satisfying the rabbit
This commit is contained in:
Dan Saunders
2025-08-06 20:20:08 -04:00
committed by GitHub
parent e442ff22aa
commit d09290f2f4
3 changed files with 156 additions and 90 deletions

View File

@@ -26,6 +26,7 @@ def get_lora_parameters(
proj: nn.Module,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
QuantState | None,
torch.Tensor | None,
torch.Tensor | None,
@@ -38,17 +39,20 @@ def get_lora_parameters(
proj: The projection module to extract parameters from.
Returns:
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
LoRA B matrix, and scaling factor. States and matrices may be None if not
available.
A tuple containing the base weights, quantization state, LoRA A and B weights,
scaling factor, and base layer bias. Quant state, weights, and bias may be
`None` if not available.
"""
# For DPO or disabled adapters
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
W = base_layer.weight
b = base_layer.bias
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
return W, quant_state, None, None, None
return W, b, quant_state, None, None, None
quant_state = getattr(W, "quant_state", None)
active_adapter = (
proj.active_adapters[0]
@@ -72,18 +76,17 @@ def get_lora_parameters(
B = linear_B.weight
s = proj.scaling[active_adapter]
quant_state = getattr(W, "quant_state", None)
return W, quant_state, A, B, s
return W, b, quant_state, A, B, s
def matmul_lora(
X: torch.Tensor,
W: torch.Tensor,
W_quant: QuantState,
A: torch.Tensor,
B: torch.Tensor,
s: float,
b: torch.Tensor | None,
W_quant: QuantState | None,
A: torch.Tensor | None,
B: torch.Tensor | None,
s: float | None,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
@@ -104,21 +107,23 @@ def matmul_lora(
dtype = X.dtype
W = dequantize(W.t(), W_quant)
reshape = False
if X.dim() == 3:
batch, seq_len, _ = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
out = torch.matmul(X, W, out=out)
if W_quant is not None:
del W
if A is not None:
A, B = A.t().to(dtype), B.t().to(dtype)
A, B = A.t().to(dtype), B.t().to(dtype) # type: ignore[union-attr]
out += s * X @ A @ B
if b is not None:
out += b
return out.view(batch, seq_len, -1) if reshape else out
@@ -131,17 +136,20 @@ class LoRA_MLP(torch.autograd.Function):
ctx,
X: torch.Tensor,
gate_weight: torch.Tensor,
gate_quant: object | None,
gate_bias: torch.Tensor | None,
gate_quant: QuantState | None,
gate_A: torch.Tensor | None,
gate_B: torch.Tensor | None,
gate_scale: float,
up_weight: torch.Tensor,
up_quant: object | None,
up_bias: torch.Tensor | None,
up_quant: QuantState | None,
up_A: torch.Tensor | None,
up_B: torch.Tensor | None,
up_scale: float,
down_weight: torch.Tensor,
down_quant: object | None,
down_bias: torch.Tensor | None,
down_quant: QuantState | None,
down_A: torch.Tensor | None,
down_B: torch.Tensor | None,
down_scale: float,
@@ -156,20 +164,22 @@ class LoRA_MLP(torch.autograd.Function):
ctx: Autograd context
X: Input features
gate_weight: Gate projection weight
gate_bias: Gate projection bias
gate_quant: Gate quantization state
gate_A: Gate LoRA A matrix
gate_B: Gate LoRA B matrix
gate_scale: Gate LoRA scale
up_weight: Up-projection weight
up_quant: Up-projection quantization state
up_A: Up-projection LoRA A matrix
up_B: Up-projection LoRA B matrix
up_scale: Up-projection LoRA scale
down_weight: Down-projection weight
down_quant: Down-projection quantization state
down_A: Down-projection LoRA A matrix
down_B: Down-projection LoRA B matrix
down_scale: Down-projection LoRA scale
up_weight: Up projection weight
up_quant: Up projection quantization state
up_A: Up projection LoRA A matrix
up_B: Up projection LoRA B matrix
up_scale: Up projection LoRA scale
down_weight: Down projection weight
down_bias: Down projection bias
down_quant: Down projection quantization state
down_A: Down projection LoRA A matrix
down_B: Down projection LoRA B matrix
down_scale: Down projection LoRA scale
activation_fn: Forward activation function
activation_fn_backward: Backward activation function
inplace: Whether to perform operations in-place
@@ -178,15 +188,17 @@ class LoRA_MLP(torch.autograd.Function):
Output transformed by multi-layer perceptron and activation function
"""
# Compute projections
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
gate = matmul_lora(
X, gate_weight, gate_bias, gate_quant, gate_A, gate_B, gate_scale
)
up = matmul_lora(X, up_weight, up_bias, up_quant, up_A, up_B, up_scale)
# Activation
hidden = activation_fn(gate, up)
# Down projection
output = matmul_lora(
hidden, down_weight, down_quant, down_A, down_B, down_scale
hidden, down_weight, down_bias, down_quant, down_A, down_B, down_scale
)
# Save for backward
@@ -209,22 +221,26 @@ class LoRA_MLP(torch.autograd.Function):
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
None,
]:
"""
Performs backward pass computation for LoRA MLP.
@@ -236,7 +252,7 @@ class LoRA_MLP(torch.autograd.Function):
Returns:
Tuple containing gradients for all inputs from forward pass:
- Input gradient tensor (or `None`)
- `None` for weights/quantization states
- `None` for weights/biases/quantization states
- LoRA A/B matrix gradients (or `None`)
- `None` for scaling factors
- `None` for activation functions and flags
@@ -279,9 +295,10 @@ class LoRA_MLP(torch.autograd.Function):
dtype = X.dtype
# Down projection
DW = matmul_lora(
grad_down = matmul_lora(
grad_output,
down_weight.t(),
None,
down_quant,
down_B,
down_A,
@@ -289,7 +306,7 @@ class LoRA_MLP(torch.autograd.Function):
)
# Activation backward
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
h, grad_gate, grad_up = ctx.activation_fn_backward(grad_down, gate, up)
# Initialize and compute LoRA gradients
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
@@ -329,8 +346,8 @@ class LoRA_MLP(torch.autograd.Function):
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients
gate_weight = dequantize(gate_weight.t(), gate_quant)
dX += grad_gate @ gate_weight.t()
gate_weight = dequantize(gate_weight, gate_quant)
dX += grad_gate @ gate_weight
del gate_weight
if gate_A is not None and gate_B is not None:
@@ -348,22 +365,26 @@ class LoRA_MLP(torch.autograd.Function):
dX,
None,
None,
None,
d_gate_A.t() if d_gate_A is not None else None,
d_gate_B.t() if d_gate_B is not None else None,
None,
None,
None,
None,
d_up_A.t() if d_up_A is not None else None,
d_up_B.t() if d_up_B is not None else None,
None,
None,
None,
None,
d_down_A.t() if d_down_A is not None else None,
d_down_B.t() if d_down_B is not None else None,
None,
None,
None,
None,
None,
)
@@ -378,23 +399,26 @@ def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.
Returns:
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
"""
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateb,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upb,
upW_quant,
upA,
upB,
upS,
downW,
downb,
downW_quant,
downA,
downB,
@@ -418,22 +442,25 @@ def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.T
Returns:
Output tensor after applying LoRA-adapted MLP with GEGLU activation
"""
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateb,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upb,
upW_quant,
upA,
upB,
upS,
downW,
downb,
downW_quant,
downA,
downB,
@@ -460,16 +487,19 @@ class LoRA_QKV(torch.autograd.Function):
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
q_weight: torch.Tensor,
q_bias: torch.Tensor | None,
q_quant: QuantState | None,
q_A: torch.Tensor | None,
q_B: torch.Tensor | None,
q_scale: float,
k_weight: torch.Tensor,
k_bias: torch.Tensor | None,
k_quant: QuantState | None,
k_A: torch.Tensor | None,
k_B: torch.Tensor | None,
k_scale: float,
v_weight: torch.Tensor,
v_bias: torch.Tensor | None,
v_quant: QuantState | None,
v_A: torch.Tensor | None,
v_B: torch.Tensor | None,
@@ -483,16 +513,19 @@ class LoRA_QKV(torch.autograd.Function):
ctx: Autograd context
X: Input tensor
q_weight: Query projection weight
q_bias: Query projection bias
q_quant: Query quantization state
q_A: Query LoRA A matrix
q_B: Query LoRA B matrix
q_scale: Query LoRA scale
k_weight: Key projection weight
k_bias: Key projection bias
k_quant: Key quantization state
k_A: Key LoRA A matrix
k_B: Key LoRA B matrix
k_scale: Key LoRA scale
v_weight: Value projection weight
v_bias: Value projection bias
v_quant: Value quantization state
v_A: Value LoRA A matrix
v_B: Value LoRA B matrix
@@ -502,20 +535,21 @@ class LoRA_QKV(torch.autograd.Function):
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
Q = matmul_lora(X, q_weight, q_bias, q_quant, q_A, q_B, q_scale)
K = matmul_lora(X, k_weight, k_bias, k_quant, k_A, k_B, k_scale)
V = matmul_lora(X, v_weight, v_bias, v_quant, v_A, v_B, v_scale)
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
ctx.scales = (q_scale, k_scale, v_scale)
ctx.quants = (q_quant, k_quant, v_quant)
ctx.weights = (q_weight, k_weight, v_weight)
ctx.biases = (q_bias, k_bias, v_bias)
ctx.inplace = inplace
return Q, K, V
@staticmethod
@torch_amp_custom_fwd
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor,
@@ -525,16 +559,19 @@ class LoRA_QKV(torch.autograd.Function):
torch.Tensor,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
@@ -622,31 +659,31 @@ class LoRA_QKV(torch.autograd.Function):
# Transpose gradients if needed
if d_A_q is not None:
d_A_q = d_A_q.t()
if d_B_q is not None:
d_B_q = d_B_q.t()
d_B_q = d_B_q.t() # type: ignore[union-attr]
if d_A_k is not None:
d_A_k = d_A_k.t()
if d_B_k is not None:
d_B_k = d_B_k.t()
d_B_k = d_B_k.t() # type: ignore[union-attr]
if d_A_v is not None:
d_A_v = d_A_v.t()
if d_B_v is not None:
d_B_v = d_B_v.t()
d_B_v = d_B_v.t() # type: ignore[union-attr]
return (
grad_X.view(batch, seq_len, -1),
None,
None,
None,
d_A_q,
d_B_q,
None,
None,
None,
None,
d_A_k,
d_B_k,
None,
None,
None,
None,
d_A_v,
d_B_v,
None,
@@ -667,22 +704,25 @@ def apply_lora_qkv(
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
QW, Qb, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, Kb, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, Vb, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(
X,
QW,
Qb,
QW_quant,
QA,
QB,
QS,
KW,
Kb,
KW_quant,
KA,
KB,
KS,
VW,
Vb,
VW_quant,
VA,
VB,
@@ -702,10 +742,11 @@ class LoRA_O(torch.autograd.Function):
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
W: torch.Tensor,
b: torch.Tensor,
W_quant: QuantState | None,
A: torch.Tensor | None,
B: torch.Tensor | None,
S: float,
A: torch.Tensor,
B: torch.Tensor,
s: float,
) -> torch.Tensor:
"""
Forward pass for output projection with LoRA.
@@ -714,19 +755,20 @@ class LoRA_O(torch.autograd.Function):
ctx: Autograd context
X: Input tensor
W: Output projection weight
b: Output projection bias
W_quant: Weight quantization state
A: LoRA A matrix
B: LoRA B matrix
S: LoRA scaling factor
s: LoRA scaling factor
Returns:
Output projection tensor
Output projection result
"""
XW = matmul_lora(X, W, W_quant, A, B, S)
XW = matmul_lora(X, W, b, W_quant, A, B, s)
ctx.custom_saved_tensors = (
W,
W_quant,
S,
s,
)
ctx.save_for_backward(A, B, X)
@@ -741,8 +783,9 @@ class LoRA_O(torch.autograd.Function):
torch.Tensor,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
torch.Tensor,
torch.Tensor,
None,
]:
"""
@@ -755,7 +798,7 @@ class LoRA_O(torch.autograd.Function):
Returns:
Tuple containing gradients for all forward inputs
"""
W, W_quant, S = ctx.custom_saved_tensors
W, W_quant, s = ctx.custom_saved_tensors
A, B, X = ctx.saved_tensors
batch, seq_len, hd = X.shape
@@ -765,17 +808,19 @@ class LoRA_O(torch.autograd.Function):
# Weight projection
dY_X = X.t() @ dY
d_A = S * dY_X @ B
d_B = S * A @ dY_X
d_A = s * dY_X @ B
d_B = s * A @ dY_X
# Get derivative for dX
W = dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
A, B = A.to(dtype), B.to(dtype)
dX += s * dY @ B @ A
# W, b, W_quant, A, B, s
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
@@ -788,7 +833,7 @@ def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
Returns:
Transformed output tensor
"""
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
OW, Ob, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
output = LoRA_O.apply(X, OW, Ob, OW_quant, OA, OB, OS)
return output

View File

@@ -390,7 +390,6 @@ def apply_lora_kernel_patches(
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
@@ -400,7 +399,8 @@ def apply_lora_kernel_patches(
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
"Cannot patch some attention QKV projections - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
)
if cfg.lora_o_kernel:
# Output patching
@@ -409,7 +409,6 @@ def apply_lora_kernel_patches(
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
@@ -418,14 +417,14 @@ def apply_lora_kernel_patches(
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
"Cannot patch some attention output projection - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
)
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
)
@@ -435,7 +434,8 @@ def apply_lora_kernel_patches(
layer.mlp.forward = types.MethodType(apply_fn, mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
"Cannot patch some MLP layers - requires LoRA adapters and no "
"lora_magnitude_vector (DoRA)"
)
LOG.setLevel(original_level)

View File

@@ -64,6 +64,7 @@ def sample_tensors():
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
),
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
"b": torch.randn(out_dim, device="cuda", dtype=torch.float16),
"scale": 0.5,
"shapes": {
"batch": batch_size,
@@ -103,23 +104,24 @@ def mock_proj():
def test_get_lora_parameters(mock_proj):
"""Tests get_lora_parameters function"""
# Test with LoRA enabled
W, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert isinstance(W, torch.Tensor)
assert W.shape == (128, 64)
assert b.shape == (128,)
assert A.shape == (8, 64)
assert B.shape == (128, 8)
assert s == 0.5
# Test with LoRA disabled
mock_proj.disable_adapters = True
W, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
# Test with merged state
mock_proj.disable_adapters = False
mock_proj.merged = True
W, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
@@ -127,6 +129,7 @@ def test_matmul_lora(sample_tensors):
"""Tests matmul_lora function"""
X = sample_tensors["X"]
W = sample_tensors["W"]
b = sample_tensors["b"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
@@ -138,19 +141,20 @@ def test_matmul_lora(sample_tensors):
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test base matmul
out1 = matmul_lora(X, W, None, None, None, None)
expected1 = torch.matmul(X, W.t())
out1 = matmul_lora(X, W, b, None, None, None, None)
matmul = torch.matmul(X, W.t())
expected1 = matmul + b
assert torch.allclose(out1, expected1, rtol=1e-3)
# Test with LoRA
out2 = matmul_lora(X, W, None, A, B, scale)
out2 = matmul_lora(X, W, b, None, A, B, scale)
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
expected2 = expected1 + lora_term
expected2 = matmul + lora_term + b
assert torch.allclose(out2, expected2, rtol=1e-3)
# Test 3D input reshaping
X_3d = X.clone()
out3 = matmul_lora(X_3d, W, None, A, B, scale)
out3 = matmul_lora(X_3d, W, b, None, A, B, scale)
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
@@ -175,16 +179,19 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
output = LoRA_MLP.apply(
X,
gate_proj.weight,
gate_proj.bias,
None, # gate_quant
None, # gate_A
None, # gate_B
None, # gate_scale
up_proj.weight,
up_proj.bias,
None, # up_quant
None, # up_A
None, # up_B
None, # up_scale
down_proj.weight,
down_proj.bias,
None, # down_quant
None, # down_A
None, # down_B
@@ -243,16 +250,19 @@ def test_lora_mlp_with_adapters(
output = LoRA_MLP.apply(
X,
gate_proj.weight,
gate_proj.bias,
None,
gate_A,
gate_B,
scale,
up_proj.weight,
up_proj.bias,
None,
up_A,
up_B,
scale,
down_proj.weight,
down_proj.bias,
None,
down_A,
down_B,
@@ -323,6 +333,7 @@ def test_lora_qkv(sample_tensors):
X.requires_grad = True
# Test without LoRA adapters
# pylint: disable=duplicate-code
Q1, K1, V1 = LoRA_QKV.apply(
X,
q_weight,
@@ -330,16 +341,19 @@ def test_lora_qkv(sample_tensors):
None,
None,
None,
None,
k_weight,
None,
None,
None,
None,
None,
v_weight,
None,
None,
None,
None,
None,
True,
)
@@ -356,16 +370,19 @@ def test_lora_qkv(sample_tensors):
X,
q_weight,
None,
None,
q_A,
q_B,
scale,
k_weight,
None,
None,
k_A,
k_B,
scale,
v_weight,
None,
None,
v_A,
v_B,
scale,
@@ -399,6 +416,7 @@ def test_lora_o(sample_tensors):
"""Tests LoRA output projection"""
X = sample_tensors["X"]
W = sample_tensors["W"]
b = sample_tensors["b"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
@@ -411,7 +429,7 @@ def test_lora_o(sample_tensors):
# Test forward pass
X.requires_grad = True
output = LoRA_O.apply(X, W, None, A, B, scale)
output = LoRA_O.apply(X, W, b, None, A, B, scale)
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
@@ -425,6 +443,7 @@ def test_with_quantization(sample_tensors, mock_quantstate):
"""Tests LoRA with quantized weights"""
X = sample_tensors["X"] # [batch, seq, hidden]
W = sample_tensors["W"] # [out, hidden]
b = sample_tensors["b"] # [out]
scale = 0.5
shapes = sample_tensors["shapes"]
@@ -436,13 +455,13 @@ def test_with_quantization(sample_tensors, mock_quantstate):
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test matmul with quantization
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
out = matmul_lora(X, W, b, mock_quantstate, A, B, scale)
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
assert not torch.isnan(out).any()
# Test with different batch sizes
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
out2 = matmul_lora(X2, W, b, mock_quantstate, A, B, scale)
assert out2.shape == (4, 6, W.shape[0])
assert not torch.isnan(out2).any()
@@ -459,11 +478,12 @@ def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
"""Tests various input shapes and dimensions"""
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
b = torch.randn(out, device="cuda", dtype=torch.float16)
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
scale = 0.5
result = matmul_lora(X, W, None, A, B, scale)
result = matmul_lora(X, W, b, None, A, B, scale)
assert result.shape == (batch, seq, out)
@@ -471,6 +491,7 @@ def test_gradient_flow(sample_tensors):
"""Tests gradient flow through LoRA layers"""
X = sample_tensors["X"].clone()
W = sample_tensors["W"].clone()
b = sample_tensors["b"].clone()
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
@@ -486,7 +507,7 @@ def test_gradient_flow(sample_tensors):
B.requires_grad = True
# Forward pass
out = matmul_lora(X, W, None, A, B, scale)
out = matmul_lora(X, W, b, None, A, B, scale)
loss = out.sum()
# Backward pass