Skip redundant evaluation when resuming from checkpoint (#3575) [skip ci]
* Skip redundant evaluation when resuming from checkpoint * add condition check for adding callback --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -41,6 +41,7 @@ from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveModelOnFirstStepCallback,
|
||||
SkipEvalOnResumeCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.distributed import build_parallelism_config
|
||||
@@ -118,6 +119,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||
)
|
||||
|
||||
if self.cfg.resume_from_checkpoint:
|
||||
callbacks.append(SkipEvalOnResumeCallback())
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
|
||||
|
||||
@@ -98,6 +98,56 @@ class SaveModelOnFirstStepCallback(TrainerCallback):
|
||||
return control
|
||||
|
||||
|
||||
class SkipEvalOnResumeCallback(TrainerCallback):
|
||||
"""Skip the redundant evaluation that fires when resuming from a checkpoint
|
||||
whose step aligns with ``eval_steps``.
|
||||
|
||||
When HuggingFace Trainer resumes, it restores ``global_step`` from the
|
||||
checkpoint and immediately triggers ``_maybe_log_save_evaluate`` for that
|
||||
step. Because the evaluation was already performed during the original
|
||||
run, repeating it wastes time and pollutes metric logs.
|
||||
|
||||
This callback records the ``global_step`` at the start of training (i.e.
|
||||
the checkpoint step when resuming, or 0 for a fresh run) and suppresses
|
||||
any evaluation request on that exact step.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._resume_step: int | None = None
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**_kwargs,
|
||||
):
|
||||
# ``global_step`` is already restored from the checkpoint at this
|
||||
# point. For a fresh run it will be 0, so the guard below becomes a
|
||||
# no-op.
|
||||
self._resume_step = state.global_step
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**_kwargs,
|
||||
) -> TrainerControl:
|
||||
if (
|
||||
self._resume_step
|
||||
and state.global_step <= self._resume_step
|
||||
and control.should_evaluate
|
||||
):
|
||||
LOG.info(
|
||||
"Skipping evaluation at step %d (already completed before resume)",
|
||||
state.global_step,
|
||||
)
|
||||
control.should_evaluate = False
|
||||
return control
|
||||
|
||||
|
||||
def bench_eval_callback_factory(trainer, tokenizer):
|
||||
accuracy = evaluate.load("accuracy")
|
||||
abcd_idx = [
|
||||
|
||||
63
tests/utils/callbacks/test_skip_eval_on_resume.py
Normal file
63
tests/utils/callbacks/test_skip_eval_on_resume.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Tests for SkipEvalOnResumeCallback."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
|
||||
from axolotl.utils.callbacks import SkipEvalOnResumeCallback
|
||||
|
||||
|
||||
class TestSkipEvalOnResumeCallback:
|
||||
"""Tests for skipping redundant evaluation on checkpoint resume."""
|
||||
|
||||
@staticmethod
|
||||
def _make_state(global_step: int) -> TrainerState:
|
||||
state = MagicMock(spec=TrainerState)
|
||||
state.global_step = global_step
|
||||
return state
|
||||
|
||||
def test_suppresses_eval_at_resume_step(self):
|
||||
cb = SkipEvalOnResumeCallback()
|
||||
args = MagicMock(spec=TrainingArguments)
|
||||
state = self._make_state(20)
|
||||
control = TrainerControl(should_evaluate=False)
|
||||
|
||||
# Simulate on_train_begin at checkpoint-20
|
||||
cb.on_train_begin(args, state, control)
|
||||
|
||||
# Trainer sets should_evaluate = True for step 20
|
||||
control.should_evaluate = True
|
||||
result = cb.on_step_end(args, state, control)
|
||||
|
||||
assert result.should_evaluate is False
|
||||
|
||||
def test_allows_eval_after_resume_step(self):
|
||||
cb = SkipEvalOnResumeCallback()
|
||||
args = MagicMock(spec=TrainingArguments)
|
||||
state = self._make_state(20)
|
||||
control = TrainerControl(should_evaluate=False)
|
||||
|
||||
cb.on_train_begin(args, state, control)
|
||||
|
||||
# Advance past the resume point
|
||||
state.global_step = 30
|
||||
control.should_evaluate = True
|
||||
result = cb.on_step_end(args, state, control)
|
||||
|
||||
assert result.should_evaluate is True
|
||||
|
||||
def test_noop_on_fresh_run(self):
|
||||
cb = SkipEvalOnResumeCallback()
|
||||
args = MagicMock(spec=TrainingArguments)
|
||||
state = self._make_state(0)
|
||||
control = TrainerControl(should_evaluate=False)
|
||||
|
||||
# Fresh run: global_step starts at 0
|
||||
cb.on_train_begin(args, state, control)
|
||||
|
||||
# Even if eval triggers at step 0 (unlikely but defensive)
|
||||
state.global_step = 10
|
||||
control.should_evaluate = True
|
||||
result = cb.on_step_end(args, state, control)
|
||||
|
||||
assert result.should_evaluate is True
|
||||
Reference in New Issue
Block a user