move deprecated kwargs from trainer to trainingargs (#2028)
This commit is contained in:
@@ -1933,6 +1933,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = AxolotlDPOConfig
|
||||||
|
if self.cfg.rl == "ipo":
|
||||||
|
training_args_kwargs["loss_type"] = "ipo"
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
training_args_kwargs["max_completion_length"] = None
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
|
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||||
if self.cfg.dpo_use_weighting is not None:
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
|
|
||||||
@@ -1956,7 +1962,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args = self.build_training_arguments(total_num_steps)
|
training_args = self.build_training_arguments(total_num_steps)
|
||||||
dpo_trainer_kwargs = {}
|
dpo_trainer_kwargs = {}
|
||||||
if self.cfg.rl == "ipo":
|
if self.cfg.rl == "ipo":
|
||||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
|
||||||
if self.cfg.dpo_label_smoothing:
|
if self.cfg.dpo_label_smoothing:
|
||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
@@ -1970,12 +1975,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl in ["dpo", "ipo"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
# these aren't used for the ORPO trainer
|
|
||||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
|
||||||
dpo_trainer_kwargs["max_target_length"] = None
|
|
||||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
|
||||||
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
|||||||
Reference in New Issue
Block a user