fix: address pr feedback
This commit is contained in:
@@ -175,10 +175,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
||||||
if "optimizer_cls_and_kwargs" in training_args:
|
if hasattr(training_args, "optimizer_cls_and_kwargs"):
|
||||||
trainer_kwargs["optimizer_cls_and_kwargs"] = training_args.pop(
|
trainer_kwargs["optimizer_cls_and_kwargs"] = getattr(
|
||||||
"optimizer_cls_and_kwargs"
|
training_args, "optimizer_cls_and_kwargs"
|
||||||
)
|
)
|
||||||
|
# prevent duplication downstream
|
||||||
|
delattr(training_args, "optimizer_cls_and_kwargs")
|
||||||
|
|
||||||
if self.cfg.rl is RLType.IPO:
|
if self.cfg.rl is RLType.IPO:
|
||||||
if self.cfg.dpo_label_smoothing:
|
if self.cfg.dpo_label_smoothing:
|
||||||
|
|||||||
@@ -346,10 +346,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
||||||
if "optimizer_cls_and_kwargs" in training_arguments_kwargs:
|
if hasattr(training_arguments_kwargs, "optimizer_cls_and_kwargs"):
|
||||||
trainer_kwargs["optimizer_cls_and_kwargs"] = training_arguments_kwargs.pop(
|
trainer_kwargs["optimizer_cls_and_kwargs"] = getattr(
|
||||||
"optimizer_cls_and_kwargs"
|
training_arguments_kwargs, "optimizer_cls_and_kwargs"
|
||||||
)
|
)
|
||||||
|
# prevent duplication downstream
|
||||||
|
delattr(training_arguments_kwargs, "optimizer_cls_and_kwargs")
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
|
|||||||
Reference in New Issue
Block a user