* adding pre-commit auto-update GH action and bumping plugin versions * running updated pre-commit plugins * sorry to revert, but pylint complained * Update .pre-commit-config.yaml Co-authored-by: Wing Lian <wing.lian@gmail.com> --------- Co-authored-by: Dan Saunders <dan@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
78 lines
2.5 KiB
Python
78 lines
2.5 KiB
Python
"""Tests for GEGLU activation function Triton kernels."""
|
|
|
|
# pylint: disable=duplicate-code
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from axolotl.kernels.geglu import geglu_backward, geglu_forward
|
|
|
|
|
|
def test_geglu_forward_shape():
|
|
"""Test that GEGLU forward pass preserves expected shapes."""
|
|
batch, seq_len, hidden_dim = 2, 3, 64
|
|
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
|
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
|
|
|
out = geglu_forward(gate, up)
|
|
assert out.shape == (batch, seq_len, hidden_dim)
|
|
assert out.dtype == gate.dtype
|
|
assert out.device == gate.device
|
|
|
|
|
|
def test_geglu_forward_values():
|
|
"""Test GEGLU forward pass matches PyTorch reference implementation."""
|
|
gate = torch.randn(2, 3, 64, device="cuda")
|
|
up = torch.randn(2, 3, 64, device="cuda")
|
|
|
|
# Custom implementation
|
|
triton_out = geglu_forward(gate.clone(), up.clone())
|
|
|
|
# PyTorch reference
|
|
torch_out = F.gelu(gate) * up
|
|
|
|
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
|
|
|
|
|
def test_geglu_backward():
|
|
"""Test GEGLU backward pass matches PyTorch autograd."""
|
|
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
|
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
|
grad_output = torch.randn(2, 3, 64, device="cuda")
|
|
|
|
# PyTorch reference - compute intermediates
|
|
gelu_gate = F.gelu(gate)
|
|
torch_out = gelu_gate * up
|
|
torch_out.backward(grad_output)
|
|
|
|
# Custom backward pass
|
|
gate_clone = gate.clone().detach()
|
|
up_clone = up.clone().detach()
|
|
grad_output_clone = grad_output.clone()
|
|
|
|
h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone)
|
|
|
|
# Compare outputs and gradients
|
|
assert torch.allclose(h, torch_out, rtol=1e-3)
|
|
assert torch.allclose(grad_gate, gate.grad, rtol=1e-3)
|
|
assert torch.allclose(grad_up, up.grad, rtol=1e-3)
|
|
|
|
|
|
def test_geglu_inplace_preservation():
|
|
"""Test that GEGLU backward doesn't modify original tensors unexpectedly."""
|
|
gate = torch.randn(2, 3, 64, device="cuda")
|
|
up = torch.randn(2, 3, 64, device="cuda")
|
|
grad_output = torch.randn(2, 3, 64, device="cuda")
|
|
|
|
gate_copy = gate.clone()
|
|
up_copy = up.clone()
|
|
grad_copy = grad_output.clone()
|
|
|
|
geglu_backward(grad_output, gate, up)
|
|
|
|
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
|
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
|
assert not torch.equal(
|
|
grad_output, grad_copy
|
|
), "Grad output should be modified in-place"
|