fix streaming packing test (#2454)
* fix streaming packing test * constrain amount of text generated
This commit is contained in:
@@ -23,22 +23,26 @@ class TestPretrainingPacking:
|
|||||||
# seed with random.seed(0) for reproducibility
|
# seed with random.seed(0) for reproducibility
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
|
|
||||||
# generate 20 rows of random text with "words" of between 2 and 10 characters and
|
# generate row of random text with "words" of between 2 and 10 characters and
|
||||||
# between 400 to 1200 characters per line
|
# between 400 to 1200 characters per line
|
||||||
data = [
|
def rand_txt():
|
||||||
"".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10)))
|
return " ".join(
|
||||||
for _ in range(20)
|
[
|
||||||
] + [
|
"".join(
|
||||||
" ".join(
|
random.choices(string.ascii_lowercase, k=random.randint(2, 10))
|
||||||
random.choices(string.ascii_lowercase, k=random.randint(400, 1200))
|
)
|
||||||
|
for _ in range(random.randint(50, 200))
|
||||||
|
]
|
||||||
)
|
)
|
||||||
for _ in range(20)
|
|
||||||
]
|
# Create a list of 2000 random texts rather than just using it within the
|
||||||
|
# generator so the test runs faster
|
||||||
|
data = [rand_txt() for _ in range(500)]
|
||||||
|
|
||||||
# Create an IterableDataset
|
# Create an IterableDataset
|
||||||
def generator():
|
def generator():
|
||||||
for text in data:
|
for row in data:
|
||||||
yield {"text": text}
|
yield {"text": row}
|
||||||
|
|
||||||
return IterableDataset.from_generator(generator)
|
return IterableDataset.from_generator(generator)
|
||||||
|
|
||||||
@@ -92,7 +96,7 @@ class TestPretrainingPacking:
|
|||||||
)
|
)
|
||||||
idx = 0
|
idx = 0
|
||||||
for data in trainer_loader:
|
for data in trainer_loader:
|
||||||
if idx > 10:
|
if idx > 3:
|
||||||
break
|
break
|
||||||
assert data["input_ids"].shape == torch.Size(
|
assert data["input_ids"].shape == torch.Size(
|
||||||
[1, original_bsz * cfg.sequence_len]
|
[1, original_bsz * cfg.sequence_len]
|
||||||
|
|||||||
Reference in New Issue
Block a user