fix: optimizer cls not being popped

This commit is contained in:
NanoCode012
2025-05-22 18:07:07 +07:00
parent c6e730df64
commit 66d4319d80
2 changed files with 11 additions and 15 deletions

View File

@@ -175,12 +175,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
# Pop optimizer_cls_and_kwargs to trainer_kwargs
if hasattr(training_args, "optimizer_cls_and_kwargs"):
trainer_kwargs["optimizer_cls_and_kwargs"] = getattr(
training_args, "optimizer_cls_and_kwargs"
if "optimizer_cls_and_kwargs" in training_args:
trainer_kwargs["optimizer_cls_and_kwargs"] = training_args.pop(
"optimizer_cls_and_kwargs"
)
# prevent duplication downstream
delattr(training_args, "optimizer_cls_and_kwargs")
if self.cfg.rl is RLType.IPO:
if self.cfg.dpo_label_smoothing:

View File

@@ -145,6 +145,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
def build(self, total_num_steps):
training_arguments_kwargs = self._set_base_training_args(total_num_steps)
trainer_kwargs = {}
# Pop optimizer_cls_and_kwargs to trainer_kwargs
if "optimizer_cls_and_kwargs" in trainer_kwargs:
trainer_kwargs["optimizer_cls_and_kwargs"] = training_arguments_kwargs.pop(
"optimizer_cls_and_kwargs"
)
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
@@ -323,16 +331,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.kd_top_k_before_softmax
)
trainer_kwargs = {}
# Pop optimizer_cls_and_kwargs to trainer_kwargs
if hasattr(training_arguments_kwargs, "optimizer_cls_and_kwargs"):
trainer_kwargs["optimizer_cls_and_kwargs"] = getattr(
training_arguments_kwargs, "optimizer_cls_and_kwargs"
)
# prevent duplication downstream
delattr(training_arguments_kwargs, "optimizer_cls_and_kwargs")
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model: