Skip redundant evaluation when resuming from checkpoint (#3575) [skip ci]
* Skip redundant evaluation when resuming from checkpoint * add condition check for adding callback --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -41,6 +41,7 @@ from axolotl.utils.callbacks import (
|
|||||||
GCCallback,
|
GCCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveModelOnFirstStepCallback,
|
SaveModelOnFirstStepCallback,
|
||||||
|
SkipEvalOnResumeCallback,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
from axolotl.utils.distributed import build_parallelism_config
|
from axolotl.utils.distributed import build_parallelism_config
|
||||||
@@ -118,6 +119,9 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cfg.resume_from_checkpoint:
|
||||||
|
callbacks.append(SkipEvalOnResumeCallback())
|
||||||
|
|
||||||
if self.cfg.gc_steps:
|
if self.cfg.gc_steps:
|
||||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||||
|
|
||||||
|
|||||||
@@ -98,6 +98,56 @@ class SaveModelOnFirstStepCallback(TrainerCallback):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class SkipEvalOnResumeCallback(TrainerCallback):
|
||||||
|
"""Skip the redundant evaluation that fires when resuming from a checkpoint
|
||||||
|
whose step aligns with ``eval_steps``.
|
||||||
|
|
||||||
|
When HuggingFace Trainer resumes, it restores ``global_step`` from the
|
||||||
|
checkpoint and immediately triggers ``_maybe_log_save_evaluate`` for that
|
||||||
|
step. Because the evaluation was already performed during the original
|
||||||
|
run, repeating it wastes time and pollutes metric logs.
|
||||||
|
|
||||||
|
This callback records the ``global_step`` at the start of training (i.e.
|
||||||
|
the checkpoint step when resuming, or 0 for a fresh run) and suppresses
|
||||||
|
any evaluation request on that exact step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._resume_step: int | None = None
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**_kwargs,
|
||||||
|
):
|
||||||
|
# ``global_step`` is already restored from the checkpoint at this
|
||||||
|
# point. For a fresh run it will be 0, so the guard below becomes a
|
||||||
|
# no-op.
|
||||||
|
self._resume_step = state.global_step
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**_kwargs,
|
||||||
|
) -> TrainerControl:
|
||||||
|
if (
|
||||||
|
self._resume_step
|
||||||
|
and state.global_step <= self._resume_step
|
||||||
|
and control.should_evaluate
|
||||||
|
):
|
||||||
|
LOG.info(
|
||||||
|
"Skipping evaluation at step %d (already completed before resume)",
|
||||||
|
state.global_step,
|
||||||
|
)
|
||||||
|
control.should_evaluate = False
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
def bench_eval_callback_factory(trainer, tokenizer):
|
def bench_eval_callback_factory(trainer, tokenizer):
|
||||||
accuracy = evaluate.load("accuracy")
|
accuracy = evaluate.load("accuracy")
|
||||||
abcd_idx = [
|
abcd_idx = [
|
||||||
|
|||||||
63
tests/utils/callbacks/test_skip_eval_on_resume.py
Normal file
63
tests/utils/callbacks/test_skip_eval_on_resume.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""Tests for SkipEvalOnResumeCallback."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||||
|
|
||||||
|
from axolotl.utils.callbacks import SkipEvalOnResumeCallback
|
||||||
|
|
||||||
|
|
||||||
|
class TestSkipEvalOnResumeCallback:
|
||||||
|
"""Tests for skipping redundant evaluation on checkpoint resume."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_state(global_step: int) -> TrainerState:
|
||||||
|
state = MagicMock(spec=TrainerState)
|
||||||
|
state.global_step = global_step
|
||||||
|
return state
|
||||||
|
|
||||||
|
def test_suppresses_eval_at_resume_step(self):
|
||||||
|
cb = SkipEvalOnResumeCallback()
|
||||||
|
args = MagicMock(spec=TrainingArguments)
|
||||||
|
state = self._make_state(20)
|
||||||
|
control = TrainerControl(should_evaluate=False)
|
||||||
|
|
||||||
|
# Simulate on_train_begin at checkpoint-20
|
||||||
|
cb.on_train_begin(args, state, control)
|
||||||
|
|
||||||
|
# Trainer sets should_evaluate = True for step 20
|
||||||
|
control.should_evaluate = True
|
||||||
|
result = cb.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert result.should_evaluate is False
|
||||||
|
|
||||||
|
def test_allows_eval_after_resume_step(self):
|
||||||
|
cb = SkipEvalOnResumeCallback()
|
||||||
|
args = MagicMock(spec=TrainingArguments)
|
||||||
|
state = self._make_state(20)
|
||||||
|
control = TrainerControl(should_evaluate=False)
|
||||||
|
|
||||||
|
cb.on_train_begin(args, state, control)
|
||||||
|
|
||||||
|
# Advance past the resume point
|
||||||
|
state.global_step = 30
|
||||||
|
control.should_evaluate = True
|
||||||
|
result = cb.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert result.should_evaluate is True
|
||||||
|
|
||||||
|
def test_noop_on_fresh_run(self):
|
||||||
|
cb = SkipEvalOnResumeCallback()
|
||||||
|
args = MagicMock(spec=TrainingArguments)
|
||||||
|
state = self._make_state(0)
|
||||||
|
control = TrainerControl(should_evaluate=False)
|
||||||
|
|
||||||
|
# Fresh run: global_step starts at 0
|
||||||
|
cb.on_train_begin(args, state, control)
|
||||||
|
|
||||||
|
# Even if eval triggers at step 0 (unlikely but defensive)
|
||||||
|
state.global_step = 10
|
||||||
|
control.should_evaluate = True
|
||||||
|
result = cb.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert result.should_evaluate is True
|
||||||
Reference in New Issue
Block a user