From dcf24fd24ed59993e03cde0fc17e464a542bf52e Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Thu, 13 Nov 2025 20:51:05 +0530 Subject: [PATCH] feat: save checkpoint after training started (#3233) * add:config parameters for checkpoint * callback main * test file_type fix * lint * unit * simplify dict/obj handeling * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Delete tests/e2e/integrations/__init__.py * remove hard code path in test * device check * lint * Update src/axolotl/utils/callbacks/dynamic_checkpoint.py Co-authored-by: NanoCode012 * Update src/axolotl/utils/callbacks/dynamic_checkpoint.py Co-authored-by: NanoCode012 * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: NanoCode012 * lint-2 * remove: singal based checkpoints * lint * remove signal tests * add:is_main_process * lint * addis_d:istributed() for tests * remove nested is_main_process * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: Wing Lian * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: Wing Lian * add user_defined_filename --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: NanoCode012 Co-authored-by: Wing Lian --- src/axolotl/core/builders/base.py | 7 + .../utils/callbacks/dynamic_checkpoint.py | 132 ++++++ src/axolotl/utils/schemas/config.py | 8 + .../utils/schemas/dynamic_checkpoint.py | 31 ++ tests/e2e/integrations/__init__.py | 0 .../callbacks/test_dynamic_checkpoint.py | 389 ++++++++++++++++++ 6 files changed, 567 insertions(+) create mode 100644 src/axolotl/utils/callbacks/dynamic_checkpoint.py create mode 100644 src/axolotl/utils/schemas/dynamic_checkpoint.py delete mode 100644 tests/e2e/integrations/__init__.py create mode 100644 tests/utils/callbacks/test_dynamic_checkpoint.py diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 7954e1fbd..fc6759ffb 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -118,6 +118,13 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.gc_steps: callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) + if self.cfg.dynamic_checkpoint and self.cfg.dynamic_checkpoint.enabled: + from axolotl.utils.callbacks.dynamic_checkpoint import ( + DynamicCheckpointCallback, + ) + + callbacks.append(DynamicCheckpointCallback(self.cfg)) + if self.cfg.use_wandb: callbacks.append( SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) diff --git a/src/axolotl/utils/callbacks/dynamic_checkpoint.py b/src/axolotl/utils/callbacks/dynamic_checkpoint.py new file mode 100644 index 000000000..632109225 --- /dev/null +++ b/src/axolotl/utils/callbacks/dynamic_checkpoint.py @@ -0,0 +1,132 @@ +from pathlib import Path + +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +from axolotl.utils.distributed import ( + barrier, + is_distributed, + is_main_process, +) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +DEFAULT_TRIGGER_FILENAME = "axolotl_checkpoint.save" + + +class DynamicCheckpointCallback(TrainerCallback): + """ + Callback to save checkpoints on-demand during training via: + 1. File-based trigger (works everywhere, rank 0 checks file) + + Thread-safe for multi-GPU distributed training. + + Usage: + # File-based: + touch /path/to/output_dir/axolotl_checkpoint.save + """ + + def _get_config_value(self, config, key, default=None): + """Helper to get config value from dict or object.""" + if isinstance(config, dict): + return config.get(key, default) + return getattr(config, key, default) + + def __init__(self, cfg): + self.cfg = cfg + if not cfg.dynamic_checkpoint or not cfg.dynamic_checkpoint.enabled: + self.enabled = False + return + + self.enabled = True + dc_config = cfg.dynamic_checkpoint + + trigger_file_path = self._get_config_value(dc_config, "trigger_file_path") + self.trigger_filename = ( + trigger_file_path if trigger_file_path else DEFAULT_TRIGGER_FILENAME + ) + + check_interval = self._get_config_value(dc_config, "check_interval") + self.check_interval = check_interval if check_interval is not None else 100 + self.should_save_checkpoint = False + + LOG.info( + f"Dynamic checkpoint enabled. To trigger checkpoint save:\n" + f" • File: touch {cfg.output_dir}/{self.trigger_filename}\n" + f" • Check interval: every {self.check_interval} steps", + main_process_only=True, + ) + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **_kwargs, + ) -> TrainerControl: + """ + Check for checkpoint triggers at the end of each step. + ONLY rank 0 checks the file, then all ranks synchronize. + """ + if not self.enabled: + return control + + trigger_detected = False + + if state.global_step % self.check_interval == 0: + if is_main_process(): + trigger_path = Path(args.output_dir) / self.trigger_filename + + if trigger_path.exists(): + trigger_detected = True + try: + trigger_path.unlink() # Delete the trigger file + LOG.info( + f"Dynamic checkpoint triggered via file '{self.trigger_filename}' " + f"at step {state.global_step}", + main_process_only=True, + ) + except OSError as exc: + LOG.warning( + f"Failed to delete trigger file: {exc}", + main_process_only=True, + ) + + if self.should_save_checkpoint: + trigger_detected = True + self.should_save_checkpoint = False # Reset flag + + if is_distributed(): + import torch + import torch.distributed as dist + + device = getattr( + args, + "device", + torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) + + trigger_tensor = torch.tensor( + 1 if trigger_detected else 0, + dtype=torch.long, + device=device, + ) + + dist.broadcast(trigger_tensor, src=0) + + trigger_detected = bool(trigger_tensor.item()) + + barrier() + + if trigger_detected: + control.should_save = True + LOG.info( + f"Saving dynamic checkpoint at step {state.global_step}", + main_process_only=True, + ) + return control diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 86b3aa17b..5ad55f8b7 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -23,6 +23,7 @@ from axolotl.utils.schemas.datasets import ( StepwiseSupervisedDataset, ) from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters +from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType from axolotl.utils.schemas.fsdp import FSDPConfig from axolotl.utils.schemas.integrations import ( @@ -141,6 +142,13 @@ class AxolotlInputConfig( default=None, json_schema_extra={"description": "Reward modelling: `True` or `False`"}, ) + dynamic_checkpoint: DynamicCheckpointConfig | None = Field( + default=None, + json_schema_extra={ + "description": "Configuration for dynamic checkpointing (trigger by file or signal). " + "Set 'enabled: true' to activate this feature." + }, + ) process_reward_model: bool | None = Field( default=None, json_schema_extra={ diff --git a/src/axolotl/utils/schemas/dynamic_checkpoint.py b/src/axolotl/utils/schemas/dynamic_checkpoint.py new file mode 100644 index 000000000..e0e1d0c1d --- /dev/null +++ b/src/axolotl/utils/schemas/dynamic_checkpoint.py @@ -0,0 +1,31 @@ +"""Schema for dynamic checkpoint configuration.""" + +from pydantic import BaseModel, Field + + +class DynamicCheckpointConfig(BaseModel): + """Configuration for dynamic checkpoint triggering during training.""" + + enabled: bool = Field( + default=False, + json_schema_extra={ + "description": "Enable dynamic checkpoint triggering during training. " + "Create a file 'axolotl_checkpoint.save' in the configured `output_dir` to trigger. " + }, + ) + check_interval: int = Field( + default=10, + ge=1, + json_schema_extra={ + "description": "Check for trigger file every N steps (reduces I/O overhead). " + "Default: 100" + }, + ) + trigger_file_path: str = Field( + default="", + json_schema_extra={ + "description": "Custom trigger filename (optional). " + "If not specified, defaults to 'axolotl_checkpoint.save'. " + "Specify a filename (not a full path) to override the default." + }, + ) diff --git a/tests/e2e/integrations/__init__.py b/tests/e2e/integrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/utils/callbacks/test_dynamic_checkpoint.py b/tests/utils/callbacks/test_dynamic_checkpoint.py new file mode 100644 index 000000000..1fd792102 --- /dev/null +++ b/tests/utils/callbacks/test_dynamic_checkpoint.py @@ -0,0 +1,389 @@ +"""Unit tests for dynamic checkpoint callback""" + +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch + +from axolotl.utils.callbacks.dynamic_checkpoint import ( + DEFAULT_TRIGGER_FILENAME, + DynamicCheckpointCallback, +) +from axolotl.utils.dict import DictDefault + + +class TestDynamicCheckpointCallbackInit: + """Test callback initialization""" + + def test_callback_disabled_by_default(self): + """Test that callback is disabled when config.enabled=False""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": False}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.enabled is False + + def test_callback_disabled_when_none(self): + """Test that callback is disabled when dynamic_checkpoint is None""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": None, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.enabled is False + + def test_callback_enabled_when_configured(self): + """Test that callback is enabled when config.enabled=True""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 10}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.enabled is True + assert callback.check_interval == 10 + + def test_default_trigger_filename(self): + """Test that default trigger filename is used""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 10}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.trigger_filename == DEFAULT_TRIGGER_FILENAME + + def test_check_interval_default(self): + """Test default check interval""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.check_interval == 100 # Default from schema + + +class TestDynamicCheckpointFileDetection: + """Test file-based checkpoint triggering""" + + def test_trigger_file_detected_and_deleted(self): + """Test that trigger file is detected and deleted""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + assert trigger_file.exists() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + result = callback.on_step_end(args, state, control) + + assert not trigger_file.exists() + assert result.should_save is True + + def test_check_interval_honored(self): + """Test that file is only checked at check_interval steps""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 10}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + args = Mock(output_dir=tmpdir) + control = Mock(should_save=False) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + # Step 5 - shouldn't check (not divisible by 10) + state = Mock(global_step=5) + result = callback.on_step_end(args, state, control) + assert trigger_file.exists() # Still there + assert result.should_save is False + + # Step 10 - should check + state = Mock(global_step=10) + result = callback.on_step_end(args, state, control) + assert not trigger_file.exists() # Deleted + assert result.should_save is True + + def test_no_file_no_trigger(self): + """Test that no trigger occurs when file doesn't exist""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + result = callback.on_step_end(args, state, control) + + assert result.should_save is False + + def test_file_deletion_error_handling(self): + """Test that file deletion errors are handled gracefully""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + with patch.object( + Path, "unlink", side_effect=OSError("Permission denied") + ): + result = callback.on_step_end(args, state, control) + + assert result.should_save is True + + +class TestDynamicCheckpointMultiGPU: + """Test multi-GPU synchronization""" + + def test_only_rank_0_checks_file(self): + """Test that only rank 0 checks filesystem in multi-GPU setup""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + # Rank 1 (not main process) - shouldn't check file + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=False, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=True, + ): + with patch("torch.distributed.broadcast") as mock_broadcast: + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.barrier" + ): + mock_tensor = MagicMock() + mock_tensor.item.return_value = 0 + with patch("torch.tensor", return_value=mock_tensor): + callback.on_step_end(args, state, control) + + assert trigger_file.exists() + # Broadcast should have been called + assert mock_broadcast.called + + def test_broadcast_synchronization(self): + """Test that trigger decision is broadcasted to all ranks""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + # Rank 0 detects file + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=True, + ): + with patch("torch.distributed.broadcast") as mock_broadcast: + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.barrier" + ) as mock_barrier: + mock_tensor = MagicMock() + mock_tensor.item.return_value = 1 + with patch("torch.tensor", return_value=mock_tensor): + with patch("torch.cuda.current_device", return_value=0): + result = callback.on_step_end(args, state, control) + + assert mock_broadcast.called + assert mock_barrier.called + # All ranks should trigger + assert result.should_save is True + + +class TestDynamicCheckpointSignalHandling: + """Test signal-based checkpoint triggering""" + + def test_signal_trigger_via_callback(self): + """Test that signal flag triggers checkpoint save""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": { + "enabled": True, + "check_interval": 1, + "enable_signal": True, + }, + "output_dir": tmpdir, + } + ) + + with patch("signal.signal"): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.hasattr", + return_value=True, + ): + callback = DynamicCheckpointCallback(cfg) + + callback.should_save_checkpoint = True + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + result = callback.on_step_end(args, state, control) + + assert result.should_save is True + assert callback.should_save_checkpoint is False + + def test_signal_not_registered_when_disabled(self): + """Test that signal handler is not registered when disabled""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": { + "enabled": True, + "check_interval": 10, + "enable_signal": False, + }, + "output_dir": tmpdir, + } + ) + + with patch("signal.signal") as mock_signal_register: + _ = DynamicCheckpointCallback(cfg) + + assert not mock_signal_register.called + + +class TestDynamicCheckpointDisabled: + """Test behavior when callback is disabled""" + + def test_disabled_callback_does_nothing(self): + """Test that disabled callback doesn't check or trigger""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": False}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + result = callback.on_step_end(args, state, control) + + assert trigger_file.exists() + assert result.should_save is False