experimental expansion of ctx len

This commit is contained in:
Wing Lian
2023-05-31 16:51:19 -04:00
parent 71a43f8479
commit 488a67d75a
2 changed files with 57 additions and 19 deletions

View File

@@ -5,7 +5,8 @@ from hashlib import md5
from pathlib import Path
from typing import List, Tuple, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
import torch
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase
@@ -392,3 +393,32 @@ def load_prepare_datasets(
eval_dataset = dataset["test"]
return train_dataset, eval_dataset
class PretrainingDatasetWrapper(IterableDataset):
"""
Wrapper for pretraining dataset that avoids loading the dataset into memory
"""
def __init__(self, tokenizer, dataset_path, max_tokens=2048):
self.tokenizer = tokenizer
self.dataset_path = dataset_path
self.max_tokens = max_tokens
def __iter__(self):
buffer = []
for sample in load_dataset(
self.dataset_path,
name="all",
split="train",
streaming=True,
).shuffle(buffer_size=10000):
buffer += self.tokenizer(sample["text"])["input_ids"]
buffer += [self.tokenizer.eos_token_id]
while len(buffer) > self.max_tokens:
yield torch.tensor(buffer[: self.max_tokens])
buffer = buffer[self.max_tokens :]
def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens)