feat: remove evalfirst callback with built-in trainer arg (#2797)

This commit is contained in:
NanoCode012
2025-06-17 09:09:33 -07:00
committed by GitHub
parent ccc94da8ad
commit d8e8cd8558
3 changed files with 2 additions and 21 deletions

View File

@@ -386,8 +386,10 @@ class TrainerBuilderBase(abc.ABC):
elif self.cfg.eval_steps:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
training_args_kwargs["eval_on_start"] = True
elif self.cfg.eval_strategy:
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
training_args_kwargs["eval_on_start"] = True
def _configure_reporting(self, training_args_kwargs: dict):
report_to = []

View File

@@ -27,7 +27,6 @@ from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
@@ -58,7 +57,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(EvalFirstStepCallback())
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))

View File

@@ -53,25 +53,6 @@ IGNORE_INDEX = -100
LOG = get_logger(__name__)
class EvalFirstStepCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""
Callback to trigger evals on the first step
"""
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
control.should_evaluate = True
return control
class SaveBetterTransformerModelCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods