Merge pull request #293 from NanoCode012/fix/tokenize-speed
Fix(tokenizing): Use multi-core
This commit is contained in:
@@ -1,12 +1,13 @@
|
|||||||
"""Module containing Dataset functionality"""
|
"""Module containing Dataset functionality"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import IterableDataset
|
from datasets import IterableDataset
|
||||||
|
|
||||||
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
|
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||||
|
|
||||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||||
# lets use the concept of middlewares to wrap each dataset, for example
|
# lets use the concept of middlewares to wrap each dataset, for example
|
||||||
@@ -34,17 +35,15 @@ class TokenizedPromptDataset(IterableDataset):
|
|||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
iterator = iter(self.dataset)
|
features = self.dataset.features.keys()
|
||||||
count = 0
|
num_proc = os.cpu_count()
|
||||||
# Loop through the entire dataset
|
return iter(
|
||||||
for example in iterator:
|
self.dataset.map(
|
||||||
try:
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
yield self.prompt_tokenizer.tokenize_prompt(example)
|
num_proc=num_proc,
|
||||||
count += 1
|
remove_columns=features,
|
||||||
except InvalidDataException:
|
)
|
||||||
pass
|
)
|
||||||
if count == 0:
|
|
||||||
raise RuntimeError("Expected at least one datapoint in dataset.")
|
|
||||||
|
|
||||||
|
|
||||||
# TODO this isn't the best since it can't interleave datasets
|
# TODO this isn't the best since it can't interleave datasets
|
||||||
|
|||||||
Reference in New Issue
Block a user