limit packing to sequences of max seq len
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -101,6 +102,18 @@ def allocate(
|
|||||||
return result, result_totseqs, s, len(result) * c * n
|
return result, result_totseqs, s, len(result) * c * n
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(iterable, n):
|
||||||
|
"""
|
||||||
|
Chunk data into tuples of length n
|
||||||
|
"""
|
||||||
|
# batched('ABCDEFG', 3) --> ABC DEF G
|
||||||
|
if n < 1:
|
||||||
|
raise ValueError("n must be at least one")
|
||||||
|
it = iter(iterable)
|
||||||
|
while batch := tuple(itertools.islice(it, n)):
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
class MultipackDistributedDataloader:
|
class MultipackDistributedDataloader:
|
||||||
"""Unpadded data loading using Multipack.
|
"""Unpadded data loading using Multipack.
|
||||||
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
||||||
@@ -150,7 +163,8 @@ class MultipackDistributedDataloader:
|
|||||||
lengths=lengths,
|
lengths=lengths,
|
||||||
lengths_cumsum=lengths_cumsum,
|
lengths_cumsum=lengths_cumsum,
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
c=self.batch_max_length,
|
# c=self.batch_max_length,
|
||||||
|
c=self.seq_max_length,
|
||||||
n=self.num_replicas,
|
n=self.num_replicas,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -167,37 +181,42 @@ class MultipackDistributedDataloader:
|
|||||||
all_batches, _ = self.generate_batches(set_stats=True)
|
all_batches, _ = self.generate_batches(set_stats=True)
|
||||||
features = self.dataset.features.keys()
|
features = self.dataset.features.keys()
|
||||||
len_remaining = self._len_est()
|
len_remaining = self._len_est()
|
||||||
for batch in all_batches:
|
for batches in chunk(all_batches, self.batch_size):
|
||||||
concatenated = {}
|
chunked_data = []
|
||||||
batched = [self.dataset[batch_idx] for batch_idx in batch]
|
for batch in batches:
|
||||||
for feature in features:
|
concatenated = {}
|
||||||
if feature == "attention_mask":
|
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||||
arrays = [
|
for feature in features:
|
||||||
(idx + 1) * np.array(item[feature])
|
if feature == "attention_mask":
|
||||||
for idx, item in enumerate(batched)
|
arrays = [
|
||||||
if feature in item
|
(idx + 1) * np.array(item[feature])
|
||||||
]
|
for idx, item in enumerate(batched_data)
|
||||||
concatenated[feature] = np.concatenate(arrays)
|
if feature in item
|
||||||
else:
|
]
|
||||||
arrays = [
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
np.array(item[feature]) for item in batched if feature in item
|
else:
|
||||||
]
|
arrays = [
|
||||||
concatenated[feature] = np.concatenate(arrays)
|
np.array(item[feature])
|
||||||
# num_chunks = int(
|
for item in batched_data
|
||||||
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
|
if feature in item
|
||||||
# )
|
]
|
||||||
# chunked_data = []
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
#
|
chunked_data.append(concatenated)
|
||||||
# for i in range(num_chunks):
|
# num_chunks = int(
|
||||||
# chunk = {
|
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
|
||||||
# feature: array[
|
# )
|
||||||
# i * self.seq_max_length : (i + 1) * self.seq_max_length
|
# chunked_data = []
|
||||||
# ]
|
#
|
||||||
# for feature, array in concatenated.items()
|
# for i in range(num_chunks):
|
||||||
# }
|
# chunk = {
|
||||||
# chunked_data.append(chunk)
|
# feature: array[
|
||||||
# yield self.collate_fn(chunked_data)
|
# i * self.seq_max_length : (i + 1) * self.seq_max_length
|
||||||
yield self.collate_fn([concatenated])
|
# ]
|
||||||
|
# for feature, array in concatenated.items()
|
||||||
|
# }
|
||||||
|
# chunked_data.append(chunk)
|
||||||
|
# yield self.collate_fn(chunked_data)
|
||||||
|
yield self.collate_fn([chunked_data])
|
||||||
len_remaining -= 1
|
len_remaining -= 1
|
||||||
if not len_remaining:
|
if not len_remaining:
|
||||||
return
|
return
|
||||||
@@ -219,7 +238,7 @@ class MultipackDistributedDataloader:
|
|||||||
* lengths_sum_per_device
|
* lengths_sum_per_device
|
||||||
/ self.packing_efficiency_estimate
|
/ self.packing_efficiency_estimate
|
||||||
/ self.seq_max_length
|
/ self.seq_max_length
|
||||||
/ self.batch_size
|
// self.batch_size
|
||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user