diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 18b9c4db1..2f2b0b372 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -183,20 +183,21 @@ class MultipackDistributedDataloader: np.array(item[feature]) for item in batched if feature in item ] concatenated[feature] = np.concatenate(arrays) - num_chunks = int( - np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length) - ) - chunked_data = [] - - for i in range(num_chunks): - chunk = { - feature: array[ - i * self.seq_max_length : (i + 1) * self.seq_max_length - ] - for feature, array in concatenated.items() - } - chunked_data.append(chunk) - yield self.collate_fn(chunked_data) + # num_chunks = int( + # np.ceil(len(next(iter(concatenated.values()))) / self.seq_max_length) + # ) + # chunked_data = [] + # + # for i in range(num_chunks): + # chunk = { + # feature: array[ + # i * self.seq_max_length : (i + 1) * self.seq_max_length + # ] + # for feature, array in concatenated.items() + # } + # chunked_data.append(chunk) + # yield self.collate_fn(chunked_data) + yield self.collate_fn(concatenated) len_remaining -= 1 if not len_remaining: return