don't split batches when packing

This commit is contained in:
Wing Lian
2023-08-02 08:26:49 -04:00
parent 958d423e7c
commit 83f7362480

View File

@@ -183,20 +183,21 @@ class MultipackDistributedDataloader:
np.array(item[feature]) for item in batched if feature in item np.array(item[feature]) for item in batched if feature in item
] ]
concatenated[feature] = np.concatenate(arrays) concatenated[feature] = np.concatenate(arrays)
num_chunks = int( # num_chunks = int(
np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length) # np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length)
) # )
chunked_data = [] # chunked_data = []
#
for i in range(num_chunks): # for i in range(num_chunks):
chunk = { # chunk = {
feature: array[ # feature: array[
i * self.seq_max_length : (i + 1) * self.seq_max_length # i * self.seq_max_length : (i + 1) * self.seq_max_length
] # ]
for feature, array in concatenated.items() # for feature, array in concatenated.items()
} # }
chunked_data.append(chunk) # chunked_data.append(chunk)
yield self.collate_fn(chunked_data) # yield self.collate_fn(chunked_data)
yield self.collate_fn(concatenated)
len_remaining -= 1 len_remaining -= 1
if not len_remaining: if not len_remaining:
return return