diff --git a/requirements-tests.txt b/requirements-tests.txt index 9cda381d0..7a34809da 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -1,2 +1,3 @@ pytest pytest-xdist +pytest-retry diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index 5d517585f..fbb776aa5 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -2,6 +2,7 @@ import functools import unittest +import pytest import torch from datasets import load_dataset from torch.utils.data import DataLoader @@ -21,6 +22,7 @@ class TestPretrainingPacking(unittest.TestCase): self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer.pad_token = "" + @pytest.mark.flaky(retries=3, delay=5) def test_packing_stream_dataset(self): # pylint: disable=duplicate-code dataset = load_dataset(