From 80e1468b8d444189ea007ccd0dec297458b7d347 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 1 Feb 2025 21:10:34 -0500 Subject: [PATCH] better handling of multipack dataset length (#2296) --- src/axolotl/utils/samplers/multipack.py | 43 +++++++------------------ 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index db14a6819..6119dff30 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -4,7 +4,6 @@ Multipack Batch Sampler """ import logging import math -import os from typing import Any, Iterable, List, Union import numba @@ -117,6 +116,7 @@ class MultipackBatchSampler(BatchSampler): lengths: np.ndarray, packing_efficiency_estimate: float = 1.0, drop_last: bool = False, + num_count_samples: int = 16, **kwargs, ): super().__init__(sampler, batch_size, drop_last) @@ -133,6 +133,9 @@ class MultipackBatchSampler(BatchSampler): self.eff_total_used = 0 self.eff_total_slots = 0 + # The number of times to calculate the batches to determine the minimum packed dataset length for the local rank + self.num_count_samples = num_count_samples + # the minimum packed dataset length across all ranks determined by a gather/broadcast self.len_across_ranks = None def set_epoch(self, epoch: int): @@ -169,6 +172,9 @@ class MultipackBatchSampler(BatchSampler): def __iter__(self): batches = self.generate_batches(set_stats=True) + if self.len_across_ranks: + # make sure the batches we iterate over is truncated to the same min length across all ranks + batches = batches[: self.len_across_ranks] return iter(batches) def num_batches(self): @@ -195,42 +201,15 @@ class MultipackBatchSampler(BatchSampler): def gather_len_batches(self, num): def calc_min_len(estimates: list[(int, float)]): LOG.info(f"gather_len_batches: {repr(estimates)}") - return math.floor(0.998 * min(estimates)) + return math.floor(min(estimates)) min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len) return min_len_batches def __len__(self): if not self.len_across_ranks: - len_batches = self.num_batches() + len_batches = min( + [self.num_batches() for _ in range(self.num_count_samples)] + ) self.len_across_ranks = self.gather_len_batches(len_batches) return self.len_across_ranks - - def _len_est(self): - efficiency = ( - self.packing_efficiency_estimate - if self.packing_efficiency_estimate - else self.gather_efficiency() - ) - world_size = int(os.getenv("WORLD_SIZE", "1")) - lengths_sum = np.sum(self.lengths) - lengths_sum_per_device = lengths_sum // world_size - LOG.info( - f"packing_efficiency_estimate: {efficiency} " - f"total_num_tokens per device: {lengths_sum_per_device}" - ) - - # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler - return max( - 0, - ( - world_size - * math.floor( - 0.99 - * lengths_sum_per_device - / efficiency - // (self.batch_max_len * self.batch_size) - ) - - 1 - ), - )