don't split batches when packing

This commit is contained in:
Wing Lian
2023-08-02 08:26:49 -04:00
parent 958d423e7c
commit 83f7362480

View File

@@ -183,20 +183,21 @@ class MultipackDistributedDataloader:
np.array(item[feature]) for item in batched if feature in item
]
concatenated[feature] = np.concatenate(arrays)
num_chunks = int(
np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
)
chunked_data = []
for i in range(num_chunks):
chunk = {
feature: array[
i * self.seq_max_length : (i + 1) * self.seq_max_length
]
for feature, array in concatenated.items()
}
chunked_data.append(chunk)
yield self.collate_fn(chunked_data)
# num_chunks = int(
# np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
# )
# chunked_data = []
#
# for i in range(num_chunks):
# chunk = {
# feature: array[
# i * self.seq_max_length : (i + 1) * self.seq_max_length
# ]
# for feature, array in concatenated.items()
# }
# chunked_data.append(chunk)
# yield self.collate_fn(chunked_data)
yield self.collate_fn(concatenated)
len_remaining -= 1
if not len_remaining:
return