removing some obvious comments
This commit is contained in:
@@ -551,25 +551,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
if self.args.dataloader_prefetch_factor:
|
if self.args.dataloader_prefetch_factor:
|
||||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
# Use the same sampling logic for all modes, including sequence parallelism
|
|
||||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||||
sampler = self._get_train_sampler()
|
sampler = self._get_train_sampler()
|
||||||
|
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
dataloader_params["batch_sampler"] = sampler
|
|
||||||
# batch_size and batch_sampler are mutually exclusive
|
# batch_size and batch_sampler are mutually exclusive
|
||||||
if "batch_size" in dataloader_params:
|
dataloader_params["batch_sampler"] = sampler
|
||||||
del dataloader_params["batch_size"]
|
del dataloader_params["batch_size"]
|
||||||
else:
|
else:
|
||||||
dataloader_params["sampler"] = sampler
|
dataloader_params["sampler"] = sampler
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
# Create dataloader
|
|
||||||
dataloader = DataLoader(train_dataset, **dataloader_params)
|
dataloader = DataLoader(train_dataset, **dataloader_params)
|
||||||
|
|
||||||
# Sample packing with accelerator preparation
|
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
@@ -578,7 +573,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
if self.args.sequence_parallel_size > 1:
|
if self.args.sequence_parallel_size > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
# Prepare dataloader for accelerate distributed training
|
|
||||||
return self.accelerator.prepare_data_loader(dataloader)
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||||
@@ -627,17 +621,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
# Create dataloader
|
|
||||||
dataloader = DataLoader(eval_dataset, **dataloader_params)
|
dataloader = DataLoader(eval_dataset, **dataloader_params)
|
||||||
|
|
||||||
# Don't prepare dataloader for sequence parallelism
|
# Don't prepare dataloader for sequence parallelism
|
||||||
|
# We use a distributed sampler in this case
|
||||||
if self.args.sequence_parallel_size > 1:
|
if self.args.sequence_parallel_size > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
return self.accelerator.prepare_data_loader(dataloader)
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
if self.args.sequence_parallel_size > 1:
|
if self.args.sequence_parallel_size > 1:
|
||||||
# We need to customize the default dataloader for sequence parallelism
|
|
||||||
eval_dataset = (
|
eval_dataset = (
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
)
|
)
|
||||||
@@ -669,7 +661,8 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
sampler = self._get_eval_sampler(eval_dataset)
|
sampler = self._get_eval_sampler(eval_dataset)
|
||||||
dataloader_params["sampler"] = sampler
|
dataloader_params["sampler"] = sampler
|
||||||
|
|
||||||
# Create dataloader without accelerator preparation for sequence parallelism
|
# Don't prepare dataloader for sequence parallelism
|
||||||
|
# We use a distributed sampler in this case
|
||||||
return DataLoader(eval_dataset, **dataloader_params)
|
return DataLoader(eval_dataset, **dataloader_params)
|
||||||
|
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|||||||
Reference in New Issue
Block a user