Skip redundant evaluation when resuming from checkpoint (#3575) [skip ci]

* Skip redundant evaluation when resuming from checkpoint

* add condition check for adding callback

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Joaquin Hui
2026-04-13 01:50:15 +01:00
committed by GitHub
parent 66c3e5a3fd
commit a44edda6d7
3 changed files with 117 additions and 0 deletions

View File

@@ -41,6 +41,7 @@ from axolotl.utils.callbacks import (
GCCallback,
SaveAxolotlConfigtoWandBCallback,
SaveModelOnFirstStepCallback,
SkipEvalOnResumeCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.distributed import build_parallelism_config
@@ -118,6 +119,9 @@ class TrainerBuilderBase(abc.ABC):
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.resume_from_checkpoint:
callbacks.append(SkipEvalOnResumeCallback())
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))

View File

@@ -98,6 +98,56 @@ class SaveModelOnFirstStepCallback(TrainerCallback):
return control
class SkipEvalOnResumeCallback(TrainerCallback):
"""Skip the redundant evaluation that fires when resuming from a checkpoint
whose step aligns with ``eval_steps``.
When HuggingFace Trainer resumes, it restores ``global_step`` from the
checkpoint and immediately triggers ``_maybe_log_save_evaluate`` for that
step. Because the evaluation was already performed during the original
run, repeating it wastes time and pollutes metric logs.
This callback records the ``global_step`` at the start of training (i.e.
the checkpoint step when resuming, or 0 for a fresh run) and suppresses
any evaluation request on that exact step.
"""
def __init__(self):
super().__init__()
self._resume_step: int | None = None
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
# ``global_step`` is already restored from the checkpoint at this
# point. For a fresh run it will be 0, so the guard below becomes a
# no-op.
self._resume_step = state.global_step
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
) -> TrainerControl:
if (
self._resume_step
and state.global_step <= self._resume_step
and control.should_evaluate
):
LOG.info(
"Skipping evaluation at step %d (already completed before resume)",
state.global_step,
)
control.should_evaluate = False
return control
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [