mark flaky geglu tests and add torch seed (#2876) [skip ci]

* mark flaky geglu tests and add torch seed

* restore accidental removal of seed
This commit is contained in:
Wing Lian
2025-07-07 15:24:16 -04:00
committed by GitHub
parent 9c0d7ee761
commit de2c5ba103

View File

@@ -19,8 +19,15 @@ def test_geglu_forward_shape():
assert out.device == gate.device
def test_geglu_forward_values():
@pytest.mark.flaky(retries=1, delay=5)
@pytest.mark.parametrize(
"torch_seed",
[0, 42],
)
def test_geglu_forward_values(torch_seed):
"""Test GEGLU forward pass matches PyTorch reference implementation."""
torch.manual_seed(torch_seed)
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
@@ -33,6 +40,7 @@ def test_geglu_forward_values():
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
@pytest.mark.flaky(retries=1, delay=5)
@pytest.mark.parametrize(
"torch_seed",
[0, 42],