limit packing to sequences of max seq len

This commit is contained in:
Wing Lian
2023-08-02 22:07:40 -04:00
parent bdd34c7400
commit 8378335dc9

View File

@@ -1,4 +1,5 @@
# pylint: skip-file
import itertools
import logging
import math
import os
@@ -101,6 +102,18 @@ def allocate(
return result, result_totseqs, s, len(result) * c * n
def chunk(iterable, n):
"""
Chunk data into tuples of length n
"""
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(itertools.islice(it, n)):
yield batch
class MultipackDistributedDataloader:
"""Unpadded data loading using Multipack.
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
@@ -150,7 +163,8 @@ class MultipackDistributedDataloader:
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
c=self.batch_max_length,
# c=self.batch_max_length,
c=self.seq_max_length,
n=self.num_replicas,
)
@@ -167,37 +181,42 @@ class MultipackDistributedDataloader:
all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys()
len_remaining = self._len_est()
for batch in all_batches:
concatenated = {}
batched = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched)
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in batched if feature in item
]
concatenated[feature] = np.concatenate(arrays)
# num_chunks = int(
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
# )
# chunked_data = []
#
# for i in range(num_chunks):
# chunk = {
# feature: array[
# i * self.seq_max_length : (i + 1) * self.seq_max_length
# ]
# for feature, array in concatenated.items()
# }
# chunked_data.append(chunk)
# yield self.collate_fn(chunked_data)
yield self.collate_fn([concatenated])
for batches in chunk(all_batches, self.batch_size):
chunked_data = []
for batch in batches:
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature])
for item in batched_data
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
# num_chunks = int(
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
# )
# chunked_data = []
#
# for i in range(num_chunks):
# chunk = {
# feature: array[
# i * self.seq_max_length : (i + 1) * self.seq_max_length
# ]
# for feature, array in concatenated.items()
# }
# chunked_data.append(chunk)
# yield self.collate_fn(chunked_data)
yield self.collate_fn([chunked_data])
len_remaining -= 1
if not len_remaining:
return
@@ -219,7 +238,7 @@ class MultipackDistributedDataloader:
* lengths_sum_per_device
/ self.packing_efficiency_estimate
/ self.seq_max_length
/ self.batch_size
// self.batch_size
)
- 1
)