diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 9dba48b88..9cceb1bf8 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -41,6 +41,7 @@ from axolotl.utils.callbacks import ( GCCallback, SaveAxolotlConfigtoWandBCallback, SaveModelOnFirstStepCallback, + SkipEvalOnResumeCallback, ) from axolotl.utils.callbacks.profiler import PytorchProfilerCallback from axolotl.utils.distributed import build_parallelism_config @@ -118,6 +119,9 @@ class TrainerBuilderBase(abc.ABC): plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model) ) + if self.cfg.resume_from_checkpoint: + callbacks.append(SkipEvalOnResumeCallback()) + if self.cfg.gc_steps: callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index afdb7f2a2..8137bac0c 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -98,6 +98,56 @@ class SaveModelOnFirstStepCallback(TrainerCallback): return control +class SkipEvalOnResumeCallback(TrainerCallback): + """Skip the redundant evaluation that fires when resuming from a checkpoint + whose step aligns with ``eval_steps``. + + When HuggingFace Trainer resumes, it restores ``global_step`` from the + checkpoint and immediately triggers ``_maybe_log_save_evaluate`` for that + step. Because the evaluation was already performed during the original + run, repeating it wastes time and pollutes metric logs. + + This callback records the ``global_step`` at the start of training (i.e. + the checkpoint step when resuming, or 0 for a fresh run) and suppresses + any evaluation request on that exact step. + """ + + def __init__(self): + super().__init__() + self._resume_step: int | None = None + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **_kwargs, + ): + # ``global_step`` is already restored from the checkpoint at this + # point. For a fresh run it will be 0, so the guard below becomes a + # no-op. + self._resume_step = state.global_step + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **_kwargs, + ) -> TrainerControl: + if ( + self._resume_step + and state.global_step <= self._resume_step + and control.should_evaluate + ): + LOG.info( + "Skipping evaluation at step %d (already completed before resume)", + state.global_step, + ) + control.should_evaluate = False + return control + + def bench_eval_callback_factory(trainer, tokenizer): accuracy = evaluate.load("accuracy") abcd_idx = [ diff --git a/tests/utils/callbacks/test_skip_eval_on_resume.py b/tests/utils/callbacks/test_skip_eval_on_resume.py new file mode 100644 index 000000000..55cf27438 --- /dev/null +++ b/tests/utils/callbacks/test_skip_eval_on_resume.py @@ -0,0 +1,63 @@ +"""Tests for SkipEvalOnResumeCallback.""" + +from unittest.mock import MagicMock + +from transformers import TrainerControl, TrainerState, TrainingArguments + +from axolotl.utils.callbacks import SkipEvalOnResumeCallback + + +class TestSkipEvalOnResumeCallback: + """Tests for skipping redundant evaluation on checkpoint resume.""" + + @staticmethod + def _make_state(global_step: int) -> TrainerState: + state = MagicMock(spec=TrainerState) + state.global_step = global_step + return state + + def test_suppresses_eval_at_resume_step(self): + cb = SkipEvalOnResumeCallback() + args = MagicMock(spec=TrainingArguments) + state = self._make_state(20) + control = TrainerControl(should_evaluate=False) + + # Simulate on_train_begin at checkpoint-20 + cb.on_train_begin(args, state, control) + + # Trainer sets should_evaluate = True for step 20 + control.should_evaluate = True + result = cb.on_step_end(args, state, control) + + assert result.should_evaluate is False + + def test_allows_eval_after_resume_step(self): + cb = SkipEvalOnResumeCallback() + args = MagicMock(spec=TrainingArguments) + state = self._make_state(20) + control = TrainerControl(should_evaluate=False) + + cb.on_train_begin(args, state, control) + + # Advance past the resume point + state.global_step = 30 + control.should_evaluate = True + result = cb.on_step_end(args, state, control) + + assert result.should_evaluate is True + + def test_noop_on_fresh_run(self): + cb = SkipEvalOnResumeCallback() + args = MagicMock(spec=TrainingArguments) + state = self._make_state(0) + control = TrainerControl(should_evaluate=False) + + # Fresh run: global_step starts at 0 + cb.on_train_begin(args, state, control) + + # Even if eval triggers at step 0 (unlikely but defensive) + state.global_step = 10 + control.should_evaluate = True + result = cb.on_step_end(args, state, control) + + assert result.should_evaluate is True