make it work with pythia in the cloud
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import IterableDataset
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
|
||||
|
||||
|
||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||
@@ -23,7 +23,12 @@ class TokenizedPromptDataset(IterableDataset):
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.dataset)
|
||||
yield self.prompt_tokenizer.tokenize_prompt(next(iterator))
|
||||
# Loop through the entire dataset
|
||||
for example in iterator:
|
||||
try:
|
||||
yield self.prompt_tokenizer.tokenize_prompt(example)
|
||||
except InvalidDataException:
|
||||
pass
|
||||
|
||||
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
@@ -32,55 +37,68 @@ class ConstantLengthDataset(IterableDataset):
|
||||
Args:
|
||||
tokenizer (Tokenizer): The processor used for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
|
||||
seq_length (int): Length of token sequences to return.
|
||||
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
datasets,
|
||||
infinite=False,
|
||||
seq_length=2048,
|
||||
num_of_sequences=1024,
|
||||
chars_per_token=3.6,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
|
||||
self.concat_token_id = tokenizer.eos_token_id
|
||||
self.datasets: List[IterableDataset] = datasets
|
||||
self.seq_length = seq_length
|
||||
self.infinite = infinite
|
||||
self.current_size = 0
|
||||
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.datasets)
|
||||
more_examples = True
|
||||
while more_examples:
|
||||
buffer, buffer_len = [], 0
|
||||
while True:
|
||||
if buffer_len >= self.max_buffer_size:
|
||||
break
|
||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
buffer_len = 0
|
||||
for dataset in self.datasets:
|
||||
iterator = iter(dataset)
|
||||
more_examples = True
|
||||
while more_examples:
|
||||
try:
|
||||
buffer.append(next(iterator))
|
||||
buffer_len += len(buffer[-1])
|
||||
example = next(iterator)
|
||||
except StopIteration:
|
||||
if self.infinite:
|
||||
iterator = iter(self.datasets)
|
||||
else:
|
||||
more_examples = False
|
||||
break
|
||||
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
|
||||
all_token_ids = []
|
||||
for tokenized_input in tokenized_inputs:
|
||||
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
||||
for i in range(0, len(all_token_ids), self.seq_length):
|
||||
input_ids = all_token_ids[i : i + self.seq_length]
|
||||
if len(input_ids) == self.seq_length:
|
||||
self.current_size += 1
|
||||
yield {
|
||||
"input_ids": torch.LongTensor(input_ids),
|
||||
"labels": torch.LongTensor(input_ids),
|
||||
"attention_masks": torch.LongTensor(input_ids),
|
||||
}
|
||||
more_examples = False
|
||||
example = None
|
||||
|
||||
add_concat_token = False
|
||||
if example:
|
||||
example_len = len(example["input_ids"])
|
||||
add_concat_token = example["input_ids"][-1] != self.concat_token_id
|
||||
else:
|
||||
example_len = 0
|
||||
|
||||
if not example_len or buffer_len + int(add_concat_token) + example_len > self.seq_length:
|
||||
if buffer["input_ids"]:
|
||||
input_ids = torch.cat(buffer["input_ids"], dim=-1)[: self.seq_length]
|
||||
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[: self.seq_length]
|
||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||
yield {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
buffer_len = 0
|
||||
|
||||
if example:
|
||||
input_ids = example["input_ids"]
|
||||
attention_mask = example["attention_mask"]
|
||||
labels = example["labels"]
|
||||
|
||||
if add_concat_token:
|
||||
input_ids.append(self.concat_token_id)
|
||||
attention_mask.append(1)
|
||||
labels.append(self.concat_token_id)
|
||||
|
||||
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
|
||||
attention_mask_with_concat = torch.tensor(attention_mask, dtype=torch.long)
|
||||
labels_with_concat = torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
buffer["input_ids"].append(input_ids_with_concat)
|
||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||
buffer["labels"].append(labels_with_concat)
|
||||
buffer_len += len(input_ids)
|
||||
|
||||
Reference in New Issue
Block a user