diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index eb2fb8df8..18b9c4db1 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -1,11 +1,15 @@ # pylint: skip-file +import logging import math +import os from typing import Any, Callable, List, Union import numba import numpy as np from torch.utils.data import DistributedSampler, Sampler +LOG = logging.getLogger("axolotl.utils.dataloader") + @numba.njit def ffd_check(a: np.ndarray, c: int, n: int): @@ -110,6 +114,7 @@ class MultipackDistributedDataloader: seq_max_length: int = 2048, batch_size: int = 1, sampler: Union[Sampler, DistributedSampler] = None, + packing_efficiency_estimate: float = 1.0, ): # Dataset self.dataset = dataset @@ -130,6 +135,7 @@ class MultipackDistributedDataloader: # statistics self.eff_total_used = 0 self.eff_total_slots = 0 + self.packing_efficiency_estimate = packing_efficiency_estimate def generate_batches(self, set_stats=False): if self.sampler: @@ -160,6 +166,7 @@ class MultipackDistributedDataloader: def __iter__(self): all_batches, _ = self.generate_batches(set_stats=True) features = self.dataset.features.keys() + len_remaining = self._len_est() for batch in all_batches: concatenated = {} batched = [self.dataset[batch_idx] for batch_idx in batch] @@ -190,15 +197,42 @@ class MultipackDistributedDataloader: } chunked_data.append(chunk) yield self.collate_fn(chunked_data) + len_remaining -= 1 + if not len_remaining: + return + + def _len_est(self): + indices = range(0, len(self.dataset)) + lengths = self.lengths[indices] + lengths_sum = np.cumsum(lengths)[-1] + lengths_sum_per_device = lengths_sum // int(os.environ.get("WORLD_SIZE", 1)) + LOG.info( + f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + f"total_num_tokens per device: {lengths_sum_per_device}" + ) + + # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler + return ( + math.floor( + 0.99 + * lengths_sum_per_device + / self.packing_efficiency_estimate + / self.seq_max_length + / self.batch_size + ) + - 1 + ) def __len__(self): - batches, _ = self.generate_batches() - # shave off 1% for dealing with variance in packing and dataset length - return math.floor(len(batches) * 0.99) - - def num_batches(self): - batches, _ = self.generate_batches() - return math.floor(len(batches) * 0.99) + # this doesn't return the actual length b/c with distributed samplers, not all dataloaders get + # the same share of total tokens + if not self.eff_total_used: + batches, _ = self.generate_batches(set_stats=True) + LOG.info( + f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " + f"actual packing efficiency: {self.efficiency()}" + ) + return self._len_est() def efficiency(self): return self.eff_total_used / self.eff_total_slots diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 8507d604f..f78793bc4 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -187,11 +187,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): else sum(len(s["input_ids"]) for s in train_dataset) ) total_num_steps = ( - math.ceil( - total_num_tokens - / cfg.sample_packing_eff_est - / 2048 - / cfg.batch_size + # match count to len est in dataloader + ( + 0.99 + * math.ceil( + total_num_tokens + / cfg.sample_packing_eff_est + / 2048 + / cfg.batch_size + ) + - 1 ) * cfg.num_epochs ) @@ -210,6 +215,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): padding="longest", ), sampler=sampler, + packing_efficiency_estimate=cfg.sample_packing_eff_est, ) data_loader_len = len(data_loader) LOG.info(f"data_loader_len: {data_loader_len}")