Data loader refactor (#2707)
* data loading refactor (wip) * updates * progress * pytest * pytest fix * lint * zero_first -> filelock, more simplifications * small simplification * import change * nit * lint * simplify dedup * couldnt resist * review comments WIP * continued wip * minor changes * fix; remove contrived test * further refactor * set default seed in pydantic config * lint * continued simplication * lint * renaming and nits * filelock tests * fix * fix * lint * remove nullable arg * remove unnecessary code * moving dataset save fn to shared module * remove debug print * matching var naming * fn name change * coderabbit comments * naming nit * fix test
This commit is contained in:
@@ -105,7 +105,7 @@ def start_vllm(
|
||||
print(f"{i}: VLLM server failed to start: {str(exc)}")
|
||||
|
||||
# also check if the process.pid is still running
|
||||
if not process.poll() is None:
|
||||
if process.poll() is not None:
|
||||
break
|
||||
|
||||
time.sleep(period_seconds)
|
||||
|
||||
192
tests/e2e/multigpu/test_locking.py
Normal file
192
tests/e2e/multigpu/test_locking.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Tests for FileLockLoader class."""
|
||||
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.data.lock import FileLockLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestFileLockLoader:
|
||||
"""Class with tests for FileLockLoader."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(self):
|
||||
"""Create a temporary directory for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
yield Path(tmp_dir)
|
||||
|
||||
@pytest.fixture
|
||||
def cfg(self, temp_dir):
|
||||
"""Create a test configuration."""
|
||||
return DictDefault({"dataset_prepared_path": str(temp_dir)})
|
||||
|
||||
@pytest.fixture
|
||||
def loader(self, cfg):
|
||||
"""Create a FileLockLoader instance for testing."""
|
||||
return FileLockLoader(cfg)
|
||||
|
||||
def test_load_first_process(self, loader):
|
||||
"""Test load() when no ready flag exists (first process)."""
|
||||
mock_load_fn = Mock(return_value="test_data")
|
||||
|
||||
result = loader.load(mock_load_fn)
|
||||
|
||||
# Should call the load function
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "test_data"
|
||||
|
||||
# Should create the ready flag
|
||||
assert loader.ready_flag_path.exists()
|
||||
|
||||
def test_load_subsequent_process(self, loader):
|
||||
"""Test load() when ready flag already exists (subsequent process)."""
|
||||
# Create ready flag first
|
||||
loader.ready_flag_path.touch()
|
||||
|
||||
mock_load_fn = Mock(return_value="loaded_data")
|
||||
|
||||
result = loader.load(mock_load_fn)
|
||||
|
||||
# Should still call load function (to load the prepared data)
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "loaded_data"
|
||||
|
||||
def test_load_concurrent_processes(self, cfg):
|
||||
"""Test that concurrent processes coordinate correctly."""
|
||||
results = []
|
||||
call_count = 0
|
||||
|
||||
def slow_load_fn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
time.sleep(0.1) # Simulate slow loading
|
||||
return f"data_{call_count}"
|
||||
|
||||
def worker():
|
||||
loader = FileLockLoader(cfg)
|
||||
result = loader.load(slow_load_fn)
|
||||
results.append(result)
|
||||
|
||||
# Start multiple threads simultaneously
|
||||
threads = [threading.Thread(target=worker) for _ in range(3)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Only one thread should have done the initial loading
|
||||
# All should return data, but the load function should be called
|
||||
# once by the first process and once by each subsequent process
|
||||
assert len(results) == 3
|
||||
assert all(result.startswith("data_") for result in results)
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_load_waiting_for_ready_flag(self, mock_sleep, loader):
|
||||
"""Test that processes wait for the ready flag to appear."""
|
||||
mock_load_fn = Mock(return_value="waiting_data")
|
||||
mock_ready_flag_path = Mock()
|
||||
exists_call_count = 0
|
||||
|
||||
def mock_exists():
|
||||
nonlocal exists_call_count
|
||||
exists_call_count += 1
|
||||
|
||||
if exists_call_count == 1:
|
||||
# First check: ready flag exists (not first process)
|
||||
return True
|
||||
if exists_call_count <= 3:
|
||||
# While loop checks: flag doesn't exist yet
|
||||
return False
|
||||
return True
|
||||
|
||||
mock_ready_flag_path.exists.side_effect = mock_exists
|
||||
|
||||
# Replace the ready_flag_path with our mock
|
||||
original_path = loader.ready_flag_path
|
||||
loader.ready_flag_path = mock_ready_flag_path
|
||||
|
||||
try:
|
||||
result = loader.load(mock_load_fn)
|
||||
finally:
|
||||
# Restore original path
|
||||
loader.ready_flag_path = original_path
|
||||
|
||||
# Should have slept twice while waiting
|
||||
assert mock_sleep.call_count == 2
|
||||
mock_sleep.assert_called_with(1)
|
||||
|
||||
# Should eventually call load function
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "waiting_data"
|
||||
|
||||
def test_complete_workflow_with_cleanup(self, loader):
|
||||
"""Test the complete load -> cleanup workflow."""
|
||||
mock_load_fn = Mock(return_value="test_data")
|
||||
|
||||
# First process calls load (this should set up counter)
|
||||
result = loader.load(mock_load_fn)
|
||||
assert result == "test_data"
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.exists()
|
||||
|
||||
# Cleanup should remove everything since there's only one process
|
||||
loader.cleanup()
|
||||
assert not loader.ready_flag_path.exists()
|
||||
assert not loader.counter_path.exists()
|
||||
|
||||
def test_multiple_processes_workflow(self, loader):
|
||||
"""Test workflow with multiple processes."""
|
||||
# Simulate multiple processes by manually setting up counter
|
||||
loader.ready_flag_path.touch()
|
||||
loader.counter_path.write_text("3") # 3 processes
|
||||
|
||||
# First process cleanup
|
||||
loader.cleanup()
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.read_text().strip() == "2"
|
||||
|
||||
# Second process cleanup
|
||||
loader.cleanup()
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.read_text().strip() == "1"
|
||||
|
||||
# Last process cleanup
|
||||
loader.cleanup()
|
||||
assert not loader.ready_flag_path.exists()
|
||||
assert not loader.counter_path.exists()
|
||||
|
||||
def test_load_exception_handling(self, loader):
|
||||
"""Test behavior when load_fn raises an exception."""
|
||||
|
||||
def failing_load_fn():
|
||||
raise ValueError("Load failed")
|
||||
|
||||
with pytest.raises(ValueError, match="Load failed"):
|
||||
loader.load(failing_load_fn)
|
||||
|
||||
# Ready flag should not be created on failure
|
||||
assert not loader.ready_flag_path.exists()
|
||||
|
||||
def test_file_lock_called(self, loader):
|
||||
"""Test that FileLock is properly used."""
|
||||
mock_load_fn = Mock(return_value="locked_data")
|
||||
|
||||
with patch("axolotl.utils.data.lock.FileLock") as mock_filelock:
|
||||
mock_context = MagicMock()
|
||||
mock_filelock.return_value.__enter__ = Mock(return_value=mock_context)
|
||||
mock_filelock.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
loader.load(mock_load_fn)
|
||||
|
||||
# Verify FileLock was called with correct path
|
||||
mock_filelock.assert_called_once_with(str(loader.lock_file_path))
|
||||
|
||||
# Verify context manager was used
|
||||
mock_filelock.return_value.__enter__.assert_called_once()
|
||||
mock_filelock.return_value.__exit__.assert_called_once()
|
||||
Reference in New Issue
Block a user