better handling of multipack dataset length (#2296)

This commit is contained in:
Wing Lian
2025-02-01 21:10:34 -05:00
committed by GitHub
parent a20f17689b
commit 80e1468b8d

View File

@@ -4,7 +4,6 @@ Multipack Batch Sampler
"""
import logging
import math
import os
from typing import Any, Iterable, List, Union
import numba
@@ -117,6 +116,7 @@ class MultipackBatchSampler(BatchSampler):
lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0,
drop_last: bool = False,
num_count_samples: int = 16,
**kwargs,
):
super().__init__(sampler, batch_size, drop_last)
@@ -133,6 +133,9 @@ class MultipackBatchSampler(BatchSampler):
self.eff_total_used = 0
self.eff_total_slots = 0
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
self.num_count_samples = num_count_samples
# the minimum packed dataset length across all ranks determined by a gather/broadcast
self.len_across_ranks = None
def set_epoch(self, epoch: int):
@@ -169,6 +172,9 @@ class MultipackBatchSampler(BatchSampler):
def __iter__(self):
batches = self.generate_batches(set_stats=True)
if self.len_across_ranks:
# make sure the batches we iterate over is truncated to the same min length across all ranks
batches = batches[: self.len_across_ranks]
return iter(batches)
def num_batches(self):
@@ -195,42 +201,15 @@ class MultipackBatchSampler(BatchSampler):
def gather_len_batches(self, num):
def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates))
return math.floor(min(estimates))
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches
def __len__(self):
if not self.len_across_ranks:
len_batches = self.num_batches()
len_batches = min(
[self.num_batches() for _ in range(self.num_count_samples)]
)
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks
def _len_est(self):
efficiency = (
self.packing_efficiency_estimate
if self.packing_efficiency_estimate
else self.gather_efficiency()
)
world_size = int(os.getenv("WORLD_SIZE", "1"))
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // world_size
LOG.info(
f"packing_efficiency_estimate: {efficiency} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return max(
0,
(
world_size
* math.floor(
0.99
* lengths_sum_per_device
/ efficiency
// (self.batch_max_len * self.batch_size)
)
- 1
),
)