diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index f2804b1a5..0798336e6 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -152,9 +152,10 @@ class MultipackDistributedDataloader: ): # Dataset self.dataset = dataset - self.lengths: np.ndarray = np.array( - [len(sample["input_ids"]) for sample in self.dataset] + lengths_series = ( + dataset.data.column("position_ids").to_pandas().apply(lambda x: x[-1] + 1) ) + self.lengths: np.ndarray = lengths_series.values assert isinstance(self.lengths, np.ndarray) assert batch_size % sample_packing_seq_len_multiplier == 0 assert batch_size >= sample_packing_seq_len_multiplier @@ -208,6 +209,7 @@ class MultipackDistributedDataloader: if set_stats: self.eff_total_used = total_used self.eff_total_slots = total_slots + self.batch_queue.put(None) # Signal the end of batch generation def _generate_batches_thread(self): try: @@ -260,7 +262,8 @@ class MultipackDistributedDataloader: if not len_remaining: break # Wait for the batch generation thread to finish - batch_gen_thread.join() + batch_gen_thread.join(timeout=5) + LOG.info(f"actual packing efficiency: {self.efficiency()}") def _len_est(self): if not self.total_num_tokens: @@ -277,7 +280,7 @@ class MultipackDistributedDataloader: 0.99 * lengths_sum_per_device / self.packing_efficiency_estimate - / self.seq_max_length + // self.seq_max_length // self.batch_size ) - 1