Skip redundant evaluation when resuming from checkpoint (#3575) [skip ci]

* Skip redundant evaluation when resuming from checkpoint

* add condition check for adding callback

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Joaquin Hui
2026-04-13 01:50:15 +01:00
committed by GitHub
parent 66c3e5a3fd
commit a44edda6d7
3 changed files with 117 additions and 0 deletions

View File

@@ -0,0 +1,63 @@
"""Tests for SkipEvalOnResumeCallback."""
from unittest.mock import MagicMock
from transformers import TrainerControl, TrainerState, TrainingArguments
from axolotl.utils.callbacks import SkipEvalOnResumeCallback
class TestSkipEvalOnResumeCallback:
"""Tests for skipping redundant evaluation on checkpoint resume."""
@staticmethod
def _make_state(global_step: int) -> TrainerState:
state = MagicMock(spec=TrainerState)
state.global_step = global_step
return state
def test_suppresses_eval_at_resume_step(self):
cb = SkipEvalOnResumeCallback()
args = MagicMock(spec=TrainingArguments)
state = self._make_state(20)
control = TrainerControl(should_evaluate=False)
# Simulate on_train_begin at checkpoint-20
cb.on_train_begin(args, state, control)
# Trainer sets should_evaluate = True for step 20
control.should_evaluate = True
result = cb.on_step_end(args, state, control)
assert result.should_evaluate is False
def test_allows_eval_after_resume_step(self):
cb = SkipEvalOnResumeCallback()
args = MagicMock(spec=TrainingArguments)
state = self._make_state(20)
control = TrainerControl(should_evaluate=False)
cb.on_train_begin(args, state, control)
# Advance past the resume point
state.global_step = 30
control.should_evaluate = True
result = cb.on_step_end(args, state, control)
assert result.should_evaluate is True
def test_noop_on_fresh_run(self):
cb = SkipEvalOnResumeCallback()
args = MagicMock(spec=TrainingArguments)
state = self._make_state(0)
control = TrainerControl(should_evaluate=False)
# Fresh run: global_step starts at 0
cb.on_train_begin(args, state, control)
# Even if eval triggers at step 0 (unlikely but defensive)
state.global_step = 10
control.should_evaluate = True
result = cb.on_step_end(args, state, control)
assert result.should_evaluate is True