Files
axolotl/tests/test_packed_pretraining.py
Dan Saunders 231a67e70b Streaming SFT support (#3101)
* working

* fixes

* deprecate --iterable; cleanup

* pretrain_multipack_buffer_size -> streaming_multipack_buffer_size

* improvements

* tests

* remove unused

* docs, examples

* nit

* nit

* add val_set_size validation

* val

* nit

* min

* coderabbito

* cleanup

* nit

* add depr warning, cleanup

* nit

* fix test, fix quarto

* fix

* review comments

* review comments

* fix
2025-09-02 12:08:44 -04:00

112 lines
3.4 KiB
Python

"""Module for testing streaming dataset sequence packing"""
import functools
import random
import string
import pytest
import torch
from datasets import IterableDataset
from torch.utils.data import DataLoader
from axolotl.utils.data import get_dataset_wrapper, wrap_streaming_dataset
from axolotl.utils.dict import DictDefault
class TestPretrainingPacking:
"""
Test class for packing streaming dataset sequences
"""
@pytest.fixture
def random_text(self):
# seed with random.seed(0) for reproducibility
random.seed(0)
# generate row of random text with "words" of between 2 and 10 characters and
# between 400 to 1200 characters per line
def rand_txt():
return " ".join(
[
"".join(
random.choices(string.ascii_lowercase, k=random.randint(2, 10))
)
for _ in range(random.randint(50, 200))
]
)
# Create a list of 2000 random texts rather than just using it within the
# generator so the test runs faster
data = [rand_txt() for _ in range(500)]
# Create an IterableDataset
def generator():
for row in data:
yield {"text": row}
return IterableDataset.from_generator(generator)
@pytest.mark.flaky(retries=1, delay=5)
def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
dataset = random_text
cfg = DictDefault(
{
"pretraining_dataset": [
{
"path": "winglian/tiny-shakespeare",
"type": "pretrain",
}
],
"sample_packing": True,
"pretrain_multipack_attn": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"micro_batch_size": 2,
"sample_packing_group_size": 100000,
"sample_packing_bin_size": 200,
}
)
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
tokenizer_huggyllama,
cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
original_bsz = cfg.micro_batch_size
train_dataset = wrap_streaming_dataset(
dataset,
tokenizer_huggyllama,
cfg,
ds_wrapper_partial,
)
trainer_loader = DataLoader(
train_dataset,
batch_size=1,
collate_fn=None,
drop_last=True,
)
idx = 0
for data in trainer_loader:
if idx > 3:
break
assert data["input_ids"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
)
assert data["position_ids"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
)
assert data["labels"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
)
assert "attention_mask" not in data
# FIXME add back once we fix packing unpad/pad with attention mask
# assert data["attention_mask"].shape == torch.Size(
# [1, original_bsz * cfg.sequence_len]
# )
idx += 1