use offline for precached stream dataset (#2453)

This commit is contained in:
Wing Lian
2025-03-28 23:39:09 -04:00
committed by GitHub
parent e46239f8d3
commit c49682132b
8 changed files with 179 additions and 124 deletions

View File

@@ -1,38 +1,50 @@
"""Module for testing streaming dataset sequence packing"""
import functools
import unittest
import random
import string
import pytest
import torch
from datasets import load_dataset
from datasets import IterableDataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from utils import disable_hf_offline, enable_hf_offline
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.dict import DictDefault
class TestPretrainingPacking(unittest.TestCase):
class TestPretrainingPacking:
"""
Test class for packing streaming dataset sequences
"""
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.pad_token = "</s>"
@pytest.fixture
def random_text(self):
# seed with random.seed(0) for reproducibility
random.seed(0)
# generate 20 rows of random text with "words" of between 2 and 10 characters and
# between 400 to 1200 characters per line
data = [
"".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10)))
for _ in range(20)
] + [
" ".join(
random.choices(string.ascii_lowercase, k=random.randint(400, 1200))
)
for _ in range(20)
]
# Create an IterableDataset
def generator():
for text in data:
yield {"text": text}
return IterableDataset.from_generator(generator)
@pytest.mark.flaky(retries=1, delay=5)
@disable_hf_offline
def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
dataset = load_dataset(
"winglian/tiny-shakespeare",
streaming=True,
)["train"]
def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
dataset = random_text
cfg = DictDefault(
{
@@ -55,15 +67,16 @@ class TestPretrainingPacking(unittest.TestCase):
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
self.tokenizer,
tokenizer_huggyllama,
cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
# pylint: disable=duplicate-code
original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset(
dataset,
self.tokenizer,
tokenizer_huggyllama,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
@@ -96,7 +109,3 @@ class TestPretrainingPacking(unittest.TestCase):
# [1, original_bsz * cfg.sequence_len]
# )
idx += 1
if __name__ == "__main__":
unittest.main()