From 2565c2f259cc0a18e11e374b7b9c59d7c8702a33 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 10 Aug 2023 18:28:15 -0400 Subject: [PATCH] async batching for multipack --- src/axolotl/utils/dataloader.py | 121 +++++++++++++++++--------------- 1 file changed, 63 insertions(+), 58 deletions(-) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 83e91cee5..ca74de271 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -3,6 +3,8 @@ import hashlib import itertools import logging import math +import queue +import threading from typing import Any, Callable, List, Union import numba @@ -78,7 +80,6 @@ def allocate( s = 0 start_index = 0 result = [] - result_totseqs = [] while True: # binary search [left, right) @@ -104,10 +105,8 @@ def allocate( # add local rank result.append(batch[rank]) - # add total seqs for all ranks - result_totseqs.append(tot_seqs) - # yield batch[rank], tot_seqs, s, len(result) * c * n - return result, result_totseqs, s, len(result) * c * n + + yield batch[rank], tot_seqs, s, len(result) * c * n def chunk(iterable, n): @@ -174,6 +173,11 @@ class MultipackDistributedDataloader: self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.device_count = device_count + # for non-blocking batch creation + self.batch_queue: queue.Queue = queue.Queue( + maxsize=10 + ) # Adjust maxsize as needed + def generate_batches(self, set_stats=False): LOG.info("generating packed batches") if self.sampler: @@ -185,75 +189,76 @@ class MultipackDistributedDataloader: lengths = self.lengths[indices] lengths_cumsum = np.cumsum(lengths) - batches, totseqs, total_used, total_slots = allocate( - lengths=lengths, - lengths_cumsum=lengths_cumsum, - rank=self.rank, - # c=self.batch_max_length, - c=self.seq_max_length * self.sample_packing_seq_len_multiplier, - n=self.num_replicas, + alloc_iter = iter( + allocate( + lengths=lengths, + lengths_cumsum=lengths_cumsum, + rank=self.rank, + # c=self.batch_max_length, + c=self.seq_max_length * self.sample_packing_seq_len_multiplier, + n=self.num_replicas, + ) ) - batches = [[indices[b_idx] for b_idx in batch] for batch in batches] + for batch, tot_seqs, total_used, total_slots in alloc_iter: + self.batch_queue.put([indices[b_idx] for b_idx in batch]) + # statistics + if set_stats: + self.eff_total_used = total_used + self.eff_total_slots = total_slots - # statistics - if set_stats: - self.eff_total_used += total_used - self.eff_total_slots += total_slots - - return batches, totseqs + def _generate_batches_thread(self): + try: + self.generate_batches(set_stats=True) + except Exception as e: + LOG.error(f"Error in batch generation thread: {e}") + self.batch_queue.put( + None + ) # Signal the end of batch generation in case of error def __iter__(self): if hasattr(self.sampler, "set_epoch"): new_epoch = self.sampler.epoch + 1 self.sampler.set_epoch(new_epoch) LOG.info(f"calling sampler.set_epoch({new_epoch})") - all_batches, _ = self.generate_batches(set_stats=True) + # Start the batch generation in a separate thread + batch_gen_thread = threading.Thread(target=self._generate_batches_thread) + batch_gen_thread.start() + features = self.dataset.features.keys() len_remaining = self._len_est() - for batches in chunk( - all_batches, self.batch_size // self.sample_packing_seq_len_multiplier - ): + while True: + batch = self.batch_queue.get() + if batch is None: # Sentinel value received, stop iteration + break chunked_data = [] attn_mask_cum_idx = 0 - for batch in batches: - concatenated = {} - batched_data = [self.dataset[batch_idx] for batch_idx in batch] - for feature in features: - if feature == "attention_mask": - arrays = [ - (attn_mask_cum_idx + idx + 1) * np.array(item[feature]) - for idx, item in enumerate(batched_data) - if feature in item - ] - attn_mask_cum_idx += len(batched_data) - concatenated[feature] = np.concatenate(arrays) - else: - arrays = [ - np.array(item[feature]) - for item in batched_data - if feature in item - ] - concatenated[feature] = np.concatenate(arrays) - chunked_data.append(concatenated) - # num_chunks = int( - # np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length) - # ) - # chunked_data = [] - # - # for i in range(num_chunks): - # chunk = { - # feature: array[ - # i * self.seq_max_length : (i + 1) * self.seq_max_length - # ] - # for feature, array in concatenated.items() - # } - # chunked_data.append(chunk) - # yield self.collate_fn(chunked_data) + concatenated = {} + batched_data = [self.dataset[batch_idx] for batch_idx in batch] + for feature in features: + if feature == "attention_mask": + arrays = [ + (attn_mask_cum_idx + idx + 1) * np.array(item[feature]) + for idx, item in enumerate(batched_data) + if feature in item + ] + attn_mask_cum_idx += len(batched_data) + concatenated[feature] = np.concatenate(arrays) + else: + arrays = [ + np.array(item[feature]) + for item in batched_data + if feature in item + ] + concatenated[feature] = np.concatenate(arrays) + chunked_data.append(concatenated) + yield self.collate_fn(chunked_data) len_remaining -= 1 if not len_remaining: - return + break + # Wait for the batch generation thread to finish + batch_gen_thread.join() def _len_est(self): indices = range(0, len(self.dataset))