From 58e9dee2041268e76eca0b2d706f262daae3c0c0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 8 Aug 2023 03:49:29 -0400 Subject: [PATCH] fixes and go back to distributed sampler since batch sampler won't work --- src/axolotl/utils/data.py | 8 ++++---- src/axolotl/utils/trainer.py | 30 ++++++++++++------------------ 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index b2a147dab..3861a8c74 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -427,7 +427,7 @@ def load_prepare_datasets( + "|" + "train" + "|" - + cfg.seed + + str(cfg.seed or 42) ) to_hash_test = ( dataset._fingerprint # pylint: disable=protected-access @@ -436,7 +436,7 @@ def load_prepare_datasets( + "|" + "test" + "|" - + cfg.seed + + str(cfg.seed or 42) ) train_fingerprint = hashlib.md5( to_hash_train.encode(), usedforsecurity=False @@ -449,7 +449,7 @@ def load_prepare_datasets( dataset = dataset.train_test_split( test_size=cfg.val_set_size, shuffle=False, - seed=cfg.seed, + seed=cfg.seed or 42, train_new_fingerprint=train_fingerprint, test_new_fingerprint=test_fingerprint, ) @@ -458,7 +458,7 @@ def load_prepare_datasets( dataset = dataset.train_test_split( test_size=cfg.val_set_size, shuffle=False, - seed=cfg.seed, + seed=cfg.seed or 42, train_new_fingerprint=train_fingerprint, test_new_fingerprint=test_fingerprint, ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 1f7992b0a..f97618a73 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -16,7 +16,7 @@ import transformers from datasets import Dataset, set_caching_enabled from torch import nn from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names @@ -159,23 +159,20 @@ class AxolotlTrainer(Trainer): return super().create_scheduler(num_training_steps, optimizer) return self.lr_scheduler - # def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - # if self.args.world_size > 1 and self.args.sample_packing: - # return DistributedSampler( - # self.train_dataset, - # num_replicas=self.args.world_size, - # rank=self.args.process_index, - # seed=self.args.seed, - # ) - # return super()._get_train_sampler() + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.args.world_size > 1 and self.args.sample_packing: + return DistributedSampler( + self.train_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + seed=self.args.seed, + ) + return super()._get_train_sampler() def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: if self.args.sample_packing: train_sampler = self._get_train_sampler() - # If set to True, the dataloader prepared is only iterated through on the - # main process and then the batches are split and broadcast to each process - self.accelerator.dispatch_batches = True - return self.accelerator.prepare_data_loader( + return self.accelerator.prepare( MultipackDistributedDataloader( self.train_dataset, batch_size=self._train_batch_size, @@ -197,10 +194,7 @@ class AxolotlTrainer(Trainer): eval_dataset if eval_dataset is not None else self.eval_dataset ) eval_sampler = self._get_eval_sampler(eval_dataset) - # If set to True, the datalaoder prepared is only iterated through on the - # main process and then the batches are split and broadcast to each process - self.accelerator.dispatch_batches = True - return self.accelerator.prepare_data_loader( + return self.accelerator.prepare( MultipackDistributedDataloader( eval_dataset, batch_size=self.args.eval_batch_size,