Skip redundant evaluation when resuming from checkpoint (#3575) [skip ci]
* Skip redundant evaluation when resuming from checkpoint * add condition check for adding callback --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -41,6 +41,7 @@ from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveModelOnFirstStepCallback,
|
||||
SkipEvalOnResumeCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.distributed import build_parallelism_config
|
||||
@@ -118,6 +119,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||
)
|
||||
|
||||
if self.cfg.resume_from_checkpoint:
|
||||
callbacks.append(SkipEvalOnResumeCallback())
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
|
||||
|
||||
@@ -98,6 +98,56 @@ class SaveModelOnFirstStepCallback(TrainerCallback):
|
||||
return control
|
||||
|
||||
|
||||
class SkipEvalOnResumeCallback(TrainerCallback):
|
||||
"""Skip the redundant evaluation that fires when resuming from a checkpoint
|
||||
whose step aligns with ``eval_steps``.
|
||||
|
||||
When HuggingFace Trainer resumes, it restores ``global_step`` from the
|
||||
checkpoint and immediately triggers ``_maybe_log_save_evaluate`` for that
|
||||
step. Because the evaluation was already performed during the original
|
||||
run, repeating it wastes time and pollutes metric logs.
|
||||
|
||||
This callback records the ``global_step`` at the start of training (i.e.
|
||||
the checkpoint step when resuming, or 0 for a fresh run) and suppresses
|
||||
any evaluation request on that exact step.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._resume_step: int | None = None
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**_kwargs,
|
||||
):
|
||||
# ``global_step`` is already restored from the checkpoint at this
|
||||
# point. For a fresh run it will be 0, so the guard below becomes a
|
||||
# no-op.
|
||||
self._resume_step = state.global_step
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**_kwargs,
|
||||
) -> TrainerControl:
|
||||
if (
|
||||
self._resume_step
|
||||
and state.global_step <= self._resume_step
|
||||
and control.should_evaluate
|
||||
):
|
||||
LOG.info(
|
||||
"Skipping evaluation at step %d (already completed before resume)",
|
||||
state.global_step,
|
||||
)
|
||||
control.should_evaluate = False
|
||||
return control
|
||||
|
||||
|
||||
def bench_eval_callback_factory(trainer, tokenizer):
|
||||
accuracy = evaluate.load("accuracy")
|
||||
abcd_idx = [
|
||||
|
||||
Reference in New Issue
Block a user