From d683c50113d4d00b43446cebc5092538470192ef Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 4 Feb 2025 09:08:08 -0500 Subject: [PATCH] fix config cls --- src/axolotl/core/trainer_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d0aad6fdc..0891ee96a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -35,7 +35,6 @@ from transformers import ( EarlyStoppingCallback, TrainerCallback, ) -from trl import DPOConfig from trl.trainer.utils import RewardDataCollatorWithPadding from axolotl.core.trainers.base import ( @@ -49,6 +48,7 @@ from axolotl.core.trainers.base import ( ReLoRATrainer, ) from axolotl.core.trainers.dpo import DPOStrategy +from axolotl.core.trainers.dpo.args import AxolotlDPOConfig from axolotl.core.trainers.grpo import GRPOStrategy from axolotl.core.training_args import ( AxolotlCPOConfig, @@ -1014,7 +1014,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() else: - training_args_cls = DPOConfig + training_args_cls = AxolotlDPOConfig if self.cfg.rl == "ipo": training_args_kwargs["loss_type"] = "ipo" training_args_kwargs["max_length"] = self.cfg.sequence_len