optimization if total_num_tokens is already known

This commit is contained in:
Wing Lian
2023-08-10 19:02:28 -04:00
parent ac4b700daa
commit 7e977a9b68

View File

@@ -5,7 +5,7 @@ import logging
import math import math
import queue import queue
import threading import threading
from typing import Any, Callable, List, Union from typing import Any, Callable, List, Optional, Union
import numba import numba
import numpy as np import numpy as np
@@ -148,6 +148,7 @@ class MultipackDistributedDataloader:
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1, sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1, device_count: int = 1,
total_num_tokens: Optional[int] = None,
): ):
# Dataset # Dataset
self.dataset = dataset self.dataset = dataset
@@ -168,6 +169,7 @@ class MultipackDistributedDataloader:
self.rank = 0 self.rank = 0
# statistics # statistics
self.total_num_tokens = total_num_tokens
self.eff_total_used = 0 self.eff_total_used = 0
self.eff_total_slots = 0 self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
@@ -261,8 +263,9 @@ class MultipackDistributedDataloader:
batch_gen_thread.join() batch_gen_thread.join()
def _len_est(self): def _len_est(self):
lengths_sum = np.sum(self.lengths) if not self.total_num_tokens:
lengths_sum_per_device = lengths_sum // self.device_count self.total_num_tokens = np.sum(self.lengths)
lengths_sum_per_device = self.total_num_tokens // self.device_count
LOG.info( LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}" f"total_num_tokens per device: {lengths_sum_per_device}"