feat: save checkpoint after training started (#3233)
* add:config parameters for checkpoint * callback main * test file_type fix * lint * unit * simplify dict/obj handeling * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Delete tests/e2e/integrations/__init__.py * remove hard code path in test * device check * lint * Update src/axolotl/utils/callbacks/dynamic_checkpoint.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * Update src/axolotl/utils/callbacks/dynamic_checkpoint.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * lint-2 * remove: singal based checkpoints * lint * remove signal tests * add:is_main_process * lint * addis_d:istributed() for tests * remove nested is_main_process * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * add user_defined_filename --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -118,6 +118,13 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.gc_steps:
|
if self.cfg.gc_steps:
|
||||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||||
|
|
||||||
|
if self.cfg.dynamic_checkpoint and self.cfg.dynamic_checkpoint.enabled:
|
||||||
|
from axolotl.utils.callbacks.dynamic_checkpoint import (
|
||||||
|
DynamicCheckpointCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks.append(DynamicCheckpointCallback(self.cfg))
|
||||||
|
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
|
|||||||
132
src/axolotl/utils/callbacks/dynamic_checkpoint.py
Normal file
132
src/axolotl/utils/callbacks/dynamic_checkpoint.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import (
|
||||||
|
barrier,
|
||||||
|
is_distributed,
|
||||||
|
is_main_process,
|
||||||
|
)
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_TRIGGER_FILENAME = "axolotl_checkpoint.save"
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicCheckpointCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Callback to save checkpoints on-demand during training via:
|
||||||
|
1. File-based trigger (works everywhere, rank 0 checks file)
|
||||||
|
|
||||||
|
Thread-safe for multi-GPU distributed training.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# File-based:
|
||||||
|
touch /path/to/output_dir/axolotl_checkpoint.save
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_config_value(self, config, key, default=None):
|
||||||
|
"""Helper to get config value from dict or object."""
|
||||||
|
if isinstance(config, dict):
|
||||||
|
return config.get(key, default)
|
||||||
|
return getattr(config, key, default)
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
|
if not cfg.dynamic_checkpoint or not cfg.dynamic_checkpoint.enabled:
|
||||||
|
self.enabled = False
|
||||||
|
return
|
||||||
|
|
||||||
|
self.enabled = True
|
||||||
|
dc_config = cfg.dynamic_checkpoint
|
||||||
|
|
||||||
|
trigger_file_path = self._get_config_value(dc_config, "trigger_file_path")
|
||||||
|
self.trigger_filename = (
|
||||||
|
trigger_file_path if trigger_file_path else DEFAULT_TRIGGER_FILENAME
|
||||||
|
)
|
||||||
|
|
||||||
|
check_interval = self._get_config_value(dc_config, "check_interval")
|
||||||
|
self.check_interval = check_interval if check_interval is not None else 100
|
||||||
|
self.should_save_checkpoint = False
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"Dynamic checkpoint enabled. To trigger checkpoint save:\n"
|
||||||
|
f" • File: touch {cfg.output_dir}/{self.trigger_filename}\n"
|
||||||
|
f" • Check interval: every {self.check_interval} steps",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**_kwargs,
|
||||||
|
) -> TrainerControl:
|
||||||
|
"""
|
||||||
|
Check for checkpoint triggers at the end of each step.
|
||||||
|
ONLY rank 0 checks the file, then all ranks synchronize.
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
return control
|
||||||
|
|
||||||
|
trigger_detected = False
|
||||||
|
|
||||||
|
if state.global_step % self.check_interval == 0:
|
||||||
|
if is_main_process():
|
||||||
|
trigger_path = Path(args.output_dir) / self.trigger_filename
|
||||||
|
|
||||||
|
if trigger_path.exists():
|
||||||
|
trigger_detected = True
|
||||||
|
try:
|
||||||
|
trigger_path.unlink() # Delete the trigger file
|
||||||
|
LOG.info(
|
||||||
|
f"Dynamic checkpoint triggered via file '{self.trigger_filename}' "
|
||||||
|
f"at step {state.global_step}",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
|
except OSError as exc:
|
||||||
|
LOG.warning(
|
||||||
|
f"Failed to delete trigger file: {exc}",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.should_save_checkpoint:
|
||||||
|
trigger_detected = True
|
||||||
|
self.should_save_checkpoint = False # Reset flag
|
||||||
|
|
||||||
|
if is_distributed():
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
device = getattr(
|
||||||
|
args,
|
||||||
|
"device",
|
||||||
|
torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||||
|
)
|
||||||
|
|
||||||
|
trigger_tensor = torch.tensor(
|
||||||
|
1 if trigger_detected else 0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
dist.broadcast(trigger_tensor, src=0)
|
||||||
|
|
||||||
|
trigger_detected = bool(trigger_tensor.item())
|
||||||
|
|
||||||
|
barrier()
|
||||||
|
|
||||||
|
if trigger_detected:
|
||||||
|
control.should_save = True
|
||||||
|
LOG.info(
|
||||||
|
f"Saving dynamic checkpoint at step {state.global_step}",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
|
return control
|
||||||
@@ -23,6 +23,7 @@ from axolotl.utils.schemas.datasets import (
|
|||||||
StepwiseSupervisedDataset,
|
StepwiseSupervisedDataset,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
|
||||||
|
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
|
||||||
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
||||||
from axolotl.utils.schemas.fsdp import FSDPConfig
|
from axolotl.utils.schemas.fsdp import FSDPConfig
|
||||||
from axolotl.utils.schemas.integrations import (
|
from axolotl.utils.schemas.integrations import (
|
||||||
@@ -141,6 +142,13 @@ class AxolotlInputConfig(
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Reward modelling: `True` or `False`"},
|
json_schema_extra={"description": "Reward modelling: `True` or `False`"},
|
||||||
)
|
)
|
||||||
|
dynamic_checkpoint: DynamicCheckpointConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Configuration for dynamic checkpointing (trigger by file or signal). "
|
||||||
|
"Set 'enabled: true' to activate this feature."
|
||||||
|
},
|
||||||
|
)
|
||||||
process_reward_model: bool | None = Field(
|
process_reward_model: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
31
src/axolotl/utils/schemas/dynamic_checkpoint.py
Normal file
31
src/axolotl/utils/schemas/dynamic_checkpoint.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""Schema for dynamic checkpoint configuration."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicCheckpointConfig(BaseModel):
|
||||||
|
"""Configuration for dynamic checkpoint triggering during training."""
|
||||||
|
|
||||||
|
enabled: bool = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Enable dynamic checkpoint triggering during training. "
|
||||||
|
"Create a file 'axolotl_checkpoint.save' in the configured `output_dir` to trigger. "
|
||||||
|
},
|
||||||
|
)
|
||||||
|
check_interval: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Check for trigger file every N steps (reduces I/O overhead). "
|
||||||
|
"Default: 100"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
trigger_file_path: str = Field(
|
||||||
|
default="",
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Custom trigger filename (optional). "
|
||||||
|
"If not specified, defaults to 'axolotl_checkpoint.save'. "
|
||||||
|
"Specify a filename (not a full path) to override the default."
|
||||||
|
},
|
||||||
|
)
|
||||||
389
tests/utils/callbacks/test_dynamic_checkpoint.py
Normal file
389
tests/utils/callbacks/test_dynamic_checkpoint.py
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
"""Unit tests for dynamic checkpoint callback"""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
from axolotl.utils.callbacks.dynamic_checkpoint import (
|
||||||
|
DEFAULT_TRIGGER_FILENAME,
|
||||||
|
DynamicCheckpointCallback,
|
||||||
|
)
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class TestDynamicCheckpointCallbackInit:
|
||||||
|
"""Test callback initialization"""
|
||||||
|
|
||||||
|
def test_callback_disabled_by_default(self):
|
||||||
|
"""Test that callback is disabled when config.enabled=False"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {"enabled": False},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
assert callback.enabled is False
|
||||||
|
|
||||||
|
def test_callback_disabled_when_none(self):
|
||||||
|
"""Test that callback is disabled when dynamic_checkpoint is None"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": None,
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
assert callback.enabled is False
|
||||||
|
|
||||||
|
def test_callback_enabled_when_configured(self):
|
||||||
|
"""Test that callback is enabled when config.enabled=True"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {"enabled": True, "check_interval": 10},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
assert callback.enabled is True
|
||||||
|
assert callback.check_interval == 10
|
||||||
|
|
||||||
|
def test_default_trigger_filename(self):
|
||||||
|
"""Test that default trigger filename is used"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {"enabled": True, "check_interval": 10},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
assert callback.trigger_filename == DEFAULT_TRIGGER_FILENAME
|
||||||
|
|
||||||
|
def test_check_interval_default(self):
|
||||||
|
"""Test default check interval"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {"enabled": True},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
assert callback.check_interval == 100 # Default from schema
|
||||||
|
|
||||||
|
|
||||||
|
class TestDynamicCheckpointFileDetection:
|
||||||
|
"""Test file-based checkpoint triggering"""
|
||||||
|
|
||||||
|
def test_trigger_file_detected_and_deleted(self):
|
||||||
|
"""Test that trigger file is detected and deleted"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
|
||||||
|
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
|
||||||
|
trigger_file.touch()
|
||||||
|
assert trigger_file.exists()
|
||||||
|
|
||||||
|
args = Mock(output_dir=tmpdir)
|
||||||
|
state = Mock(global_step=1)
|
||||||
|
control = Mock(should_save=False)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
|
result = callback.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert not trigger_file.exists()
|
||||||
|
assert result.should_save is True
|
||||||
|
|
||||||
|
def test_check_interval_honored(self):
|
||||||
|
"""Test that file is only checked at check_interval steps"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {"enabled": True, "check_interval": 10},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
|
||||||
|
args = Mock(output_dir=tmpdir)
|
||||||
|
control = Mock(should_save=False)
|
||||||
|
|
||||||
|
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
|
||||||
|
trigger_file.touch()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
|
# Step 5 - shouldn't check (not divisible by 10)
|
||||||
|
state = Mock(global_step=5)
|
||||||
|
result = callback.on_step_end(args, state, control)
|
||||||
|
assert trigger_file.exists() # Still there
|
||||||
|
assert result.should_save is False
|
||||||
|
|
||||||
|
# Step 10 - should check
|
||||||
|
state = Mock(global_step=10)
|
||||||
|
result = callback.on_step_end(args, state, control)
|
||||||
|
assert not trigger_file.exists() # Deleted
|
||||||
|
assert result.should_save is True
|
||||||
|
|
||||||
|
def test_no_file_no_trigger(self):
|
||||||
|
"""Test that no trigger occurs when file doesn't exist"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
|
||||||
|
args = Mock(output_dir=tmpdir)
|
||||||
|
state = Mock(global_step=1)
|
||||||
|
control = Mock(should_save=False)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
|
result = callback.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert result.should_save is False
|
||||||
|
|
||||||
|
def test_file_deletion_error_handling(self):
|
||||||
|
"""Test that file deletion errors are handled gracefully"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
|
||||||
|
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
|
||||||
|
trigger_file.touch()
|
||||||
|
|
||||||
|
args = Mock(output_dir=tmpdir)
|
||||||
|
state = Mock(global_step=1)
|
||||||
|
control = Mock(should_save=False)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
Path, "unlink", side_effect=OSError("Permission denied")
|
||||||
|
):
|
||||||
|
result = callback.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert result.should_save is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestDynamicCheckpointMultiGPU:
|
||||||
|
"""Test multi-GPU synchronization"""
|
||||||
|
|
||||||
|
def test_only_rank_0_checks_file(self):
|
||||||
|
"""Test that only rank 0 checks filesystem in multi-GPU setup"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
|
||||||
|
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
|
||||||
|
trigger_file.touch()
|
||||||
|
|
||||||
|
args = Mock(output_dir=tmpdir)
|
||||||
|
state = Mock(global_step=1)
|
||||||
|
control = Mock(should_save=False)
|
||||||
|
|
||||||
|
# Rank 1 (not main process) - shouldn't check file
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
with patch("torch.distributed.broadcast") as mock_broadcast:
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.barrier"
|
||||||
|
):
|
||||||
|
mock_tensor = MagicMock()
|
||||||
|
mock_tensor.item.return_value = 0
|
||||||
|
with patch("torch.tensor", return_value=mock_tensor):
|
||||||
|
callback.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert trigger_file.exists()
|
||||||
|
# Broadcast should have been called
|
||||||
|
assert mock_broadcast.called
|
||||||
|
|
||||||
|
def test_broadcast_synchronization(self):
|
||||||
|
"""Test that trigger decision is broadcasted to all ranks"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
|
||||||
|
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
|
||||||
|
trigger_file.touch()
|
||||||
|
|
||||||
|
args = Mock(output_dir=tmpdir)
|
||||||
|
state = Mock(global_step=1)
|
||||||
|
control = Mock(should_save=False)
|
||||||
|
|
||||||
|
# Rank 0 detects file
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
with patch("torch.distributed.broadcast") as mock_broadcast:
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.barrier"
|
||||||
|
) as mock_barrier:
|
||||||
|
mock_tensor = MagicMock()
|
||||||
|
mock_tensor.item.return_value = 1
|
||||||
|
with patch("torch.tensor", return_value=mock_tensor):
|
||||||
|
with patch("torch.cuda.current_device", return_value=0):
|
||||||
|
result = callback.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert mock_broadcast.called
|
||||||
|
assert mock_barrier.called
|
||||||
|
# All ranks should trigger
|
||||||
|
assert result.should_save is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestDynamicCheckpointSignalHandling:
|
||||||
|
"""Test signal-based checkpoint triggering"""
|
||||||
|
|
||||||
|
def test_signal_trigger_via_callback(self):
|
||||||
|
"""Test that signal flag triggers checkpoint save"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {
|
||||||
|
"enabled": True,
|
||||||
|
"check_interval": 1,
|
||||||
|
"enable_signal": True,
|
||||||
|
},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("signal.signal"):
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.hasattr",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
|
||||||
|
callback.should_save_checkpoint = True
|
||||||
|
|
||||||
|
args = Mock(output_dir=tmpdir)
|
||||||
|
state = Mock(global_step=1)
|
||||||
|
control = Mock(should_save=False)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
|
result = callback.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert result.should_save is True
|
||||||
|
assert callback.should_save_checkpoint is False
|
||||||
|
|
||||||
|
def test_signal_not_registered_when_disabled(self):
|
||||||
|
"""Test that signal handler is not registered when disabled"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {
|
||||||
|
"enabled": True,
|
||||||
|
"check_interval": 10,
|
||||||
|
"enable_signal": False,
|
||||||
|
},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("signal.signal") as mock_signal_register:
|
||||||
|
_ = DynamicCheckpointCallback(cfg)
|
||||||
|
|
||||||
|
assert not mock_signal_register.called
|
||||||
|
|
||||||
|
|
||||||
|
class TestDynamicCheckpointDisabled:
|
||||||
|
"""Test behavior when callback is disabled"""
|
||||||
|
|
||||||
|
def test_disabled_callback_does_nothing(self):
|
||||||
|
"""Test that disabled callback doesn't check or trigger"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"dynamic_checkpoint": {"enabled": False},
|
||||||
|
"output_dir": tmpdir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
callback = DynamicCheckpointCallback(cfg)
|
||||||
|
|
||||||
|
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
|
||||||
|
trigger_file.touch()
|
||||||
|
|
||||||
|
args = Mock(output_dir=tmpdir)
|
||||||
|
state = Mock(global_step=1)
|
||||||
|
control = Mock(should_save=False)
|
||||||
|
|
||||||
|
result = callback.on_step_end(args, state, control)
|
||||||
|
|
||||||
|
assert trigger_file.exists()
|
||||||
|
assert result.should_save is False
|
||||||
Reference in New Issue
Block a user