async batching for multipack

This commit is contained in:
Wing Lian
2023-08-10 18:28:15 -04:00
parent a07f432d9c
commit 2565c2f259

View File

@@ -3,6 +3,8 @@ import hashlib
import itertools
import logging
import math
import queue
import threading
from typing import Any, Callable, List, Union
import numba
@@ -78,7 +80,6 @@ def allocate(
s = 0
start_index = 0
result = []
result_totseqs = []
while True:
# binary search [left, right)
@@ -104,10 +105,8 @@ def allocate(
# add local rank
result.append(batch[rank])
# add total seqs for all ranks
result_totseqs.append(tot_seqs)
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n
yield batch[rank], tot_seqs, s, len(result) * c * n
def chunk(iterable, n):
@@ -174,6 +173,11 @@ class MultipackDistributedDataloader:
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count
# for non-blocking batch creation
self.batch_queue: queue.Queue = queue.Queue(
maxsize=10
) # Adjust maxsize as needed
def generate_batches(self, set_stats=False):
LOG.info("generating packed batches")
if self.sampler:
@@ -185,75 +189,76 @@ class MultipackDistributedDataloader:
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, totseqs, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
alloc_iter = iter(
allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
)
)
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
for batch, tot_seqs, total_used, total_slots in alloc_iter:
self.batch_queue.put([indices[b_idx] for b_idx in batch])
# statistics
if set_stats:
self.eff_total_used = total_used
self.eff_total_slots = total_slots
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches, totseqs
def _generate_batches_thread(self):
try:
self.generate_batches(set_stats=True)
except Exception as e:
LOG.error(f"Error in batch generation thread: {e}")
self.batch_queue.put(
None
) # Signal the end of batch generation in case of error
def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True)
# Start the batch generation in a separate thread
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
batch_gen_thread.start()
features = self.dataset.features.keys()
len_remaining = self._len_est()
for batches in chunk(
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
):
while True:
batch = self.batch_queue.get()
if batch is None: # Sentinel value received, stop iteration
break
chunked_data = []
attn_mask_cum_idx = 0
for batch in batches:
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature])
for item in batched_data
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
# num_chunks = int(
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
# )
# chunked_data = []
#
# for i in range(num_chunks):
# chunk = {
# feature: array[
# i * self.seq_max_length : (i + 1) * self.seq_max_length
# ]
# for feature, array in concatenated.items()
# }
# chunked_data.append(chunk)
# yield self.collate_fn(chunked_data)
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature])
for item in batched_data
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
yield self.collate_fn(chunked_data)
len_remaining -= 1
if not len_remaining:
return
break
# Wait for the batch generation thread to finish
batch_gen_thread.join()
def _len_est(self):
indices = range(0, len(self.dataset))