Merge pull request #293 from NanoCode012/fix/tokenize-speed

Fix(tokenizing): Use multi-core
This commit is contained in:
NanoCode012
2023-07-19 11:02:04 +09:00
committed by GitHub

View File

@@ -1,12 +1,13 @@
"""Module containing Dataset functionality""" """Module containing Dataset functionality"""
import logging import logging
import os
from typing import List from typing import List
import torch import torch
from datasets import IterableDataset from datasets import IterableDataset
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded # We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example # lets use the concept of middlewares to wrap each dataset, for example
@@ -34,17 +35,15 @@ class TokenizedPromptDataset(IterableDataset):
self.dataset = dataset self.dataset = dataset
def __iter__(self): def __iter__(self):
iterator = iter(self.dataset) features = self.dataset.features.keys()
count = 0 num_proc = os.cpu_count()
# Loop through the entire dataset return iter(
for example in iterator: self.dataset.map(
try: self.prompt_tokenizer.tokenize_prompt,
yield self.prompt_tokenizer.tokenize_prompt(example) num_proc=num_proc,
count += 1 remove_columns=features,
except InvalidDataException: )
pass )
if count == 0:
raise RuntimeError("Expected at least one datapoint in dataset.")
# TODO this isn't the best since it can't interleave datasets # TODO this isn't the best since it can't interleave datasets