fix sampler to prevent overfit w new epochs

This commit is contained in:
Wing Lian
2023-08-08 15:34:18 -04:00
parent 1b8747e319
commit 26983a1974

View File

@@ -1,4 +1,5 @@
# pylint: skip-file
import hashlib
import itertools
import logging
import math
@@ -121,6 +122,17 @@ def chunk(iterable, n):
yield batch
def hash_indices(lst: List[int]) -> str:
# Convert the list of integers to a string representation
concatenated = ",".join(map(str, lst))
# Generate the hash
sha256 = hashlib.sha256()
sha256.update(concatenated.encode())
return sha256.hexdigest()
class MultipackDistributedDataloader:
"""Unpadded data loading using Multipack.
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
@@ -169,6 +181,7 @@ class MultipackDistributedDataloader:
else:
indices = range(0, len(self.dataset))
LOG.info(hash_indices(indices))
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
@@ -191,6 +204,10 @@ class MultipackDistributedDataloader:
return batches, totseqs
def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys()
len_remaining = self._len_est()