From 26983a1974ac5313729d0643732dbc115dfb0b4a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 8 Aug 2023 15:34:18 -0400 Subject: [PATCH] fix sampler to prevent overfit w new epochs --- src/axolotl/utils/dataloader.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index 78ad6c150..83e91cee5 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -1,4 +1,5 @@ # pylint: skip-file +import hashlib import itertools import logging import math @@ -121,6 +122,17 @@ def chunk(iterable, n): yield batch +def hash_indices(lst: List[int]) -> str: + # Convert the list of integers to a string representation + concatenated = ",".join(map(str, lst)) + + # Generate the hash + sha256 = hashlib.sha256() + sha256.update(concatenated.encode()) + + return sha256.hexdigest() + + class MultipackDistributedDataloader: """Unpadded data loading using Multipack. Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py @@ -169,6 +181,7 @@ class MultipackDistributedDataloader: else: indices = range(0, len(self.dataset)) + LOG.info(hash_indices(indices)) lengths = self.lengths[indices] lengths_cumsum = np.cumsum(lengths) @@ -191,6 +204,10 @@ class MultipackDistributedDataloader: return batches, totseqs def __iter__(self): + if hasattr(self.sampler, "set_epoch"): + new_epoch = self.sampler.epoch + 1 + self.sampler.set_epoch(new_epoch) + LOG.info(f"calling sampler.set_epoch({new_epoch})") all_batches, _ = self.generate_batches(set_stats=True) features = self.dataset.features.keys() len_remaining = self._len_est()