fix: optimizer cls not being popped
This commit is contained in:
@@ -175,12 +175,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
||||||
if hasattr(training_args, "optimizer_cls_and_kwargs"):
|
if "optimizer_cls_and_kwargs" in training_args:
|
||||||
trainer_kwargs["optimizer_cls_and_kwargs"] = getattr(
|
trainer_kwargs["optimizer_cls_and_kwargs"] = training_args.pop(
|
||||||
training_args, "optimizer_cls_and_kwargs"
|
"optimizer_cls_and_kwargs"
|
||||||
)
|
)
|
||||||
# prevent duplication downstream
|
|
||||||
delattr(training_args, "optimizer_cls_and_kwargs")
|
|
||||||
|
|
||||||
if self.cfg.rl is RLType.IPO:
|
if self.cfg.rl is RLType.IPO:
|
||||||
if self.cfg.dpo_label_smoothing:
|
if self.cfg.dpo_label_smoothing:
|
||||||
|
|||||||
@@ -145,6 +145,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
training_arguments_kwargs = self._set_base_training_args(total_num_steps)
|
training_arguments_kwargs = self._set_base_training_args(total_num_steps)
|
||||||
|
|
||||||
|
trainer_kwargs = {}
|
||||||
|
|
||||||
|
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
||||||
|
if "optimizer_cls_and_kwargs" in trainer_kwargs:
|
||||||
|
trainer_kwargs["optimizer_cls_and_kwargs"] = training_arguments_kwargs.pop(
|
||||||
|
"optimizer_cls_and_kwargs"
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
@@ -323,16 +331,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kd_top_k_before_softmax
|
self.cfg.kd_top_k_before_softmax
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer_kwargs = {}
|
|
||||||
|
|
||||||
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
|
||||||
if hasattr(training_arguments_kwargs, "optimizer_cls_and_kwargs"):
|
|
||||||
trainer_kwargs["optimizer_cls_and_kwargs"] = getattr(
|
|
||||||
training_arguments_kwargs, "optimizer_cls_and_kwargs"
|
|
||||||
)
|
|
||||||
# prevent duplication downstream
|
|
||||||
delattr(training_arguments_kwargs, "optimizer_cls_and_kwargs")
|
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
elif self.cfg.process_reward_model:
|
elif self.cfg.process_reward_model:
|
||||||
|
|||||||
Reference in New Issue
Block a user