"""Tests for GEGLU activation function Triton kernels.""" import pytest import torch import torch.nn.functional as F from axolotl.kernels.geglu import geglu_backward, geglu_forward def test_geglu_forward_shape(): """Test that GEGLU forward pass preserves expected shapes.""" batch, seq_len, hidden_dim = 2, 3, 64 gate = torch.randn(batch, seq_len, hidden_dim, device="cuda") up = torch.randn(batch, seq_len, hidden_dim, device="cuda") out = geglu_forward(gate, up) assert out.shape == (batch, seq_len, hidden_dim) assert out.dtype == gate.dtype assert out.device == gate.device @pytest.mark.flaky(retries=1, delay=5) @pytest.mark.parametrize( "torch_seed", [0, 42], ) def test_geglu_forward_values(torch_seed): """Test GEGLU forward pass matches PyTorch reference implementation.""" torch.manual_seed(torch_seed) gate = torch.randn(2, 3, 64, device="cuda") up = torch.randn(2, 3, 64, device="cuda") # Custom implementation triton_out = geglu_forward(gate.clone(), up.clone()) # PyTorch reference torch_out = F.gelu(gate) * up assert torch.allclose(triton_out, torch_out, rtol=1e-3) @pytest.mark.flaky(retries=1, delay=5) @pytest.mark.parametrize( "torch_seed", [0, 42], ) def test_geglu_backward(torch_seed): """Test GEGLU backward pass matches PyTorch autograd.""" torch.manual_seed(torch_seed) gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True) up = torch.randn(2, 3, 64, device="cuda", requires_grad=True) grad_output = torch.randn(2, 3, 64, device="cuda") # PyTorch reference - compute intermediates gelu_gate = F.gelu(gate) torch_out = gelu_gate * up torch_out.backward(grad_output) # Custom backward pass gate_clone = gate.clone().detach() up_clone = up.clone().detach() grad_output_clone = grad_output.clone() h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone) # Compare outputs and gradients assert torch.allclose(h, torch_out, rtol=1e-3) assert torch.allclose(grad_gate, gate.grad, rtol=1e-3) assert torch.allclose(grad_up, up.grad, rtol=1e-3) def test_geglu_inplace_preservation(): """Test that GEGLU backward doesn't modify original tensors unexpectedly.""" gate = torch.randn(2, 3, 64, device="cuda") up = torch.randn(2, 3, 64, device="cuda") grad_output = torch.randn(2, 3, 64, device="cuda") gate_copy = gate.clone() up_copy = up.clone() grad_copy = grad_output.clone() geglu_backward(grad_output, gate, up) assert not torch.equal(gate, gate_copy), "Gate should be modified in-place" assert not torch.equal(up, up_copy), "Up should be modified in-place" assert not torch.equal(grad_output, grad_copy), ( "Grad output should be modified in-place" )