fix dpo config and add use_logits_to_keep
This commit is contained in:
@@ -1014,8 +1014,19 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
|
||||
else:
|
||||
training_args_cls = DPOConfig.get_training_args_class()
|
||||
training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg))
|
||||
training_args_cls = DPOConfig
|
||||
if self.cfg.rl == "ipo":
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs[
|
||||
"use_logits_to_keep"
|
||||
] = self.cfg.dpo_use_logits_to_keep
|
||||
|
||||
for blocklist_key in blocklist_args_kwargs:
|
||||
if blocklist_key in training_args_kwargs:
|
||||
|
||||
@@ -677,6 +677,7 @@ class AxolotlInputConfig(
|
||||
dpo_use_weighting: Optional[
|
||||
bool
|
||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||
dpo_use_logits_to_keep: Optional[bool] = None
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user