fix config cls

This commit is contained in:
Wing Lian
2025-02-04 09:08:08 -05:00
parent 234cd8311e
commit d683c50113

View File

@@ -35,7 +35,6 @@ from transformers import (
EarlyStoppingCallback,
TrainerCallback,
)
from trl import DPOConfig
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
@@ -49,6 +48,7 @@ from axolotl.core.trainers.base import (
ReLoRATrainer,
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.training_args import (
AxolotlCPOConfig,
@@ -1014,7 +1014,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else:
training_args_cls = DPOConfig
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len