optimization if total_num_tokens is already known
This commit is contained in:
@@ -5,7 +5,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Callable, List, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -148,6 +148,7 @@ class MultipackDistributedDataloader:
|
|||||||
packing_efficiency_estimate: float = 1.0,
|
packing_efficiency_estimate: float = 1.0,
|
||||||
sample_packing_seq_len_multiplier: int = 1,
|
sample_packing_seq_len_multiplier: int = 1,
|
||||||
device_count: int = 1,
|
device_count: int = 1,
|
||||||
|
total_num_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
@@ -168,6 +169,7 @@ class MultipackDistributedDataloader:
|
|||||||
self.rank = 0
|
self.rank = 0
|
||||||
|
|
||||||
# statistics
|
# statistics
|
||||||
|
self.total_num_tokens = total_num_tokens
|
||||||
self.eff_total_used = 0
|
self.eff_total_used = 0
|
||||||
self.eff_total_slots = 0
|
self.eff_total_slots = 0
|
||||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
@@ -261,8 +263,9 @@ class MultipackDistributedDataloader:
|
|||||||
batch_gen_thread.join()
|
batch_gen_thread.join()
|
||||||
|
|
||||||
def _len_est(self):
|
def _len_est(self):
|
||||||
lengths_sum = np.sum(self.lengths)
|
if not self.total_num_tokens:
|
||||||
lengths_sum_per_device = lengths_sum // self.device_count
|
self.total_num_tokens = np.sum(self.lengths)
|
||||||
|
lengths_sum_per_device = self.total_num_tokens // self.device_count
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||||
|
|||||||
Reference in New Issue
Block a user