better handling of multipack dataset length (#2296)

This commit is contained in:
Wing Lian
2025-02-01 21:10:34 -05:00
committed by GitHub
parent a20f17689b
commit 80e1468b8d

View File

@@ -4,7 +4,6 @@ Multipack Batch Sampler
""" """
import logging import logging
import math import math
import os
from typing import Any, Iterable, List, Union from typing import Any, Iterable, List, Union
import numba import numba
@@ -117,6 +116,7 @@ class MultipackBatchSampler(BatchSampler):
lengths: np.ndarray, lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0,
drop_last: bool = False, drop_last: bool = False,
num_count_samples: int = 16,
**kwargs, **kwargs,
): ):
super().__init__(sampler, batch_size, drop_last) super().__init__(sampler, batch_size, drop_last)
@@ -133,6 +133,9 @@ class MultipackBatchSampler(BatchSampler):
self.eff_total_used = 0 self.eff_total_used = 0
self.eff_total_slots = 0 self.eff_total_slots = 0
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
self.num_count_samples = num_count_samples
# the minimum packed dataset length across all ranks determined by a gather/broadcast
self.len_across_ranks = None self.len_across_ranks = None
def set_epoch(self, epoch: int): def set_epoch(self, epoch: int):
@@ -169,6 +172,9 @@ class MultipackBatchSampler(BatchSampler):
def __iter__(self): def __iter__(self):
batches = self.generate_batches(set_stats=True) batches = self.generate_batches(set_stats=True)
if self.len_across_ranks:
# make sure the batches we iterate over is truncated to the same min length across all ranks
batches = batches[: self.len_across_ranks]
return iter(batches) return iter(batches)
def num_batches(self): def num_batches(self):
@@ -195,42 +201,15 @@ class MultipackBatchSampler(BatchSampler):
def gather_len_batches(self, num): def gather_len_batches(self, num):
def calc_min_len(estimates: list[(int, float)]): def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}") LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates)) return math.floor(min(estimates))
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len) min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches return min_len_batches
def __len__(self): def __len__(self):
if not self.len_across_ranks: if not self.len_across_ranks:
len_batches = self.num_batches() len_batches = min(
[self.num_batches() for _ in range(self.num_count_samples)]
)
self.len_across_ranks = self.gather_len_batches(len_batches) self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks return self.len_across_ranks
def _len_est(self):
efficiency = (
self.packing_efficiency_estimate
if self.packing_efficiency_estimate
else self.gather_efficiency()
)
world_size = int(os.getenv("WORLD_SIZE", "1"))
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // world_size
LOG.info(
f"packing_efficiency_estimate: {efficiency} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return max(
0,
(
world_size
* math.floor(
0.99
* lengths_sum_per_device
/ efficiency
// (self.batch_max_len * self.batch_size)
)
- 1
),
)