memoize dataset length for eval sample packing (#1974)
* wip on multimodal sample packing support * wip on multimodal packing support * llama-1b-yml * setup logging for test * yml * yml * yml * fix for __len__ for eval sample packing * reverted irrelavant changes * reformatted, reverted log message * reverted unnecessary changes * added e2e multigpu testing for eval sample packing * formatting * fixed e2e test_eval params * fix test_eval e2e multigpu * fix test_eval e2e multigpu * Update tests/e2e/multigpu/test_eval.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * Update tests/e2e/multigpu/test_eval.py Co-authored-by: Wing Lian <wing.lian@gmail.com> --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -133,6 +133,8 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
|
||||
self.len_across_ranks = None
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
self.epoch = epoch
|
||||
|
||||
@@ -195,15 +197,14 @@ class MultipackBatchSampler(BatchSampler):
|
||||
LOG.info(f"gather_len_batches: {repr(estimates)}")
|
||||
return math.floor(0.998 * min(estimates))
|
||||
|
||||
min_len_batches = reduce_and_broadcast(
|
||||
lambda: num,
|
||||
calc_min_len,
|
||||
)
|
||||
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
|
||||
return min_len_batches
|
||||
|
||||
def __len__(self):
|
||||
len_batches = self.num_batches()
|
||||
return self.gather_len_batches(len_batches)
|
||||
if not self.len_across_ranks:
|
||||
len_batches = self.num_batches()
|
||||
self.len_across_ranks = self.gather_len_batches(len_batches)
|
||||
return self.len_across_ranks
|
||||
|
||||
def _len_est(self):
|
||||
efficiency = (
|
||||
|
||||
Reference in New Issue
Block a user