Compare commits
2 Commits
version-de
...
split-batc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb65157dcf | ||
|
|
7fd3d8abc4 |
@@ -424,7 +424,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if (
|
if (
|
||||||
self.cfg.pretraining_sample_concatenation is False
|
not self.cfg.pretraining_sample_concatenation
|
||||||
or self.cfg.micro_batch_size > 1
|
or self.cfg.micro_batch_size > 1
|
||||||
):
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
|
|||||||
@@ -272,6 +272,20 @@ class AxolotlTrainer(
|
|||||||
num_workers=self.args.dataloader_num_workers,
|
num_workers=self.args.dataloader_num_workers,
|
||||||
rank=self.args.process_index,
|
rank=self.args.process_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.args.accelerator_config is not None
|
||||||
|
and self.args.accelerator_config.split_batches
|
||||||
|
and self.args.accelerator_config.dispatch_batches
|
||||||
|
):
|
||||||
|
if self.args.sample_packing and self.args.pretraining:
|
||||||
|
if not self.args.eval_sample_packing and not is_training:
|
||||||
|
dataloader_params["batch_size"] *= self.accelerator.num_processes
|
||||||
|
else:
|
||||||
|
dataloader_params["batch_size"] = self.accelerator.num_processes
|
||||||
|
elif not self.args.sample_packing and self.args.pretraining:
|
||||||
|
dataloader_params["batch_size"] *= self.accelerator.num_processes
|
||||||
|
|
||||||
if self.args.sample_packing and (
|
if self.args.sample_packing and (
|
||||||
(is_training and not self.args.pretraining)
|
(is_training and not self.args.pretraining)
|
||||||
or (not is_training and self.args.eval_sample_packing is not False)
|
or (not is_training and self.args.eval_sample_packing is not False)
|
||||||
|
|||||||
0
src/axolotl/exception_handling.py
Normal file
0
src/axolotl/exception_handling.py
Normal file
Reference in New Issue
Block a user