feat: save checkpoint after training started (#3233)

* add:config parameters for checkpoint

* callback main

* test file_type fix

* lint

* unit

* simplify dict/obj handeling

* Update src/axolotl/utils/schemas/dynamic_checkpoint.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Delete tests/e2e/integrations/__init__.py

* remove hard code path in test

* device check

* lint

* Update src/axolotl/utils/callbacks/dynamic_checkpoint.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update src/axolotl/utils/callbacks/dynamic_checkpoint.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update src/axolotl/utils/schemas/dynamic_checkpoint.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* lint-2

* remove: singal based checkpoints

* lint

* remove signal tests

* add:is_main_process

* lint

* addis_d:istributed() for tests

* remove nested is_main_process

* Update src/axolotl/utils/schemas/dynamic_checkpoint.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update src/axolotl/utils/schemas/dynamic_checkpoint.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* add user_defined_filename

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
VED
2025-11-13 20:51:05 +05:30
committed by GitHub
parent 49b8107989
commit dcf24fd24e
6 changed files with 567 additions and 0 deletions

View File

@@ -118,6 +118,13 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.dynamic_checkpoint and self.cfg.dynamic_checkpoint.enabled:
from axolotl.utils.callbacks.dynamic_checkpoint import (
DynamicCheckpointCallback,
)
callbacks.append(DynamicCheckpointCallback(self.cfg))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)

View File

@@ -0,0 +1,132 @@
from pathlib import Path
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.utils.distributed import (
barrier,
is_distributed,
is_main_process,
)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
DEFAULT_TRIGGER_FILENAME = "axolotl_checkpoint.save"
class DynamicCheckpointCallback(TrainerCallback):
"""
Callback to save checkpoints on-demand during training via:
1. File-based trigger (works everywhere, rank 0 checks file)
Thread-safe for multi-GPU distributed training.
Usage:
# File-based:
touch /path/to/output_dir/axolotl_checkpoint.save
"""
def _get_config_value(self, config, key, default=None):
"""Helper to get config value from dict or object."""
if isinstance(config, dict):
return config.get(key, default)
return getattr(config, key, default)
def __init__(self, cfg):
self.cfg = cfg
if not cfg.dynamic_checkpoint or not cfg.dynamic_checkpoint.enabled:
self.enabled = False
return
self.enabled = True
dc_config = cfg.dynamic_checkpoint
trigger_file_path = self._get_config_value(dc_config, "trigger_file_path")
self.trigger_filename = (
trigger_file_path if trigger_file_path else DEFAULT_TRIGGER_FILENAME
)
check_interval = self._get_config_value(dc_config, "check_interval")
self.check_interval = check_interval if check_interval is not None else 100
self.should_save_checkpoint = False
LOG.info(
f"Dynamic checkpoint enabled. To trigger checkpoint save:\n"
f" • File: touch {cfg.output_dir}/{self.trigger_filename}\n"
f" • Check interval: every {self.check_interval} steps",
main_process_only=True,
)
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
) -> TrainerControl:
"""
Check for checkpoint triggers at the end of each step.
ONLY rank 0 checks the file, then all ranks synchronize.
"""
if not self.enabled:
return control
trigger_detected = False
if state.global_step % self.check_interval == 0:
if is_main_process():
trigger_path = Path(args.output_dir) / self.trigger_filename
if trigger_path.exists():
trigger_detected = True
try:
trigger_path.unlink() # Delete the trigger file
LOG.info(
f"Dynamic checkpoint triggered via file '{self.trigger_filename}' "
f"at step {state.global_step}",
main_process_only=True,
)
except OSError as exc:
LOG.warning(
f"Failed to delete trigger file: {exc}",
main_process_only=True,
)
if self.should_save_checkpoint:
trigger_detected = True
self.should_save_checkpoint = False # Reset flag
if is_distributed():
import torch
import torch.distributed as dist
device = getattr(
args,
"device",
torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
trigger_tensor = torch.tensor(
1 if trigger_detected else 0,
dtype=torch.long,
device=device,
)
dist.broadcast(trigger_tensor, src=0)
trigger_detected = bool(trigger_tensor.item())
barrier()
if trigger_detected:
control.should_save = True
LOG.info(
f"Saving dynamic checkpoint at step {state.global_step}",
main_process_only=True,
)
return control

View File

@@ -23,6 +23,7 @@ from axolotl.utils.schemas.datasets import (
StepwiseSupervisedDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.fsdp import FSDPConfig
from axolotl.utils.schemas.integrations import (
@@ -141,6 +142,13 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "Reward modelling: `True` or `False`"},
)
dynamic_checkpoint: DynamicCheckpointConfig | None = Field(
default=None,
json_schema_extra={
"description": "Configuration for dynamic checkpointing (trigger by file or signal). "
"Set 'enabled: true' to activate this feature."
},
)
process_reward_model: bool | None = Field(
default=None,
json_schema_extra={

View File

@@ -0,0 +1,31 @@
"""Schema for dynamic checkpoint configuration."""
from pydantic import BaseModel, Field
class DynamicCheckpointConfig(BaseModel):
"""Configuration for dynamic checkpoint triggering during training."""
enabled: bool = Field(
default=False,
json_schema_extra={
"description": "Enable dynamic checkpoint triggering during training. "
"Create a file 'axolotl_checkpoint.save' in the configured `output_dir` to trigger. "
},
)
check_interval: int = Field(
default=10,
ge=1,
json_schema_extra={
"description": "Check for trigger file every N steps (reduces I/O overhead). "
"Default: 100"
},
)
trigger_file_path: str = Field(
default="",
json_schema_extra={
"description": "Custom trigger filename (optional). "
"If not specified, defaults to 'axolotl_checkpoint.save'. "
"Specify a filename (not a full path) to override the default."
},
)