diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 0f9f6e4c4..86f125852 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -272,6 +272,20 @@ class AxolotlTrainer( num_workers=self.args.dataloader_num_workers, rank=self.args.process_index, ) + + if ( + self.args.accelerator_config is not None + and self.args.accelerator_config.split_batches + and self.args.accelerator_config.dispatch_batches + ): + if self.args.sample_packing and self.args.pretraining: + if not self.args.eval_sample_packing and not is_training: + dataloader_params["batch_size"] *= self.accelerator.num_processes + else: + dataloader_params["batch_size"] = self.accelerator.num_processes + elif not self.args.sample_packing and self.args.pretraining: + dataloader_params["batch_size"] *= self.accelerator.num_processes + if self.args.sample_packing and ( (is_training and not self.args.pretraining) or (not is_training and self.args.eval_sample_packing is not False)