Merge pull request #293 from NanoCode012/fix/tokenize-speed
Fix(tokenizing): Use multi-core
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
"""Module containing Dataset functionality"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import IterableDataset
|
||||
|
||||
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||
# lets use the concept of middlewares to wrap each dataset, for example
|
||||
@@ -34,17 +35,15 @@ class TokenizedPromptDataset(IterableDataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.dataset)
|
||||
count = 0
|
||||
# Loop through the entire dataset
|
||||
for example in iterator:
|
||||
try:
|
||||
yield self.prompt_tokenizer.tokenize_prompt(example)
|
||||
count += 1
|
||||
except InvalidDataException:
|
||||
pass
|
||||
if count == 0:
|
||||
raise RuntimeError("Expected at least one datapoint in dataset.")
|
||||
features = self.dataset.features.keys()
|
||||
num_proc = os.cpu_count()
|
||||
return iter(
|
||||
self.dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# TODO this isn't the best since it can't interleave datasets
|
||||
|
||||
Reference in New Issue
Block a user