memoize dataset length for eval sample packing (#1974)

* wip on multimodal sample packing support

* wip on multimodal packing support

* llama-1b-yml

* setup logging for test

* yml

* yml

* yml

* fix for __len__ for eval sample packing

* reverted irrelavant changes

* reformatted, reverted log message

* reverted unnecessary changes

* added e2e multigpu testing for eval sample packing

* formatting

* fixed e2e test_eval params

* fix test_eval e2e multigpu

* fix test_eval e2e multigpu

* Update tests/e2e/multigpu/test_eval.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update tests/e2e/multigpu/test_eval.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Sunny Liu
2024-10-17 15:15:29 -04:00
committed by GitHub
parent 54673fd6ca
commit f62e23737b
3 changed files with 239 additions and 6 deletions

View File

@@ -133,6 +133,8 @@ class MultipackBatchSampler(BatchSampler):
self.eff_total_used = 0
self.eff_total_slots = 0
self.len_across_ranks = None
def set_epoch(self, epoch: int):
self.epoch = epoch
@@ -195,15 +197,14 @@ class MultipackBatchSampler(BatchSampler):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates))
min_len_batches = reduce_and_broadcast(
lambda: num,
calc_min_len,
)
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches
def __len__(self):
len_batches = self.num_batches()
return self.gather_len_batches(len_batches)
if not self.len_across_ranks:
len_batches = self.num_batches()
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks
def _len_est(self):
efficiency = (