move deprecated kwargs from trainer to trainingargs (#2028)

This commit is contained in:
Wing Lian
2024-11-10 12:45:47 -05:00
committed by GitHub
parent e4af51eb66
commit d356740ffa

View File

@@ -1933,6 +1933,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
@@ -1956,7 +1962,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset:
@@ -1970,12 +1975,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]