optimization if total_num_tokens is already known
This commit is contained in:
@@ -5,7 +5,7 @@ import logging
|
||||
import math
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any, Callable, List, Union
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
@@ -148,6 +148,7 @@ class MultipackDistributedDataloader:
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
sample_packing_seq_len_multiplier: int = 1,
|
||||
device_count: int = 1,
|
||||
total_num_tokens: Optional[int] = None,
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
@@ -168,6 +169,7 @@ class MultipackDistributedDataloader:
|
||||
self.rank = 0
|
||||
|
||||
# statistics
|
||||
self.total_num_tokens = total_num_tokens
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
@@ -261,8 +263,9 @@ class MultipackDistributedDataloader:
|
||||
batch_gen_thread.join()
|
||||
|
||||
def _len_est(self):
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
lengths_sum_per_device = lengths_sum // self.device_count
|
||||
if not self.total_num_tokens:
|
||||
self.total_num_tokens = np.sum(self.lengths)
|
||||
lengths_sum_per_device = self.total_num_tokens // self.device_count
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
|
||||
Reference in New Issue
Block a user