From f58fcd09ec84a706947ad77d5f49ee72c4c80460 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 2 May 2024 13:44:26 -0400 Subject: [PATCH] use DPOConfig --- requirements.txt | 2 +- src/axolotl/core/trainer_builder.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 19eb13d19..4ec2aec89 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,6 +39,6 @@ s3fs gcsfs # adlfs -trl @ git+https://github.com/huggingface/trl.git@adf17a5a269a0bc59162597f81e3d489a8c144e5 +trl @ git+https://github.com/huggingface/trl.git@7075cec94df1a0c5be90e75214e996efaf9a6c0b zstandard==0.22.0 fastcore diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 55eecf839..576c303b7 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,7 +30,7 @@ from transformers import ( ) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOTrainer, ORPOConfig, ORPOTrainer +from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer from trl.trainer.utils import pad_to_length from axolotl.loraplus import create_loraplus_optimizer @@ -1526,6 +1526,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl == "orpo": training_args_cls = ORPOConfig training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes + elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]: + training_args_cls = DPOConfig + training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args = training_args_cls( per_device_train_batch_size=self.cfg.micro_batch_size,