optimize length reducer from 9m -> <5sec
This commit is contained in:
@@ -152,9 +152,10 @@ class MultipackDistributedDataloader:
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
self.lengths: np.ndarray = np.array(
|
||||
[len(sample["input_ids"]) for sample in self.dataset]
|
||||
lengths_series = (
|
||||
dataset.data.column("position_ids").to_pandas().apply(lambda x: x[-1] + 1)
|
||||
)
|
||||
self.lengths: np.ndarray = lengths_series.values
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
assert batch_size % sample_packing_seq_len_multiplier == 0
|
||||
assert batch_size >= sample_packing_seq_len_multiplier
|
||||
@@ -208,6 +209,7 @@ class MultipackDistributedDataloader:
|
||||
if set_stats:
|
||||
self.eff_total_used = total_used
|
||||
self.eff_total_slots = total_slots
|
||||
self.batch_queue.put(None) # Signal the end of batch generation
|
||||
|
||||
def _generate_batches_thread(self):
|
||||
try:
|
||||
@@ -260,7 +262,8 @@ class MultipackDistributedDataloader:
|
||||
if not len_remaining:
|
||||
break
|
||||
# Wait for the batch generation thread to finish
|
||||
batch_gen_thread.join()
|
||||
batch_gen_thread.join(timeout=5)
|
||||
LOG.info(f"actual packing efficiency: {self.efficiency()}")
|
||||
|
||||
def _len_est(self):
|
||||
if not self.total_num_tokens:
|
||||
@@ -277,7 +280,7 @@ class MultipackDistributedDataloader:
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
/ self.seq_max_length
|
||||
// self.seq_max_length
|
||||
// self.batch_size
|
||||
)
|
||||
- 1
|
||||
|
||||
Reference in New Issue
Block a user