Activation function Triton kernels, LoRA custom autograd functions (#2324)

* LoRA + activation fn Triton kernels: initial commit

* implementing optims

* finalizing MLP LoRA kernels and progress on QKV / W kernels

* updates

* O projection optim

* adding monkey patching logic

* doc strings, typing, pre-commit fixes

* updates

* adding lora 8b kernels example

* working on fsdp support

* tests and fixes

* small fixes, getting tests to pass, adding doc strings

* integration tests for LoRA patching

* config.qmd

* remove unneeded pytest fixture

* fix

* review comments first pass

* improving tests, attention class agnostic patching

* adding support for more archs

* wip SiLU / GELU impls

* improved testing, small updates, etc.

* slightly updating docs

* rebase

* fixing test_attention_patching_integration

* additional review comments, fixing test in CI (hopefully)

* isolating problematic patching test

* relaxing allclose threshold to reduce flakiness

* fixing accidental change

* adding model arch agnostic attention class fetching

* removing unused activations
This commit is contained in:
Dan Saunders
2025-02-17 14:23:15 -05:00
committed by GitHub
parent 97a2fa2781
commit 3d8425fa91
22 changed files with 3102 additions and 22 deletions

View File

@@ -0,0 +1,76 @@
"""Tests for GEGLU activation function Triton kernels."""
# pylint: disable=duplicate-code
import torch
import torch.nn.functional as F
from axolotl.kernels.geglu import geglu_backward, geglu_forward
def test_geglu_forward_shape():
"""Test that GEGLU forward pass preserves expected shapes."""
batch, seq_len, hidden_dim = 2, 3, 64
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
out = geglu_forward(gate, up)
assert out.shape == (batch, seq_len, hidden_dim)
assert out.dtype == gate.dtype
assert out.device == gate.device
def test_geglu_forward_values():
"""Test GEGLU forward pass matches PyTorch reference implementation."""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
# Custom implementation
triton_out = geglu_forward(gate.clone(), up.clone())
# PyTorch reference
torch_out = F.gelu(gate) * up
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
def test_geglu_backward():
"""Test GEGLU backward pass matches PyTorch autograd."""
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
grad_output = torch.randn(2, 3, 64, device="cuda")
# PyTorch reference - compute intermediates
gelu_gate = F.gelu(gate)
torch_out = gelu_gate * up
torch_out.backward(grad_output)
# Custom backward pass
gate_clone = gate.clone().detach()
up_clone = up.clone().detach()
grad_output_clone = grad_output.clone()
h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone)
# Compare outputs and gradients
assert torch.allclose(h, torch_out, rtol=1e-3)
assert torch.allclose(grad_gate, gate.grad, rtol=1e-3)
assert torch.allclose(grad_up, up.grad, rtol=1e-3)
def test_geglu_inplace_preservation():
"""Test that GEGLU backward doesn't modify original tensors unexpectedly."""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
grad_output = torch.randn(2, 3, 64, device="cuda")
gate_copy = gate.clone()
up_copy = up.clone()
grad_copy = grad_output.clone()
geglu_backward(grad_output, gate, up)
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
assert not torch.equal(up, up_copy), "Up should be modified in-place"
assert not torch.equal(
grad_output, grad_copy
), "Grad output should be modified in-place"

View File

@@ -0,0 +1,531 @@
"""Tests for LoRA custom autograd."""
# pylint: disable=invalid-name,redefined-outer-name
import pytest
import torch
from bitsandbytes.functional import QuantState
from torch import nn
from axolotl.kernels.geglu import geglu_backward, geglu_forward
from axolotl.kernels.lora import (
LoRA_MLP,
LoRA_O,
LoRA_QKV,
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
get_lora_parameters,
matmul_lora,
)
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
@pytest.fixture
def mock_quantstate():
"""Creates a mock QuantState for testing"""
shape = (64, 64)
n_blocks = shape[0] # Assuming blockwise quantization along first dimension
# Create nested state first
nested_state = QuantState(
absmax=torch.ones(n_blocks, device="cuda"), # One value per block
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"), # NF4 range is 0-15
dtype=torch.float16,
blocksize=64,
quant_type="nf4",
offset=None,
state2=None,
)
# Create main state with nested state
return QuantState(
absmax=torch.ones(n_blocks, device="cuda"),
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"),
dtype=torch.float16,
blocksize=64,
quant_type="nf4",
offset=torch.zeros(n_blocks, dtype=torch.int32, device="cuda"),
state2=nested_state,
)
@pytest.fixture
def sample_tensors():
"""Creates sample tensors for testing"""
torch.manual_seed(42)
batch_size, seq_len, hidden_dim = 2, 3, 64
rank = 8
out_dim = hidden_dim
return {
"X": torch.randn(
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
),
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
"scale": 0.5,
"shapes": {
"batch": batch_size,
"seq": seq_len,
"hidden": hidden_dim,
"out": out_dim,
"rank": rank,
},
}
@pytest.fixture
def mock_proj():
"""Creates a mock projection module for testing."""
class MockProj(nn.Module):
"""Mock projection class."""
def __init__(self, in_features=64, out_features=128, rank=8):
super().__init__()
self.base_layer = nn.Linear(in_features, out_features)
self.base_layer.to("cuda")
self.lora_A = nn.ModuleDict(
{"default": nn.Linear(in_features, rank, bias=False).to("cuda")}
)
self.lora_B = nn.ModuleDict(
{"default": nn.Linear(rank, out_features, bias=False).to("cuda")}
)
self.scaling = {"default": 0.5}
self.active_adapter = "default"
self.disable_adapters = False
self.merged = False
return MockProj()
def test_get_lora_parameters(mock_proj):
"""Tests get_lora_parameters function"""
# Test with LoRA enabled
W, _, A, B, s = get_lora_parameters(mock_proj)
assert isinstance(W, torch.Tensor)
assert W.shape == (128, 64)
assert A.shape == (8, 64)
assert B.shape == (128, 8)
assert s == 0.5
# Test with LoRA disabled
mock_proj.disable_adapters = True
W, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
# Test with merged state
mock_proj.disable_adapters = False
mock_proj.merged = True
W, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
def test_matmul_lora(sample_tensors):
"""Tests matmul_lora function"""
X = sample_tensors["X"]
W = sample_tensors["W"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test base matmul
out1 = matmul_lora(X, W, None, None, None, None)
expected1 = torch.matmul(X, W.t())
assert torch.allclose(out1, expected1, rtol=1e-3)
# Test with LoRA
out2 = matmul_lora(X, W, None, A, B, scale)
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
expected2 = expected1 + lora_term
assert torch.allclose(out2, expected2, rtol=1e-3)
# Test 3D input reshaping
X_3d = X.clone()
out3 = matmul_lora(X_3d, W, None, A, B, scale)
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
@pytest.mark.parametrize(
"activation_forward,activation_backward",
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
)
def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward):
"""Tests LoRA_MLP directly with different activation functions"""
X = sample_tensors["X"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
# Create linear layers
gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16)
# Test SwiGLU path
X.requires_grad = True
output = LoRA_MLP.apply(
X,
gate_proj.weight,
None, # gate_quant
None, # gate_A
None, # gate_B
None, # gate_scale
up_proj.weight,
None, # up_quant
None, # up_A
None, # up_B
None, # up_scale
down_proj.weight,
None, # down_quant
None, # down_A
None, # down_B
None, # down_scale
activation_forward,
activation_backward,
True, # inplace
)
assert output.shape == X.shape
assert not torch.isnan(output).any()
# Test backward pass
loss = output.sum()
loss.backward()
assert X.grad is not None
assert not torch.isnan(X.grad).any()
@pytest.mark.parametrize(
"activation_forward,activation_backward",
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
)
def test_lora_mlp_with_adapters(
sample_tensors, activation_forward, activation_backward
):
"""Tests LoRA_MLP with LoRA adapters"""
X = sample_tensors["X"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
# Create LoRA components
gate_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
gate_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
up_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
up_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
down_A = torch.randn(rank, out_dim, device="cuda", dtype=torch.float16)
down_B = torch.randn(hidden_dim, rank, device="cuda", dtype=torch.float16)
scale = 0.5
gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16)
X.requires_grad = True
gate_A.requires_grad = True
gate_B.requires_grad = True
up_A.requires_grad = True
up_B.requires_grad = True
down_A.requires_grad = True
down_B.requires_grad = True
# Forward pass with adapters
output = LoRA_MLP.apply(
X,
gate_proj.weight,
None,
gate_A,
gate_B,
scale,
up_proj.weight,
None,
up_A,
up_B,
scale,
down_proj.weight,
None,
down_A,
down_B,
scale,
activation_forward,
activation_backward,
True,
)
assert output.shape == X.shape
assert not torch.isnan(output).any()
# Test backward pass
loss = output.sum()
loss.backward()
# Check all gradients
assert X.grad is not None
assert gate_A.grad is not None
assert gate_B.grad is not None
assert up_A.grad is not None
assert up_B.grad is not None
assert down_A.grad is not None
assert down_B.grad is not None
assert not torch.isnan(X.grad).any()
assert not torch.isnan(gate_A.grad).any()
assert not torch.isnan(gate_B.grad).any()
assert not torch.isnan(up_A.grad).any()
assert not torch.isnan(up_B.grad).any()
assert not torch.isnan(down_A.grad).any()
assert not torch.isnan(down_B.grad).any()
def test_lora_qkv(sample_tensors):
"""Tests LoRA QKV implementation with and without adapters"""
X = sample_tensors["X"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
rank = shapes["rank"]
# Create base weights
q_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
k_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
v_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
# Create LoRA matrices
q_A = torch.randn(
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
)
q_B = torch.randn(
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
)
k_A = torch.randn(
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
)
k_B = torch.randn(
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
)
v_A = torch.randn(
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
)
v_B = torch.randn(
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
)
scale = 0.5
X.requires_grad = True
# Test without LoRA adapters
Q1, K1, V1 = LoRA_QKV.apply(
X,
q_weight,
None,
None,
None,
None,
k_weight,
None,
None,
None,
None,
v_weight,
None,
None,
None,
None,
True,
)
assert Q1.shape == K1.shape == V1.shape == X.shape
loss1 = (Q1 + K1 + V1).sum()
loss1.backward()
assert X.grad is not None
# Clear gradients
X.grad = None
# Test with LoRA adapters
Q2, K2, V2 = LoRA_QKV.apply(
X,
q_weight,
None,
q_A,
q_B,
scale,
k_weight,
None,
k_A,
k_B,
scale,
v_weight,
None,
v_A,
v_B,
scale,
True,
)
assert Q2.shape == K2.shape == V2.shape == X.shape
loss2 = (Q2 + K2 + V2).sum()
loss2.backward()
# Check gradients
assert X.grad is not None
assert q_A.grad is not None
assert q_B.grad is not None
assert k_A.grad is not None
assert k_B.grad is not None
assert v_A.grad is not None
assert v_B.grad is not None
# Check for NaN values
assert not torch.isnan(X.grad).any()
assert not torch.isnan(q_A.grad).any()
assert not torch.isnan(q_B.grad).any()
assert not torch.isnan(k_A.grad).any()
assert not torch.isnan(k_B.grad).any()
assert not torch.isnan(v_A.grad).any()
assert not torch.isnan(v_B.grad).any()
def test_lora_o(sample_tensors):
"""Tests LoRA output projection"""
X = sample_tensors["X"]
W = sample_tensors["W"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test forward pass
X.requires_grad = True
output = LoRA_O.apply(X, W, None, A, B, scale)
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
# Test backward pass
loss = output.sum()
loss.backward()
assert X.grad is not None
def test_with_quantization(sample_tensors, mock_quantstate):
"""Tests LoRA with quantized weights"""
X = sample_tensors["X"] # [batch, seq, hidden]
W = sample_tensors["W"] # [out, hidden]
scale = 0.5
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test matmul with quantization
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
assert not torch.isnan(out).any()
# Test with different batch sizes
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
assert out2.shape == (4, 6, W.shape[0])
assert not torch.isnan(out2).any()
@pytest.mark.parametrize(
"batch,seq,hidden,rank,out",
[
(1, 1, 32, 4, 64),
(2, 3, 64, 8, 128),
(4, 5, 128, 16, 256),
],
)
def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
"""Tests various input shapes and dimensions"""
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
scale = 0.5
result = matmul_lora(X, W, None, A, B, scale)
assert result.shape == (batch, seq, out)
def test_gradient_flow(sample_tensors):
"""Tests gradient flow through LoRA layers"""
X = sample_tensors["X"].clone()
W = sample_tensors["W"].clone()
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
X.requires_grad = True
A.requires_grad = True
B.requires_grad = True
# Forward pass
out = matmul_lora(X, W, None, A, B, scale)
loss = out.sum()
# Backward pass
loss.backward()
assert X.grad is not None
assert A.grad is not None
assert B.grad is not None
assert not torch.isnan(X.grad).any()
assert not torch.isnan(A.grad).any()
assert not torch.isnan(B.grad).any()
@pytest.mark.parametrize(
"apply_function",
[apply_lora_mlp_swiglu, apply_lora_mlp_geglu],
)
def test_inplace_operations(sample_tensors, apply_function):
"""Tests inplace operation behavior"""
X = sample_tensors["X"]
shapes = sample_tensors["shapes"]
# Create MLP with both inplace=True and inplace=False
mlp = type(
"MLPModule",
(),
{
"gate_proj": nn.Linear(shapes["hidden"], shapes["out"]).to(
device="cuda", dtype=torch.float16
),
"up_proj": nn.Linear(shapes["hidden"], shapes["out"]).to(
device="cuda", dtype=torch.float16
),
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
device="cuda", dtype=torch.float16
),
},
)
out1 = apply_function(mlp, X.clone(), inplace=True)
out2 = apply_function(mlp, X.clone(), inplace=False)
assert torch.allclose(out1, out2, rtol=1e-3)

View File

@@ -0,0 +1,103 @@
"""Tests for quantization utility functions."""
# pylint: disable=invalid-name
import torch
from bitsandbytes.functional import QuantState
from axolotl.kernels.quantize import dequantize
def test_dequantize_null_state():
"""Test that dequantize returns input unchanged when quant_state is None"""
W = torch.randn(32, 32)
assert torch.equal(dequantize(W, None), W)
def test_dequantize_shape_preservation():
"""Test that dequantization preserves expected shapes"""
shape = (32, 32)
W = torch.randn(shape, device="cuda")
quant_state = QuantState(
absmax=torch.ones(shape[0], device="cuda"),
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=torch.zeros(shape[0], dtype=torch.int32, device="cuda"),
state2=QuantState(
absmax=torch.ones(shape[0], device="cuda"),
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=None,
state2=None,
),
)
result = dequantize(W, quant_state)
assert result.shape == shape
assert result.dtype == torch.float16
assert result.device == W.device
def test_dequantize_transposed():
"""Test that transposed input produces transposed output"""
shape = (32, 32)
W = torch.randn(1, shape[1], device="cuda") # Transposed input
quant_state = QuantState(
absmax=torch.ones(1),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=torch.zeros(1, dtype=torch.int32),
state2=QuantState(
absmax=torch.ones(1),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=None,
state2=None,
),
)
result = dequantize(W, quant_state)
assert result.shape[0] == shape[0]
def test_dequantize_output_tensor():
"""Test dequantization with provided output tensor"""
shape = (32, 32)
W = torch.randn(shape, device="cuda")
out = torch.empty(shape, dtype=torch.float16, device="cuda")
quant_state = QuantState(
absmax=torch.ones(shape[0]),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=torch.zeros(shape[0], dtype=torch.int32),
state2=QuantState(
absmax=torch.ones(shape[0]),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=None,
state2=None,
),
)
result = dequantize(W, quant_state, out=out)
assert result is out

View File

@@ -0,0 +1,78 @@
"""Tests for SwiGLU activation function Triton kernels."""
# pylint: disable=duplicate-code
import torch
import torch.nn.functional as F
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
def test_swiglu_forward_shape():
"""Test that SwiGLU forward pass preserves expected shapes"""
batch, seq_len, hidden_dim = 2, 3, 64
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
out = swiglu_forward(gate, up)
assert out.shape == (batch, seq_len, hidden_dim)
assert out.dtype == gate.dtype
assert out.device == gate.device
def test_swiglu_forward_values():
"""Test SwiGLU forward pass matches PyTorch reference implementation"""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
# Custom implementation
triton_out = swiglu_forward(gate.clone(), up.clone())
# PyTorch reference
torch_out = F.silu(gate) * up
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
def test_swiglu_backward():
"""Test SwiGLU backward pass matches PyTorch autograd"""
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
grad_output = torch.randn(2, 3, 64, device="cuda")
# PyTorch reference - compute intermediates
silu_gate = F.silu(gate)
torch_out = silu_gate * up
torch_out.backward(grad_output)
# Custom backward pass
gate_clone = gate.clone().detach()
up_clone = up.clone().detach()
grad_output_clone = grad_output.clone()
h, our_grad_gate, our_grad_up = swiglu_backward(
grad_output_clone, gate_clone, up_clone
)
# Compare outputs and gradients
assert torch.allclose(h, torch_out, rtol=1e-3)
assert torch.allclose(our_grad_gate, gate.grad, rtol=1e-3)
assert torch.allclose(our_grad_up, up.grad, rtol=1e-3)
def test_swiglu_inplace_preservation():
"""Test that SwiGLU backward doesn't modify original tensors unexpectedly"""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
grad_output = torch.randn(2, 3, 64, device="cuda")
gate_copy = gate.clone()
up_copy = up.clone()
grad_copy = grad_output.clone()
swiglu_backward(grad_output, gate, up)
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
assert not torch.equal(up, up_copy), "Up should be modified in-place"
assert not torch.equal(
grad_output, grad_copy
), "Grad output should be modified in-place"

View File

@@ -0,0 +1,414 @@
"""Integration tests for LoRA activation and attention kernels."""
# pylint: disable=redefined-outer-name
import pytest
import torch
from accelerate.state import PartialState
from peft import PeftModelForCausalLM, get_peft_config
from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qkv,
)
from axolotl.monkeypatch.lora_kernels import (
apply_lora_kernel_patches,
patch_self_attn_lora,
)
from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [
{
"name": "openaccess-ai-collective/tiny-mistral",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
{
"name": "Qwen/Qwen2-7B",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
{
"name": "HuggingFaceTB/SmolLM2-135M",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float32,
},
{
"name": "mhenrichsen/gemma-2b",
"expected_activation": apply_lora_mlp_geglu,
"dtype": torch.float16,
},
]
@pytest.fixture(autouse=True)
def init_accelerate():
"""Initialize Accelerate state before tests."""
_ = PartialState()
@pytest.fixture
def small_llama_model():
"""Create a small LLaMA model for testing."""
config = {
"vocab_size": 100,
"hidden_size": 128,
"intermediate_size": 256,
"num_hidden_layers": 2,
"num_attention_heads": 4,
}
return LlamaForCausalLM(LlamaConfig(**config))
def test_attention_patching_integration():
"""Test attention patching in integration context."""
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
# Store the original implementation
original_forward = getattr(LlamaAttention, "forward")
# Apply patch
patch_self_attn_lora(cfg)
# Get the new forward method
patched_forward = LlamaAttention.forward
# Check the forward method was replaced
assert original_forward is not patched_forward
assert patched_forward.__name__ == "axolotl_attn_forward"
# Check original implementation was stored
assert hasattr(LlamaAttention, "_original_forward")
# Clean up
setattr(LlamaAttention, "forward", original_forward)
delattr(LlamaAttention, "_original_forward")
def test_swiglu_mlp_integration(small_llama_model):
"""Test SwiGLU activation in LoRA MLP context."""
peft_config = get_peft_config(
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "none",
}
)
model = PeftModelForCausalLM(small_llama_model, peft_config).to("cuda")
cfg = DictDefault({"lora_mlp_kernel": True})
# Apply patches
patched_model = apply_lora_kernel_patches(model, cfg)
# Verify patches
layer = patched_model.model.model.layers[0]
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
# Test forward pass
batch_size, seq_len = 2, 10
hidden_states = torch.randn(
batch_size, seq_len, model.config.hidden_size, device=model.device
)
position_ids = (
torch.arange(seq_len, device=model.device).unsqueeze(0).expand(batch_size, -1)
)
cos, sin = model.model.model.rotary_emb(hidden_states, position_ids)
inputs = {
"hidden_states": hidden_states,
"attention_mask": None,
"position_embeddings": (cos, sin),
"output_attentions": False,
"use_cache": False,
"past_key_value": None,
}
# Compare outputs
with torch.no_grad():
original_output = model.model.model.layers[0](**inputs)[0]
patched_output = layer(**inputs)[0]
assert torch.allclose(original_output, patched_output, rtol=1e-4)
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda"
)
peft_config = get_peft_config(
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "none",
}
)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
# Verify patches
layer = patched_model.model.model.layers[0]
assert layer.mlp.forward.__func__ is apply_lora_mlp_geglu
# Test end-to-end
inputs = torch.randint(0, 100, (1, 20), device=model.device, dtype=torch.long)
with torch.no_grad():
original_output = model(inputs).logits
patched_output = patched_model(inputs).logits
assert torch.allclose(original_output, patched_output, rtol=1e-4)
@pytest.mark.parametrize(
"model_name,expected_activation",
[
("HuggingFaceTB/SmolLM2-135M", apply_lora_mlp_swiglu),
("mhenrichsen/gemma-2b", apply_lora_mlp_geglu),
],
)
def test_model_specific_activation(model_name, expected_activation):
"""Test that each model type gets the correct activation function."""
model = AutoModelForCausalLM.from_pretrained(model_name)
peft_config = get_peft_config(
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "none",
}
)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0]
assert layer.mlp.forward.__func__ is expected_activation
def test_kernel_patch_conditions():
"""Test various conditions that should prevent kernel patching."""
test_configs = [
# Dropout prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0.1,
"bias": "none",
},
# Bias prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "lora_only",
},
]
for config in test_configs:
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
# Should not patch
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify no patches applied
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
def test_kernel_config_options():
"""Test that kernel configuration options are respected."""
# Test different configurations
test_configs = [
(
{"lora_mlp_kernel": True, "lora_qkv_kernel": False, "lora_o_kernel": False},
lambda layer: (
layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv
and layer.self_attn.apply_o.__func__ is not apply_lora_o
),
),
(
{"lora_mlp_kernel": False, "lora_qkv_kernel": True, "lora_o_kernel": False},
lambda layer: (
layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu
and layer.self_attn.apply_qkv.__func__ is apply_lora_qkv
and layer.self_attn.apply_o.__func__ is not apply_lora_o
),
),
(
{"lora_mlp_kernel": False, "lora_qkv_kernel": False, "lora_o_kernel": True},
lambda layer: (
layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu
and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv
and layer.self_attn.apply_o.__func__ is apply_lora_o
),
),
]
for config_dict, check_fn in test_configs:
# Create fresh model for each test
config = {
"vocab_size": 100,
"hidden_size": 128,
"intermediate_size": 256,
"num_hidden_layers": 2,
"num_attention_heads": 4,
}
small_llama_model = LlamaForCausalLM(LlamaConfig(**config))
peft_config = get_peft_config(
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": [
"gate_proj",
"up_proj",
"down_proj",
"q_proj",
"k_proj",
"v_proj",
"o_proj",
],
"lora_dropout": 0,
"bias": "none",
}
)
model = PeftModelForCausalLM(small_llama_model, peft_config).to("cuda")
cfg = DictDefault(config_dict)
patched_model = apply_lora_kernel_patches(model, cfg)
# Verify only requested optimizations were applied
for layer in patched_model.model.model.layers:
assert check_fn(layer), f"Failed for config: {config_dict}"
# Clean up
del model
del small_llama_model
del patched_model
def get_lora_config():
"""Get standard LoRA configuration for testing."""
return {
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "none",
}
def get_test_inputs(model, seq_length=20):
"""Generate test inputs for model evaluation."""
return torch.randint(
0,
model.config.vocab_size,
(1, seq_length),
device=model.device,
dtype=torch.long,
)
@pytest.mark.parametrize("model_config", MODEL_CONFIGS)
def test_model_architecture(model_config):
"""Test LoRA kernel patches across different model architectures."""
# Load model with appropriate dtype
model = AutoModelForCausalLM.from_pretrained(
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda"
)
# Apply LoRA configuration
peft_config = get_peft_config(get_lora_config())
model = PeftModelForCausalLM(model, peft_config)
# Apply kernel patches
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
# Verify correct activation function
layer = patched_model.model.model.layers[0]
assert (
layer.mlp.forward.__func__ is model_config["expected_activation"]
), f"Wrong activation for {model_config['name']}"
# Test forward pass
inputs = get_test_inputs(model)
with torch.no_grad():
original_output = model(inputs).logits
patched_output = patched_model(inputs).logits
# Check outputs match
assert torch.allclose(
original_output, patched_output, rtol=1e-4
), f"Outputs don't match for {model_config['name']}"
# pylint: disable=duplicate-code
def test_kernel_training_integration():
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
# Load model
model, _ = load_model_and_tokenizer(cfg=cfg)
# Verify correct activation function
layer = model.model.model.layers[0]
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu