* LoRA + activation fn Triton kernels: initial commit * implementing optims * finalizing MLP LoRA kernels and progress on QKV / W kernels * updates * O projection optim * adding monkey patching logic * doc strings, typing, pre-commit fixes * updates * adding lora 8b kernels example * working on fsdp support * tests and fixes * small fixes, getting tests to pass, adding doc strings * integration tests for LoRA patching * config.qmd * remove unneeded pytest fixture * fix * review comments first pass * improving tests, attention class agnostic patching * adding support for more archs * wip SiLU / GELU impls * improved testing, small updates, etc. * slightly updating docs * rebase * fixing test_attention_patching_integration * additional review comments, fixing test in CI (hopefully) * isolating problematic patching test * relaxing allclose threshold to reduce flakiness * fixing accidental change * adding model arch agnostic attention class fetching * removing unused activations
79 lines
2.6 KiB
Python
79 lines
2.6 KiB
Python
"""Tests for SwiGLU activation function Triton kernels."""
|
|
# pylint: disable=duplicate-code
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
|
|
|
|
|
|
def test_swiglu_forward_shape():
|
|
"""Test that SwiGLU forward pass preserves expected shapes"""
|
|
batch, seq_len, hidden_dim = 2, 3, 64
|
|
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
|
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
|
|
|
out = swiglu_forward(gate, up)
|
|
assert out.shape == (batch, seq_len, hidden_dim)
|
|
assert out.dtype == gate.dtype
|
|
assert out.device == gate.device
|
|
|
|
|
|
def test_swiglu_forward_values():
|
|
"""Test SwiGLU forward pass matches PyTorch reference implementation"""
|
|
gate = torch.randn(2, 3, 64, device="cuda")
|
|
up = torch.randn(2, 3, 64, device="cuda")
|
|
|
|
# Custom implementation
|
|
triton_out = swiglu_forward(gate.clone(), up.clone())
|
|
|
|
# PyTorch reference
|
|
torch_out = F.silu(gate) * up
|
|
|
|
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
|
|
|
|
|
def test_swiglu_backward():
|
|
"""Test SwiGLU backward pass matches PyTorch autograd"""
|
|
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
|
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
|
grad_output = torch.randn(2, 3, 64, device="cuda")
|
|
|
|
# PyTorch reference - compute intermediates
|
|
silu_gate = F.silu(gate)
|
|
torch_out = silu_gate * up
|
|
torch_out.backward(grad_output)
|
|
|
|
# Custom backward pass
|
|
gate_clone = gate.clone().detach()
|
|
up_clone = up.clone().detach()
|
|
grad_output_clone = grad_output.clone()
|
|
|
|
h, our_grad_gate, our_grad_up = swiglu_backward(
|
|
grad_output_clone, gate_clone, up_clone
|
|
)
|
|
|
|
# Compare outputs and gradients
|
|
assert torch.allclose(h, torch_out, rtol=1e-3)
|
|
assert torch.allclose(our_grad_gate, gate.grad, rtol=1e-3)
|
|
assert torch.allclose(our_grad_up, up.grad, rtol=1e-3)
|
|
|
|
|
|
def test_swiglu_inplace_preservation():
|
|
"""Test that SwiGLU backward doesn't modify original tensors unexpectedly"""
|
|
gate = torch.randn(2, 3, 64, device="cuda")
|
|
up = torch.randn(2, 3, 64, device="cuda")
|
|
grad_output = torch.randn(2, 3, 64, device="cuda")
|
|
|
|
gate_copy = gate.clone()
|
|
up_copy = up.clone()
|
|
grad_copy = grad_output.clone()
|
|
|
|
swiglu_backward(grad_output, gate, up)
|
|
|
|
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
|
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
|
assert not torch.equal(
|
|
grad_output, grad_copy
|
|
), "Grad output should be modified in-place"
|