fix: address pr feedback
This commit is contained in:
@@ -175,10 +175,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs = {}
|
||||
|
||||
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
||||
if "optimizer_cls_and_kwargs" in training_args:
|
||||
trainer_kwargs["optimizer_cls_and_kwargs"] = training_args.pop(
|
||||
"optimizer_cls_and_kwargs"
|
||||
if hasattr(training_args, "optimizer_cls_and_kwargs"):
|
||||
trainer_kwargs["optimizer_cls_and_kwargs"] = getattr(
|
||||
training_args, "optimizer_cls_and_kwargs"
|
||||
)
|
||||
# prevent duplication downstream
|
||||
delattr(training_args, "optimizer_cls_and_kwargs")
|
||||
|
||||
if self.cfg.rl is RLType.IPO:
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
|
||||
@@ -346,10 +346,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs = {}
|
||||
|
||||
# Pop optimizer_cls_and_kwargs to trainer_kwargs
|
||||
if "optimizer_cls_and_kwargs" in training_arguments_kwargs:
|
||||
trainer_kwargs["optimizer_cls_and_kwargs"] = training_arguments_kwargs.pop(
|
||||
"optimizer_cls_and_kwargs"
|
||||
if hasattr(training_arguments_kwargs, "optimizer_cls_and_kwargs"):
|
||||
trainer_kwargs["optimizer_cls_and_kwargs"] = getattr(
|
||||
training_arguments_kwargs, "optimizer_cls_and_kwargs"
|
||||
)
|
||||
# prevent duplication downstream
|
||||
delattr(training_arguments_kwargs, "optimizer_cls_and_kwargs")
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
|
||||
Reference in New Issue
Block a user