diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 09651bdc9..bf18a287a 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1462,6 +1462,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_kwargs["eval_steps"] = self.cfg.eval_steps else: training_args_kwargs["evaluation_strategy"] = "no" + if self.cfg.bf16 or self.cfg.bfloat16: training_args_kwargs["bf16"] = True @@ -1520,6 +1521,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_cls = TrainingArguments if self.cfg.rl == "orpo": training_args_cls = ORPOConfig + training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args = training_args_cls( per_device_train_batch_size=self.cfg.micro_batch_size, @@ -1564,6 +1566,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs["max_target_length"] = None dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len dpo_trainer_kwargs["generate_during_eval"] = True + if self.cfg.rl == "dpo": + dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model]