fix config cls

This commit is contained in:
Wing Lian
2025-02-04 09:08:08 -05:00
parent 234cd8311e
commit d683c50113

View File

@@ -35,7 +35,6 @@ from transformers import (
EarlyStoppingCallback, EarlyStoppingCallback,
TrainerCallback, TrainerCallback,
) )
from trl import DPOConfig
from trl.trainer.utils import RewardDataCollatorWithPadding from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import ( from axolotl.core.trainers.base import (
@@ -49,6 +48,7 @@ from axolotl.core.trainers.base import (
ReLoRATrainer, ReLoRATrainer,
) )
from axolotl.core.trainers.dpo import DPOStrategy from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.training_args import ( from axolotl.core.training_args import (
AxolotlCPOConfig, AxolotlCPOConfig,
@@ -1014,7 +1014,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else: else:
training_args_cls = DPOConfig training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo": if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo" training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_length"] = self.cfg.sequence_len