mark flaky geglu tests and add torch seed (#2876) [skip ci]
* mark flaky geglu tests and add torch seed * restore accidental removal of seed
This commit is contained in:
@@ -19,8 +19,15 @@ def test_geglu_forward_shape():
|
|||||||
assert out.device == gate.device
|
assert out.device == gate.device
|
||||||
|
|
||||||
|
|
||||||
def test_geglu_forward_values():
|
@pytest.mark.flaky(retries=1, delay=5)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"torch_seed",
|
||||||
|
[0, 42],
|
||||||
|
)
|
||||||
|
def test_geglu_forward_values(torch_seed):
|
||||||
"""Test GEGLU forward pass matches PyTorch reference implementation."""
|
"""Test GEGLU forward pass matches PyTorch reference implementation."""
|
||||||
|
torch.manual_seed(torch_seed)
|
||||||
|
|
||||||
gate = torch.randn(2, 3, 64, device="cuda")
|
gate = torch.randn(2, 3, 64, device="cuda")
|
||||||
up = torch.randn(2, 3, 64, device="cuda")
|
up = torch.randn(2, 3, 64, device="cuda")
|
||||||
|
|
||||||
@@ -33,6 +40,7 @@ def test_geglu_forward_values():
|
|||||||
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(retries=1, delay=5)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"torch_seed",
|
"torch_seed",
|
||||||
[0, 42],
|
[0, 42],
|
||||||
|
|||||||
Reference in New Issue
Block a user