Threaded MultipackDistributedDataloader with prefetched samples (#759)
* Multithreading implementation [WIP] * Added benchmarking * 35% increased throughput * Memory pinning * Start threads in init * Correct print of samples * Sleep if queue is full * Remove pin_memory (worse) * Simplify logic to one thread * Remove benchmark * Use deque for constant speed * Formatting * Formatting * Formatting * Formatting * Rollback to use queue * Fix multi-epoch training * Add num epochs arg * Start thread in __iter__ * Formatting * Use is_alive correctly * Simplify loading thread
This commit is contained in:
@@ -111,7 +111,8 @@ class AxolotlTrainer(Trainer):
|
|||||||
|
|
||||||
args = None # type: AxolotlTrainingArguments
|
args = None # type: AxolotlTrainingArguments
|
||||||
|
|
||||||
def __init__(self, *args, bench_data_collator=None, **kwargs):
|
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
||||||
|
self.num_epochs = num_epochs
|
||||||
self.bench_data_collator = bench_data_collator
|
self.bench_data_collator = bench_data_collator
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@@ -182,6 +183,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
|
num_epochs=self.num_epochs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
@@ -205,6 +207,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
|
num_epochs=self.num_epochs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
@@ -680,6 +683,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
),
|
),
|
||||||
callbacks=self.get_callbacks(),
|
callbacks=self.get_callbacks(),
|
||||||
|
num_epochs=self.cfg.num_epochs,
|
||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
trainer = self.hook_post_create_trainer(trainer)
|
trainer = self.hook_post_create_trainer(trainer)
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ import hashlib
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
|
from queue import Queue
|
||||||
|
from threading import Thread
|
||||||
from typing import Any, Callable, List, Union
|
from typing import Any, Callable, List, Union
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
@@ -149,6 +152,8 @@ class MultipackDistributedDataloader:
|
|||||||
packing_efficiency_estimate: float = 1.0,
|
packing_efficiency_estimate: float = 1.0,
|
||||||
sample_packing_seq_len_multiplier: int = 1,
|
sample_packing_seq_len_multiplier: int = 1,
|
||||||
device_count: int = 1,
|
device_count: int = 1,
|
||||||
|
prefetch_max: int = 1000,
|
||||||
|
num_epochs: int = 1,
|
||||||
):
|
):
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
@@ -167,6 +172,7 @@ class MultipackDistributedDataloader:
|
|||||||
self.seq_max_length = seq_max_length
|
self.seq_max_length = seq_max_length
|
||||||
self.batch_max_length = batch_size * seq_max_length
|
self.batch_max_length = batch_size * seq_max_length
|
||||||
self.collate_fn = collate_fn
|
self.collate_fn = collate_fn
|
||||||
|
self.num_epochs = num_epochs
|
||||||
|
|
||||||
self.num_replicas = 1
|
self.num_replicas = 1
|
||||||
self.rank = 0
|
self.rank = 0
|
||||||
@@ -177,6 +183,44 @@ class MultipackDistributedDataloader:
|
|||||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
self.device_count = device_count
|
self.device_count = device_count
|
||||||
|
|
||||||
|
# maxsize is maximum number of samples in queue
|
||||||
|
self.prefetch_max = prefetch_max
|
||||||
|
self.queue: Queue = Queue(maxsize=prefetch_max)
|
||||||
|
self.thread = None
|
||||||
|
|
||||||
|
def _worker(self):
|
||||||
|
LOG.info(
|
||||||
|
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
|
||||||
|
)
|
||||||
|
for epoch in range(self.num_epochs):
|
||||||
|
for sample in self._internal_batch_generator():
|
||||||
|
while True:
|
||||||
|
if self.queue.full():
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
self.queue.put(sample)
|
||||||
|
|
||||||
|
# stop the queue when epoch is done
|
||||||
|
self.queue.put(None)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if hasattr(self.sampler, "set_epoch"):
|
||||||
|
new_epoch = self.sampler.epoch + 1
|
||||||
|
self.sampler.set_epoch(new_epoch)
|
||||||
|
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||||
|
|
||||||
|
if self.thread is None:
|
||||||
|
self.thread = Thread(target=self._worker, daemon=True)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = self.queue.get()
|
||||||
|
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
yield item
|
||||||
|
|
||||||
def generate_batches(self, set_stats=False):
|
def generate_batches(self, set_stats=False):
|
||||||
LOG.info("generating packed batches")
|
LOG.info("generating packed batches")
|
||||||
if self.sampler:
|
if self.sampler:
|
||||||
@@ -206,11 +250,7 @@ class MultipackDistributedDataloader:
|
|||||||
|
|
||||||
return batches, totseqs
|
return batches, totseqs
|
||||||
|
|
||||||
def __iter__(self):
|
def _internal_batch_generator(self):
|
||||||
if hasattr(self.sampler, "set_epoch"):
|
|
||||||
new_epoch = self.sampler.epoch + 1
|
|
||||||
self.sampler.set_epoch(new_epoch)
|
|
||||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
|
||||||
all_batches, _ = self.generate_batches(set_stats=True)
|
all_batches, _ = self.generate_batches(set_stats=True)
|
||||||
features = self.dataset.features.keys()
|
features = self.dataset.features.keys()
|
||||||
len_remaining = self._len_est()
|
len_remaining = self._len_est()
|
||||||
|
|||||||
@@ -216,6 +216,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|||||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
|
num_epochs=cfg.num_epochs,
|
||||||
)
|
)
|
||||||
data_loader_len = data_loader.len_w_stats()
|
data_loader_len = data_loader.len_w_stats()
|
||||||
actual_eff = data_loader.efficiency()
|
actual_eff = data_loader.efficiency()
|
||||||
|
|||||||
Reference in New Issue
Block a user