add support for rpo_alpha (#1681)

* add support for rpo_alpha

* Add smoke test for dpo + nll loss
This commit is contained in:
Wing Lian
2024-06-04 16:09:51 -04:00
committed by GitHub
parent 1f151c0d52
commit c996881ec2
4 changed files with 58 additions and 3 deletions

View File

@@ -30,7 +30,7 @@ from transformers import (
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
@@ -238,6 +238,13 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""
@@ -1608,7 +1615,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_cls = AxolotlTrainingArguments
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes

View File

@@ -619,6 +619,7 @@ class AxolotlInputConfig(
neftune_noise_alpha: Optional[float] = None
orpo_alpha: Optional[float] = None
rpo_alpha: Optional[float] = None
kto_desirable_weight: Optional[float] = None
kto_undesirable_weight: Optional[float] = None