diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py index 82ec91107..fb45f2aa7 100644 --- a/src/axolotl/kernels/lora.py +++ b/src/axolotl/kernels/lora.py @@ -26,6 +26,7 @@ def get_lora_parameters( proj: nn.Module, ) -> tuple[ torch.Tensor, + torch.Tensor | None, QuantState | None, torch.Tensor | None, torch.Tensor | None, @@ -38,17 +39,20 @@ def get_lora_parameters( proj: The projection module to extract parameters from. Returns: - A tuple containing the base weight matrix, quantization state, LoRA A matrix, - LoRA B matrix, and scaling factor. States and matrices may be None if not - available. + A tuple containing the base weights, quantization state, LoRA A and B weights, + scaling factor, and base layer bias. Quant state, weights, and bias may be + `None` if not available. """ # For DPO or disabled adapters base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj W = base_layer.weight + b = base_layer.bias if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: quant_state = getattr(W, "quant_state", None) - return W, quant_state, None, None, None + return W, b, quant_state, None, None, None + + quant_state = getattr(W, "quant_state", None) active_adapter = ( proj.active_adapters[0] @@ -72,18 +76,17 @@ def get_lora_parameters( B = linear_B.weight s = proj.scaling[active_adapter] - quant_state = getattr(W, "quant_state", None) - - return W, quant_state, A, B, s + return W, b, quant_state, A, B, s def matmul_lora( X: torch.Tensor, W: torch.Tensor, - W_quant: QuantState, - A: torch.Tensor, - B: torch.Tensor, - s: float, + b: torch.Tensor | None, + W_quant: QuantState | None, + A: torch.Tensor | None, + B: torch.Tensor | None, + s: float | None, out: torch.Tensor | None = None, ) -> torch.Tensor: """ @@ -104,21 +107,23 @@ def matmul_lora( dtype = X.dtype W = dequantize(W.t(), W_quant) + reshape = False if X.dim() == 3: batch, seq_len, _ = X.shape X = X.view(-1, X.shape[-1]) reshape = True - else: - reshape = False out = torch.matmul(X, W, out=out) if W_quant is not None: del W if A is not None: - A, B = A.t().to(dtype), B.t().to(dtype) + A, B = A.t().to(dtype), B.t().to(dtype) # type: ignore[union-attr] out += s * X @ A @ B + if b is not None: + out += b + return out.view(batch, seq_len, -1) if reshape else out @@ -131,17 +136,20 @@ class LoRA_MLP(torch.autograd.Function): ctx, X: torch.Tensor, gate_weight: torch.Tensor, - gate_quant: object | None, + gate_bias: torch.Tensor | None, + gate_quant: QuantState | None, gate_A: torch.Tensor | None, gate_B: torch.Tensor | None, gate_scale: float, up_weight: torch.Tensor, - up_quant: object | None, + up_bias: torch.Tensor | None, + up_quant: QuantState | None, up_A: torch.Tensor | None, up_B: torch.Tensor | None, up_scale: float, down_weight: torch.Tensor, - down_quant: object | None, + down_bias: torch.Tensor | None, + down_quant: QuantState | None, down_A: torch.Tensor | None, down_B: torch.Tensor | None, down_scale: float, @@ -156,20 +164,22 @@ class LoRA_MLP(torch.autograd.Function): ctx: Autograd context X: Input features gate_weight: Gate projection weight + gate_bias: Gate projection bias gate_quant: Gate quantization state gate_A: Gate LoRA A matrix gate_B: Gate LoRA B matrix gate_scale: Gate LoRA scale - up_weight: Up-projection weight - up_quant: Up-projection quantization state - up_A: Up-projection LoRA A matrix - up_B: Up-projection LoRA B matrix - up_scale: Up-projection LoRA scale - down_weight: Down-projection weight - down_quant: Down-projection quantization state - down_A: Down-projection LoRA A matrix - down_B: Down-projection LoRA B matrix - down_scale: Down-projection LoRA scale + up_weight: Up projection weight + up_quant: Up projection quantization state + up_A: Up projection LoRA A matrix + up_B: Up projection LoRA B matrix + up_scale: Up projection LoRA scale + down_weight: Down projection weight + down_bias: Down projection bias + down_quant: Down projection quantization state + down_A: Down projection LoRA A matrix + down_B: Down projection LoRA B matrix + down_scale: Down projection LoRA scale activation_fn: Forward activation function activation_fn_backward: Backward activation function inplace: Whether to perform operations in-place @@ -178,15 +188,17 @@ class LoRA_MLP(torch.autograd.Function): Output transformed by multi-layer perceptron and activation function """ # Compute projections - gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale) - up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale) + gate = matmul_lora( + X, gate_weight, gate_bias, gate_quant, gate_A, gate_B, gate_scale + ) + up = matmul_lora(X, up_weight, up_bias, up_quant, up_A, up_B, up_scale) # Activation hidden = activation_fn(gate, up) # Down projection output = matmul_lora( - hidden, down_weight, down_quant, down_A, down_B, down_scale + hidden, down_weight, down_bias, down_quant, down_A, down_B, down_scale ) # Save for backward @@ -209,22 +221,26 @@ class LoRA_MLP(torch.autograd.Function): torch.Tensor | None, None, None, + None, torch.Tensor | None, torch.Tensor | None, None, None, None, + None, torch.Tensor | None, torch.Tensor | None, None, None, None, + None, torch.Tensor | None, torch.Tensor | None, None, None, None, None, + None, ]: """ Performs backward pass computation for LoRA MLP. @@ -236,7 +252,7 @@ class LoRA_MLP(torch.autograd.Function): Returns: Tuple containing gradients for all inputs from forward pass: - Input gradient tensor (or `None`) - - `None` for weights/quantization states + - `None` for weights/biases/quantization states - LoRA A/B matrix gradients (or `None`) - `None` for scaling factors - `None` for activation functions and flags @@ -279,9 +295,10 @@ class LoRA_MLP(torch.autograd.Function): dtype = X.dtype # Down projection - DW = matmul_lora( + grad_down = matmul_lora( grad_output, down_weight.t(), + None, down_quant, down_B, down_A, @@ -289,7 +306,7 @@ class LoRA_MLP(torch.autograd.Function): ) # Activation backward - h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up) + h, grad_gate, grad_up = ctx.activation_fn_backward(grad_down, gate, up) # Initialize and compute LoRA gradients d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None @@ -329,8 +346,8 @@ class LoRA_MLP(torch.autograd.Function): dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t()) # Gate projection gradients - gate_weight = dequantize(gate_weight.t(), gate_quant) - dX += grad_gate @ gate_weight.t() + gate_weight = dequantize(gate_weight, gate_quant) + dX += grad_gate @ gate_weight del gate_weight if gate_A is not None and gate_B is not None: @@ -348,22 +365,26 @@ class LoRA_MLP(torch.autograd.Function): dX, None, None, + None, d_gate_A.t() if d_gate_A is not None else None, d_gate_B.t() if d_gate_B is not None else None, None, None, None, + None, d_up_A.t() if d_up_A is not None else None, d_up_B.t() if d_up_B is not None else None, None, None, None, + None, d_down_A.t() if d_down_A is not None else None, d_down_B.t() if d_down_B is not None else None, None, None, None, None, + None, ) @@ -378,23 +399,26 @@ def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch. Returns: Output tensor after applying LoRA-adapted MLP with SwiGLU activation """ - gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) - upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) - downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) + gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) + upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) + downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) out = LoRA_MLP.apply( X, gateW, + gateb, gateW_quant, gateA, gateB, gateS, upW, + upb, upW_quant, upA, upB, upS, downW, + downb, downW_quant, downA, downB, @@ -418,22 +442,25 @@ def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.T Returns: Output tensor after applying LoRA-adapted MLP with GEGLU activation """ - gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) - upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) - downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) + gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) + upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) + downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) out = LoRA_MLP.apply( X, gateW, + gateb, gateW_quant, gateA, gateB, gateS, upW, + upb, upW_quant, upA, upB, upS, downW, + downb, downW_quant, downA, downB, @@ -460,16 +487,19 @@ class LoRA_QKV(torch.autograd.Function): ctx: torch.autograd.function.FunctionCtx, X: torch.Tensor, q_weight: torch.Tensor, + q_bias: torch.Tensor | None, q_quant: QuantState | None, q_A: torch.Tensor | None, q_B: torch.Tensor | None, q_scale: float, k_weight: torch.Tensor, + k_bias: torch.Tensor | None, k_quant: QuantState | None, k_A: torch.Tensor | None, k_B: torch.Tensor | None, k_scale: float, v_weight: torch.Tensor, + v_bias: torch.Tensor | None, v_quant: QuantState | None, v_A: torch.Tensor | None, v_B: torch.Tensor | None, @@ -483,16 +513,19 @@ class LoRA_QKV(torch.autograd.Function): ctx: Autograd context X: Input tensor q_weight: Query projection weight + q_bias: Query projection bias q_quant: Query quantization state q_A: Query LoRA A matrix q_B: Query LoRA B matrix q_scale: Query LoRA scale k_weight: Key projection weight + k_bias: Key projection bias k_quant: Key quantization state k_A: Key LoRA A matrix k_B: Key LoRA B matrix k_scale: Key LoRA scale v_weight: Value projection weight + v_bias: Value projection bias v_quant: Value quantization state v_A: Value LoRA A matrix v_B: Value LoRA B matrix @@ -502,20 +535,21 @@ class LoRA_QKV(torch.autograd.Function): Returns: Tuple of (Query, Key, Value) projection tensors """ - Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale) - K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale) - V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale) + Q = matmul_lora(X, q_weight, q_bias, q_quant, q_A, q_B, q_scale) + K = matmul_lora(X, k_weight, k_bias, k_quant, k_A, k_B, k_scale) + V = matmul_lora(X, v_weight, v_bias, v_quant, v_A, v_B, v_scale) ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B) ctx.scales = (q_scale, k_scale, v_scale) ctx.quants = (q_quant, k_quant, v_quant) ctx.weights = (q_weight, k_weight, v_weight) + ctx.biases = (q_bias, k_bias, v_bias) ctx.inplace = inplace return Q, K, V @staticmethod - @torch_amp_custom_fwd + @torch_amp_custom_bwd def backward( ctx: torch.autograd.function.FunctionCtx, q_grad: torch.Tensor, @@ -525,16 +559,19 @@ class LoRA_QKV(torch.autograd.Function): torch.Tensor, None, None, + None, torch.Tensor | None, torch.Tensor | None, None, None, None, + None, torch.Tensor | None, torch.Tensor | None, None, None, None, + None, torch.Tensor | None, torch.Tensor | None, None, @@ -622,31 +659,31 @@ class LoRA_QKV(torch.autograd.Function): # Transpose gradients if needed if d_A_q is not None: d_A_q = d_A_q.t() - if d_B_q is not None: - d_B_q = d_B_q.t() + d_B_q = d_B_q.t() # type: ignore[union-attr] if d_A_k is not None: d_A_k = d_A_k.t() - if d_B_k is not None: - d_B_k = d_B_k.t() + d_B_k = d_B_k.t() # type: ignore[union-attr] if d_A_v is not None: d_A_v = d_A_v.t() - if d_B_v is not None: - d_B_v = d_B_v.t() + d_B_v = d_B_v.t() # type: ignore[union-attr] return ( grad_X.view(batch, seq_len, -1), None, None, + None, d_A_q, d_B_q, None, None, None, + None, d_A_k, d_B_k, None, None, None, + None, d_A_v, d_B_v, None, @@ -667,22 +704,25 @@ def apply_lora_qkv( Returns: Tuple of (Query, Key, Value) projection tensors """ - QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj) - KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj) - VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj) + QW, Qb, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj) + KW, Kb, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj) + VW, Vb, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj) Q, K, V = LoRA_QKV.apply( X, QW, + Qb, QW_quant, QA, QB, QS, KW, + Kb, KW_quant, KA, KB, KS, VW, + Vb, VW_quant, VA, VB, @@ -702,10 +742,11 @@ class LoRA_O(torch.autograd.Function): ctx: torch.autograd.function.FunctionCtx, X: torch.Tensor, W: torch.Tensor, + b: torch.Tensor, W_quant: QuantState | None, - A: torch.Tensor | None, - B: torch.Tensor | None, - S: float, + A: torch.Tensor, + B: torch.Tensor, + s: float, ) -> torch.Tensor: """ Forward pass for output projection with LoRA. @@ -714,19 +755,20 @@ class LoRA_O(torch.autograd.Function): ctx: Autograd context X: Input tensor W: Output projection weight + b: Output projection bias W_quant: Weight quantization state A: LoRA A matrix B: LoRA B matrix - S: LoRA scaling factor + s: LoRA scaling factor Returns: - Output projection tensor + Output projection result """ - XW = matmul_lora(X, W, W_quant, A, B, S) + XW = matmul_lora(X, W, b, W_quant, A, B, s) ctx.custom_saved_tensors = ( W, W_quant, - S, + s, ) ctx.save_for_backward(A, B, X) @@ -741,8 +783,9 @@ class LoRA_O(torch.autograd.Function): torch.Tensor, None, None, - torch.Tensor | None, - torch.Tensor | None, + None, + torch.Tensor, + torch.Tensor, None, ]: """ @@ -755,7 +798,7 @@ class LoRA_O(torch.autograd.Function): Returns: Tuple containing gradients for all forward inputs """ - W, W_quant, S = ctx.custom_saved_tensors + W, W_quant, s = ctx.custom_saved_tensors A, B, X = ctx.saved_tensors batch, seq_len, hd = X.shape @@ -765,17 +808,19 @@ class LoRA_O(torch.autograd.Function): # Weight projection dY_X = X.t() @ dY - d_A = S * dY_X @ B - d_B = S * A @ dY_X + d_A = s * dY_X @ B + d_B = s * A @ dY_X # Get derivative for dX W = dequantize(W.t(), W_quant) dX = dY @ W.t() del W - dX += dY @ B.to(dtype) @ (S * A.to(dtype)) - # W, W_quant, A, B, S - return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None + A, B = A.to(dtype), B.to(dtype) + dX += s * dY @ B @ A + + # W, b, W_quant, A, B, s + return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor: @@ -788,7 +833,7 @@ def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor: Returns: Transformed output tensor """ - OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj) - output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS) + OW, Ob, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj) + output = LoRA_O.apply(X, OW, Ob, OW_quant, OA, OB, OS) return output diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 48bc10c0b..b420a965c 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -390,7 +390,6 @@ def apply_lora_kernel_patches( ] can_patch_qkv = all( hasattr(module, "lora_A") - and getattr(module, "base_layer", module).bias is None and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 for module in layer_modules ) @@ -400,7 +399,8 @@ def apply_lora_kernel_patches( self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn) else: LOG.warning_once( - "Cannot patch some attention QKV projections - requires LoRA adapters with no bias" + "Cannot patch some attention QKV projections - requires LoRA " + "adapters and no lora_magnitude_vector (DoRA)" ) if cfg.lora_o_kernel: # Output patching @@ -409,7 +409,6 @@ def apply_lora_kernel_patches( ] can_patch_o = all( hasattr(module, "lora_A") - and getattr(module, "base_layer", module).bias is None and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 for module in layer_modules ) @@ -418,14 +417,14 @@ def apply_lora_kernel_patches( self_attn.apply_o = types.MethodType(apply_lora_o, self_attn) else: LOG.warning_once( - "Cannot patch some attention output projection - requires LoRA adapters with no bias" + "Cannot patch some attention output projection - requires LoRA " + "adapters and no lora_magnitude_vector (DoRA)" ) for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer): if cfg.lora_mlp_kernel: # MLP patching can_patch_mlp = all( hasattr(proj, "lora_A") - and getattr(proj, "base_layer", proj).bias is None and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 for proj in (gate_proj, up_proj, down_proj) ) @@ -435,7 +434,8 @@ def apply_lora_kernel_patches( layer.mlp.forward = types.MethodType(apply_fn, mlp) else: LOG.warning_once( - "Cannot patch some MLP layers - requires LoRA adapters with no bias" + "Cannot patch some MLP layers - requires LoRA adapters and no " + "lora_magnitude_vector (DoRA)" ) LOG.setLevel(original_level) diff --git a/tests/e2e/kernels/test_lora.py b/tests/e2e/kernels/test_lora.py index 5ad186cbf..cd6131ff1 100644 --- a/tests/e2e/kernels/test_lora.py +++ b/tests/e2e/kernels/test_lora.py @@ -64,6 +64,7 @@ def sample_tensors(): batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16 ), "W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16), + "b": torch.randn(out_dim, device="cuda", dtype=torch.float16), "scale": 0.5, "shapes": { "batch": batch_size, @@ -103,23 +104,24 @@ def mock_proj(): def test_get_lora_parameters(mock_proj): """Tests get_lora_parameters function""" # Test with LoRA enabled - W, _, A, B, s = get_lora_parameters(mock_proj) + W, b, _, A, B, s = get_lora_parameters(mock_proj) assert isinstance(W, torch.Tensor) assert W.shape == (128, 64) + assert b.shape == (128,) assert A.shape == (8, 64) assert B.shape == (128, 8) assert s == 0.5 # Test with LoRA disabled mock_proj.disable_adapters = True - W, _, A, B, s = get_lora_parameters(mock_proj) + W, b, _, A, B, s = get_lora_parameters(mock_proj) assert A is None and B is None and s is None # Test with merged state mock_proj.disable_adapters = False mock_proj.merged = True - W, _, A, B, s = get_lora_parameters(mock_proj) + W, b, _, A, B, s = get_lora_parameters(mock_proj) assert A is None and B is None and s is None @@ -127,6 +129,7 @@ def test_matmul_lora(sample_tensors): """Tests matmul_lora function""" X = sample_tensors["X"] W = sample_tensors["W"] + b = sample_tensors["b"] scale = sample_tensors["scale"] shapes = sample_tensors["shapes"] @@ -138,19 +141,20 @@ def test_matmul_lora(sample_tensors): B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16) # Test base matmul - out1 = matmul_lora(X, W, None, None, None, None) - expected1 = torch.matmul(X, W.t()) + out1 = matmul_lora(X, W, b, None, None, None, None) + matmul = torch.matmul(X, W.t()) + expected1 = matmul + b assert torch.allclose(out1, expected1, rtol=1e-3) # Test with LoRA - out2 = matmul_lora(X, W, None, A, B, scale) + out2 = matmul_lora(X, W, b, None, A, B, scale) lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t()) - expected2 = expected1 + lora_term + expected2 = matmul + lora_term + b assert torch.allclose(out2, expected2, rtol=1e-3) # Test 3D input reshaping X_3d = X.clone() - out3 = matmul_lora(X_3d, W, None, A, B, scale) + out3 = matmul_lora(X_3d, W, b, None, A, B, scale) assert out3.shape == (X.shape[0], X.shape[1], W.shape[0]) @@ -175,16 +179,19 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward output = LoRA_MLP.apply( X, gate_proj.weight, + gate_proj.bias, None, # gate_quant None, # gate_A None, # gate_B None, # gate_scale up_proj.weight, + up_proj.bias, None, # up_quant None, # up_A None, # up_B None, # up_scale down_proj.weight, + down_proj.bias, None, # down_quant None, # down_A None, # down_B @@ -243,16 +250,19 @@ def test_lora_mlp_with_adapters( output = LoRA_MLP.apply( X, gate_proj.weight, + gate_proj.bias, None, gate_A, gate_B, scale, up_proj.weight, + up_proj.bias, None, up_A, up_B, scale, down_proj.weight, + down_proj.bias, None, down_A, down_B, @@ -323,6 +333,7 @@ def test_lora_qkv(sample_tensors): X.requires_grad = True # Test without LoRA adapters + # pylint: disable=duplicate-code Q1, K1, V1 = LoRA_QKV.apply( X, q_weight, @@ -330,16 +341,19 @@ def test_lora_qkv(sample_tensors): None, None, None, + None, k_weight, None, None, None, None, + None, v_weight, None, None, None, None, + None, True, ) @@ -356,16 +370,19 @@ def test_lora_qkv(sample_tensors): X, q_weight, None, + None, q_A, q_B, scale, k_weight, None, + None, k_A, k_B, scale, v_weight, None, + None, v_A, v_B, scale, @@ -399,6 +416,7 @@ def test_lora_o(sample_tensors): """Tests LoRA output projection""" X = sample_tensors["X"] W = sample_tensors["W"] + b = sample_tensors["b"] scale = sample_tensors["scale"] shapes = sample_tensors["shapes"] @@ -411,7 +429,7 @@ def test_lora_o(sample_tensors): # Test forward pass X.requires_grad = True - output = LoRA_O.apply(X, W, None, A, B, scale) + output = LoRA_O.apply(X, W, b, None, A, B, scale) assert output.shape == (X.shape[0], X.shape[1], W.shape[0]) @@ -425,6 +443,7 @@ def test_with_quantization(sample_tensors, mock_quantstate): """Tests LoRA with quantized weights""" X = sample_tensors["X"] # [batch, seq, hidden] W = sample_tensors["W"] # [out, hidden] + b = sample_tensors["b"] # [out] scale = 0.5 shapes = sample_tensors["shapes"] @@ -436,13 +455,13 @@ def test_with_quantization(sample_tensors, mock_quantstate): B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16) # Test matmul with quantization - out = matmul_lora(X, W, mock_quantstate, A, B, scale) + out = matmul_lora(X, W, b, mock_quantstate, A, B, scale) assert out.shape == (X.shape[0], X.shape[1], W.shape[0]) assert not torch.isnan(out).any() # Test with different batch sizes X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16) - out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale) + out2 = matmul_lora(X2, W, b, mock_quantstate, A, B, scale) assert out2.shape == (4, 6, W.shape[0]) assert not torch.isnan(out2).any() @@ -459,11 +478,12 @@ def test_shapes_and_dimensions(batch, seq, hidden, rank, out): """Tests various input shapes and dimensions""" X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16) W = torch.randn(out, hidden, device="cuda", dtype=torch.float16) + b = torch.randn(out, device="cuda", dtype=torch.float16) A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16) B = torch.randn(out, rank, device="cuda", dtype=torch.float16) scale = 0.5 - result = matmul_lora(X, W, None, A, B, scale) + result = matmul_lora(X, W, b, None, A, B, scale) assert result.shape == (batch, seq, out) @@ -471,6 +491,7 @@ def test_gradient_flow(sample_tensors): """Tests gradient flow through LoRA layers""" X = sample_tensors["X"].clone() W = sample_tensors["W"].clone() + b = sample_tensors["b"].clone() scale = sample_tensors["scale"] shapes = sample_tensors["shapes"] @@ -486,7 +507,7 @@ def test_gradient_flow(sample_tensors): B.requires_grad = True # Forward pass - out = matmul_lora(X, W, None, A, B, scale) + out = matmul_lora(X, W, b, None, A, B, scale) loss = out.sum() # Backward pass