diff --git a/tests/e2e/kernels/test_geglu.py b/tests/e2e/kernels/test_geglu.py index 90403ab4a..4094a8ce7 100644 --- a/tests/e2e/kernels/test_geglu.py +++ b/tests/e2e/kernels/test_geglu.py @@ -19,8 +19,15 @@ def test_geglu_forward_shape(): assert out.device == gate.device -def test_geglu_forward_values(): +@pytest.mark.flaky(retries=1, delay=5) +@pytest.mark.parametrize( + "torch_seed", + [0, 42], +) +def test_geglu_forward_values(torch_seed): """Test GEGLU forward pass matches PyTorch reference implementation.""" + torch.manual_seed(torch_seed) + gate = torch.randn(2, 3, 64, device="cuda") up = torch.randn(2, 3, 64, device="cuda") @@ -33,6 +40,7 @@ def test_geglu_forward_values(): assert torch.allclose(triton_out, torch_out, rtol=1e-3) +@pytest.mark.flaky(retries=1, delay=5) @pytest.mark.parametrize( "torch_seed", [0, 42],