don't split batches when packing
This commit is contained in:
@@ -183,20 +183,21 @@ class MultipackDistributedDataloader:
|
|||||||
np.array(item[feature]) for item in batched if feature in item
|
np.array(item[feature]) for item in batched if feature in item
|
||||||
]
|
]
|
||||||
concatenated[feature] = np.concatenate(arrays)
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
num_chunks = int(
|
# num_chunks = int(
|
||||||
np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
|
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
|
||||||
)
|
# )
|
||||||
chunked_data = []
|
# chunked_data = []
|
||||||
|
#
|
||||||
for i in range(num_chunks):
|
# for i in range(num_chunks):
|
||||||
chunk = {
|
# chunk = {
|
||||||
feature: array[
|
# feature: array[
|
||||||
i * self.seq_max_length : (i + 1) * self.seq_max_length
|
# i * self.seq_max_length : (i + 1) * self.seq_max_length
|
||||||
]
|
# ]
|
||||||
for feature, array in concatenated.items()
|
# for feature, array in concatenated.items()
|
||||||
}
|
# }
|
||||||
chunked_data.append(chunk)
|
# chunked_data.append(chunk)
|
||||||
yield self.collate_fn(chunked_data)
|
# yield self.collate_fn(chunked_data)
|
||||||
|
yield self.collate_fn(concatenated)
|
||||||
len_remaining -= 1
|
len_remaining -= 1
|
||||||
if not len_remaining:
|
if not len_remaining:
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user