feat: save checkpoint after training started (#3233)
* add:config parameters for checkpoint * callback main * test file_type fix * lint * unit * simplify dict/obj handeling * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Delete tests/e2e/integrations/__init__.py * remove hard code path in test * device check * lint * Update src/axolotl/utils/callbacks/dynamic_checkpoint.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * Update src/axolotl/utils/callbacks/dynamic_checkpoint.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * lint-2 * remove: singal based checkpoints * lint * remove signal tests * add:is_main_process * lint * addis_d:istributed() for tests * remove nested is_main_process * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * Update src/axolotl/utils/schemas/dynamic_checkpoint.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * add user_defined_filename --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
389
tests/utils/callbacks/test_dynamic_checkpoint.py
Normal file
389
tests/utils/callbacks/test_dynamic_checkpoint.py
Normal file
@@ -0,0 +1,389 @@
|
||||
"""Unit tests for dynamic checkpoint callback"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from axolotl.utils.callbacks.dynamic_checkpoint import (
|
||||
DEFAULT_TRIGGER_FILENAME,
|
||||
DynamicCheckpointCallback,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestDynamicCheckpointCallbackInit:
|
||||
"""Test callback initialization"""
|
||||
|
||||
def test_callback_disabled_by_default(self):
|
||||
"""Test that callback is disabled when config.enabled=False"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {"enabled": False},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
assert callback.enabled is False
|
||||
|
||||
def test_callback_disabled_when_none(self):
|
||||
"""Test that callback is disabled when dynamic_checkpoint is None"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": None,
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
assert callback.enabled is False
|
||||
|
||||
def test_callback_enabled_when_configured(self):
|
||||
"""Test that callback is enabled when config.enabled=True"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {"enabled": True, "check_interval": 10},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
assert callback.enabled is True
|
||||
assert callback.check_interval == 10
|
||||
|
||||
def test_default_trigger_filename(self):
|
||||
"""Test that default trigger filename is used"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {"enabled": True, "check_interval": 10},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
assert callback.trigger_filename == DEFAULT_TRIGGER_FILENAME
|
||||
|
||||
def test_check_interval_default(self):
|
||||
"""Test default check interval"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {"enabled": True},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
assert callback.check_interval == 100 # Default from schema
|
||||
|
||||
|
||||
class TestDynamicCheckpointFileDetection:
|
||||
"""Test file-based checkpoint triggering"""
|
||||
|
||||
def test_trigger_file_detected_and_deleted(self):
|
||||
"""Test that trigger file is detected and deleted"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
|
||||
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
|
||||
trigger_file.touch()
|
||||
assert trigger_file.exists()
|
||||
|
||||
args = Mock(output_dir=tmpdir)
|
||||
state = Mock(global_step=1)
|
||||
control = Mock(should_save=False)
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||
return_value=True,
|
||||
):
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||
return_value=False,
|
||||
):
|
||||
result = callback.on_step_end(args, state, control)
|
||||
|
||||
assert not trigger_file.exists()
|
||||
assert result.should_save is True
|
||||
|
||||
def test_check_interval_honored(self):
|
||||
"""Test that file is only checked at check_interval steps"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {"enabled": True, "check_interval": 10},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
|
||||
args = Mock(output_dir=tmpdir)
|
||||
control = Mock(should_save=False)
|
||||
|
||||
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
|
||||
trigger_file.touch()
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||
return_value=True,
|
||||
):
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||
return_value=False,
|
||||
):
|
||||
# Step 5 - shouldn't check (not divisible by 10)
|
||||
state = Mock(global_step=5)
|
||||
result = callback.on_step_end(args, state, control)
|
||||
assert trigger_file.exists() # Still there
|
||||
assert result.should_save is False
|
||||
|
||||
# Step 10 - should check
|
||||
state = Mock(global_step=10)
|
||||
result = callback.on_step_end(args, state, control)
|
||||
assert not trigger_file.exists() # Deleted
|
||||
assert result.should_save is True
|
||||
|
||||
def test_no_file_no_trigger(self):
|
||||
"""Test that no trigger occurs when file doesn't exist"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
|
||||
args = Mock(output_dir=tmpdir)
|
||||
state = Mock(global_step=1)
|
||||
control = Mock(should_save=False)
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||
return_value=True,
|
||||
):
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||
return_value=False,
|
||||
):
|
||||
result = callback.on_step_end(args, state, control)
|
||||
|
||||
assert result.should_save is False
|
||||
|
||||
def test_file_deletion_error_handling(self):
|
||||
"""Test that file deletion errors are handled gracefully"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
|
||||
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
|
||||
trigger_file.touch()
|
||||
|
||||
args = Mock(output_dir=tmpdir)
|
||||
state = Mock(global_step=1)
|
||||
control = Mock(should_save=False)
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||
return_value=True,
|
||||
):
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||
return_value=False,
|
||||
):
|
||||
with patch.object(
|
||||
Path, "unlink", side_effect=OSError("Permission denied")
|
||||
):
|
||||
result = callback.on_step_end(args, state, control)
|
||||
|
||||
assert result.should_save is True
|
||||
|
||||
|
||||
class TestDynamicCheckpointMultiGPU:
|
||||
"""Test multi-GPU synchronization"""
|
||||
|
||||
def test_only_rank_0_checks_file(self):
|
||||
"""Test that only rank 0 checks filesystem in multi-GPU setup"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
|
||||
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
|
||||
trigger_file.touch()
|
||||
|
||||
args = Mock(output_dir=tmpdir)
|
||||
state = Mock(global_step=1)
|
||||
control = Mock(should_save=False)
|
||||
|
||||
# Rank 1 (not main process) - shouldn't check file
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||
return_value=False,
|
||||
):
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||
return_value=True,
|
||||
):
|
||||
with patch("torch.distributed.broadcast") as mock_broadcast:
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.barrier"
|
||||
):
|
||||
mock_tensor = MagicMock()
|
||||
mock_tensor.item.return_value = 0
|
||||
with patch("torch.tensor", return_value=mock_tensor):
|
||||
callback.on_step_end(args, state, control)
|
||||
|
||||
assert trigger_file.exists()
|
||||
# Broadcast should have been called
|
||||
assert mock_broadcast.called
|
||||
|
||||
def test_broadcast_synchronization(self):
|
||||
"""Test that trigger decision is broadcasted to all ranks"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
|
||||
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
|
||||
trigger_file.touch()
|
||||
|
||||
args = Mock(output_dir=tmpdir)
|
||||
state = Mock(global_step=1)
|
||||
control = Mock(should_save=False)
|
||||
|
||||
# Rank 0 detects file
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||
return_value=True,
|
||||
):
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||
return_value=True,
|
||||
):
|
||||
with patch("torch.distributed.broadcast") as mock_broadcast:
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.barrier"
|
||||
) as mock_barrier:
|
||||
mock_tensor = MagicMock()
|
||||
mock_tensor.item.return_value = 1
|
||||
with patch("torch.tensor", return_value=mock_tensor):
|
||||
with patch("torch.cuda.current_device", return_value=0):
|
||||
result = callback.on_step_end(args, state, control)
|
||||
|
||||
assert mock_broadcast.called
|
||||
assert mock_barrier.called
|
||||
# All ranks should trigger
|
||||
assert result.should_save is True
|
||||
|
||||
|
||||
class TestDynamicCheckpointSignalHandling:
|
||||
"""Test signal-based checkpoint triggering"""
|
||||
|
||||
def test_signal_trigger_via_callback(self):
|
||||
"""Test that signal flag triggers checkpoint save"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {
|
||||
"enabled": True,
|
||||
"check_interval": 1,
|
||||
"enable_signal": True,
|
||||
},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
|
||||
with patch("signal.signal"):
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||
return_value=True,
|
||||
):
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.hasattr",
|
||||
return_value=True,
|
||||
):
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
|
||||
callback.should_save_checkpoint = True
|
||||
|
||||
args = Mock(output_dir=tmpdir)
|
||||
state = Mock(global_step=1)
|
||||
control = Mock(should_save=False)
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
|
||||
return_value=True,
|
||||
):
|
||||
with patch(
|
||||
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
|
||||
return_value=False,
|
||||
):
|
||||
result = callback.on_step_end(args, state, control)
|
||||
|
||||
assert result.should_save is True
|
||||
assert callback.should_save_checkpoint is False
|
||||
|
||||
def test_signal_not_registered_when_disabled(self):
|
||||
"""Test that signal handler is not registered when disabled"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {
|
||||
"enabled": True,
|
||||
"check_interval": 10,
|
||||
"enable_signal": False,
|
||||
},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
|
||||
with patch("signal.signal") as mock_signal_register:
|
||||
_ = DynamicCheckpointCallback(cfg)
|
||||
|
||||
assert not mock_signal_register.called
|
||||
|
||||
|
||||
class TestDynamicCheckpointDisabled:
|
||||
"""Test behavior when callback is disabled"""
|
||||
|
||||
def test_disabled_callback_does_nothing(self):
|
||||
"""Test that disabled callback doesn't check or trigger"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"dynamic_checkpoint": {"enabled": False},
|
||||
"output_dir": tmpdir,
|
||||
}
|
||||
)
|
||||
callback = DynamicCheckpointCallback(cfg)
|
||||
|
||||
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
|
||||
trigger_file.touch()
|
||||
|
||||
args = Mock(output_dir=tmpdir)
|
||||
state = Mock(global_step=1)
|
||||
control = Mock(should_save=False)
|
||||
|
||||
result = callback.on_step_end(args, state, control)
|
||||
|
||||
assert trigger_file.exists()
|
||||
assert result.should_save is False
|
||||
Reference in New Issue
Block a user