Lora kernels bias support (#3025)

* lora kernels bias support

* revert rename

* nit

* lint, tests

* satisfying the rabbit
This commit is contained in:
Dan Saunders
2025-08-06 20:20:08 -04:00
committed by GitHub
parent e442ff22aa
commit d09290f2f4
3 changed files with 156 additions and 90 deletions

View File

@@ -64,6 +64,7 @@ def sample_tensors():
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
),
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
"b": torch.randn(out_dim, device="cuda", dtype=torch.float16),
"scale": 0.5,
"shapes": {
"batch": batch_size,
@@ -103,23 +104,24 @@ def mock_proj():
def test_get_lora_parameters(mock_proj):
"""Tests get_lora_parameters function"""
# Test with LoRA enabled
W, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert isinstance(W, torch.Tensor)
assert W.shape == (128, 64)
assert b.shape == (128,)
assert A.shape == (8, 64)
assert B.shape == (128, 8)
assert s == 0.5
# Test with LoRA disabled
mock_proj.disable_adapters = True
W, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
# Test with merged state
mock_proj.disable_adapters = False
mock_proj.merged = True
W, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
@@ -127,6 +129,7 @@ def test_matmul_lora(sample_tensors):
"""Tests matmul_lora function"""
X = sample_tensors["X"]
W = sample_tensors["W"]
b = sample_tensors["b"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
@@ -138,19 +141,20 @@ def test_matmul_lora(sample_tensors):
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test base matmul
out1 = matmul_lora(X, W, None, None, None, None)
expected1 = torch.matmul(X, W.t())
out1 = matmul_lora(X, W, b, None, None, None, None)
matmul = torch.matmul(X, W.t())
expected1 = matmul + b
assert torch.allclose(out1, expected1, rtol=1e-3)
# Test with LoRA
out2 = matmul_lora(X, W, None, A, B, scale)
out2 = matmul_lora(X, W, b, None, A, B, scale)
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
expected2 = expected1 + lora_term
expected2 = matmul + lora_term + b
assert torch.allclose(out2, expected2, rtol=1e-3)
# Test 3D input reshaping
X_3d = X.clone()
out3 = matmul_lora(X_3d, W, None, A, B, scale)
out3 = matmul_lora(X_3d, W, b, None, A, B, scale)
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
@@ -175,16 +179,19 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
output = LoRA_MLP.apply(
X,
gate_proj.weight,
gate_proj.bias,
None, # gate_quant
None, # gate_A
None, # gate_B
None, # gate_scale
up_proj.weight,
up_proj.bias,
None, # up_quant
None, # up_A
None, # up_B
None, # up_scale
down_proj.weight,
down_proj.bias,
None, # down_quant
None, # down_A
None, # down_B
@@ -243,16 +250,19 @@ def test_lora_mlp_with_adapters(
output = LoRA_MLP.apply(
X,
gate_proj.weight,
gate_proj.bias,
None,
gate_A,
gate_B,
scale,
up_proj.weight,
up_proj.bias,
None,
up_A,
up_B,
scale,
down_proj.weight,
down_proj.bias,
None,
down_A,
down_B,
@@ -323,6 +333,7 @@ def test_lora_qkv(sample_tensors):
X.requires_grad = True
# Test without LoRA adapters
# pylint: disable=duplicate-code
Q1, K1, V1 = LoRA_QKV.apply(
X,
q_weight,
@@ -330,16 +341,19 @@ def test_lora_qkv(sample_tensors):
None,
None,
None,
None,
k_weight,
None,
None,
None,
None,
None,
v_weight,
None,
None,
None,
None,
None,
True,
)
@@ -356,16 +370,19 @@ def test_lora_qkv(sample_tensors):
X,
q_weight,
None,
None,
q_A,
q_B,
scale,
k_weight,
None,
None,
k_A,
k_B,
scale,
v_weight,
None,
None,
v_A,
v_B,
scale,
@@ -399,6 +416,7 @@ def test_lora_o(sample_tensors):
"""Tests LoRA output projection"""
X = sample_tensors["X"]
W = sample_tensors["W"]
b = sample_tensors["b"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
@@ -411,7 +429,7 @@ def test_lora_o(sample_tensors):
# Test forward pass
X.requires_grad = True
output = LoRA_O.apply(X, W, None, A, B, scale)
output = LoRA_O.apply(X, W, b, None, A, B, scale)
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
@@ -425,6 +443,7 @@ def test_with_quantization(sample_tensors, mock_quantstate):
"""Tests LoRA with quantized weights"""
X = sample_tensors["X"] # [batch, seq, hidden]
W = sample_tensors["W"] # [out, hidden]
b = sample_tensors["b"] # [out]
scale = 0.5
shapes = sample_tensors["shapes"]
@@ -436,13 +455,13 @@ def test_with_quantization(sample_tensors, mock_quantstate):
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test matmul with quantization
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
out = matmul_lora(X, W, b, mock_quantstate, A, B, scale)
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
assert not torch.isnan(out).any()
# Test with different batch sizes
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
out2 = matmul_lora(X2, W, b, mock_quantstate, A, B, scale)
assert out2.shape == (4, 6, W.shape[0])
assert not torch.isnan(out2).any()
@@ -459,11 +478,12 @@ def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
"""Tests various input shapes and dimensions"""
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
b = torch.randn(out, device="cuda", dtype=torch.float16)
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
scale = 0.5
result = matmul_lora(X, W, None, A, B, scale)
result = matmul_lora(X, W, b, None, A, B, scale)
assert result.shape == (batch, seq, out)
@@ -471,6 +491,7 @@ def test_gradient_flow(sample_tensors):
"""Tests gradient flow through LoRA layers"""
X = sample_tensors["X"].clone()
W = sample_tensors["W"].clone()
b = sample_tensors["b"].clone()
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
@@ -486,7 +507,7 @@ def test_gradient_flow(sample_tensors):
B.requires_grad = True
# Forward pass
out = matmul_lora(X, W, None, A, B, scale)
out = matmul_lora(X, W, b, None, A, B, scale)
loss = out.sum()
# Backward pass