async batching for multipack

This commit is contained in:
Wing Lian
2023-08-10 18:28:15 -04:00
parent a07f432d9c
commit 2565c2f259

View File

@@ -3,6 +3,8 @@ import hashlib
import itertools import itertools
import logging import logging
import math import math
import queue
import threading
from typing import Any, Callable, List, Union from typing import Any, Callable, List, Union
import numba import numba
@@ -78,7 +80,6 @@ def allocate(
s = 0 s = 0
start_index = 0 start_index = 0
result = [] result = []
result_totseqs = []
while True: while True:
# binary search [left, right) # binary search [left, right)
@@ -104,10 +105,8 @@ def allocate(
# add local rank # add local rank
result.append(batch[rank]) result.append(batch[rank])
# add total seqs for all ranks
result_totseqs.append(tot_seqs) yield batch[rank], tot_seqs, s, len(result) * c * n
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n
def chunk(iterable, n): def chunk(iterable, n):
@@ -174,6 +173,11 @@ class MultipackDistributedDataloader:
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count self.device_count = device_count
# for non-blocking batch creation
self.batch_queue: queue.Queue = queue.Queue(
maxsize=10
) # Adjust maxsize as needed
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
LOG.info("generating packed batches") LOG.info("generating packed batches")
if self.sampler: if self.sampler:
@@ -185,75 +189,76 @@ class MultipackDistributedDataloader:
lengths = self.lengths[indices] lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths) lengths_cumsum = np.cumsum(lengths)
batches, totseqs, total_used, total_slots = allocate( alloc_iter = iter(
lengths=lengths, allocate(
lengths_cumsum=lengths_cumsum, lengths=lengths,
rank=self.rank, lengths_cumsum=lengths_cumsum,
# c=self.batch_max_length, rank=self.rank,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier, # c=self.batch_max_length,
n=self.num_replicas, c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
)
) )
batches = [[indices[b_idx] for b_idx in batch] for batch in batches] for batch, tot_seqs, total_used, total_slots in alloc_iter:
self.batch_queue.put([indices[b_idx] for b_idx in batch])
# statistics
if set_stats:
self.eff_total_used = total_used
self.eff_total_slots = total_slots
# statistics def _generate_batches_thread(self):
if set_stats: try:
self.eff_total_used += total_used self.generate_batches(set_stats=True)
self.eff_total_slots += total_slots except Exception as e:
LOG.error(f"Error in batch generation thread: {e}")
return batches, totseqs self.batch_queue.put(
None
) # Signal the end of batch generation in case of error
def __iter__(self): def __iter__(self):
if hasattr(self.sampler, "set_epoch"): if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1 new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch) self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})") LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True) # Start the batch generation in a separate thread
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
batch_gen_thread.start()
features = self.dataset.features.keys() features = self.dataset.features.keys()
len_remaining = self._len_est() len_remaining = self._len_est()
for batches in chunk( while True:
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier batch = self.batch_queue.get()
): if batch is None: # Sentinel value received, stop iteration
break
chunked_data = [] chunked_data = []
attn_mask_cum_idx = 0 attn_mask_cum_idx = 0
for batch in batches: concatenated = {}
concatenated = {} batched_data = [self.dataset[batch_idx] for batch_idx in batch]
batched_data = [self.dataset[batch_idx] for batch_idx in batch] for feature in features:
for feature in features: if feature == "attention_mask":
if feature == "attention_mask": arrays = [
arrays = [ (attn_mask_cum_idx + idx + 1) * np.array(item[feature])
(attn_mask_cum_idx + idx + 1) * np.array(item[feature]) for idx, item in enumerate(batched_data)
for idx, item in enumerate(batched_data) if feature in item
if feature in item ]
] attn_mask_cum_idx += len(batched_data)
attn_mask_cum_idx += len(batched_data) concatenated[feature] = np.concatenate(arrays)
concatenated[feature] = np.concatenate(arrays) else:
else: arrays = [
arrays = [ np.array(item[feature])
np.array(item[feature]) for item in batched_data
for item in batched_data if feature in item
if feature in item ]
] concatenated[feature] = np.concatenate(arrays)
concatenated[feature] = np.concatenate(arrays) chunked_data.append(concatenated)
chunked_data.append(concatenated)
# num_chunks = int(
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
# )
# chunked_data = []
#
# for i in range(num_chunks):
# chunk = {
# feature: array[
# i * self.seq_max_length : (i + 1) * self.seq_max_length
# ]
# for feature, array in concatenated.items()
# }
# chunked_data.append(chunk)
# yield self.collate_fn(chunked_data)
yield self.collate_fn(chunked_data) yield self.collate_fn(chunked_data)
len_remaining -= 1 len_remaining -= 1
if not len_remaining: if not len_remaining:
return break
# Wait for the batch generation thread to finish
batch_gen_thread.join()
def _len_est(self): def _len_est(self):
indices = range(0, len(self.dataset)) indices = range(0, len(self.dataset))