handle batch size correchtly when using split and dispatch batches

This commit is contained in:
Wing Lian
2025-08-16 22:05:31 -04:00
parent ecbe8b2b61
commit 7fd3d8abc4

View File

@@ -272,6 +272,20 @@ class AxolotlTrainer(
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
if (
self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes
if self.args.sample_packing and (
(is_training and not self.args.pretraining)
or (not is_training and self.args.eval_sample_packing is not False)