diff --git a/src/axolotl/core/trainer_builder/base.py b/src/axolotl/core/trainer_builder/base.py index 35fc7d681..6c3735256 100644 --- a/src/axolotl/core/trainer_builder/base.py +++ b/src/axolotl/core/trainer_builder/base.py @@ -304,6 +304,8 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1 + training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs + # max_length is not used in CausalTrainer if self.cfg.reward_model or self.cfg.rl: training_args_kwargs["max_length"] = self.cfg.sequence_len diff --git a/src/axolotl/core/trainer_builder/rl.py b/src/axolotl/core/trainer_builder/rl.py index d41462f87..c45edbe4a 100644 --- a/src/axolotl/core/trainer_builder/rl.py +++ b/src/axolotl/core/trainer_builder/rl.py @@ -156,7 +156,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if blocklist_key in training_args_kwargs: del training_args_kwargs[blocklist_key] - training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg logging_first_step=True, **training_args_kwargs, diff --git a/src/axolotl/core/trainer_builder/sft.py b/src/axolotl/core/trainer_builder/sft.py index ca3ba79f3..56b487248 100644 --- a/src/axolotl/core/trainer_builder/sft.py +++ b/src/axolotl/core/trainer_builder/sft.py @@ -148,7 +148,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_kwargs = {} # Pop optimizer_cls_and_kwargs to trainer_kwargs - if "optimizer_cls_and_kwargs" in trainer_kwargs: + if "optimizer_cls_and_kwargs" in training_arguments_kwargs: trainer_kwargs["optimizer_cls_and_kwargs"] = training_arguments_kwargs.pop( "optimizer_cls_and_kwargs" ) @@ -219,7 +219,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["eval_accumulation_steps"] = ( self.cfg.gradient_accumulation_steps ) - training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs training_arguments_kwargs["load_best_model_at_end"] = ( (