Activation function Triton kernels, LoRA custom autograd functions (#2324)
* LoRA + activation fn Triton kernels: initial commit * implementing optims * finalizing MLP LoRA kernels and progress on QKV / W kernels * updates * O projection optim * adding monkey patching logic * doc strings, typing, pre-commit fixes * updates * adding lora 8b kernels example * working on fsdp support * tests and fixes * small fixes, getting tests to pass, adding doc strings * integration tests for LoRA patching * config.qmd * remove unneeded pytest fixture * fix * review comments first pass * improving tests, attention class agnostic patching * adding support for more archs * wip SiLU / GELU impls * improved testing, small updates, etc. * slightly updating docs * rebase * fixing test_attention_patching_integration * additional review comments, fixing test in CI (hopefully) * isolating problematic patching test * relaxing allclose threshold to reduce flakiness * fixing accidental change * adding model arch agnostic attention class fetching * removing unused activations
This commit is contained in:
76
tests/e2e/kernels/test_geglu.py
Normal file
76
tests/e2e/kernels/test_geglu.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Tests for GEGLU activation function Triton kernels."""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.kernels.geglu import geglu_backward, geglu_forward
|
||||
|
||||
|
||||
def test_geglu_forward_shape():
|
||||
"""Test that GEGLU forward pass preserves expected shapes."""
|
||||
batch, seq_len, hidden_dim = 2, 3, 64
|
||||
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
|
||||
out = geglu_forward(gate, up)
|
||||
assert out.shape == (batch, seq_len, hidden_dim)
|
||||
assert out.dtype == gate.dtype
|
||||
assert out.device == gate.device
|
||||
|
||||
|
||||
def test_geglu_forward_values():
|
||||
"""Test GEGLU forward pass matches PyTorch reference implementation."""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# Custom implementation
|
||||
triton_out = geglu_forward(gate.clone(), up.clone())
|
||||
|
||||
# PyTorch reference
|
||||
torch_out = F.gelu(gate) * up
|
||||
|
||||
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
||||
|
||||
|
||||
def test_geglu_backward():
|
||||
"""Test GEGLU backward pass matches PyTorch autograd."""
|
||||
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# PyTorch reference - compute intermediates
|
||||
gelu_gate = F.gelu(gate)
|
||||
torch_out = gelu_gate * up
|
||||
torch_out.backward(grad_output)
|
||||
|
||||
# Custom backward pass
|
||||
gate_clone = gate.clone().detach()
|
||||
up_clone = up.clone().detach()
|
||||
grad_output_clone = grad_output.clone()
|
||||
|
||||
h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone)
|
||||
|
||||
# Compare outputs and gradients
|
||||
assert torch.allclose(h, torch_out, rtol=1e-3)
|
||||
assert torch.allclose(grad_gate, gate.grad, rtol=1e-3)
|
||||
assert torch.allclose(grad_up, up.grad, rtol=1e-3)
|
||||
|
||||
|
||||
def test_geglu_inplace_preservation():
|
||||
"""Test that GEGLU backward doesn't modify original tensors unexpectedly."""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
gate_copy = gate.clone()
|
||||
up_copy = up.clone()
|
||||
grad_copy = grad_output.clone()
|
||||
|
||||
geglu_backward(grad_output, gate, up)
|
||||
|
||||
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
||||
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
||||
assert not torch.equal(
|
||||
grad_output, grad_copy
|
||||
), "Grad output should be modified in-place"
|
||||
531
tests/e2e/kernels/test_lora.py
Normal file
531
tests/e2e/kernels/test_lora.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""Tests for LoRA custom autograd."""
|
||||
# pylint: disable=invalid-name,redefined-outer-name
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
from torch import nn
|
||||
|
||||
from axolotl.kernels.geglu import geglu_backward, geglu_forward
|
||||
from axolotl.kernels.lora import (
|
||||
LoRA_MLP,
|
||||
LoRA_O,
|
||||
LoRA_QKV,
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
get_lora_parameters,
|
||||
matmul_lora,
|
||||
)
|
||||
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quantstate():
|
||||
"""Creates a mock QuantState for testing"""
|
||||
shape = (64, 64)
|
||||
n_blocks = shape[0] # Assuming blockwise quantization along first dimension
|
||||
|
||||
# Create nested state first
|
||||
nested_state = QuantState(
|
||||
absmax=torch.ones(n_blocks, device="cuda"), # One value per block
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"), # NF4 range is 0-15
|
||||
dtype=torch.float16,
|
||||
blocksize=64,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
)
|
||||
|
||||
# Create main state with nested state
|
||||
return QuantState(
|
||||
absmax=torch.ones(n_blocks, device="cuda"),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"),
|
||||
dtype=torch.float16,
|
||||
blocksize=64,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(n_blocks, dtype=torch.int32, device="cuda"),
|
||||
state2=nested_state,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensors():
|
||||
"""Creates sample tensors for testing"""
|
||||
torch.manual_seed(42)
|
||||
batch_size, seq_len, hidden_dim = 2, 3, 64
|
||||
rank = 8
|
||||
out_dim = hidden_dim
|
||||
|
||||
return {
|
||||
"X": torch.randn(
|
||||
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
|
||||
),
|
||||
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
|
||||
"scale": 0.5,
|
||||
"shapes": {
|
||||
"batch": batch_size,
|
||||
"seq": seq_len,
|
||||
"hidden": hidden_dim,
|
||||
"out": out_dim,
|
||||
"rank": rank,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_proj():
|
||||
"""Creates a mock projection module for testing."""
|
||||
|
||||
class MockProj(nn.Module):
|
||||
"""Mock projection class."""
|
||||
|
||||
def __init__(self, in_features=64, out_features=128, rank=8):
|
||||
super().__init__()
|
||||
self.base_layer = nn.Linear(in_features, out_features)
|
||||
self.base_layer.to("cuda")
|
||||
self.lora_A = nn.ModuleDict(
|
||||
{"default": nn.Linear(in_features, rank, bias=False).to("cuda")}
|
||||
)
|
||||
self.lora_B = nn.ModuleDict(
|
||||
{"default": nn.Linear(rank, out_features, bias=False).to("cuda")}
|
||||
)
|
||||
self.scaling = {"default": 0.5}
|
||||
self.active_adapter = "default"
|
||||
self.disable_adapters = False
|
||||
self.merged = False
|
||||
|
||||
return MockProj()
|
||||
|
||||
|
||||
def test_get_lora_parameters(mock_proj):
|
||||
"""Tests get_lora_parameters function"""
|
||||
# Test with LoRA enabled
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
|
||||
assert isinstance(W, torch.Tensor)
|
||||
assert W.shape == (128, 64)
|
||||
assert A.shape == (8, 64)
|
||||
assert B.shape == (128, 8)
|
||||
assert s == 0.5
|
||||
|
||||
# Test with LoRA disabled
|
||||
mock_proj.disable_adapters = True
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
# Test with merged state
|
||||
mock_proj.disable_adapters = False
|
||||
mock_proj.merged = True
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
|
||||
def test_matmul_lora(sample_tensors):
|
||||
"""Tests matmul_lora function"""
|
||||
X = sample_tensors["X"]
|
||||
W = sample_tensors["W"]
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test base matmul
|
||||
out1 = matmul_lora(X, W, None, None, None, None)
|
||||
expected1 = torch.matmul(X, W.t())
|
||||
assert torch.allclose(out1, expected1, rtol=1e-3)
|
||||
|
||||
# Test with LoRA
|
||||
out2 = matmul_lora(X, W, None, A, B, scale)
|
||||
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
|
||||
expected2 = expected1 + lora_term
|
||||
assert torch.allclose(out2, expected2, rtol=1e-3)
|
||||
|
||||
# Test 3D input reshaping
|
||||
X_3d = X.clone()
|
||||
out3 = matmul_lora(X_3d, W, None, A, B, scale)
|
||||
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"activation_forward,activation_backward",
|
||||
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
|
||||
)
|
||||
def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward):
|
||||
"""Tests LoRA_MLP directly with different activation functions"""
|
||||
X = sample_tensors["X"]
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
|
||||
# Create linear layers
|
||||
gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
|
||||
up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
|
||||
down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test SwiGLU path
|
||||
X.requires_grad = True
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
gate_proj.weight,
|
||||
None, # gate_quant
|
||||
None, # gate_A
|
||||
None, # gate_B
|
||||
None, # gate_scale
|
||||
up_proj.weight,
|
||||
None, # up_quant
|
||||
None, # up_A
|
||||
None, # up_B
|
||||
None, # up_scale
|
||||
down_proj.weight,
|
||||
None, # down_quant
|
||||
None, # down_A
|
||||
None, # down_B
|
||||
None, # down_scale
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True, # inplace
|
||||
)
|
||||
|
||||
assert output.shape == X.shape
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
# Test backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert X.grad is not None
|
||||
assert not torch.isnan(X.grad).any()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"activation_forward,activation_backward",
|
||||
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
|
||||
)
|
||||
def test_lora_mlp_with_adapters(
|
||||
sample_tensors, activation_forward, activation_backward
|
||||
):
|
||||
"""Tests LoRA_MLP with LoRA adapters"""
|
||||
X = sample_tensors["X"]
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
# Create LoRA components
|
||||
gate_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
gate_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
up_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
up_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
down_A = torch.randn(rank, out_dim, device="cuda", dtype=torch.float16)
|
||||
down_B = torch.randn(hidden_dim, rank, device="cuda", dtype=torch.float16)
|
||||
scale = 0.5
|
||||
|
||||
gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
|
||||
up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
|
||||
down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16)
|
||||
|
||||
X.requires_grad = True
|
||||
gate_A.requires_grad = True
|
||||
gate_B.requires_grad = True
|
||||
up_A.requires_grad = True
|
||||
up_B.requires_grad = True
|
||||
down_A.requires_grad = True
|
||||
down_B.requires_grad = True
|
||||
|
||||
# Forward pass with adapters
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
gate_proj.weight,
|
||||
None,
|
||||
gate_A,
|
||||
gate_B,
|
||||
scale,
|
||||
up_proj.weight,
|
||||
None,
|
||||
up_A,
|
||||
up_B,
|
||||
scale,
|
||||
down_proj.weight,
|
||||
None,
|
||||
down_A,
|
||||
down_B,
|
||||
scale,
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True,
|
||||
)
|
||||
|
||||
assert output.shape == X.shape
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
# Test backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
# Check all gradients
|
||||
assert X.grad is not None
|
||||
assert gate_A.grad is not None
|
||||
assert gate_B.grad is not None
|
||||
assert up_A.grad is not None
|
||||
assert up_B.grad is not None
|
||||
assert down_A.grad is not None
|
||||
assert down_B.grad is not None
|
||||
|
||||
assert not torch.isnan(X.grad).any()
|
||||
assert not torch.isnan(gate_A.grad).any()
|
||||
assert not torch.isnan(gate_B.grad).any()
|
||||
assert not torch.isnan(up_A.grad).any()
|
||||
assert not torch.isnan(up_B.grad).any()
|
||||
assert not torch.isnan(down_A.grad).any()
|
||||
assert not torch.isnan(down_B.grad).any()
|
||||
|
||||
|
||||
def test_lora_qkv(sample_tensors):
|
||||
"""Tests LoRA QKV implementation with and without adapters"""
|
||||
X = sample_tensors["X"]
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
# Create base weights
|
||||
q_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
k_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
v_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Create LoRA matrices
|
||||
q_A = torch.randn(
|
||||
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
q_B = torch.randn(
|
||||
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
k_A = torch.randn(
|
||||
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
k_B = torch.randn(
|
||||
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
v_A = torch.randn(
|
||||
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
v_B = torch.randn(
|
||||
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
scale = 0.5
|
||||
|
||||
X.requires_grad = True
|
||||
|
||||
# Test without LoRA adapters
|
||||
Q1, K1, V1 = LoRA_QKV.apply(
|
||||
X,
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
assert Q1.shape == K1.shape == V1.shape == X.shape
|
||||
loss1 = (Q1 + K1 + V1).sum()
|
||||
loss1.backward()
|
||||
assert X.grad is not None
|
||||
|
||||
# Clear gradients
|
||||
X.grad = None
|
||||
|
||||
# Test with LoRA adapters
|
||||
Q2, K2, V2 = LoRA_QKV.apply(
|
||||
X,
|
||||
q_weight,
|
||||
None,
|
||||
q_A,
|
||||
q_B,
|
||||
scale,
|
||||
k_weight,
|
||||
None,
|
||||
k_A,
|
||||
k_B,
|
||||
scale,
|
||||
v_weight,
|
||||
None,
|
||||
v_A,
|
||||
v_B,
|
||||
scale,
|
||||
True,
|
||||
)
|
||||
|
||||
assert Q2.shape == K2.shape == V2.shape == X.shape
|
||||
loss2 = (Q2 + K2 + V2).sum()
|
||||
loss2.backward()
|
||||
|
||||
# Check gradients
|
||||
assert X.grad is not None
|
||||
assert q_A.grad is not None
|
||||
assert q_B.grad is not None
|
||||
assert k_A.grad is not None
|
||||
assert k_B.grad is not None
|
||||
assert v_A.grad is not None
|
||||
assert v_B.grad is not None
|
||||
|
||||
# Check for NaN values
|
||||
assert not torch.isnan(X.grad).any()
|
||||
assert not torch.isnan(q_A.grad).any()
|
||||
assert not torch.isnan(q_B.grad).any()
|
||||
assert not torch.isnan(k_A.grad).any()
|
||||
assert not torch.isnan(k_B.grad).any()
|
||||
assert not torch.isnan(v_A.grad).any()
|
||||
assert not torch.isnan(v_B.grad).any()
|
||||
|
||||
|
||||
def test_lora_o(sample_tensors):
|
||||
"""Tests LoRA output projection"""
|
||||
X = sample_tensors["X"]
|
||||
W = sample_tensors["W"]
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test forward pass
|
||||
X.requires_grad = True
|
||||
output = LoRA_O.apply(X, W, None, A, B, scale)
|
||||
|
||||
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
# Test backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert X.grad is not None
|
||||
|
||||
|
||||
def test_with_quantization(sample_tensors, mock_quantstate):
|
||||
"""Tests LoRA with quantized weights"""
|
||||
X = sample_tensors["X"] # [batch, seq, hidden]
|
||||
W = sample_tensors["W"] # [out, hidden]
|
||||
scale = 0.5
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test matmul with quantization
|
||||
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
|
||||
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
assert not torch.isnan(out).any()
|
||||
|
||||
# Test with different batch sizes
|
||||
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
|
||||
assert out2.shape == (4, 6, W.shape[0])
|
||||
assert not torch.isnan(out2).any()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch,seq,hidden,rank,out",
|
||||
[
|
||||
(1, 1, 32, 4, 64),
|
||||
(2, 3, 64, 8, 128),
|
||||
(4, 5, 128, 16, 256),
|
||||
],
|
||||
)
|
||||
def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
|
||||
"""Tests various input shapes and dimensions"""
|
||||
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
|
||||
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
|
||||
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
|
||||
scale = 0.5
|
||||
|
||||
result = matmul_lora(X, W, None, A, B, scale)
|
||||
assert result.shape == (batch, seq, out)
|
||||
|
||||
|
||||
def test_gradient_flow(sample_tensors):
|
||||
"""Tests gradient flow through LoRA layers"""
|
||||
X = sample_tensors["X"].clone()
|
||||
W = sample_tensors["W"].clone()
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
X.requires_grad = True
|
||||
A.requires_grad = True
|
||||
B.requires_grad = True
|
||||
|
||||
# Forward pass
|
||||
out = matmul_lora(X, W, None, A, B, scale)
|
||||
loss = out.sum()
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
assert X.grad is not None
|
||||
assert A.grad is not None
|
||||
assert B.grad is not None
|
||||
assert not torch.isnan(X.grad).any()
|
||||
assert not torch.isnan(A.grad).any()
|
||||
assert not torch.isnan(B.grad).any()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"apply_function",
|
||||
[apply_lora_mlp_swiglu, apply_lora_mlp_geglu],
|
||||
)
|
||||
def test_inplace_operations(sample_tensors, apply_function):
|
||||
"""Tests inplace operation behavior"""
|
||||
X = sample_tensors["X"]
|
||||
shapes = sample_tensors["shapes"]
|
||||
|
||||
# Create MLP with both inplace=True and inplace=False
|
||||
mlp = type(
|
||||
"MLPModule",
|
||||
(),
|
||||
{
|
||||
"gate_proj": nn.Linear(shapes["hidden"], shapes["out"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
"up_proj": nn.Linear(shapes["hidden"], shapes["out"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
out1 = apply_function(mlp, X.clone(), inplace=True)
|
||||
out2 = apply_function(mlp, X.clone(), inplace=False)
|
||||
|
||||
assert torch.allclose(out1, out2, rtol=1e-3)
|
||||
103
tests/e2e/kernels/test_quantize.py
Normal file
103
tests/e2e/kernels/test_quantize.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Tests for quantization utility functions."""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
from axolotl.kernels.quantize import dequantize
|
||||
|
||||
|
||||
def test_dequantize_null_state():
|
||||
"""Test that dequantize returns input unchanged when quant_state is None"""
|
||||
W = torch.randn(32, 32)
|
||||
assert torch.equal(dequantize(W, None), W)
|
||||
|
||||
|
||||
def test_dequantize_shape_preservation():
|
||||
"""Test that dequantization preserves expected shapes"""
|
||||
shape = (32, 32)
|
||||
W = torch.randn(shape, device="cuda")
|
||||
|
||||
quant_state = QuantState(
|
||||
absmax=torch.ones(shape[0], device="cuda"),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(shape[0], dtype=torch.int32, device="cuda"),
|
||||
state2=QuantState(
|
||||
absmax=torch.ones(shape[0], device="cuda"),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
),
|
||||
)
|
||||
|
||||
result = dequantize(W, quant_state)
|
||||
assert result.shape == shape
|
||||
assert result.dtype == torch.float16
|
||||
assert result.device == W.device
|
||||
|
||||
|
||||
def test_dequantize_transposed():
|
||||
"""Test that transposed input produces transposed output"""
|
||||
shape = (32, 32)
|
||||
W = torch.randn(1, shape[1], device="cuda") # Transposed input
|
||||
|
||||
quant_state = QuantState(
|
||||
absmax=torch.ones(1),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(1, dtype=torch.int32),
|
||||
state2=QuantState(
|
||||
absmax=torch.ones(1),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
),
|
||||
)
|
||||
|
||||
result = dequantize(W, quant_state)
|
||||
assert result.shape[0] == shape[0]
|
||||
|
||||
|
||||
def test_dequantize_output_tensor():
|
||||
"""Test dequantization with provided output tensor"""
|
||||
shape = (32, 32)
|
||||
W = torch.randn(shape, device="cuda")
|
||||
out = torch.empty(shape, dtype=torch.float16, device="cuda")
|
||||
|
||||
quant_state = QuantState(
|
||||
absmax=torch.ones(shape[0]),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(shape[0], dtype=torch.int32),
|
||||
state2=QuantState(
|
||||
absmax=torch.ones(shape[0]),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
),
|
||||
)
|
||||
|
||||
result = dequantize(W, quant_state, out=out)
|
||||
assert result is out
|
||||
78
tests/e2e/kernels/test_swiglu.py
Normal file
78
tests/e2e/kernels/test_swiglu.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Tests for SwiGLU activation function Triton kernels."""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
|
||||
|
||||
|
||||
def test_swiglu_forward_shape():
|
||||
"""Test that SwiGLU forward pass preserves expected shapes"""
|
||||
batch, seq_len, hidden_dim = 2, 3, 64
|
||||
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
|
||||
out = swiglu_forward(gate, up)
|
||||
assert out.shape == (batch, seq_len, hidden_dim)
|
||||
assert out.dtype == gate.dtype
|
||||
assert out.device == gate.device
|
||||
|
||||
|
||||
def test_swiglu_forward_values():
|
||||
"""Test SwiGLU forward pass matches PyTorch reference implementation"""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# Custom implementation
|
||||
triton_out = swiglu_forward(gate.clone(), up.clone())
|
||||
|
||||
# PyTorch reference
|
||||
torch_out = F.silu(gate) * up
|
||||
|
||||
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
||||
|
||||
|
||||
def test_swiglu_backward():
|
||||
"""Test SwiGLU backward pass matches PyTorch autograd"""
|
||||
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# PyTorch reference - compute intermediates
|
||||
silu_gate = F.silu(gate)
|
||||
torch_out = silu_gate * up
|
||||
torch_out.backward(grad_output)
|
||||
|
||||
# Custom backward pass
|
||||
gate_clone = gate.clone().detach()
|
||||
up_clone = up.clone().detach()
|
||||
grad_output_clone = grad_output.clone()
|
||||
|
||||
h, our_grad_gate, our_grad_up = swiglu_backward(
|
||||
grad_output_clone, gate_clone, up_clone
|
||||
)
|
||||
|
||||
# Compare outputs and gradients
|
||||
assert torch.allclose(h, torch_out, rtol=1e-3)
|
||||
assert torch.allclose(our_grad_gate, gate.grad, rtol=1e-3)
|
||||
assert torch.allclose(our_grad_up, up.grad, rtol=1e-3)
|
||||
|
||||
|
||||
def test_swiglu_inplace_preservation():
|
||||
"""Test that SwiGLU backward doesn't modify original tensors unexpectedly"""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
gate_copy = gate.clone()
|
||||
up_copy = up.clone()
|
||||
grad_copy = grad_output.clone()
|
||||
|
||||
swiglu_backward(grad_output, gate, up)
|
||||
|
||||
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
||||
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
||||
assert not torch.equal(
|
||||
grad_output, grad_copy
|
||||
), "Grad output should be modified in-place"
|
||||
0
tests/e2e/patched/lora_kernels/__init__.py
Normal file
0
tests/e2e/patched/lora_kernels/__init__.py
Normal file
414
tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py
Normal file
414
tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""Integration tests for LoRA activation and attention kernels."""
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.state import PartialState
|
||||
from peft import PeftModelForCausalLM, get_peft_config
|
||||
from transformers import AutoModelForCausalLM, LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
apply_lora_o,
|
||||
apply_lora_qkv,
|
||||
)
|
||||
from axolotl.monkeypatch.lora_kernels import (
|
||||
apply_lora_kernel_patches,
|
||||
patch_self_attn_lora,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
{
|
||||
"name": "openaccess-ai-collective/tiny-mistral",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
{
|
||||
"name": "Qwen/Qwen2-7B",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
{
|
||||
"name": "HuggingFaceTB/SmolLM2-135M",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float32,
|
||||
},
|
||||
{
|
||||
"name": "mhenrichsen/gemma-2b",
|
||||
"expected_activation": apply_lora_mlp_geglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def init_accelerate():
|
||||
"""Initialize Accelerate state before tests."""
|
||||
_ = PartialState()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def small_llama_model():
|
||||
"""Create a small LLaMA model for testing."""
|
||||
config = {
|
||||
"vocab_size": 100,
|
||||
"hidden_size": 128,
|
||||
"intermediate_size": 256,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
}
|
||||
|
||||
return LlamaForCausalLM(LlamaConfig(**config))
|
||||
|
||||
|
||||
def test_attention_patching_integration():
|
||||
"""Test attention patching in integration context."""
|
||||
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
|
||||
|
||||
# Store the original implementation
|
||||
original_forward = getattr(LlamaAttention, "forward")
|
||||
|
||||
# Apply patch
|
||||
patch_self_attn_lora(cfg)
|
||||
|
||||
# Get the new forward method
|
||||
patched_forward = LlamaAttention.forward
|
||||
|
||||
# Check the forward method was replaced
|
||||
assert original_forward is not patched_forward
|
||||
assert patched_forward.__name__ == "axolotl_attn_forward"
|
||||
|
||||
# Check original implementation was stored
|
||||
assert hasattr(LlamaAttention, "_original_forward")
|
||||
|
||||
# Clean up
|
||||
setattr(LlamaAttention, "forward", original_forward)
|
||||
delattr(LlamaAttention, "_original_forward")
|
||||
|
||||
|
||||
def test_swiglu_mlp_integration(small_llama_model):
|
||||
"""Test SwiGLU activation in LoRA MLP context."""
|
||||
peft_config = get_peft_config(
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "none",
|
||||
}
|
||||
)
|
||||
model = PeftModelForCausalLM(small_llama_model, peft_config).to("cuda")
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
|
||||
# Apply patches
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
|
||||
# Verify patches
|
||||
layer = patched_model.model.model.layers[0]
|
||||
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
|
||||
|
||||
# Test forward pass
|
||||
batch_size, seq_len = 2, 10
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, model.config.hidden_size, device=model.device
|
||||
)
|
||||
position_ids = (
|
||||
torch.arange(seq_len, device=model.device).unsqueeze(0).expand(batch_size, -1)
|
||||
)
|
||||
cos, sin = model.model.model.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
inputs = {
|
||||
"hidden_states": hidden_states,
|
||||
"attention_mask": None,
|
||||
"position_embeddings": (cos, sin),
|
||||
"output_attentions": False,
|
||||
"use_cache": False,
|
||||
"past_key_value": None,
|
||||
}
|
||||
|
||||
# Compare outputs
|
||||
with torch.no_grad():
|
||||
original_output = model.model.model.layers[0](**inputs)[0]
|
||||
patched_output = layer(**inputs)[0]
|
||||
|
||||
assert torch.allclose(original_output, patched_output, rtol=1e-4)
|
||||
|
||||
|
||||
def test_geglu_model_integration():
|
||||
"""Test GeGLU activation with Gemma model."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda"
|
||||
)
|
||||
peft_config = get_peft_config(
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "none",
|
||||
}
|
||||
)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
|
||||
# Verify patches
|
||||
layer = patched_model.model.model.layers[0]
|
||||
assert layer.mlp.forward.__func__ is apply_lora_mlp_geglu
|
||||
|
||||
# Test end-to-end
|
||||
inputs = torch.randint(0, 100, (1, 20), device=model.device, dtype=torch.long)
|
||||
with torch.no_grad():
|
||||
original_output = model(inputs).logits
|
||||
patched_output = patched_model(inputs).logits
|
||||
|
||||
assert torch.allclose(original_output, patched_output, rtol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,expected_activation",
|
||||
[
|
||||
("HuggingFaceTB/SmolLM2-135M", apply_lora_mlp_swiglu),
|
||||
("mhenrichsen/gemma-2b", apply_lora_mlp_geglu),
|
||||
],
|
||||
)
|
||||
def test_model_specific_activation(model_name, expected_activation):
|
||||
"""Test that each model type gets the correct activation function."""
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
peft_config = get_peft_config(
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "none",
|
||||
}
|
||||
)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
layer = patched_model.model.model.layers[0]
|
||||
assert layer.mlp.forward.__func__ is expected_activation
|
||||
|
||||
|
||||
def test_kernel_patch_conditions():
|
||||
"""Test various conditions that should prevent kernel patching."""
|
||||
test_configs = [
|
||||
# Dropout prevents patching
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0.1,
|
||||
"bias": "none",
|
||||
},
|
||||
# Bias prevents patching
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "lora_only",
|
||||
},
|
||||
]
|
||||
|
||||
for config in test_configs:
|
||||
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||
peft_config = get_peft_config(config)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
|
||||
# Should not patch
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
layer = patched_model.model.model.layers[0].mlp
|
||||
|
||||
# Verify no patches applied
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
||||
|
||||
|
||||
def test_kernel_config_options():
|
||||
"""Test that kernel configuration options are respected."""
|
||||
# Test different configurations
|
||||
test_configs = [
|
||||
(
|
||||
{"lora_mlp_kernel": True, "lora_qkv_kernel": False, "lora_o_kernel": False},
|
||||
lambda layer: (
|
||||
layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
|
||||
and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv
|
||||
and layer.self_attn.apply_o.__func__ is not apply_lora_o
|
||||
),
|
||||
),
|
||||
(
|
||||
{"lora_mlp_kernel": False, "lora_qkv_kernel": True, "lora_o_kernel": False},
|
||||
lambda layer: (
|
||||
layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu
|
||||
and layer.self_attn.apply_qkv.__func__ is apply_lora_qkv
|
||||
and layer.self_attn.apply_o.__func__ is not apply_lora_o
|
||||
),
|
||||
),
|
||||
(
|
||||
{"lora_mlp_kernel": False, "lora_qkv_kernel": False, "lora_o_kernel": True},
|
||||
lambda layer: (
|
||||
layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu
|
||||
and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv
|
||||
and layer.self_attn.apply_o.__func__ is apply_lora_o
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
for config_dict, check_fn in test_configs:
|
||||
# Create fresh model for each test
|
||||
config = {
|
||||
"vocab_size": 100,
|
||||
"hidden_size": 128,
|
||||
"intermediate_size": 256,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
}
|
||||
small_llama_model = LlamaForCausalLM(LlamaConfig(**config))
|
||||
|
||||
peft_config = get_peft_config(
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
],
|
||||
"lora_dropout": 0,
|
||||
"bias": "none",
|
||||
}
|
||||
)
|
||||
model = PeftModelForCausalLM(small_llama_model, peft_config).to("cuda")
|
||||
cfg = DictDefault(config_dict)
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
|
||||
# Verify only requested optimizations were applied
|
||||
for layer in patched_model.model.model.layers:
|
||||
assert check_fn(layer), f"Failed for config: {config_dict}"
|
||||
|
||||
# Clean up
|
||||
del model
|
||||
del small_llama_model
|
||||
del patched_model
|
||||
|
||||
|
||||
def get_lora_config():
|
||||
"""Get standard LoRA configuration for testing."""
|
||||
return {
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "none",
|
||||
}
|
||||
|
||||
|
||||
def get_test_inputs(model, seq_length=20):
|
||||
"""Generate test inputs for model evaluation."""
|
||||
return torch.randint(
|
||||
0,
|
||||
model.config.vocab_size,
|
||||
(1, seq_length),
|
||||
device=model.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_config", MODEL_CONFIGS)
|
||||
def test_model_architecture(model_config):
|
||||
"""Test LoRA kernel patches across different model architectures."""
|
||||
# Load model with appropriate dtype
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda"
|
||||
)
|
||||
|
||||
# Apply LoRA configuration
|
||||
peft_config = get_peft_config(get_lora_config())
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
|
||||
# Apply kernel patches
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
|
||||
# Verify correct activation function
|
||||
layer = patched_model.model.model.layers[0]
|
||||
assert (
|
||||
layer.mlp.forward.__func__ is model_config["expected_activation"]
|
||||
), f"Wrong activation for {model_config['name']}"
|
||||
|
||||
# Test forward pass
|
||||
inputs = get_test_inputs(model)
|
||||
with torch.no_grad():
|
||||
original_output = model(inputs).logits
|
||||
patched_output = patched_model(inputs).logits
|
||||
|
||||
# Check outputs match
|
||||
assert torch.allclose(
|
||||
original_output, patched_output, rtol=1e-4
|
||||
), f"Outputs don't match for {model_config['name']}"
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def test_kernel_training_integration():
|
||||
"""Test model loading with kernel patches enabled."""
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
|
||||
# Create minimal config
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Load model
|
||||
model, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# Verify correct activation function
|
||||
layer = model.model.model.layers[0]
|
||||
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
|
||||
Reference in New Issue
Block a user