handle batch size correchtly when using split and dispatch batches
This commit is contained in:
@@ -272,6 +272,20 @@ class AxolotlTrainer(
|
|||||||
num_workers=self.args.dataloader_num_workers,
|
num_workers=self.args.dataloader_num_workers,
|
||||||
rank=self.args.process_index,
|
rank=self.args.process_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.args.accelerator_config is not None
|
||||||
|
and self.args.accelerator_config.split_batches
|
||||||
|
and self.args.accelerator_config.dispatch_batches
|
||||||
|
):
|
||||||
|
if self.args.sample_packing and self.args.pretraining:
|
||||||
|
if not self.args.eval_sample_packing and not is_training:
|
||||||
|
dataloader_params["batch_size"] *= self.accelerator.num_processes
|
||||||
|
else:
|
||||||
|
dataloader_params["batch_size"] = self.accelerator.num_processes
|
||||||
|
elif not self.args.sample_packing and self.args.pretraining:
|
||||||
|
dataloader_params["batch_size"] *= self.accelerator.num_processes
|
||||||
|
|
||||||
if self.args.sample_packing and (
|
if self.args.sample_packing and (
|
||||||
(is_training and not self.args.pretraining)
|
(is_training and not self.args.pretraining)
|
||||||
or (not is_training and self.args.eval_sample_packing is not False)
|
or (not is_training and self.args.eval_sample_packing is not False)
|
||||||
|
|||||||
Reference in New Issue
Block a user