Files
axolotl/tests/e2e/kernels/test_geglu.py
Dan Saunders 79ddaebe9a Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
2025-08-23 23:37:33 -04:00

91 lines
2.8 KiB
Python

"""Tests for GEGLU activation function Triton kernels."""
import pytest
import torch
import torch.nn.functional as F
from axolotl.kernels.geglu import geglu_backward, geglu_forward
def test_geglu_forward_shape():
"""Test that GEGLU forward pass preserves expected shapes."""
batch, seq_len, hidden_dim = 2, 3, 64
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
out = geglu_forward(gate, up)
assert out.shape == (batch, seq_len, hidden_dim)
assert out.dtype == gate.dtype
assert out.device == gate.device
@pytest.mark.flaky(retries=1, delay=5)
@pytest.mark.parametrize(
"torch_seed",
[0, 42],
)
def test_geglu_forward_values(torch_seed):
"""Test GEGLU forward pass matches PyTorch reference implementation."""
torch.manual_seed(torch_seed)
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
# Custom implementation
triton_out = geglu_forward(gate.clone(), up.clone())
# PyTorch reference
torch_out = F.gelu(gate) * up
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
@pytest.mark.flaky(retries=1, delay=5)
@pytest.mark.parametrize(
"torch_seed",
[0, 42],
)
def test_geglu_backward(torch_seed):
"""Test GEGLU backward pass matches PyTorch autograd."""
torch.manual_seed(torch_seed)
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
grad_output = torch.randn(2, 3, 64, device="cuda")
# PyTorch reference - compute intermediates
gelu_gate = F.gelu(gate)
torch_out = gelu_gate * up
torch_out.backward(grad_output)
# Custom backward pass
gate_clone = gate.clone().detach()
up_clone = up.clone().detach()
grad_output_clone = grad_output.clone()
h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone)
# Compare outputs and gradients
assert torch.allclose(h, torch_out, rtol=1e-3)
assert torch.allclose(grad_gate, gate.grad, rtol=1e-3)
assert torch.allclose(grad_up, up.grad, rtol=1e-3)
def test_geglu_inplace_preservation():
"""Test that GEGLU backward doesn't modify original tensors unexpectedly."""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
grad_output = torch.randn(2, 3, 64, device="cuda")
gate_copy = gate.clone()
up_copy = up.clone()
grad_copy = grad_output.clone()
geglu_backward(grad_output, gate, up)
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
assert not torch.equal(up, up_copy), "Up should be modified in-place"
assert not torch.equal(grad_output, grad_copy), (
"Grad output should be modified in-place"
)