diff --git a/src/axolotl/core/trainer_builder/rl.py b/src/axolotl/core/trainer_builder/rl.py index e36af0f71..d41462f87 100644 --- a/src/axolotl/core/trainer_builder/rl.py +++ b/src/axolotl/core/trainer_builder/rl.py @@ -175,12 +175,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase): trainer_kwargs = {} # Pop optimizer_cls_and_kwargs to trainer_kwargs - if hasattr(training_args, "optimizer_cls_and_kwargs"): - trainer_kwargs["optimizer_cls_and_kwargs"] = getattr( - training_args, "optimizer_cls_and_kwargs" + if "optimizer_cls_and_kwargs" in training_args: + trainer_kwargs["optimizer_cls_and_kwargs"] = training_args.pop( + "optimizer_cls_and_kwargs" ) - # prevent duplication downstream - delattr(training_args, "optimizer_cls_and_kwargs") if self.cfg.rl is RLType.IPO: if self.cfg.dpo_label_smoothing: diff --git a/src/axolotl/core/trainer_builder/sft.py b/src/axolotl/core/trainer_builder/sft.py index 22b4e45fa..ca3ba79f3 100644 --- a/src/axolotl/core/trainer_builder/sft.py +++ b/src/axolotl/core/trainer_builder/sft.py @@ -145,6 +145,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): def build(self, total_num_steps): training_arguments_kwargs = self._set_base_training_args(total_num_steps) + trainer_kwargs = {} + + # Pop optimizer_cls_and_kwargs to trainer_kwargs + if "optimizer_cls_and_kwargs" in trainer_kwargs: + trainer_kwargs["optimizer_cls_and_kwargs"] = training_arguments_kwargs.pop( + "optimizer_cls_and_kwargs" + ) + if self.cfg.fsdp: training_arguments_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp_config: @@ -323,16 +331,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.kd_top_k_before_softmax ) - trainer_kwargs = {} - - # Pop optimizer_cls_and_kwargs to trainer_kwargs - if hasattr(training_arguments_kwargs, "optimizer_cls_and_kwargs"): - trainer_kwargs["optimizer_cls_and_kwargs"] = getattr( - training_arguments_kwargs, "optimizer_cls_and_kwargs" - ) - # prevent duplication downstream - delattr(training_arguments_kwargs, "optimizer_cls_and_kwargs") - if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig elif self.cfg.process_reward_model: