fix streaming packing test (#2454)

* fix streaming packing test

* constrain amount of text generated
This commit is contained in:
Wing Lian
2025-03-29 08:30:06 -04:00
committed by GitHub
parent c49682132b
commit 4ba80a0e5a

View File

@@ -23,22 +23,26 @@ class TestPretrainingPacking:
# seed with random.seed(0) for reproducibility # seed with random.seed(0) for reproducibility
random.seed(0) random.seed(0)
# generate 20 rows of random text with "words" of between 2 and 10 characters and # generate row of random text with "words" of between 2 and 10 characters and
# between 400 to 1200 characters per line # between 400 to 1200 characters per line
data = [ def rand_txt():
"".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10))) return " ".join(
for _ in range(20) [
] + [ "".join(
" ".join( random.choices(string.ascii_lowercase, k=random.randint(2, 10))
random.choices(string.ascii_lowercase, k=random.randint(400, 1200)) )
for _ in range(random.randint(50, 200))
]
) )
for _ in range(20)
] # Create a list of 2000 random texts rather than just using it within the
# generator so the test runs faster
data = [rand_txt() for _ in range(500)]
# Create an IterableDataset # Create an IterableDataset
def generator(): def generator():
for text in data: for row in data:
yield {"text": text} yield {"text": row}
return IterableDataset.from_generator(generator) return IterableDataset.from_generator(generator)
@@ -92,7 +96,7 @@ class TestPretrainingPacking:
) )
idx = 0 idx = 0
for data in trainer_loader: for data in trainer_loader:
if idx > 10: if idx > 3:
break break
assert data["input_ids"].shape == torch.Size( assert data["input_ids"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len] [1, original_bsz * cfg.sequence_len]