From de2c5ba103d4170539703634e9d8800e017794bf Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 7 Jul 2025 15:24:16 -0400 Subject: [PATCH] mark flaky geglu tests and add torch seed (#2876) [skip ci] * mark flaky geglu tests and add torch seed * restore accidental removal of seed --- tests/e2e/kernels/test_geglu.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/e2e/kernels/test_geglu.py b/tests/e2e/kernels/test_geglu.py index 90403ab4a..4094a8ce7 100644 --- a/tests/e2e/kernels/test_geglu.py +++ b/tests/e2e/kernels/test_geglu.py @@ -19,8 +19,15 @@ def test_geglu_forward_shape(): assert out.device == gate.device -def test_geglu_forward_values(): +@pytest.mark.flaky(retries=1, delay=5) +@pytest.mark.parametrize( + "torch_seed", + [0, 42], +) +def test_geglu_forward_values(torch_seed): """Test GEGLU forward pass matches PyTorch reference implementation.""" + torch.manual_seed(torch_seed) + gate = torch.randn(2, 3, 64, device="cuda") up = torch.randn(2, 3, 64, device="cuda") @@ -33,6 +40,7 @@ def test_geglu_forward_values(): assert torch.allclose(triton_out, torch_out, rtol=1e-3) +@pytest.mark.flaky(retries=1, delay=5) @pytest.mark.parametrize( "torch_seed", [0, 42],