fix streaming packing test (#2454)
* fix streaming packing test * constrain amount of text generated
This commit is contained in:
@@ -23,22 +23,26 @@ class TestPretrainingPacking:
|
||||
# seed with random.seed(0) for reproducibility
|
||||
random.seed(0)
|
||||
|
||||
# generate 20 rows of random text with "words" of between 2 and 10 characters and
|
||||
# generate row of random text with "words" of between 2 and 10 characters and
|
||||
# between 400 to 1200 characters per line
|
||||
data = [
|
||||
"".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10)))
|
||||
for _ in range(20)
|
||||
] + [
|
||||
" ".join(
|
||||
random.choices(string.ascii_lowercase, k=random.randint(400, 1200))
|
||||
def rand_txt():
|
||||
return " ".join(
|
||||
[
|
||||
"".join(
|
||||
random.choices(string.ascii_lowercase, k=random.randint(2, 10))
|
||||
)
|
||||
for _ in range(random.randint(50, 200))
|
||||
]
|
||||
)
|
||||
for _ in range(20)
|
||||
]
|
||||
|
||||
# Create a list of 2000 random texts rather than just using it within the
|
||||
# generator so the test runs faster
|
||||
data = [rand_txt() for _ in range(500)]
|
||||
|
||||
# Create an IterableDataset
|
||||
def generator():
|
||||
for text in data:
|
||||
yield {"text": text}
|
||||
for row in data:
|
||||
yield {"text": row}
|
||||
|
||||
return IterableDataset.from_generator(generator)
|
||||
|
||||
@@ -92,7 +96,7 @@ class TestPretrainingPacking:
|
||||
)
|
||||
idx = 0
|
||||
for data in trainer_loader:
|
||||
if idx > 10:
|
||||
if idx > 3:
|
||||
break
|
||||
assert data["input_ids"].shape == torch.Size(
|
||||
[1, original_bsz * cfg.sequence_len]
|
||||
|
||||
Reference in New Issue
Block a user