From 345a9dd8313c9af27783a81e53dea627d457a2a4 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 13 Mar 2025 23:05:27 +0000 Subject: [PATCH] removing some obvious comments --- src/axolotl/core/trainers/base.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 63195a55a..05a6f92f0 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -551,25 +551,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): if self.args.dataloader_prefetch_factor: dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - # Use the same sampling logic for all modes, including sequence parallelism if not isinstance(train_dataset, torch.utils.data.IterableDataset): sampler = self._get_train_sampler() if isinstance(sampler, BatchSampler): - dataloader_params["batch_sampler"] = sampler # batch_size and batch_sampler are mutually exclusive - if "batch_size" in dataloader_params: - del dataloader_params["batch_size"] + dataloader_params["batch_sampler"] = sampler + del dataloader_params["batch_size"] else: dataloader_params["sampler"] = sampler dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["worker_init_fn"] = seed_worker - # Create dataloader dataloader = DataLoader(train_dataset, **dataloader_params) - - # Sample packing with accelerator preparation if self.args.sample_packing and not self.args.pretraining: self.accelerator.even_batches = False @@ -578,7 +573,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): if self.args.sequence_parallel_size > 1: return dataloader - # Prepare dataloader for accelerate distributed training return self.accelerator.prepare_data_loader(dataloader) def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: @@ -627,17 +621,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): dataloader_params["drop_last"] = self.args.dataloader_drop_last self.accelerator.even_batches = False - - # Create dataloader dataloader = DataLoader(eval_dataset, **dataloader_params) # Don't prepare dataloader for sequence parallelism + # We use a distributed sampler in this case if self.args.sequence_parallel_size > 1: return dataloader return self.accelerator.prepare_data_loader(dataloader) if self.args.sequence_parallel_size > 1: - # We need to customize the default dataloader for sequence parallelism eval_dataset = ( eval_dataset if eval_dataset is not None else self.eval_dataset ) @@ -669,7 +661,8 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): sampler = self._get_eval_sampler(eval_dataset) dataloader_params["sampler"] = sampler - # Create dataloader without accelerator preparation for sequence parallelism + # Don't prepare dataloader for sequence parallelism + # We use a distributed sampler in this case return DataLoader(eval_dataset, **dataloader_params) return super().get_eval_dataloader(eval_dataset)