use DPOConfig
This commit is contained in:
@@ -39,6 +39,6 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@adf17a5a269a0bc59162597f81e3d489a8c144e5
|
trl @ git+https://github.com/huggingface/trl.git@7075cec94df1a0c5be90e75214e996efaf9a6c0b
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
from trl import DPOConfig, DPOTrainer, ORPOConfig, ORPOTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
@@ -1526,6 +1526,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_args_cls = ORPOConfig
|
training_args_cls = ORPOConfig
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]:
|
||||||
|
training_args_cls = DPOConfig
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
training_args = training_args_cls(
|
training_args = training_args_cls(
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user