From b9bb169602550979cd82460ebbcfdb7c9364cd3c Mon Sep 17 00:00:00 2001 From: Ali Mosavian Date: Fri, 3 May 2024 17:02:59 +0200 Subject: [PATCH] FIX: TRL trainer preprocessing step was running in one process (#1583) * FIX: TRL trainer preprocessing step was running in one process * FIX: Changed so that dataset_num_proc is sent to CPO, KTO and ORPO trainer args and directly to the trainer when DPO * FIX: Changed back to only support ORPO for now, since KTO is handled in another way --------- Co-authored-by: Ali Mosavian --- src/axolotl/core/trainer_builder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 09651bdc9..bf18a287a 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1462,6 +1462,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_kwargs["eval_steps"] = self.cfg.eval_steps else: training_args_kwargs["evaluation_strategy"] = "no" + if self.cfg.bf16 or self.cfg.bfloat16: training_args_kwargs["bf16"] = True @@ -1520,6 +1521,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_cls = TrainingArguments if self.cfg.rl == "orpo": training_args_cls = ORPOConfig + training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args = training_args_cls( per_device_train_batch_size=self.cfg.micro_batch_size, @@ -1564,6 +1566,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs["max_target_length"] = None dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len dpo_trainer_kwargs["generate_during_eval"] = True + if self.cfg.rl == "dpo": + dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model]