fix sampler to prevent overfit w new epochs
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
|
import hashlib
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@@ -121,6 +122,17 @@ def chunk(iterable, n):
|
|||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def hash_indices(lst: List[int]) -> str:
|
||||||
|
# Convert the list of integers to a string representation
|
||||||
|
concatenated = ",".join(map(str, lst))
|
||||||
|
|
||||||
|
# Generate the hash
|
||||||
|
sha256 = hashlib.sha256()
|
||||||
|
sha256.update(concatenated.encode())
|
||||||
|
|
||||||
|
return sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
class MultipackDistributedDataloader:
|
class MultipackDistributedDataloader:
|
||||||
"""Unpadded data loading using Multipack.
|
"""Unpadded data loading using Multipack.
|
||||||
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
||||||
@@ -169,6 +181,7 @@ class MultipackDistributedDataloader:
|
|||||||
else:
|
else:
|
||||||
indices = range(0, len(self.dataset))
|
indices = range(0, len(self.dataset))
|
||||||
|
|
||||||
|
LOG.info(hash_indices(indices))
|
||||||
lengths = self.lengths[indices]
|
lengths = self.lengths[indices]
|
||||||
lengths_cumsum = np.cumsum(lengths)
|
lengths_cumsum = np.cumsum(lengths)
|
||||||
|
|
||||||
@@ -191,6 +204,10 @@ class MultipackDistributedDataloader:
|
|||||||
return batches, totseqs
|
return batches, totseqs
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
if hasattr(self.sampler, "set_epoch"):
|
||||||
|
new_epoch = self.sampler.epoch + 1
|
||||||
|
self.sampler.set_epoch(new_epoch)
|
||||||
|
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||||
all_batches, _ = self.generate_batches(set_stats=True)
|
all_batches, _ = self.generate_batches(set_stats=True)
|
||||||
features = self.dataset.features.keys()
|
features = self.dataset.features.keys()
|
||||||
len_remaining = self._len_est()
|
len_remaining = self._len_est()
|
||||||
|
|||||||
Reference in New Issue
Block a user