don't split batches when packing
This commit is contained in:
@@ -183,20 +183,21 @@ class MultipackDistributedDataloader:
|
||||
np.array(item[feature]) for item in batched if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
num_chunks = int(
|
||||
np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
|
||||
)
|
||||
chunked_data = []
|
||||
|
||||
for i in range(num_chunks):
|
||||
chunk = {
|
||||
feature: array[
|
||||
i * self.seq_max_length : (i + 1) * self.seq_max_length
|
||||
]
|
||||
for feature, array in concatenated.items()
|
||||
}
|
||||
chunked_data.append(chunk)
|
||||
yield self.collate_fn(chunked_data)
|
||||
# num_chunks = int(
|
||||
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
|
||||
# )
|
||||
# chunked_data = []
|
||||
#
|
||||
# for i in range(num_chunks):
|
||||
# chunk = {
|
||||
# feature: array[
|
||||
# i * self.seq_max_length : (i + 1) * self.seq_max_length
|
||||
# ]
|
||||
# for feature, array in concatenated.items()
|
||||
# }
|
||||
# chunked_data.append(chunk)
|
||||
# yield self.collate_fn(chunked_data)
|
||||
yield self.collate_fn(concatenated)
|
||||
len_remaining -= 1
|
||||
if not len_remaining:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user