diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 3b281eb2c..f2804b1a5 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -5,7 +5,7 @@ import logging import math import queue import threading -from typing import Any, Callable, List, Union +from typing import Any, Callable, List, Optional, Union import numba import numpy as np @@ -148,6 +148,7 @@ class MultipackDistributedDataloader: packing_efficiency_estimate: float = 1.0, sample_packing_seq_len_multiplier: int = 1, device_count: int = 1, + total_num_tokens: Optional[int] = None, ): # Dataset self.dataset = dataset @@ -168,6 +169,7 @@ class MultipackDistributedDataloader: self.rank = 0 # statistics + self.total_num_tokens = total_num_tokens self.eff_total_used = 0 self.eff_total_slots = 0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 @@ -261,8 +263,9 @@ class MultipackDistributedDataloader: batch_gen_thread.join() def _len_est(self): - lengths_sum = np.sum(self.lengths) - lengths_sum_per_device = lengths_sum // self.device_count + if not self.total_num_tokens: + self.total_num_tokens = np.sum(self.lengths) + lengths_sum_per_device = self.total_num_tokens // self.device_count LOG.info( f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " f"total_num_tokens per device: {lengths_sum_per_device}"