use offline for precached stream dataset (#2453)
This commit is contained in:
@@ -1,38 +1,50 @@
|
||||
"""Module for testing streaming dataset sequence packing"""
|
||||
|
||||
import functools
|
||||
import unittest
|
||||
import random
|
||||
import string
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import IterableDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
from utils import disable_hf_offline, enable_hf_offline
|
||||
|
||||
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestPretrainingPacking(unittest.TestCase):
|
||||
class TestPretrainingPacking:
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.pad_token = "</s>"
|
||||
@pytest.fixture
|
||||
def random_text(self):
|
||||
# seed with random.seed(0) for reproducibility
|
||||
random.seed(0)
|
||||
|
||||
# generate 20 rows of random text with "words" of between 2 and 10 characters and
|
||||
# between 400 to 1200 characters per line
|
||||
data = [
|
||||
"".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10)))
|
||||
for _ in range(20)
|
||||
] + [
|
||||
" ".join(
|
||||
random.choices(string.ascii_lowercase, k=random.randint(400, 1200))
|
||||
)
|
||||
for _ in range(20)
|
||||
]
|
||||
|
||||
# Create an IterableDataset
|
||||
def generator():
|
||||
for text in data:
|
||||
yield {"text": text}
|
||||
|
||||
return IterableDataset.from_generator(generator)
|
||||
|
||||
@pytest.mark.flaky(retries=1, delay=5)
|
||||
@disable_hf_offline
|
||||
def test_packing_stream_dataset(self):
|
||||
# pylint: disable=duplicate-code
|
||||
dataset = load_dataset(
|
||||
"winglian/tiny-shakespeare",
|
||||
streaming=True,
|
||||
)["train"]
|
||||
def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
|
||||
dataset = random_text
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -55,15 +67,16 @@ class TestPretrainingPacking(unittest.TestCase):
|
||||
ds_wrapper_partial = functools.partial(
|
||||
get_dataset_wrapper,
|
||||
cfg.pretraining_dataset[0],
|
||||
self.tokenizer,
|
||||
tokenizer_huggyllama,
|
||||
cfg,
|
||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
original_bsz = cfg.micro_batch_size
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
dataset,
|
||||
self.tokenizer,
|
||||
tokenizer_huggyllama,
|
||||
cfg,
|
||||
ds_wrapper_partial,
|
||||
max_tokens=cfg.sequence_len,
|
||||
@@ -96,7 +109,3 @@ class TestPretrainingPacking(unittest.TestCase):
|
||||
# [1, original_bsz * cfg.sequence_len]
|
||||
# )
|
||||
idx += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user