diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index f783af9cc..115813df2 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -23,22 +23,26 @@ class TestPretrainingPacking: # seed with random.seed(0) for reproducibility random.seed(0) - # generate 20 rows of random text with "words" of between 2 and 10 characters and + # generate row of random text with "words" of between 2 and 10 characters and # between 400 to 1200 characters per line - data = [ - "".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10))) - for _ in range(20) - ] + [ - " ".join( - random.choices(string.ascii_lowercase, k=random.randint(400, 1200)) + def rand_txt(): + return " ".join( + [ + "".join( + random.choices(string.ascii_lowercase, k=random.randint(2, 10)) + ) + for _ in range(random.randint(50, 200)) + ] ) - for _ in range(20) - ] + + # Create a list of 2000 random texts rather than just using it within the + # generator so the test runs faster + data = [rand_txt() for _ in range(500)] # Create an IterableDataset def generator(): - for text in data: - yield {"text": text} + for row in data: + yield {"text": row} return IterableDataset.from_generator(generator) @@ -92,7 +96,7 @@ class TestPretrainingPacking: ) idx = 0 for data in trainer_loader: - if idx > 10: + if idx > 3: break assert data["input_ids"].shape == torch.Size( [1, original_bsz * cfg.sequence_len]