async batching for multipack
This commit is contained in:
@@ -3,6 +3,8 @@ import hashlib
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
from typing import Any, Callable, List, Union
|
from typing import Any, Callable, List, Union
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
@@ -78,7 +80,6 @@ def allocate(
|
|||||||
s = 0
|
s = 0
|
||||||
start_index = 0
|
start_index = 0
|
||||||
result = []
|
result = []
|
||||||
result_totseqs = []
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# binary search [left, right)
|
# binary search [left, right)
|
||||||
@@ -104,10 +105,8 @@ def allocate(
|
|||||||
|
|
||||||
# add local rank
|
# add local rank
|
||||||
result.append(batch[rank])
|
result.append(batch[rank])
|
||||||
# add total seqs for all ranks
|
|
||||||
result_totseqs.append(tot_seqs)
|
yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||||
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
|
||||||
return result, result_totseqs, s, len(result) * c * n
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(iterable, n):
|
def chunk(iterable, n):
|
||||||
@@ -174,6 +173,11 @@ class MultipackDistributedDataloader:
|
|||||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
self.device_count = device_count
|
self.device_count = device_count
|
||||||
|
|
||||||
|
# for non-blocking batch creation
|
||||||
|
self.batch_queue: queue.Queue = queue.Queue(
|
||||||
|
maxsize=10
|
||||||
|
) # Adjust maxsize as needed
|
||||||
|
|
||||||
def generate_batches(self, set_stats=False):
|
def generate_batches(self, set_stats=False):
|
||||||
LOG.info("generating packed batches")
|
LOG.info("generating packed batches")
|
||||||
if self.sampler:
|
if self.sampler:
|
||||||
@@ -185,75 +189,76 @@ class MultipackDistributedDataloader:
|
|||||||
lengths = self.lengths[indices]
|
lengths = self.lengths[indices]
|
||||||
lengths_cumsum = np.cumsum(lengths)
|
lengths_cumsum = np.cumsum(lengths)
|
||||||
|
|
||||||
batches, totseqs, total_used, total_slots = allocate(
|
alloc_iter = iter(
|
||||||
lengths=lengths,
|
allocate(
|
||||||
lengths_cumsum=lengths_cumsum,
|
lengths=lengths,
|
||||||
rank=self.rank,
|
lengths_cumsum=lengths_cumsum,
|
||||||
# c=self.batch_max_length,
|
rank=self.rank,
|
||||||
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
# c=self.batch_max_length,
|
||||||
n=self.num_replicas,
|
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||||
|
n=self.num_replicas,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
for batch, tot_seqs, total_used, total_slots in alloc_iter:
|
||||||
|
self.batch_queue.put([indices[b_idx] for b_idx in batch])
|
||||||
|
# statistics
|
||||||
|
if set_stats:
|
||||||
|
self.eff_total_used = total_used
|
||||||
|
self.eff_total_slots = total_slots
|
||||||
|
|
||||||
# statistics
|
def _generate_batches_thread(self):
|
||||||
if set_stats:
|
try:
|
||||||
self.eff_total_used += total_used
|
self.generate_batches(set_stats=True)
|
||||||
self.eff_total_slots += total_slots
|
except Exception as e:
|
||||||
|
LOG.error(f"Error in batch generation thread: {e}")
|
||||||
return batches, totseqs
|
self.batch_queue.put(
|
||||||
|
None
|
||||||
|
) # Signal the end of batch generation in case of error
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if hasattr(self.sampler, "set_epoch"):
|
if hasattr(self.sampler, "set_epoch"):
|
||||||
new_epoch = self.sampler.epoch + 1
|
new_epoch = self.sampler.epoch + 1
|
||||||
self.sampler.set_epoch(new_epoch)
|
self.sampler.set_epoch(new_epoch)
|
||||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||||
all_batches, _ = self.generate_batches(set_stats=True)
|
# Start the batch generation in a separate thread
|
||||||
|
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
|
||||||
|
batch_gen_thread.start()
|
||||||
|
|
||||||
features = self.dataset.features.keys()
|
features = self.dataset.features.keys()
|
||||||
len_remaining = self._len_est()
|
len_remaining = self._len_est()
|
||||||
for batches in chunk(
|
while True:
|
||||||
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
batch = self.batch_queue.get()
|
||||||
):
|
if batch is None: # Sentinel value received, stop iteration
|
||||||
|
break
|
||||||
chunked_data = []
|
chunked_data = []
|
||||||
attn_mask_cum_idx = 0
|
attn_mask_cum_idx = 0
|
||||||
for batch in batches:
|
concatenated = {}
|
||||||
concatenated = {}
|
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
for feature in features:
|
||||||
for feature in features:
|
if feature == "attention_mask":
|
||||||
if feature == "attention_mask":
|
arrays = [
|
||||||
arrays = [
|
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||||
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
for idx, item in enumerate(batched_data)
|
||||||
for idx, item in enumerate(batched_data)
|
if feature in item
|
||||||
if feature in item
|
]
|
||||||
]
|
attn_mask_cum_idx += len(batched_data)
|
||||||
attn_mask_cum_idx += len(batched_data)
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
concatenated[feature] = np.concatenate(arrays)
|
else:
|
||||||
else:
|
arrays = [
|
||||||
arrays = [
|
np.array(item[feature])
|
||||||
np.array(item[feature])
|
for item in batched_data
|
||||||
for item in batched_data
|
if feature in item
|
||||||
if feature in item
|
]
|
||||||
]
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
concatenated[feature] = np.concatenate(arrays)
|
chunked_data.append(concatenated)
|
||||||
chunked_data.append(concatenated)
|
|
||||||
# num_chunks = int(
|
|
||||||
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
|
|
||||||
# )
|
|
||||||
# chunked_data = []
|
|
||||||
#
|
|
||||||
# for i in range(num_chunks):
|
|
||||||
# chunk = {
|
|
||||||
# feature: array[
|
|
||||||
# i * self.seq_max_length : (i + 1) * self.seq_max_length
|
|
||||||
# ]
|
|
||||||
# for feature, array in concatenated.items()
|
|
||||||
# }
|
|
||||||
# chunked_data.append(chunk)
|
|
||||||
# yield self.collate_fn(chunked_data)
|
|
||||||
yield self.collate_fn(chunked_data)
|
yield self.collate_fn(chunked_data)
|
||||||
len_remaining -= 1
|
len_remaining -= 1
|
||||||
if not len_remaining:
|
if not len_remaining:
|
||||||
return
|
break
|
||||||
|
# Wait for the batch generation thread to finish
|
||||||
|
batch_gen_thread.join()
|
||||||
|
|
||||||
def _len_est(self):
|
def _len_est(self):
|
||||||
indices = range(0, len(self.dataset))
|
indices = range(0, len(self.dataset))
|
||||||
|
|||||||
Reference in New Issue
Block a user