better handling of multipack dataset length (#2296)
This commit is contained in:
@@ -4,7 +4,6 @@ Multipack Batch Sampler
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
from typing import Any, Iterable, List, Union
|
from typing import Any, Iterable, List, Union
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
@@ -117,6 +116,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
lengths: np.ndarray,
|
lengths: np.ndarray,
|
||||||
packing_efficiency_estimate: float = 1.0,
|
packing_efficiency_estimate: float = 1.0,
|
||||||
drop_last: bool = False,
|
drop_last: bool = False,
|
||||||
|
num_count_samples: int = 16,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(sampler, batch_size, drop_last)
|
super().__init__(sampler, batch_size, drop_last)
|
||||||
@@ -133,6 +133,9 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
self.eff_total_used = 0
|
self.eff_total_used = 0
|
||||||
self.eff_total_slots = 0
|
self.eff_total_slots = 0
|
||||||
|
|
||||||
|
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
|
||||||
|
self.num_count_samples = num_count_samples
|
||||||
|
# the minimum packed dataset length across all ranks determined by a gather/broadcast
|
||||||
self.len_across_ranks = None
|
self.len_across_ranks = None
|
||||||
|
|
||||||
def set_epoch(self, epoch: int):
|
def set_epoch(self, epoch: int):
|
||||||
@@ -169,6 +172,9 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
batches = self.generate_batches(set_stats=True)
|
batches = self.generate_batches(set_stats=True)
|
||||||
|
if self.len_across_ranks:
|
||||||
|
# make sure the batches we iterate over is truncated to the same min length across all ranks
|
||||||
|
batches = batches[: self.len_across_ranks]
|
||||||
return iter(batches)
|
return iter(batches)
|
||||||
|
|
||||||
def num_batches(self):
|
def num_batches(self):
|
||||||
@@ -195,42 +201,15 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
def gather_len_batches(self, num):
|
def gather_len_batches(self, num):
|
||||||
def calc_min_len(estimates: list[(int, float)]):
|
def calc_min_len(estimates: list[(int, float)]):
|
||||||
LOG.info(f"gather_len_batches: {repr(estimates)}")
|
LOG.info(f"gather_len_batches: {repr(estimates)}")
|
||||||
return math.floor(0.998 * min(estimates))
|
return math.floor(min(estimates))
|
||||||
|
|
||||||
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
|
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
|
||||||
return min_len_batches
|
return min_len_batches
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if not self.len_across_ranks:
|
if not self.len_across_ranks:
|
||||||
len_batches = self.num_batches()
|
len_batches = min(
|
||||||
|
[self.num_batches() for _ in range(self.num_count_samples)]
|
||||||
|
)
|
||||||
self.len_across_ranks = self.gather_len_batches(len_batches)
|
self.len_across_ranks = self.gather_len_batches(len_batches)
|
||||||
return self.len_across_ranks
|
return self.len_across_ranks
|
||||||
|
|
||||||
def _len_est(self):
|
|
||||||
efficiency = (
|
|
||||||
self.packing_efficiency_estimate
|
|
||||||
if self.packing_efficiency_estimate
|
|
||||||
else self.gather_efficiency()
|
|
||||||
)
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
||||||
lengths_sum = np.sum(self.lengths)
|
|
||||||
lengths_sum_per_device = lengths_sum // world_size
|
|
||||||
LOG.info(
|
|
||||||
f"packing_efficiency_estimate: {efficiency} "
|
|
||||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
|
||||||
return max(
|
|
||||||
0,
|
|
||||||
(
|
|
||||||
world_size
|
|
||||||
* math.floor(
|
|
||||||
0.99
|
|
||||||
* lengths_sum_per_device
|
|
||||||
/ efficiency
|
|
||||||
// (self.batch_max_len * self.batch_size)
|
|
||||||
)
|
|
||||||
- 1
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user