Activation function Triton kernels, LoRA custom autograd functions (#2324)
* LoRA + activation fn Triton kernels: initial commit * implementing optims * finalizing MLP LoRA kernels and progress on QKV / W kernels * updates * O projection optim * adding monkey patching logic * doc strings, typing, pre-commit fixes * updates * adding lora 8b kernels example * working on fsdp support * tests and fixes * small fixes, getting tests to pass, adding doc strings * integration tests for LoRA patching * config.qmd * remove unneeded pytest fixture * fix * review comments first pass * improving tests, attention class agnostic patching * adding support for more archs * wip SiLU / GELU impls * improved testing, small updates, etc. * slightly updating docs * rebase * fixing test_attention_patching_integration * additional review comments, fixing test in CI (hopefully) * isolating problematic patching test * relaxing allclose threshold to reduce flakiness * fixing accidental change * adding model arch agnostic attention class fetching * removing unused activations
This commit is contained in:
76
tests/e2e/kernels/test_geglu.py
Normal file
76
tests/e2e/kernels/test_geglu.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Tests for GEGLU activation function Triton kernels."""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.kernels.geglu import geglu_backward, geglu_forward
|
||||
|
||||
|
||||
def test_geglu_forward_shape():
|
||||
"""Test that GEGLU forward pass preserves expected shapes."""
|
||||
batch, seq_len, hidden_dim = 2, 3, 64
|
||||
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
|
||||
out = geglu_forward(gate, up)
|
||||
assert out.shape == (batch, seq_len, hidden_dim)
|
||||
assert out.dtype == gate.dtype
|
||||
assert out.device == gate.device
|
||||
|
||||
|
||||
def test_geglu_forward_values():
|
||||
"""Test GEGLU forward pass matches PyTorch reference implementation."""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# Custom implementation
|
||||
triton_out = geglu_forward(gate.clone(), up.clone())
|
||||
|
||||
# PyTorch reference
|
||||
torch_out = F.gelu(gate) * up
|
||||
|
||||
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
||||
|
||||
|
||||
def test_geglu_backward():
|
||||
"""Test GEGLU backward pass matches PyTorch autograd."""
|
||||
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# PyTorch reference - compute intermediates
|
||||
gelu_gate = F.gelu(gate)
|
||||
torch_out = gelu_gate * up
|
||||
torch_out.backward(grad_output)
|
||||
|
||||
# Custom backward pass
|
||||
gate_clone = gate.clone().detach()
|
||||
up_clone = up.clone().detach()
|
||||
grad_output_clone = grad_output.clone()
|
||||
|
||||
h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone)
|
||||
|
||||
# Compare outputs and gradients
|
||||
assert torch.allclose(h, torch_out, rtol=1e-3)
|
||||
assert torch.allclose(grad_gate, gate.grad, rtol=1e-3)
|
||||
assert torch.allclose(grad_up, up.grad, rtol=1e-3)
|
||||
|
||||
|
||||
def test_geglu_inplace_preservation():
|
||||
"""Test that GEGLU backward doesn't modify original tensors unexpectedly."""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
gate_copy = gate.clone()
|
||||
up_copy = up.clone()
|
||||
grad_copy = grad_output.clone()
|
||||
|
||||
geglu_backward(grad_output, gate, up)
|
||||
|
||||
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
||||
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
||||
assert not torch.equal(
|
||||
grad_output, grad_copy
|
||||
), "Grad output should be modified in-place"
|
||||
531
tests/e2e/kernels/test_lora.py
Normal file
531
tests/e2e/kernels/test_lora.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""Tests for LoRA custom autograd."""
|
||||
# pylint: disable=invalid-name,redefined-outer-name
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
from torch import nn
|
||||
|
||||
from axolotl.kernels.geglu import geglu_backward, geglu_forward
|
||||
from axolotl.kernels.lora import (
|
||||
LoRA_MLP,
|
||||
LoRA_O,
|
||||
LoRA_QKV,
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
get_lora_parameters,
|
||||
matmul_lora,
|
||||
)
|
||||
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quantstate():
|
||||
"""Creates a mock QuantState for testing"""
|
||||
shape = (64, 64)
|
||||
n_blocks = shape[0] # Assuming blockwise quantization along first dimension
|
||||
|
||||
# Create nested state first
|
||||
nested_state = QuantState(
|
||||
absmax=torch.ones(n_blocks, device="cuda"), # One value per block
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"), # NF4 range is 0-15
|
||||
dtype=torch.float16,
|
||||
blocksize=64,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
)
|
||||
|
||||
# Create main state with nested state
|
||||
return QuantState(
|
||||
absmax=torch.ones(n_blocks, device="cuda"),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"),
|
||||
dtype=torch.float16,
|
||||
blocksize=64,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(n_blocks, dtype=torch.int32, device="cuda"),
|
||||
state2=nested_state,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensors():
|
||||
"""Creates sample tensors for testing"""
|
||||
torch.manual_seed(42)
|
||||
batch_size, seq_len, hidden_dim = 2, 3, 64
|
||||
rank = 8
|
||||
out_dim = hidden_dim
|
||||
|
||||
return {
|
||||
"X": torch.randn(
|
||||
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
|
||||
),
|
||||
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
|
||||
"scale": 0.5,
|
||||
"shapes": {
|
||||
"batch": batch_size,
|
||||
"seq": seq_len,
|
||||
"hidden": hidden_dim,
|
||||
"out": out_dim,
|
||||
"rank": rank,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_proj():
|
||||
"""Creates a mock projection module for testing."""
|
||||
|
||||
class MockProj(nn.Module):
|
||||
"""Mock projection class."""
|
||||
|
||||
def __init__(self, in_features=64, out_features=128, rank=8):
|
||||
super().__init__()
|
||||
self.base_layer = nn.Linear(in_features, out_features)
|
||||
self.base_layer.to("cuda")
|
||||
self.lora_A = nn.ModuleDict(
|
||||
{"default": nn.Linear(in_features, rank, bias=False).to("cuda")}
|
||||
)
|
||||
self.lora_B = nn.ModuleDict(
|
||||
{"default": nn.Linear(rank, out_features, bias=False).to("cuda")}
|
||||
)
|
||||
self.scaling = {"default": 0.5}
|
||||
self.active_adapter = "default"
|
||||
self.disable_adapters = False
|
||||
self.merged = False
|
||||
|
||||
return MockProj()
|
||||
|
||||
|
||||
def test_get_lora_parameters(mock_proj):
|
||||
"""Tests get_lora_parameters function"""
|
||||
# Test with LoRA enabled
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
|
||||
assert isinstance(W, torch.Tensor)
|
||||
assert W.shape == (128, 64)
|
||||
assert A.shape == (8, 64)
|
||||
assert B.shape == (128, 8)
|
||||
assert s == 0.5
|
||||
|
||||
# Test with LoRA disabled
|
||||
mock_proj.disable_adapters = True
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
# Test with merged state
|
||||
mock_proj.disable_adapters = False
|
||||
mock_proj.merged = True
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
|
||||
def test_matmul_lora(sample_tensors):
|
||||
"""Tests matmul_lora function"""
|
||||
X = sample_tensors["X"]
|
||||
W = sample_tensors["W"]
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test base matmul
|
||||
out1 = matmul_lora(X, W, None, None, None, None)
|
||||
expected1 = torch.matmul(X, W.t())
|
||||
assert torch.allclose(out1, expected1, rtol=1e-3)
|
||||
|
||||
# Test with LoRA
|
||||
out2 = matmul_lora(X, W, None, A, B, scale)
|
||||
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
|
||||
expected2 = expected1 + lora_term
|
||||
assert torch.allclose(out2, expected2, rtol=1e-3)
|
||||
|
||||
# Test 3D input reshaping
|
||||
X_3d = X.clone()
|
||||
out3 = matmul_lora(X_3d, W, None, A, B, scale)
|
||||
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"activation_forward,activation_backward",
|
||||
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
|
||||
)
|
||||
def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward):
|
||||
"""Tests LoRA_MLP directly with different activation functions"""
|
||||
X = sample_tensors["X"]
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
|
||||
# Create linear layers
|
||||
gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
|
||||
up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
|
||||
down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test SwiGLU path
|
||||
X.requires_grad = True
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
gate_proj.weight,
|
||||
None, # gate_quant
|
||||
None, # gate_A
|
||||
None, # gate_B
|
||||
None, # gate_scale
|
||||
up_proj.weight,
|
||||
None, # up_quant
|
||||
None, # up_A
|
||||
None, # up_B
|
||||
None, # up_scale
|
||||
down_proj.weight,
|
||||
None, # down_quant
|
||||
None, # down_A
|
||||
None, # down_B
|
||||
None, # down_scale
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True, # inplace
|
||||
)
|
||||
|
||||
assert output.shape == X.shape
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
# Test backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert X.grad is not None
|
||||
assert not torch.isnan(X.grad).any()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"activation_forward,activation_backward",
|
||||
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
|
||||
)
|
||||
def test_lora_mlp_with_adapters(
|
||||
sample_tensors, activation_forward, activation_backward
|
||||
):
|
||||
"""Tests LoRA_MLP with LoRA adapters"""
|
||||
X = sample_tensors["X"]
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
# Create LoRA components
|
||||
gate_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
gate_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
up_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
up_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
down_A = torch.randn(rank, out_dim, device="cuda", dtype=torch.float16)
|
||||
down_B = torch.randn(hidden_dim, rank, device="cuda", dtype=torch.float16)
|
||||
scale = 0.5
|
||||
|
||||
gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
|
||||
up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
|
||||
down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16)
|
||||
|
||||
X.requires_grad = True
|
||||
gate_A.requires_grad = True
|
||||
gate_B.requires_grad = True
|
||||
up_A.requires_grad = True
|
||||
up_B.requires_grad = True
|
||||
down_A.requires_grad = True
|
||||
down_B.requires_grad = True
|
||||
|
||||
# Forward pass with adapters
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
gate_proj.weight,
|
||||
None,
|
||||
gate_A,
|
||||
gate_B,
|
||||
scale,
|
||||
up_proj.weight,
|
||||
None,
|
||||
up_A,
|
||||
up_B,
|
||||
scale,
|
||||
down_proj.weight,
|
||||
None,
|
||||
down_A,
|
||||
down_B,
|
||||
scale,
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True,
|
||||
)
|
||||
|
||||
assert output.shape == X.shape
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
# Test backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
# Check all gradients
|
||||
assert X.grad is not None
|
||||
assert gate_A.grad is not None
|
||||
assert gate_B.grad is not None
|
||||
assert up_A.grad is not None
|
||||
assert up_B.grad is not None
|
||||
assert down_A.grad is not None
|
||||
assert down_B.grad is not None
|
||||
|
||||
assert not torch.isnan(X.grad).any()
|
||||
assert not torch.isnan(gate_A.grad).any()
|
||||
assert not torch.isnan(gate_B.grad).any()
|
||||
assert not torch.isnan(up_A.grad).any()
|
||||
assert not torch.isnan(up_B.grad).any()
|
||||
assert not torch.isnan(down_A.grad).any()
|
||||
assert not torch.isnan(down_B.grad).any()
|
||||
|
||||
|
||||
def test_lora_qkv(sample_tensors):
|
||||
"""Tests LoRA QKV implementation with and without adapters"""
|
||||
X = sample_tensors["X"]
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
# Create base weights
|
||||
q_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
k_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
v_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Create LoRA matrices
|
||||
q_A = torch.randn(
|
||||
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
q_B = torch.randn(
|
||||
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
k_A = torch.randn(
|
||||
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
k_B = torch.randn(
|
||||
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
v_A = torch.randn(
|
||||
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
v_B = torch.randn(
|
||||
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
scale = 0.5
|
||||
|
||||
X.requires_grad = True
|
||||
|
||||
# Test without LoRA adapters
|
||||
Q1, K1, V1 = LoRA_QKV.apply(
|
||||
X,
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
assert Q1.shape == K1.shape == V1.shape == X.shape
|
||||
loss1 = (Q1 + K1 + V1).sum()
|
||||
loss1.backward()
|
||||
assert X.grad is not None
|
||||
|
||||
# Clear gradients
|
||||
X.grad = None
|
||||
|
||||
# Test with LoRA adapters
|
||||
Q2, K2, V2 = LoRA_QKV.apply(
|
||||
X,
|
||||
q_weight,
|
||||
None,
|
||||
q_A,
|
||||
q_B,
|
||||
scale,
|
||||
k_weight,
|
||||
None,
|
||||
k_A,
|
||||
k_B,
|
||||
scale,
|
||||
v_weight,
|
||||
None,
|
||||
v_A,
|
||||
v_B,
|
||||
scale,
|
||||
True,
|
||||
)
|
||||
|
||||
assert Q2.shape == K2.shape == V2.shape == X.shape
|
||||
loss2 = (Q2 + K2 + V2).sum()
|
||||
loss2.backward()
|
||||
|
||||
# Check gradients
|
||||
assert X.grad is not None
|
||||
assert q_A.grad is not None
|
||||
assert q_B.grad is not None
|
||||
assert k_A.grad is not None
|
||||
assert k_B.grad is not None
|
||||
assert v_A.grad is not None
|
||||
assert v_B.grad is not None
|
||||
|
||||
# Check for NaN values
|
||||
assert not torch.isnan(X.grad).any()
|
||||
assert not torch.isnan(q_A.grad).any()
|
||||
assert not torch.isnan(q_B.grad).any()
|
||||
assert not torch.isnan(k_A.grad).any()
|
||||
assert not torch.isnan(k_B.grad).any()
|
||||
assert not torch.isnan(v_A.grad).any()
|
||||
assert not torch.isnan(v_B.grad).any()
|
||||
|
||||
|
||||
def test_lora_o(sample_tensors):
|
||||
"""Tests LoRA output projection"""
|
||||
X = sample_tensors["X"]
|
||||
W = sample_tensors["W"]
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test forward pass
|
||||
X.requires_grad = True
|
||||
output = LoRA_O.apply(X, W, None, A, B, scale)
|
||||
|
||||
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
# Test backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert X.grad is not None
|
||||
|
||||
|
||||
def test_with_quantization(sample_tensors, mock_quantstate):
|
||||
"""Tests LoRA with quantized weights"""
|
||||
X = sample_tensors["X"] # [batch, seq, hidden]
|
||||
W = sample_tensors["W"] # [out, hidden]
|
||||
scale = 0.5
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test matmul with quantization
|
||||
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
|
||||
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
assert not torch.isnan(out).any()
|
||||
|
||||
# Test with different batch sizes
|
||||
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
|
||||
assert out2.shape == (4, 6, W.shape[0])
|
||||
assert not torch.isnan(out2).any()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch,seq,hidden,rank,out",
|
||||
[
|
||||
(1, 1, 32, 4, 64),
|
||||
(2, 3, 64, 8, 128),
|
||||
(4, 5, 128, 16, 256),
|
||||
],
|
||||
)
|
||||
def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
|
||||
"""Tests various input shapes and dimensions"""
|
||||
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
|
||||
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
|
||||
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
|
||||
scale = 0.5
|
||||
|
||||
result = matmul_lora(X, W, None, A, B, scale)
|
||||
assert result.shape == (batch, seq, out)
|
||||
|
||||
|
||||
def test_gradient_flow(sample_tensors):
|
||||
"""Tests gradient flow through LoRA layers"""
|
||||
X = sample_tensors["X"].clone()
|
||||
W = sample_tensors["W"].clone()
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
X.requires_grad = True
|
||||
A.requires_grad = True
|
||||
B.requires_grad = True
|
||||
|
||||
# Forward pass
|
||||
out = matmul_lora(X, W, None, A, B, scale)
|
||||
loss = out.sum()
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
assert X.grad is not None
|
||||
assert A.grad is not None
|
||||
assert B.grad is not None
|
||||
assert not torch.isnan(X.grad).any()
|
||||
assert not torch.isnan(A.grad).any()
|
||||
assert not torch.isnan(B.grad).any()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"apply_function",
|
||||
[apply_lora_mlp_swiglu, apply_lora_mlp_geglu],
|
||||
)
|
||||
def test_inplace_operations(sample_tensors, apply_function):
|
||||
"""Tests inplace operation behavior"""
|
||||
X = sample_tensors["X"]
|
||||
shapes = sample_tensors["shapes"]
|
||||
|
||||
# Create MLP with both inplace=True and inplace=False
|
||||
mlp = type(
|
||||
"MLPModule",
|
||||
(),
|
||||
{
|
||||
"gate_proj": nn.Linear(shapes["hidden"], shapes["out"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
"up_proj": nn.Linear(shapes["hidden"], shapes["out"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
out1 = apply_function(mlp, X.clone(), inplace=True)
|
||||
out2 = apply_function(mlp, X.clone(), inplace=False)
|
||||
|
||||
assert torch.allclose(out1, out2, rtol=1e-3)
|
||||
103
tests/e2e/kernels/test_quantize.py
Normal file
103
tests/e2e/kernels/test_quantize.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Tests for quantization utility functions."""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
from axolotl.kernels.quantize import dequantize
|
||||
|
||||
|
||||
def test_dequantize_null_state():
|
||||
"""Test that dequantize returns input unchanged when quant_state is None"""
|
||||
W = torch.randn(32, 32)
|
||||
assert torch.equal(dequantize(W, None), W)
|
||||
|
||||
|
||||
def test_dequantize_shape_preservation():
|
||||
"""Test that dequantization preserves expected shapes"""
|
||||
shape = (32, 32)
|
||||
W = torch.randn(shape, device="cuda")
|
||||
|
||||
quant_state = QuantState(
|
||||
absmax=torch.ones(shape[0], device="cuda"),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(shape[0], dtype=torch.int32, device="cuda"),
|
||||
state2=QuantState(
|
||||
absmax=torch.ones(shape[0], device="cuda"),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
),
|
||||
)
|
||||
|
||||
result = dequantize(W, quant_state)
|
||||
assert result.shape == shape
|
||||
assert result.dtype == torch.float16
|
||||
assert result.device == W.device
|
||||
|
||||
|
||||
def test_dequantize_transposed():
|
||||
"""Test that transposed input produces transposed output"""
|
||||
shape = (32, 32)
|
||||
W = torch.randn(1, shape[1], device="cuda") # Transposed input
|
||||
|
||||
quant_state = QuantState(
|
||||
absmax=torch.ones(1),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(1, dtype=torch.int32),
|
||||
state2=QuantState(
|
||||
absmax=torch.ones(1),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
),
|
||||
)
|
||||
|
||||
result = dequantize(W, quant_state)
|
||||
assert result.shape[0] == shape[0]
|
||||
|
||||
|
||||
def test_dequantize_output_tensor():
|
||||
"""Test dequantization with provided output tensor"""
|
||||
shape = (32, 32)
|
||||
W = torch.randn(shape, device="cuda")
|
||||
out = torch.empty(shape, dtype=torch.float16, device="cuda")
|
||||
|
||||
quant_state = QuantState(
|
||||
absmax=torch.ones(shape[0]),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(shape[0], dtype=torch.int32),
|
||||
state2=QuantState(
|
||||
absmax=torch.ones(shape[0]),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
),
|
||||
)
|
||||
|
||||
result = dequantize(W, quant_state, out=out)
|
||||
assert result is out
|
||||
78
tests/e2e/kernels/test_swiglu.py
Normal file
78
tests/e2e/kernels/test_swiglu.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Tests for SwiGLU activation function Triton kernels."""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
|
||||
|
||||
|
||||
def test_swiglu_forward_shape():
|
||||
"""Test that SwiGLU forward pass preserves expected shapes"""
|
||||
batch, seq_len, hidden_dim = 2, 3, 64
|
||||
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
|
||||
out = swiglu_forward(gate, up)
|
||||
assert out.shape == (batch, seq_len, hidden_dim)
|
||||
assert out.dtype == gate.dtype
|
||||
assert out.device == gate.device
|
||||
|
||||
|
||||
def test_swiglu_forward_values():
|
||||
"""Test SwiGLU forward pass matches PyTorch reference implementation"""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# Custom implementation
|
||||
triton_out = swiglu_forward(gate.clone(), up.clone())
|
||||
|
||||
# PyTorch reference
|
||||
torch_out = F.silu(gate) * up
|
||||
|
||||
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
||||
|
||||
|
||||
def test_swiglu_backward():
|
||||
"""Test SwiGLU backward pass matches PyTorch autograd"""
|
||||
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# PyTorch reference - compute intermediates
|
||||
silu_gate = F.silu(gate)
|
||||
torch_out = silu_gate * up
|
||||
torch_out.backward(grad_output)
|
||||
|
||||
# Custom backward pass
|
||||
gate_clone = gate.clone().detach()
|
||||
up_clone = up.clone().detach()
|
||||
grad_output_clone = grad_output.clone()
|
||||
|
||||
h, our_grad_gate, our_grad_up = swiglu_backward(
|
||||
grad_output_clone, gate_clone, up_clone
|
||||
)
|
||||
|
||||
# Compare outputs and gradients
|
||||
assert torch.allclose(h, torch_out, rtol=1e-3)
|
||||
assert torch.allclose(our_grad_gate, gate.grad, rtol=1e-3)
|
||||
assert torch.allclose(our_grad_up, up.grad, rtol=1e-3)
|
||||
|
||||
|
||||
def test_swiglu_inplace_preservation():
|
||||
"""Test that SwiGLU backward doesn't modify original tensors unexpectedly"""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
gate_copy = gate.clone()
|
||||
up_copy = up.clone()
|
||||
grad_copy = grad_output.clone()
|
||||
|
||||
swiglu_backward(grad_output, gate, up)
|
||||
|
||||
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
||||
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
||||
assert not torch.equal(
|
||||
grad_output, grad_copy
|
||||
), "Grad output should be modified in-place"
|
||||
Reference in New Issue
Block a user