fix: address pr feedback

This commit is contained in:
NanoCode012
2025-05-16 14:36:38 +07:00
parent 00bfdb6b2b
commit 49888eccb9
2 changed files with 10 additions and 6 deletions

View File

@@ -175,10 +175,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
# Pop optimizer_cls_and_kwargs to trainer_kwargs
if "optimizer_cls_and_kwargs" in training_args:
trainer_kwargs["optimizer_cls_and_kwargs"] = training_args.pop(
"optimizer_cls_and_kwargs"
if hasattr(training_args, "optimizer_cls_and_kwargs"):
trainer_kwargs["optimizer_cls_and_kwargs"] = getattr(
training_args, "optimizer_cls_and_kwargs"
)
# prevent duplication downstream
delattr(training_args, "optimizer_cls_and_kwargs")
if self.cfg.rl is RLType.IPO:
if self.cfg.dpo_label_smoothing:

View File

@@ -346,10 +346,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
# Pop optimizer_cls_and_kwargs to trainer_kwargs
if "optimizer_cls_and_kwargs" in training_arguments_kwargs:
trainer_kwargs["optimizer_cls_and_kwargs"] = training_arguments_kwargs.pop(
"optimizer_cls_and_kwargs"
if hasattr(training_arguments_kwargs, "optimizer_cls_and_kwargs"):
trainer_kwargs["optimizer_cls_and_kwargs"] = getattr(
training_arguments_kwargs, "optimizer_cls_and_kwargs"
)
# prevent duplication downstream
delattr(training_arguments_kwargs, "optimizer_cls_and_kwargs")
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig