* working * fixes * deprecate --iterable; cleanup * pretrain_multipack_buffer_size -> streaming_multipack_buffer_size * improvements * tests * remove unused * docs, examples * nit * nit * add val_set_size validation * val * nit * min * coderabbito * cleanup * nit * add depr warning, cleanup * nit * fix test, fix quarto * fix * review comments * review comments * fix
112 lines
3.4 KiB
Python
112 lines
3.4 KiB
Python
"""Module for testing streaming dataset sequence packing"""
|
|
|
|
import functools
|
|
import random
|
|
import string
|
|
|
|
import pytest
|
|
import torch
|
|
from datasets import IterableDataset
|
|
from torch.utils.data import DataLoader
|
|
|
|
from axolotl.utils.data import get_dataset_wrapper, wrap_streaming_dataset
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
class TestPretrainingPacking:
|
|
"""
|
|
Test class for packing streaming dataset sequences
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def random_text(self):
|
|
# seed with random.seed(0) for reproducibility
|
|
random.seed(0)
|
|
|
|
# generate row of random text with "words" of between 2 and 10 characters and
|
|
# between 400 to 1200 characters per line
|
|
def rand_txt():
|
|
return " ".join(
|
|
[
|
|
"".join(
|
|
random.choices(string.ascii_lowercase, k=random.randint(2, 10))
|
|
)
|
|
for _ in range(random.randint(50, 200))
|
|
]
|
|
)
|
|
|
|
# Create a list of 2000 random texts rather than just using it within the
|
|
# generator so the test runs faster
|
|
data = [rand_txt() for _ in range(500)]
|
|
|
|
# Create an IterableDataset
|
|
def generator():
|
|
for row in data:
|
|
yield {"text": row}
|
|
|
|
return IterableDataset.from_generator(generator)
|
|
|
|
@pytest.mark.flaky(retries=1, delay=5)
|
|
def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
|
|
dataset = random_text
|
|
|
|
cfg = DictDefault(
|
|
{
|
|
"pretraining_dataset": [
|
|
{
|
|
"path": "winglian/tiny-shakespeare",
|
|
"type": "pretrain",
|
|
}
|
|
],
|
|
"sample_packing": True,
|
|
"pretrain_multipack_attn": True,
|
|
"pad_to_sequence_len": True,
|
|
"sequence_len": 2048,
|
|
"micro_batch_size": 2,
|
|
"sample_packing_group_size": 100000,
|
|
"sample_packing_bin_size": 200,
|
|
}
|
|
)
|
|
|
|
ds_wrapper_partial = functools.partial(
|
|
get_dataset_wrapper,
|
|
cfg.pretraining_dataset[0],
|
|
tokenizer_huggyllama,
|
|
cfg,
|
|
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
|
)
|
|
|
|
original_bsz = cfg.micro_batch_size
|
|
train_dataset = wrap_streaming_dataset(
|
|
dataset,
|
|
tokenizer_huggyllama,
|
|
cfg,
|
|
ds_wrapper_partial,
|
|
)
|
|
|
|
trainer_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=1,
|
|
collate_fn=None,
|
|
drop_last=True,
|
|
)
|
|
idx = 0
|
|
for data in trainer_loader:
|
|
if idx > 3:
|
|
break
|
|
assert data["input_ids"].shape == torch.Size(
|
|
[1, original_bsz * cfg.sequence_len]
|
|
)
|
|
assert data["position_ids"].shape == torch.Size(
|
|
[1, original_bsz * cfg.sequence_len]
|
|
)
|
|
assert data["labels"].shape == torch.Size(
|
|
[1, original_bsz * cfg.sequence_len]
|
|
)
|
|
assert "attention_mask" not in data
|
|
# FIXME add back once we fix packing unpad/pad with attention mask
|
|
# assert data["attention_mask"].shape == torch.Size(
|
|
# [1, original_bsz * cfg.sequence_len]
|
|
# )
|
|
idx += 1
|