From cd31394e7018a827e6aaf6c7a77bfb6cd814bfe0 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 22 May 2025 18:11:17 +0700 Subject: [PATCH] feat: move epoch setting to base --- src/axolotl/core/trainer_builder/base.py | 2 ++ src/axolotl/core/trainer_builder/rl.py | 1 - src/axolotl/core/trainer_builder/sft.py | 3 +-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/core/trainer_builder/base.py b/src/axolotl/core/trainer_builder/base.py index 35fc7d681..6c3735256 100644 --- a/src/axolotl/core/trainer_builder/base.py +++ b/src/axolotl/core/trainer_builder/base.py @@ -304,6 +304,8 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1 + training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs + # max_length is not used in CausalTrainer if self.cfg.reward_model or self.cfg.rl: training_args_kwargs["max_length"] = self.cfg.sequence_len diff --git a/src/axolotl/core/trainer_builder/rl.py b/src/axolotl/core/trainer_builder/rl.py index d41462f87..c45edbe4a 100644 --- a/src/axolotl/core/trainer_builder/rl.py +++ b/src/axolotl/core/trainer_builder/rl.py @@ -156,7 +156,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if blocklist_key in training_args_kwargs: del training_args_kwargs[blocklist_key] - training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg logging_first_step=True, **training_args_kwargs, diff --git a/src/axolotl/core/trainer_builder/sft.py b/src/axolotl/core/trainer_builder/sft.py index ca3ba79f3..56b487248 100644 --- a/src/axolotl/core/trainer_builder/sft.py +++ b/src/axolotl/core/trainer_builder/sft.py @@ -148,7 +148,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_kwargs = {} # Pop optimizer_cls_and_kwargs to trainer_kwargs - if "optimizer_cls_and_kwargs" in trainer_kwargs: + if "optimizer_cls_and_kwargs" in training_arguments_kwargs: trainer_kwargs["optimizer_cls_and_kwargs"] = training_arguments_kwargs.pop( "optimizer_cls_and_kwargs" ) @@ -219,7 +219,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["eval_accumulation_steps"] = ( self.cfg.gradient_accumulation_steps ) - training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs training_arguments_kwargs["load_best_model_at_end"] = ( (