use DPOConfig

This commit is contained in:
Wing Lian
2024-05-02 13:44:26 -04:00
parent 60fecac367
commit f58fcd09ec
2 changed files with 5 additions and 2 deletions

View File

@@ -39,6 +39,6 @@ s3fs
gcsfs
# adlfs
trl @ git+https://github.com/huggingface/trl.git@adf17a5a269a0bc59162597f81e3d489a8c144e5
trl @ git+https://github.com/huggingface/trl.git@7075cec94df1a0c5be90e75214e996efaf9a6c0b
zstandard==0.22.0
fastcore

View File

@@ -30,7 +30,7 @@ from transformers import (
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
@@ -1526,6 +1526,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]:
training_args_cls = DPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args = training_args_cls(
per_device_train_batch_size=self.cfg.micro_batch_size,