diff --git a/src/axolotl/core/trainer_builder/rl.py b/src/axolotl/core/trainer_builder/rl.py index df2b32904..b584d3e88 100644 --- a/src/axolotl/core/trainer_builder/rl.py +++ b/src/axolotl/core/trainer_builder/rl.py @@ -175,10 +175,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase): trainer_kwargs = {} # Pop optimizer_cls_and_kwargs to trainer_kwargs - if "optimizer_cls_and_kwargs" in training_args: - trainer_kwargs["optimizer_cls_and_kwargs"] = training_args.pop( - "optimizer_cls_and_kwargs" + if hasattr(training_args, "optimizer_cls_and_kwargs"): + trainer_kwargs["optimizer_cls_and_kwargs"] = getattr( + training_args, "optimizer_cls_and_kwargs" ) + # prevent duplication downstream + delattr(training_args, "optimizer_cls_and_kwargs") if self.cfg.rl is RLType.IPO: if self.cfg.dpo_label_smoothing: diff --git a/src/axolotl/core/trainer_builder/sft.py b/src/axolotl/core/trainer_builder/sft.py index 0303f5dae..27750df15 100644 --- a/src/axolotl/core/trainer_builder/sft.py +++ b/src/axolotl/core/trainer_builder/sft.py @@ -346,10 +346,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_kwargs = {} # Pop optimizer_cls_and_kwargs to trainer_kwargs - if "optimizer_cls_and_kwargs" in training_arguments_kwargs: - trainer_kwargs["optimizer_cls_and_kwargs"] = training_arguments_kwargs.pop( - "optimizer_cls_and_kwargs" + if hasattr(training_arguments_kwargs, "optimizer_cls_and_kwargs"): + trainer_kwargs["optimizer_cls_and_kwargs"] = getattr( + training_arguments_kwargs, "optimizer_cls_and_kwargs" ) + # prevent duplication downstream + delattr(training_arguments_kwargs, "optimizer_cls_and_kwargs") if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig