optimization if total_num_tokens is already known

This commit is contained in:
Wing Lian
2023-08-10 19:02:28 -04:00
parent ac4b700daa
commit 7e977a9b68

View File

@@ -5,7 +5,7 @@ import logging
import math
import queue
import threading
from typing import Any, Callable, List, Union
from typing import Any, Callable, List, Optional, Union
import numba
import numpy as np
@@ -148,6 +148,7 @@ class MultipackDistributedDataloader:
packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1,
total_num_tokens: Optional[int] = None,
):
# Dataset
self.dataset = dataset
@@ -168,6 +169,7 @@ class MultipackDistributedDataloader:
self.rank = 0
# statistics
self.total_num_tokens = total_num_tokens
self.eff_total_used = 0
self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
@@ -261,8 +263,9 @@ class MultipackDistributedDataloader:
batch_gen_thread.join()
def _len_est(self):
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // self.device_count
if not self.total_num_tokens:
self.total_num_tokens = np.sum(self.lengths)
lengths_sum_per_device = self.total_num_tokens // self.device_count
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"