* overwrite cache on preprocess step * don't cache the TokenizedPromptDataset at all * load_from_cache_file no longer needed
187 lines
7.0 KiB
Python
187 lines
7.0 KiB
Python
"""Module containing Dataset functionality"""
|
|
|
|
import logging
|
|
import os
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from datasets import Dataset, IterableDataset
|
|
|
|
from .prompt_tokenizers import PromptTokenizingStrategy
|
|
|
|
# We want this to be a wrapper for an existing dataset that we have loaded
|
|
# lets use the concept of middlewares to wrap each dataset, for example
|
|
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
|
|
# let's check to ensure we don't truncate an item in the middle, we'll use
|
|
# the collators later on to pad the datasets
|
|
|
|
LOG = logging.getLogger("axolotl")
|
|
|
|
|
|
class TokenizedPromptDataset(Dataset):
|
|
"""
|
|
Dataset that returns tokenized prompts from a stream of text files.
|
|
Args:
|
|
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
|
|
dataset (dataset.Dataset): Dataset with text files.
|
|
"""
|
|
|
|
def __init__( # pylint: disable=super-init-not-called
|
|
self,
|
|
prompt_tokenizer: PromptTokenizingStrategy,
|
|
dataset: IterableDataset,
|
|
process_count: Optional[int] = None,
|
|
**kwargs,
|
|
):
|
|
self.prompt_tokenizer = prompt_tokenizer
|
|
self.process_count = process_count
|
|
super().__init__(
|
|
self.process(dataset).data,
|
|
**kwargs,
|
|
)
|
|
|
|
def process(self, dataset):
|
|
features = dataset.features.keys()
|
|
num_proc = (
|
|
min(64, self.process_count)
|
|
if self.process_count
|
|
else min(64, os.cpu_count())
|
|
)
|
|
map_kwargs = {}
|
|
if self.prompt_tokenizer.supports_batched:
|
|
map_kwargs["batched"] = True
|
|
map_kwargs["batch_size"] = 100
|
|
return dataset.map(
|
|
self.prompt_tokenizer.tokenize_prompt,
|
|
num_proc=num_proc,
|
|
remove_columns=features,
|
|
keep_in_memory=True,
|
|
**map_kwargs,
|
|
)
|
|
|
|
|
|
# TODO this isn't the best since it can't interleave datasets
|
|
class ConstantLengthDataset(IterableDataset):
|
|
"""
|
|
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
|
Args:
|
|
tokenizer (Tokenizer): The processor used for processing the data.
|
|
dataset (dataset.Dataset): Dataset with text files.
|
|
seq_length (int): Length of token sequences to return.
|
|
"""
|
|
|
|
def __init__( # pylint: disable=super-init-not-called
|
|
self,
|
|
tokenizer,
|
|
datasets,
|
|
seq_length=2048,
|
|
):
|
|
self.tokenizer = tokenizer
|
|
self.concat_token_id = tokenizer.eos_token_id
|
|
self.datasets: List[IterableDataset] = datasets
|
|
self.seq_length = seq_length
|
|
|
|
vocab_size = len(tokenizer.get_vocab())
|
|
|
|
if vocab_size <= torch.iinfo(torch.int16).max:
|
|
self.tokens_dtype = torch.int16
|
|
elif vocab_size <= torch.iinfo(torch.int32).max:
|
|
self.tokens_dtype = torch.int32
|
|
else:
|
|
self.tokens_dtype = torch.int64
|
|
|
|
def __iter__(self):
|
|
buffer = {
|
|
"input_ids": [],
|
|
"attention_mask": [],
|
|
"labels": [],
|
|
"position_ids": [],
|
|
}
|
|
buffer_len = 0
|
|
for dataset in self.datasets:
|
|
idx = 0
|
|
iterator = iter(dataset)
|
|
more_examples = True
|
|
while more_examples:
|
|
try:
|
|
example = next(iterator)
|
|
idx += 1
|
|
except StopIteration:
|
|
more_examples = False
|
|
example = None
|
|
|
|
add_concat_token = False
|
|
if example:
|
|
example_len = len(example["input_ids"])
|
|
add_concat_token = example["input_ids"][-1] != self.concat_token_id
|
|
else:
|
|
example_len = 0
|
|
|
|
if not example_len or (
|
|
buffer_len + int(add_concat_token) + example_len > self.seq_length
|
|
):
|
|
if buffer["input_ids"]:
|
|
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
|
: self.seq_length
|
|
]
|
|
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
|
: self.seq_length
|
|
]
|
|
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
|
|
: self.seq_length
|
|
]
|
|
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
|
if labels.size() == input_ids.size() and (
|
|
attention_mask.size() == input_ids.size()
|
|
):
|
|
yield {
|
|
"input_ids": input_ids,
|
|
"labels": labels,
|
|
"attention_mask": attention_mask,
|
|
"position_ids": position_ids,
|
|
}
|
|
else:
|
|
LOG.warning(
|
|
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
|
)
|
|
buffer = {
|
|
"input_ids": [],
|
|
"attention_mask": [],
|
|
"labels": [],
|
|
"position_ids": [],
|
|
}
|
|
buffer_len = 0
|
|
idx = 1
|
|
|
|
if example:
|
|
# FIXME
|
|
# just going to drop data points that are too long
|
|
if len(example["input_ids"]) <= self.seq_length:
|
|
input_ids = example["input_ids"]
|
|
attention_mask = example["attention_mask"]
|
|
labels = example["labels"]
|
|
|
|
if add_concat_token:
|
|
input_ids.append(self.concat_token_id)
|
|
attention_mask.append(1)
|
|
labels.append(self.concat_token_id)
|
|
|
|
input_ids_with_concat = torch.tensor(
|
|
input_ids, dtype=self.tokens_dtype
|
|
)
|
|
attention_mask_with_concat = torch.tensor(
|
|
[idx * m for m in attention_mask], dtype=torch.int16
|
|
)
|
|
labels_with_concat = torch.tensor(
|
|
labels, dtype=self.tokens_dtype
|
|
)
|
|
position_ids = torch.arange(
|
|
len(input_ids), dtype=self.tokens_dtype
|
|
)
|
|
|
|
buffer["input_ids"].append(input_ids_with_concat)
|
|
buffer["attention_mask"].append(attention_mask_with_concat)
|
|
buffer["labels"].append(labels_with_concat)
|
|
buffer["position_ids"].append(position_ids)
|
|
buffer_len += len(input_ids)
|