mark flaky geglu tests and add torch seed (#2876) [skip ci]
* mark flaky geglu tests and add torch seed * restore accidental removal of seed
This commit is contained in:
@@ -19,8 +19,15 @@ def test_geglu_forward_shape():
|
||||
assert out.device == gate.device
|
||||
|
||||
|
||||
def test_geglu_forward_values():
|
||||
@pytest.mark.flaky(retries=1, delay=5)
|
||||
@pytest.mark.parametrize(
|
||||
"torch_seed",
|
||||
[0, 42],
|
||||
)
|
||||
def test_geglu_forward_values(torch_seed):
|
||||
"""Test GEGLU forward pass matches PyTorch reference implementation."""
|
||||
torch.manual_seed(torch_seed)
|
||||
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
@@ -33,6 +40,7 @@ def test_geglu_forward_values():
|
||||
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=1, delay=5)
|
||||
@pytest.mark.parametrize(
|
||||
"torch_seed",
|
||||
[0, 42],
|
||||
|
||||
Reference in New Issue
Block a user