better handling of multipack dataset length (#2296)
This commit is contained in:
@@ -4,7 +4,6 @@ Multipack Batch Sampler
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Iterable, List, Union
|
||||
|
||||
import numba
|
||||
@@ -117,6 +116,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
lengths: np.ndarray,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
drop_last: bool = False,
|
||||
num_count_samples: int = 16,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
@@ -133,6 +133,9 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
|
||||
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
|
||||
self.num_count_samples = num_count_samples
|
||||
# the minimum packed dataset length across all ranks determined by a gather/broadcast
|
||||
self.len_across_ranks = None
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
@@ -169,6 +172,9 @@ class MultipackBatchSampler(BatchSampler):
|
||||
|
||||
def __iter__(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
if self.len_across_ranks:
|
||||
# make sure the batches we iterate over is truncated to the same min length across all ranks
|
||||
batches = batches[: self.len_across_ranks]
|
||||
return iter(batches)
|
||||
|
||||
def num_batches(self):
|
||||
@@ -195,42 +201,15 @@ class MultipackBatchSampler(BatchSampler):
|
||||
def gather_len_batches(self, num):
|
||||
def calc_min_len(estimates: list[(int, float)]):
|
||||
LOG.info(f"gather_len_batches: {repr(estimates)}")
|
||||
return math.floor(0.998 * min(estimates))
|
||||
return math.floor(min(estimates))
|
||||
|
||||
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
|
||||
return min_len_batches
|
||||
|
||||
def __len__(self):
|
||||
if not self.len_across_ranks:
|
||||
len_batches = self.num_batches()
|
||||
len_batches = min(
|
||||
[self.num_batches() for _ in range(self.num_count_samples)]
|
||||
)
|
||||
self.len_across_ranks = self.gather_len_batches(len_batches)
|
||||
return self.len_across_ranks
|
||||
|
||||
def _len_est(self):
|
||||
efficiency = (
|
||||
self.packing_efficiency_estimate
|
||||
if self.packing_efficiency_estimate
|
||||
else self.gather_efficiency()
|
||||
)
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
lengths_sum_per_device = lengths_sum // world_size
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {efficiency} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
)
|
||||
|
||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||
return max(
|
||||
0,
|
||||
(
|
||||
world_size
|
||||
* math.floor(
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ efficiency
|
||||
// (self.batch_max_len * self.batch_size)
|
||||
)
|
||||
- 1
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user