optimize length reducer from 9m -> <5sec

This commit is contained in:
Wing Lian
2023-08-11 08:30:30 -04:00
parent 79500f358a
commit 6b5cf8b5ea

View File

@@ -152,9 +152,10 @@ class MultipackDistributedDataloader:
):
# Dataset
self.dataset = dataset
self.lengths: np.ndarray = np.array(
[len(sample["input_ids"]) for sample in self.dataset]
lengths_series = (
dataset.data.column("position_ids").to_pandas().apply(lambda x: x[-1] + 1)
)
self.lengths: np.ndarray = lengths_series.values
assert isinstance(self.lengths, np.ndarray)
assert batch_size % sample_packing_seq_len_multiplier == 0
assert batch_size >= sample_packing_seq_len_multiplier
@@ -208,6 +209,7 @@ class MultipackDistributedDataloader:
if set_stats:
self.eff_total_used = total_used
self.eff_total_slots = total_slots
self.batch_queue.put(None) # Signal the end of batch generation
def _generate_batches_thread(self):
try:
@@ -260,7 +262,8 @@ class MultipackDistributedDataloader:
if not len_remaining:
break
# Wait for the batch generation thread to finish
batch_gen_thread.join()
batch_gen_thread.join(timeout=5)
LOG.info(f"actual packing efficiency: {self.efficiency()}")
def _len_est(self):
if not self.total_num_tokens:
@@ -277,7 +280,7 @@ class MultipackDistributedDataloader:
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
/ self.seq_max_length
// self.seq_max_length
// self.batch_size
)
- 1