diff --git a/requirements.txt b/requirements.txt index 9bf627619..6f1d40b24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,7 +43,7 @@ s3fs>=2024.5.0 gcsfs>=2024.5.0 # adlfs -trl==0.12.0 +trl @ git++https://github.com/huggingface/trl.git@5e90682836969310e16ed8aa711dd429f85863b7 zstandard==0.22.0 fastcore diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e6a7b48cd..78b90a141 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1926,16 +1926,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase): self.cfg.kto_undesirable_weight or 1.0 ) - training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["max_length"] = self.cfg.sequence_len if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len else: training_args_cls = AxolotlDPOConfig + + training_args_kwargs["max_length"] = self.cfg.sequence_len + + training_args_kwargs["max_target_length"] = None + if self.cfg.max_prompt_len is not None: + training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len + if self.cfg.dpo_use_weighting is not None: training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting + if self.cfg.rl == "ipo": + training_args_kwargs["loss_type"] = "ipo" + if self.cfg.dpo_label_smoothing: + training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing + + if self.cfg.precompute_ref_log_probs is not None: + training_args_kwargs["precompute_ref_log_probs"] = self.cfg.precompute_ref_log_probs + + training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb + training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg output_dir=self.cfg.output_dir, per_device_train_batch_size=self.cfg.micro_batch_size, @@ -1955,27 +1971,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase): def build(self, total_num_steps): training_args = self.build_training_arguments(total_num_steps) dpo_trainer_kwargs = {} - if self.cfg.rl == "ipo": - dpo_trainer_kwargs["loss_type"] = "ipo" - if self.cfg.dpo_label_smoothing: - dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing + if self.eval_dataset: dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset if self.cfg.adapter and self.peft_config: dpo_trainer_kwargs["peft_config"] = self.peft_config - if self.cfg.precompute_ref_log_probs is not None: - dpo_trainer_kwargs[ - "precompute_ref_log_probs" - ] = self.cfg.precompute_ref_log_probs + if self.cfg.rl in ["dpo", "ipo"]: trainer_cls = AxolotlDPOTrainer trainer_cls_args = [self.model, self.model_ref] - # these aren't used for the ORPO trainer - dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len - dpo_trainer_kwargs["max_target_length"] = None - dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len - dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model]