Skip redundant evaluation when resuming from checkpoint (#3575) [skip ci]

* Skip redundant evaluation when resuming from checkpoint

* add condition check for adding callback

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Joaquin Hui
2026-04-13 01:50:15 +01:00
committed by GitHub
parent 66c3e5a3fd
commit a44edda6d7
3 changed files with 117 additions and 0 deletions

View File

@@ -41,6 +41,7 @@ from axolotl.utils.callbacks import (
GCCallback,
SaveAxolotlConfigtoWandBCallback,
SaveModelOnFirstStepCallback,
SkipEvalOnResumeCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.distributed import build_parallelism_config
@@ -118,6 +119,9 @@ class TrainerBuilderBase(abc.ABC):
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.resume_from_checkpoint:
callbacks.append(SkipEvalOnResumeCallback())
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))

View File

@@ -98,6 +98,56 @@ class SaveModelOnFirstStepCallback(TrainerCallback):
return control
class SkipEvalOnResumeCallback(TrainerCallback):
"""Skip the redundant evaluation that fires when resuming from a checkpoint
whose step aligns with ``eval_steps``.
When HuggingFace Trainer resumes, it restores ``global_step`` from the
checkpoint and immediately triggers ``_maybe_log_save_evaluate`` for that
step. Because the evaluation was already performed during the original
run, repeating it wastes time and pollutes metric logs.
This callback records the ``global_step`` at the start of training (i.e.
the checkpoint step when resuming, or 0 for a fresh run) and suppresses
any evaluation request on that exact step.
"""
def __init__(self):
super().__init__()
self._resume_step: int | None = None
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
# ``global_step`` is already restored from the checkpoint at this
# point. For a fresh run it will be 0, so the guard below becomes a
# no-op.
self._resume_step = state.global_step
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
) -> TrainerControl:
if (
self._resume_step
and state.global_step <= self._resume_step
and control.should_evaluate
):
LOG.info(
"Skipping evaluation at step %d (already completed before resume)",
state.global_step,
)
control.should_evaluate = False
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [

View File

@@ -0,0 +1,63 @@
"""Tests for SkipEvalOnResumeCallback."""
from unittest.mock import MagicMock
from transformers import TrainerControl, TrainerState, TrainingArguments
from axolotl.utils.callbacks import SkipEvalOnResumeCallback
class TestSkipEvalOnResumeCallback:
"""Tests for skipping redundant evaluation on checkpoint resume."""
@staticmethod
def _make_state(global_step: int) -> TrainerState:
state = MagicMock(spec=TrainerState)
state.global_step = global_step
return state
def test_suppresses_eval_at_resume_step(self):
cb = SkipEvalOnResumeCallback()
args = MagicMock(spec=TrainingArguments)
state = self._make_state(20)
control = TrainerControl(should_evaluate=False)
# Simulate on_train_begin at checkpoint-20
cb.on_train_begin(args, state, control)
# Trainer sets should_evaluate = True for step 20
control.should_evaluate = True
result = cb.on_step_end(args, state, control)
assert result.should_evaluate is False
def test_allows_eval_after_resume_step(self):
cb = SkipEvalOnResumeCallback()
args = MagicMock(spec=TrainingArguments)
state = self._make_state(20)
control = TrainerControl(should_evaluate=False)
cb.on_train_begin(args, state, control)
# Advance past the resume point
state.global_step = 30
control.should_evaluate = True
result = cb.on_step_end(args, state, control)
assert result.should_evaluate is True
def test_noop_on_fresh_run(self):
cb = SkipEvalOnResumeCallback()
args = MagicMock(spec=TrainingArguments)
state = self._make_state(0)
control = TrainerControl(should_evaluate=False)
# Fresh run: global_step starts at 0
cb.on_train_begin(args, state, control)
# Even if eval triggers at step 0 (unlikely but defensive)
state.global_step = 10
control.should_evaluate = True
result = cb.on_step_end(args, state, control)
assert result.should_evaluate is True