diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index f1ab86e37..b12ea338c 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -1,4 +1,5 @@ # pylint: skip-file +import itertools import logging import math import os @@ -101,6 +102,18 @@ def allocate( return result, result_totseqs, s, len(result) * c * n +def chunk(iterable, n): + """ + Chunk data into tuples of length n + """ + # batched('ABCDEFG', 3) --> ABC DEF G + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := tuple(itertools.islice(it, n)): + yield batch + + class MultipackDistributedDataloader: """Unpadded data loading using Multipack. Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py @@ -150,7 +163,8 @@ class MultipackDistributedDataloader: lengths=lengths, lengths_cumsum=lengths_cumsum, rank=self.rank, - c=self.batch_max_length, + # c=self.batch_max_length, + c=self.seq_max_length, n=self.num_replicas, ) @@ -167,37 +181,42 @@ class MultipackDistributedDataloader: all_batches, _ = self.generate_batches(set_stats=True) features = self.dataset.features.keys() len_remaining = self._len_est() - for batch in all_batches: - concatenated = {} - batched = [self.dataset[batch_idx] for batch_idx in batch] - for feature in features: - if feature == "attention_mask": - arrays = [ - (idx + 1) * np.array(item[feature]) - for idx, item in enumerate(batched) - if feature in item - ] - concatenated[feature] = np.concatenate(arrays) - else: - arrays = [ - np.array(item[feature]) for item in batched if feature in item - ] - concatenated[feature] = np.concatenate(arrays) - # num_chunks = int( - # np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length) - # ) - # chunked_data = [] - # - # for i in range(num_chunks): - # chunk = { - # feature: array[ - # i * self.seq_max_length : (i + 1) * self.seq_max_length - # ] - # for feature, array in concatenated.items() - # } - # chunked_data.append(chunk) - # yield self.collate_fn(chunked_data) - yield self.collate_fn([concatenated]) + for batches in chunk(all_batches, self.batch_size): + chunked_data = [] + for batch in batches: + concatenated = {} + batched_data = [self.dataset[batch_idx] for batch_idx in batch] + for feature in features: + if feature == "attention_mask": + arrays = [ + (idx + 1) * np.array(item[feature]) + for idx, item in enumerate(batched_data) + if feature in item + ] + concatenated[feature] = np.concatenate(arrays) + else: + arrays = [ + np.array(item[feature]) + for item in batched_data + if feature in item + ] + concatenated[feature] = np.concatenate(arrays) + chunked_data.append(concatenated) + # num_chunks = int( + # np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length) + # ) + # chunked_data = [] + # + # for i in range(num_chunks): + # chunk = { + # feature: array[ + # i * self.seq_max_length : (i + 1) * self.seq_max_length + # ] + # for feature, array in concatenated.items() + # } + # chunked_data.append(chunk) + # yield self.collate_fn(chunked_data) + yield self.collate_fn([chunked_data]) len_remaining -= 1 if not len_remaining: return @@ -219,7 +238,7 @@ class MultipackDistributedDataloader: * lengths_sum_per_device / self.packing_efficiency_estimate / self.seq_max_length - / self.batch_size + // self.batch_size ) - 1 )