async batching for multipack
This commit is contained in:
@@ -3,6 +3,8 @@ import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numba
|
||||
@@ -78,7 +80,6 @@ def allocate(
|
||||
s = 0
|
||||
start_index = 0
|
||||
result = []
|
||||
result_totseqs = []
|
||||
|
||||
while True:
|
||||
# binary search [left, right)
|
||||
@@ -104,10 +105,8 @@ def allocate(
|
||||
|
||||
# add local rank
|
||||
result.append(batch[rank])
|
||||
# add total seqs for all ranks
|
||||
result_totseqs.append(tot_seqs)
|
||||
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||
return result, result_totseqs, s, len(result) * c * n
|
||||
|
||||
yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||
|
||||
|
||||
def chunk(iterable, n):
|
||||
@@ -174,6 +173,11 @@ class MultipackDistributedDataloader:
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.device_count = device_count
|
||||
|
||||
# for non-blocking batch creation
|
||||
self.batch_queue: queue.Queue = queue.Queue(
|
||||
maxsize=10
|
||||
) # Adjust maxsize as needed
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
LOG.info("generating packed batches")
|
||||
if self.sampler:
|
||||
@@ -185,75 +189,76 @@ class MultipackDistributedDataloader:
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, totseqs, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
# c=self.batch_max_length,
|
||||
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||
n=self.num_replicas,
|
||||
alloc_iter = iter(
|
||||
allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
# c=self.batch_max_length,
|
||||
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||
n=self.num_replicas,
|
||||
)
|
||||
)
|
||||
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
for batch, tot_seqs, total_used, total_slots in alloc_iter:
|
||||
self.batch_queue.put([indices[b_idx] for b_idx in batch])
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used = total_used
|
||||
self.eff_total_slots = total_slots
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used += total_used
|
||||
self.eff_total_slots += total_slots
|
||||
|
||||
return batches, totseqs
|
||||
def _generate_batches_thread(self):
|
||||
try:
|
||||
self.generate_batches(set_stats=True)
|
||||
except Exception as e:
|
||||
LOG.error(f"Error in batch generation thread: {e}")
|
||||
self.batch_queue.put(
|
||||
None
|
||||
) # Signal the end of batch generation in case of error
|
||||
|
||||
def __iter__(self):
|
||||
if hasattr(self.sampler, "set_epoch"):
|
||||
new_epoch = self.sampler.epoch + 1
|
||||
self.sampler.set_epoch(new_epoch)
|
||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||
all_batches, _ = self.generate_batches(set_stats=True)
|
||||
# Start the batch generation in a separate thread
|
||||
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
|
||||
batch_gen_thread.start()
|
||||
|
||||
features = self.dataset.features.keys()
|
||||
len_remaining = self._len_est()
|
||||
for batches in chunk(
|
||||
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
||||
):
|
||||
while True:
|
||||
batch = self.batch_queue.get()
|
||||
if batch is None: # Sentinel value received, stop iteration
|
||||
break
|
||||
chunked_data = []
|
||||
attn_mask_cum_idx = 0
|
||||
for batch in batches:
|
||||
concatenated = {}
|
||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||
for feature in features:
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||
for idx, item in enumerate(batched_data)
|
||||
if feature in item
|
||||
]
|
||||
attn_mask_cum_idx += len(batched_data)
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature])
|
||||
for item in batched_data
|
||||
if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
chunked_data.append(concatenated)
|
||||
# num_chunks = int(
|
||||
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
|
||||
# )
|
||||
# chunked_data = []
|
||||
#
|
||||
# for i in range(num_chunks):
|
||||
# chunk = {
|
||||
# feature: array[
|
||||
# i * self.seq_max_length : (i + 1) * self.seq_max_length
|
||||
# ]
|
||||
# for feature, array in concatenated.items()
|
||||
# }
|
||||
# chunked_data.append(chunk)
|
||||
# yield self.collate_fn(chunked_data)
|
||||
concatenated = {}
|
||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||
for feature in features:
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||
for idx, item in enumerate(batched_data)
|
||||
if feature in item
|
||||
]
|
||||
attn_mask_cum_idx += len(batched_data)
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature])
|
||||
for item in batched_data
|
||||
if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
chunked_data.append(concatenated)
|
||||
|
||||
yield self.collate_fn(chunked_data)
|
||||
len_remaining -= 1
|
||||
if not len_remaining:
|
||||
return
|
||||
break
|
||||
# Wait for the batch generation thread to finish
|
||||
batch_gen_thread.join()
|
||||
|
||||
def _len_est(self):
|
||||
indices = range(0, len(self.dataset))
|
||||
|
||||
Reference in New Issue
Block a user