limit packing to sequences of max seq len
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# pylint: skip-file
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -101,6 +102,18 @@ def allocate(
|
||||
return result, result_totseqs, s, len(result) * c * n
|
||||
|
||||
|
||||
def chunk(iterable, n):
|
||||
"""
|
||||
Chunk data into tuples of length n
|
||||
"""
|
||||
# batched('ABCDEFG', 3) --> ABC DEF G
|
||||
if n < 1:
|
||||
raise ValueError("n must be at least one")
|
||||
it = iter(iterable)
|
||||
while batch := tuple(itertools.islice(it, n)):
|
||||
yield batch
|
||||
|
||||
|
||||
class MultipackDistributedDataloader:
|
||||
"""Unpadded data loading using Multipack.
|
||||
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
||||
@@ -150,7 +163,8 @@ class MultipackDistributedDataloader:
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
c=self.batch_max_length,
|
||||
# c=self.batch_max_length,
|
||||
c=self.seq_max_length,
|
||||
n=self.num_replicas,
|
||||
)
|
||||
|
||||
@@ -167,37 +181,42 @@ class MultipackDistributedDataloader:
|
||||
all_batches, _ = self.generate_batches(set_stats=True)
|
||||
features = self.dataset.features.keys()
|
||||
len_remaining = self._len_est()
|
||||
for batch in all_batches:
|
||||
concatenated = {}
|
||||
batched = [self.dataset[batch_idx] for batch_idx in batch]
|
||||
for feature in features:
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(idx + 1) * np.array(item[feature])
|
||||
for idx, item in enumerate(batched)
|
||||
if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in batched if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
# num_chunks = int(
|
||||
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
|
||||
# )
|
||||
# chunked_data = []
|
||||
#
|
||||
# for i in range(num_chunks):
|
||||
# chunk = {
|
||||
# feature: array[
|
||||
# i * self.seq_max_length : (i + 1) * self.seq_max_length
|
||||
# ]
|
||||
# for feature, array in concatenated.items()
|
||||
# }
|
||||
# chunked_data.append(chunk)
|
||||
# yield self.collate_fn(chunked_data)
|
||||
yield self.collate_fn([concatenated])
|
||||
for batches in chunk(all_batches, self.batch_size):
|
||||
chunked_data = []
|
||||
for batch in batches:
|
||||
concatenated = {}
|
||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||
for feature in features:
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(idx + 1) * np.array(item[feature])
|
||||
for idx, item in enumerate(batched_data)
|
||||
if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature])
|
||||
for item in batched_data
|
||||
if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
chunked_data.append(concatenated)
|
||||
# num_chunks = int(
|
||||
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
|
||||
# )
|
||||
# chunked_data = []
|
||||
#
|
||||
# for i in range(num_chunks):
|
||||
# chunk = {
|
||||
# feature: array[
|
||||
# i * self.seq_max_length : (i + 1) * self.seq_max_length
|
||||
# ]
|
||||
# for feature, array in concatenated.items()
|
||||
# }
|
||||
# chunked_data.append(chunk)
|
||||
# yield self.collate_fn(chunked_data)
|
||||
yield self.collate_fn([chunked_data])
|
||||
len_remaining -= 1
|
||||
if not len_remaining:
|
||||
return
|
||||
@@ -219,7 +238,7 @@ class MultipackDistributedDataloader:
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
/ self.seq_max_length
|
||||
/ self.batch_size
|
||||
// self.batch_size
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user