use accelerator prepare for dataloader
This commit is contained in:
@@ -101,12 +101,14 @@ class AxolotlTrainer(Trainer):
|
|||||||
if self.args.sample_packing:
|
if self.args.sample_packing:
|
||||||
train_sampler = self._get_train_sampler()
|
train_sampler = self._get_train_sampler()
|
||||||
|
|
||||||
return MultipackDistributedDataloader(
|
return self.accelerator.prepare(
|
||||||
self.train_dataset,
|
MultipackDistributedDataloader(
|
||||||
batch_size=self._train_batch_size,
|
self.train_dataset,
|
||||||
seq_max_length=self.args.max_seq_length,
|
batch_size=self._train_batch_size,
|
||||||
collate_fn=self.data_collator,
|
seq_max_length=self.args.max_seq_length,
|
||||||
sampler=train_sampler,
|
collate_fn=self.data_collator,
|
||||||
|
sampler=train_sampler,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user