From f9893e3842a71176d02ec6bca45e835c72a13398 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 4 Feb 2025 08:39:37 -0500 Subject: [PATCH] fix dpo config and add use_logits_to_keep --- src/axolotl/core/trainer_builder.py | 15 +++++++++++++-- .../utils/config/models/input/v0_4_1/__init__.py | 1 + 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 2a7f9ca21..d0aad6fdc 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1014,8 +1014,19 @@ class HFRLTrainerBuilder(TrainerBuilderBase): blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() else: - training_args_cls = DPOConfig.get_training_args_class() - training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg)) + training_args_cls = DPOConfig + if self.cfg.rl == "ipo": + training_args_kwargs["loss_type"] = "ipo" + training_args_kwargs["max_length"] = self.cfg.sequence_len + training_args_kwargs["max_completion_length"] = None + training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len + training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb + if self.cfg.dpo_use_weighting is not None: + training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting + if self.cfg.dpo_use_logits_to_keep is not None: + training_args_kwargs[ + "use_logits_to_keep" + ] = self.cfg.dpo_use_logits_to_keep for blocklist_key in blocklist_args_kwargs: if blocklist_key in training_args_kwargs: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index d9df16943..b09d89bd5 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -677,6 +677,7 @@ class AxolotlInputConfig( dpo_use_weighting: Optional[ bool ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. + dpo_use_logits_to_keep: Optional[bool] = None datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore