limit packing to sequences of max seq len

This commit is contained in:
Wing Lian
2023-08-02 22:07:40 -04:00
parent bdd34c7400
commit 8378335dc9

View File

@@ -1,4 +1,5 @@
# pylint: skip-file # pylint: skip-file
import itertools
import logging import logging
import math import math
import os import os
@@ -101,6 +102,18 @@ def allocate(
return result, result_totseqs, s, len(result) * c * n return result, result_totseqs, s, len(result) * c * n
def chunk(iterable, n):
"""
Chunk data into tuples of length n
"""
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(itertools.islice(it, n)):
yield batch
class MultipackDistributedDataloader: class MultipackDistributedDataloader:
"""Unpadded data loading using Multipack. """Unpadded data loading using Multipack.
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
@@ -150,7 +163,8 @@ class MultipackDistributedDataloader:
lengths=lengths, lengths=lengths,
lengths_cumsum=lengths_cumsum, lengths_cumsum=lengths_cumsum,
rank=self.rank, rank=self.rank,
c=self.batch_max_length, # c=self.batch_max_length,
c=self.seq_max_length,
n=self.num_replicas, n=self.num_replicas,
) )
@@ -167,37 +181,42 @@ class MultipackDistributedDataloader:
all_batches, _ = self.generate_batches(set_stats=True) all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys() features = self.dataset.features.keys()
len_remaining = self._len_est() len_remaining = self._len_est()
for batch in all_batches: for batches in chunk(all_batches, self.batch_size):
concatenated = {} chunked_data = []
batched = [self.dataset[batch_idx] for batch_idx in batch] for batch in batches:
for feature in features: concatenated = {}
if feature == "attention_mask": batched_data = [self.dataset[batch_idx] for batch_idx in batch]
arrays = [ for feature in features:
(idx + 1) * np.array(item[feature]) if feature == "attention_mask":
for idx, item in enumerate(batched) arrays = [
if feature in item (idx + 1) * np.array(item[feature])
] for idx, item in enumerate(batched_data)
concatenated[feature] = np.concatenate(arrays) if feature in item
else: ]
arrays = [ concatenated[feature] = np.concatenate(arrays)
np.array(item[feature]) for item in batched if feature in item else:
] arrays = [
concatenated[feature] = np.concatenate(arrays) np.array(item[feature])
# num_chunks = int( for item in batched_data
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length) if feature in item
# ) ]
# chunked_data = [] concatenated[feature] = np.concatenate(arrays)
# chunked_data.append(concatenated)
# for i in range(num_chunks): # num_chunks = int(
# chunk = { # np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
# feature: array[ # )
# i * self.seq_max_length : (i + 1) * self.seq_max_length # chunked_data = []
# ] #
# for feature, array in concatenated.items() # for i in range(num_chunks):
# } # chunk = {
# chunked_data.append(chunk) # feature: array[
# yield self.collate_fn(chunked_data) # i * self.seq_max_length : (i + 1) * self.seq_max_length
yield self.collate_fn([concatenated]) # ]
# for feature, array in concatenated.items()
# }
# chunked_data.append(chunk)
# yield self.collate_fn(chunked_data)
yield self.collate_fn([chunked_data])
len_remaining -= 1 len_remaining -= 1
if not len_remaining: if not len_remaining:
return return
@@ -219,7 +238,7 @@ class MultipackDistributedDataloader:
* lengths_sum_per_device * lengths_sum_per_device
/ self.packing_efficiency_estimate / self.packing_efficiency_estimate
/ self.seq_max_length / self.seq_max_length
/ self.batch_size // self.batch_size
) )
- 1 - 1
) )