use accelerator prepare for dataloader

This commit is contained in:
Wing Lian
2023-07-19 22:58:16 -04:00
parent 4ab9ab79fd
commit 2e295c9f94

View File

@@ -101,12 +101,14 @@ class AxolotlTrainer(Trainer):
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
)
)
return super().get_train_dataloader()