fix streaming packing test (#2454)

* fix streaming packing test

* constrain amount of text generated
This commit is contained in:
Wing Lian
2025-03-29 08:30:06 -04:00
committed by GitHub
parent c49682132b
commit 4ba80a0e5a

View File

@@ -23,22 +23,26 @@ class TestPretrainingPacking:
# seed with random.seed(0) for reproducibility
random.seed(0)
# generate 20 rows of random text with "words" of between 2 and 10 characters and
# generate row of random text with "words" of between 2 and 10 characters and
# between 400 to 1200 characters per line
data = [
"".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10)))
for _ in range(20)
] + [
" ".join(
random.choices(string.ascii_lowercase, k=random.randint(400, 1200))
def rand_txt():
return " ".join(
[
"".join(
random.choices(string.ascii_lowercase, k=random.randint(2, 10))
)
for _ in range(random.randint(50, 200))
]
)
for _ in range(20)
]
# Create a list of 2000 random texts rather than just using it within the
# generator so the test runs faster
data = [rand_txt() for _ in range(500)]
# Create an IterableDataset
def generator():
for text in data:
yield {"text": text}
for row in data:
yield {"text": row}
return IterableDataset.from_generator(generator)
@@ -92,7 +96,7 @@ class TestPretrainingPacking:
)
idx = 0
for data in trainer_loader:
if idx > 10:
if idx > 3:
break
assert data["input_ids"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]