Threaded MultipackDistributedDataloader with prefetched samples (#759)

* Multithreading implementation [WIP]

* Added benchmarking

* 35% increased throughput

* Memory pinning

* Start threads in init

* Correct print of samples

* Sleep if queue is full

* Remove pin_memory (worse)

* Simplify logic to one thread

* Remove benchmark

* Use deque for constant speed

* Formatting

* Formatting

* Formatting

* Formatting

* Rollback to use queue

* Fix multi-epoch training

* Add num epochs arg

* Start thread in __iter__

* Formatting

* Use is_alive correctly

* Simplify loading thread
This commit is contained in:
Casper
2023-10-26 07:49:52 +02:00
committed by GitHub
parent 20aa4b57d2
commit 05bd6f1122
3 changed files with 51 additions and 6 deletions

View File

@@ -111,7 +111,8 @@ class AxolotlTrainer(Trainer):
args = None # type: AxolotlTrainingArguments args = None # type: AxolotlTrainingArguments
def __init__(self, *args, bench_data_collator=None, **kwargs): def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -182,6 +183,7 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=self.num_epochs,
) )
) )
return super().get_train_dataloader() return super().get_train_dataloader()
@@ -205,6 +207,7 @@ class AxolotlTrainer(Trainer):
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size, sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=self.num_epochs,
) )
) )
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
@@ -680,6 +683,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**data_collator_kwargs, **data_collator_kwargs,
), ),
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs,
**trainer_kwargs, **trainer_kwargs,
) )
trainer = self.hook_post_create_trainer(trainer) trainer = self.hook_post_create_trainer(trainer)

View File

@@ -3,6 +3,9 @@ import hashlib
import itertools import itertools
import logging import logging
import math import math
import time
from queue import Queue
from threading import Thread
from typing import Any, Callable, List, Union from typing import Any, Callable, List, Union
import numba import numba
@@ -149,6 +152,8 @@ class MultipackDistributedDataloader:
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1, sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1, device_count: int = 1,
prefetch_max: int = 1000,
num_epochs: int = 1,
): ):
# Dataset # Dataset
self.dataset = dataset self.dataset = dataset
@@ -167,6 +172,7 @@ class MultipackDistributedDataloader:
self.seq_max_length = seq_max_length self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.num_epochs = num_epochs
self.num_replicas = 1 self.num_replicas = 1
self.rank = 0 self.rank = 0
@@ -177,6 +183,44 @@ class MultipackDistributedDataloader:
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count self.device_count = device_count
# maxsize is maximum number of samples in queue
self.prefetch_max = prefetch_max
self.queue: Queue = Queue(maxsize=prefetch_max)
self.thread = None
def _worker(self):
LOG.info(
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
)
for epoch in range(self.num_epochs):
for sample in self._internal_batch_generator():
while True:
if self.queue.full():
time.sleep(1)
else:
break
self.queue.put(sample)
# stop the queue when epoch is done
self.queue.put(None)
def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
if self.thread is None:
self.thread = Thread(target=self._worker, daemon=True)
self.thread.start()
while True:
item = self.queue.get()
if item is None:
break
yield item
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
LOG.info("generating packed batches") LOG.info("generating packed batches")
if self.sampler: if self.sampler:
@@ -206,11 +250,7 @@ class MultipackDistributedDataloader:
return batches, totseqs return batches, totseqs
def __iter__(self): def _internal_batch_generator(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True) all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys() features = self.dataset.features.keys()
len_remaining = self._len_est() len_remaining = self._len_est()

View File

@@ -216,6 +216,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
packing_efficiency_estimate=cfg.sample_packing_eff_est, packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.micro_batch_size, sample_packing_seq_len_multiplier=cfg.micro_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=cfg.num_epochs,
) )
data_loader_len = data_loader.len_w_stats() data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency() actual_eff = data_loader.efficiency()