add manual seed for flaky test_geglu_backward test (#2763) [skip ci]
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
"""Tests for GEGLU activation function Triton kernels."""
|
"""Tests for GEGLU activation function Triton kernels."""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
import pytest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
@@ -34,8 +33,14 @@ def test_geglu_forward_values():
|
|||||||
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
def test_geglu_backward():
|
@pytest.mark.parametrize(
|
||||||
|
"torch_seed",
|
||||||
|
[0, 42],
|
||||||
|
)
|
||||||
|
def test_geglu_backward(torch_seed):
|
||||||
"""Test GEGLU backward pass matches PyTorch autograd."""
|
"""Test GEGLU backward pass matches PyTorch autograd."""
|
||||||
|
torch.manual_seed(torch_seed)
|
||||||
|
|
||||||
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||||
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||||
|
|||||||
Reference in New Issue
Block a user