fix dpo config and add use_logits_to_keep

This commit is contained in:
Wing Lian
2025-02-04 08:39:37 -05:00
parent ac1ebc58a8
commit f9893e3842
2 changed files with 14 additions and 2 deletions

View File

@@ -1014,8 +1014,19 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else:
training_args_cls = DPOConfig.get_training_args_class()
training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg))
training_args_cls = DPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs[
"use_logits_to_keep"
] = self.cfg.dpo_use_logits_to_keep
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:

View File

@@ -677,6 +677,7 @@ class AxolotlInputConfig(
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
dpo_use_logits_to_keep: Optional[bool] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore