diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index e399cf3c5..eed43e542 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -386,8 +386,10 @@ class TrainerBuilderBase(abc.ABC): elif self.cfg.eval_steps: training_args_kwargs["eval_strategy"] = "steps" training_args_kwargs["eval_steps"] = self.cfg.eval_steps + training_args_kwargs["eval_on_start"] = True elif self.cfg.eval_strategy: training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy + training_args_kwargs["eval_on_start"] = True def _configure_reporting(self, training_args_kwargs: dict): report_to = [] diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 6ed298d9f..47e33a332 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -27,7 +27,6 @@ from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.processing_strategies import get_processing_strategy from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( - EvalFirstStepCallback, LossWatchDogCallback, SaveBetterTransformerModelCallback, bench_eval_callback_factory, @@ -58,7 +57,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): def get_callbacks(self): callbacks = super().get_callbacks() - callbacks.append(EvalFirstStepCallback()) if self.cfg.relora_steps: callbacks.append(ReLoRACallback(self.cfg)) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 8b8a77611..2a93ceef5 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -53,25 +53,6 @@ IGNORE_INDEX = -100 LOG = get_logger(__name__) -class EvalFirstStepCallback( - TrainerCallback -): # pylint: disable=too-few-public-methods disable=unused-argument - """ - Callback to trigger evals on the first step - """ - - def on_step_end( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ): - if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1: - control.should_evaluate = True - return control - - class SaveBetterTransformerModelCallback( TrainerCallback ): # pylint: disable=too-few-public-methods