add manual seed for flaky test_geglu_backward test (#2763) [skip ci]

This commit is contained in:
Wing Lian
2025-06-05 09:23:17 -07:00
committed by GitHub
parent cb03c765a1
commit 7909bfb076

View File

@@ -1,7 +1,6 @@
"""Tests for GEGLU activation function Triton kernels."""
# pylint: disable=duplicate-code
import pytest
import torch
import torch.nn.functional as F
@@ -34,8 +33,14 @@ def test_geglu_forward_values():
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
def test_geglu_backward():
@pytest.mark.parametrize(
"torch_seed",
[0, 42],
)
def test_geglu_backward(torch_seed):
"""Test GEGLU backward pass matches PyTorch autograd."""
torch.manual_seed(torch_seed)
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
grad_output = torch.randn(2, 3, 64, device="cuda")