From 86ed554bda82ece1a7f4e8ee47863d25b6046224 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 21 Feb 2025 20:31:07 +0000 Subject: [PATCH] tests for runtime metrics telemetry and assoc. callback --- src/axolotl/telemetry/__init__.py | 8 - src/axolotl/telemetry/callbacks.py | 10 +- tests/telemetry/__init__.py | 0 tests/telemetry/test_callbacks.py | 372 ++++++++++++++++++++++++ tests/telemetry/test_errors.py | 2 +- tests/telemetry/test_runtime_metrics.py | 360 +++++++++++++++++++++++ 6 files changed, 740 insertions(+), 12 deletions(-) delete mode 100644 src/axolotl/telemetry/__init__.py create mode 100644 tests/telemetry/__init__.py create mode 100644 tests/telemetry/test_callbacks.py create mode 100644 tests/telemetry/test_runtime_metrics.py diff --git a/src/axolotl/telemetry/__init__.py b/src/axolotl/telemetry/__init__.py deleted file mode 100644 index 2df5fd8a5..000000000 --- a/src/axolotl/telemetry/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Init for axolotl.telemetry module.""" - -from .manager import TelemetryConfig, TelemetryManager - -__all__ = [ - "TelemetryConfig", - "TelemetryManager", -] diff --git a/src/axolotl/telemetry/callbacks.py b/src/axolotl/telemetry/callbacks.py index cbafe2086..ca0aaae92 100644 --- a/src/axolotl/telemetry/callbacks.py +++ b/src/axolotl/telemetry/callbacks.py @@ -34,7 +34,7 @@ class TelemetryCallback(TrainerCallback): self.telemetry_manager = TelemetryManager.get_instance() self.current_epoch = -1 self.start_time = time.time() - self.last_report_time = self.start_time + self.last_report_time = None self.last_report_step = 0 def on_train_begin( @@ -110,12 +110,16 @@ class TelemetryCallback(TrainerCallback): if should_report: current_time = time.time() - time_since_last_report = current_time - self.last_report_time + if self.last_report_time is not None: + time_since_last_report = current_time - self.last_report_time + else: + time_since_last_report = current_time - self.start_time steps_since_last_report = step - self.last_report_step # Only report if enough time has passed to avoid flooding if ( - time_since_last_report >= TIME_SINCE_LAST + step == 1 + or time_since_last_report >= TIME_SINCE_LAST or steps_since_last_report >= self.report_interval_steps ): # Calculate steps per second for this interval diff --git a/tests/telemetry/__init__.py b/tests/telemetry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/telemetry/test_callbacks.py b/tests/telemetry/test_callbacks.py new file mode 100644 index 000000000..4324126e7 --- /dev/null +++ b/tests/telemetry/test_callbacks.py @@ -0,0 +1,372 @@ +"""Tests for telemetry callback module.""" +# pylint: disable=redefined-outer-name + +import time +from unittest.mock import MagicMock, patch + +import pytest +from transformers import TrainerControl, TrainerState, TrainingArguments + +from axolotl.telemetry.callbacks import TIME_SINCE_LAST, TelemetryCallback + + +def calc_expected_metrics(step, last_step, current_time, last_time, start_time=900.0): + """Calculate expected metrics values for tests""" + time_diff = current_time - last_time + step_diff = step - last_step + return { + "steps_per_second": step_diff / time_diff + if time_diff > 0 and step_diff > 0 + else 0, + "time_since_last_report": time_diff, + "elapsed_time": current_time - start_time, + } + + +@pytest.fixture +def mock_time(): + """Mock time.time() to have predictable values in tests""" + with patch("axolotl.telemetry.callbacks.time") as mock_time: + mock_time.time.return_value = 1000.0 + yield mock_time + + +@pytest.fixture +def mock_telemetry_manager(): + """Create a mock TelemetryManager""" + with patch("axolotl.telemetry.callbacks.TelemetryManager") as mock_manager_class: + mock_manager = MagicMock() + mock_manager_class.get_instance.return_value = mock_manager + yield mock_manager + + +@pytest.fixture +def mock_runtime_metrics_tracker(): + """Create a mock RuntimeMetricsTracker""" + with patch( + "axolotl.telemetry.callbacks.RuntimeMetricsTracker" + ) as mock_tracker_class: + mock_tracker = MagicMock() + # Set up metrics property on the tracker + mock_metrics = MagicMock() + mock_metrics.to_dict.return_value = { + "total_steps": 100, + "peak_cpu_memory_bytes": 1024, + } + mock_tracker.metrics = mock_metrics + + # Make the constructor return our mock + mock_tracker_class.return_value = mock_tracker + yield mock_tracker + + +@pytest.fixture +def training_args(): + """Create a minimal TrainingArguments instance""" + return TrainingArguments(output_dir="./output") + + +@pytest.fixture +def trainer_state(): + """Create a mock TrainerState""" + state = MagicMock(spec=TrainerState) + state.global_step = 10 + state.epoch = 0.5 # halfway through first epoch + state.log_history = [{"loss": 2.5, "learning_rate": 5e-5}] + return state + + +@pytest.fixture +def trainer_control(): + """Create a mock TrainerControl""" + return MagicMock(spec=TrainerControl) + + +# pylint: disable=unused-argument +@pytest.fixture +def callback(mock_telemetry_manager, mock_runtime_metrics_tracker): + """Create a TelemetryCallback instance with mocked dependencies""" + return TelemetryCallback() + + +class TestTelemetryCallback: + """Tests for the TelemetryCallback class.""" + + def test_initialization(self, callback, mock_runtime_metrics_tracker): + """Test callback initialization.""" + assert callback.current_epoch == -1 + assert callback.tracker == mock_runtime_metrics_tracker + assert callback.last_report_step == 0 + assert hasattr(callback, "start_time") + assert hasattr(callback, "last_report_time") + assert callback.report_interval_steps == 100 + + def test_on_train_begin( + self, + callback, + mock_telemetry_manager, + training_args, + trainer_state, + trainer_control, + ): + """Test on_train_begin sends expected event.""" + callback.on_train_begin(training_args, trainer_state, trainer_control) + + mock_telemetry_manager.send_event.assert_called_once_with( + event_type="train-start" + ) + + def test_on_train_end( + self, + callback, + mock_telemetry_manager, + training_args, + trainer_state, + trainer_control, + ): + """Test on_train_end sends expected event with metrics.""" + callback.on_train_end(training_args, trainer_state, trainer_control) + + mock_telemetry_manager.send_event.assert_called_once() + call_args = mock_telemetry_manager.send_event.call_args[1] + + assert call_args["event_type"] == "train-end" + assert "loss" in call_args["properties"] + assert call_args["properties"]["loss"] == 2.5 + assert "learning_rate" in call_args["properties"] + assert call_args["properties"]["learning_rate"] == 5e-5 + + # Check that metrics from RuntimeMetricsTracker are included + assert "total_steps" in call_args["properties"] + assert call_args["properties"]["total_steps"] == 100 + assert "peak_cpu_memory_bytes" in call_args["properties"] + assert call_args["properties"]["peak_cpu_memory_bytes"] == 1024 + + def test_on_epoch_begin( + self, + callback, + mock_runtime_metrics_tracker, + training_args, + trainer_state, + trainer_control, + ): + """Test on_epoch_begin updates epoch counter and calls tracker.""" + initial_epoch = callback.current_epoch + + callback.on_epoch_begin(training_args, trainer_state, trainer_control) + + assert callback.current_epoch == initial_epoch + 1 + mock_runtime_metrics_tracker.start_epoch.assert_called_once_with( + initial_epoch + 1 + ) + + def test_on_epoch_end( + self, + callback, + mock_runtime_metrics_tracker, + training_args, + trainer_state, + trainer_control, + ): + """Test on_epoch_end calls tracker.""" + # Set current epoch + callback.current_epoch = 2 + + callback.on_epoch_end(training_args, trainer_state, trainer_control) + + mock_runtime_metrics_tracker.end_epoch.assert_called_once_with(2) + + def test_on_step_end_no_report( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, + training_args, + trainer_state, + trainer_control, + ): + """Test on_step_end updates tracker but doesn't report if criteria not met.""" + # Set up state to avoid reporting + trainer_state.global_step = 42 # Not divisible by report_interval_steps + callback.last_report_step = 41 # Just 1 step since last report + callback.last_report_time = time.time() # Just now + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should update tracker + mock_runtime_metrics_tracker.update_step.assert_called_once_with(42) + + # Should not send telemetry + mock_telemetry_manager.send_event.assert_not_called() + + # Should not update last report time/step + assert callback.last_report_step == 41 + + def test_on_step_end_report_interval_steps( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, + mock_time, + training_args, + trainer_state, + trainer_control, + ): + """Test on_step_end reports when step interval is reached.""" + # Set up state with clear values + current_step = 100 # Exactly matches report_interval_steps + last_step = 0 + start_time = 900.0 + current_time = 1000.0 + time_diff = current_time - start_time # 100 seconds + + # Configure state and callback + trainer_state.global_step = current_step + callback.report_interval_steps = 100 + callback.last_report_step = last_step + callback.start_time = start_time + callback.last_report_time = start_time + + # Mock time.time() to return consistent values + mock_time.time.return_value = current_time + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should update tracker + mock_runtime_metrics_tracker.update_step.assert_called_once_with(current_step) + mock_runtime_metrics_tracker.update_memory_metrics.assert_called_once() + + # Should send telemetry + mock_telemetry_manager.send_event.assert_called_once() + call_args = mock_telemetry_manager.send_event.call_args[1] + assert call_args["event_type"] == "train-progress" + + # Properties should include expected values + props = call_args["properties"] + assert props["step"] == current_step + assert props["elapsed_time"] == time_diff # 1000 - 900 = 100 + assert props["time_since_last_report"] == time_diff # 1000 - 900 = 100 + assert props["steps_per_second"] == 1.0 # 100 steps / 100 seconds + + # Should update last report time/step + assert callback.last_report_step == current_step + assert callback.last_report_time == current_time + + def test_on_step_end_report_time_elapsed( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, # pylint: disable=unused-argument + mock_time, + training_args, + trainer_state, + trainer_control, + ): + """Test on_step_end reports when enough time has elapsed.""" + # Set up state with clear values + current_step = 120 + last_step = 10 + start_time = 900.0 + current_time = 1000.0 + time_diff = TIME_SINCE_LAST + 1 # Just over the threshold + + # Configure state and callback + trainer_state.global_step = current_step + callback.report_interval_steps = 100 + callback.last_report_step = last_step + callback.start_time = start_time + callback.last_report_time = current_time - time_diff + + # Mock time.time() to return consistent values + mock_time.time.return_value = current_time + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should send telemetry + mock_telemetry_manager.send_event.assert_called_once() + + # Properties should include expected values + props = mock_telemetry_manager.send_event.call_args[1]["properties"] + expected_metrics = calc_expected_metrics( + current_step, last_step, current_time, current_time - time_diff, start_time + ) + assert props["steps_per_second"] == expected_metrics["steps_per_second"] + assert ( + props["time_since_last_report"] + == expected_metrics["time_since_last_report"] + ) + + def test_on_step_end_first_step( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, # pylint: disable=unused-argument + mock_time, + training_args, + trainer_state, + trainer_control, + ): + """Test on_step_end always reports on first step.""" + # Set up state with clear values + current_step = 1 # First step + last_step = 0 + start_time = 900.0 + current_time = 1000.0 + last_report_time = 999.0 # Just 1 second ago + + # Configure state and callback + trainer_state.global_step = current_step + callback.report_interval_steps = 100 + callback.last_report_step = last_step + callback.start_time = start_time + callback.last_report_time = last_report_time + + # Mock time.time() to return consistent values + mock_time.time.return_value = current_time + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should send telemetry even though not much time has passed + mock_telemetry_manager.send_event.assert_called_once() + + # Properties should include expected values for first step + props = mock_telemetry_manager.send_event.call_args[1]["properties"] + assert props["step"] == current_step + expected_metrics = calc_expected_metrics( + current_step, last_step, current_time, last_report_time, start_time + ) + assert props["steps_per_second"] == expected_metrics["steps_per_second"] + + def test_log_history_empty( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, # pylint: disable=unused-argument + mock_time, + training_args, + trainer_state, + trainer_control, + ): + """Test handling of empty log history.""" + # Set up state with clear values + current_step = 1 + start_time = 900.0 + current_time = 1000.0 + + # Configure state and callback + trainer_state.global_step = current_step + trainer_state.log_history = [] + callback.start_time = start_time + + # Mock time.time() to return consistent values + mock_time.time.return_value = current_time + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should still send telemetry + mock_telemetry_manager.send_event.assert_called_once() + + # Properties should have default values for missing log data + props = mock_telemetry_manager.send_event.call_args[1]["properties"] + assert props["loss"] == 0 + assert props["learning_rate"] == 0 diff --git a/tests/telemetry/test_errors.py b/tests/telemetry/test_errors.py index a021fc96f..3d00c0f28 100644 --- a/tests/telemetry/test_errors.py +++ b/tests/telemetry/test_errors.py @@ -1,4 +1,4 @@ -"""Tests for telemetry error utilities.""" +"""Tests for telemetry error utilities""" # pylint: disable=redefined-outer-name from unittest.mock import MagicMock, patch diff --git a/tests/telemetry/test_runtime_metrics.py b/tests/telemetry/test_runtime_metrics.py new file mode 100644 index 000000000..11c7faf98 --- /dev/null +++ b/tests/telemetry/test_runtime_metrics.py @@ -0,0 +1,360 @@ +"""Tests for runtime metrics telemetry module""" +# pylint: disable=redefined-outer-name + +from unittest.mock import MagicMock, patch + +import pytest + +from axolotl.telemetry.runtime_metrics import RuntimeMetrics, RuntimeMetricsTracker + + +@pytest.fixture +def mock_time(): + """Mock time.time() to have predictable values in tests""" + with patch("time.time") as mock_time: + # Start with time 1000.0 and increment by 10 seconds on each call + times = [1000.0 + i * 10 for i in range(10)] + mock_time.side_effect = times + yield mock_time + + +@pytest.fixture +def mock_telemetry_manager(): + """Create a mock TelemetryManager""" + with patch( + "axolotl.telemetry.runtime_metrics.TelemetryManager" + ) as mock_manager_class: + mock_manager = MagicMock() + mock_manager.enabled = True + mock_manager_class.get_instance.return_value = mock_manager + yield mock_manager + + +@pytest.fixture +def mock_psutil(): + """Mock psutil for memory information""" + with patch("axolotl.telemetry.runtime_metrics.psutil") as mock_psutil: + mock_process = MagicMock() + mock_memory_info = MagicMock() + # Set initial memory to 1GB + mock_memory_info.rss = 1024 * 1024 * 1024 + mock_process.memory_info.return_value = mock_memory_info + mock_psutil.Process.return_value = mock_process + yield mock_psutil + + +@pytest.fixture +def mock_torch(): + """Mock torch.cuda functions""" + with patch("axolotl.telemetry.runtime_metrics.torch") as mock_torch: + mock_torch.cuda.is_available.return_value = True + mock_torch.cuda.device_count.return_value = 2 + + # Mock memory allocated per device (1GB for device 0, 2GB for device 1) + mock_torch.cuda.memory_allocated.side_effect = ( + lambda device: (device + 1) * 1024 * 1024 * 1024 + ) + + yield mock_torch + + +class TestRuntimeMetrics: + """Tests for RuntimeMetrics class.""" + + def test_initialization(self): + """Test RuntimeMetrics initialization.""" + metrics = RuntimeMetrics(start_time=1000.0) + + assert metrics.start_time == 1000.0 + assert metrics.epoch_start_times == {} + assert metrics.epoch_end_times == {} + assert metrics.peak_gpu_memory == {} + assert metrics.total_steps == 0 + assert metrics.current_epoch == 0 + assert metrics.current_step == 0 + assert metrics.peak_cpu_memory == 0 + + def test_elapsed_time(self, mock_time): + """Test elapsed_time property.""" + metrics = RuntimeMetrics(start_time=1000.0) + + # Mock time.time() to return 1050.0 + mock_time.side_effect = [1050.0] + + assert metrics.elapsed_time == 50.0 + + def test_epoch_time(self): + """Test epoch_time method.""" + metrics = RuntimeMetrics(start_time=1000.0) + + # No epoch data + assert metrics.epoch_time(0) is None + + # Add epoch start but no end + metrics.epoch_start_times[0] = 1000.0 + assert metrics.epoch_time(0) is None + + # Add epoch end + metrics.epoch_end_times[0] = 1060.0 + assert metrics.epoch_time(0) == 60.0 + + def test_average_epoch_time(self): + """Test average_epoch_time method.""" + metrics = RuntimeMetrics(start_time=1000.0) + + # No completed epochs + assert metrics.average_epoch_time() is None + + # Add one completed epoch + metrics.epoch_start_times[0] = 1000.0 + metrics.epoch_end_times[0] = 1060.0 + assert metrics.average_epoch_time() == 60.0 + + # Add second completed epoch + metrics.epoch_start_times[1] = 1060.0 + metrics.epoch_end_times[1] = 1140.0 # 80 seconds + assert metrics.average_epoch_time() == 70.0 # Average of 60 and 80 + + # Add incomplete epoch (should not affect average) + metrics.epoch_start_times[2] = 1140.0 + assert metrics.average_epoch_time() == 70.0 + + def test_steps_per_second(self, mock_time): + """Test steps_per_second method.""" + metrics = RuntimeMetrics(start_time=1000.0) + + # No steps - first call to time.time() + mock_time.side_effect = None + mock_time.return_value = 1050.0 + assert metrics.steps_per_second() is None + + # Add steps - second call to time.time() + metrics.total_steps = 100 + mock_time.return_value = 1050.0 # Keep same time for consistent result + assert metrics.steps_per_second() == 2.0 # 100 steps / 50 seconds + + def test_to_dict_basic(self, mock_time): + """Test to_dict method with basic metrics.""" + metrics = RuntimeMetrics(start_time=1000.0) + metrics.total_steps = 100 + metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 # 2GB + + # Mock elapsed_time + mock_time.side_effect = None + mock_time.return_value = 1050.0 + + result = metrics.to_dict() + + assert result["total_time_seconds"] == 50.0 + assert result["total_steps"] == 100 + assert result["steps_per_second"] == 2.0 + assert result["epochs_completed"] == 0 + assert result["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024 + assert "epoch_times" not in result + assert "gpu_memory" not in result + + def test_to_dict_with_epochs(self, mock_time): + """Test to_dict method with epoch data.""" + metrics = RuntimeMetrics(start_time=1000.0) + metrics.total_steps = 100 + + # Add epoch data + metrics.epoch_start_times[0] = 1000.0 + metrics.epoch_end_times[0] = 1060.0 + metrics.epoch_start_times[1] = 1060.0 + metrics.epoch_end_times[1] = 1140.0 + + # Mock elapsed_time + mock_time.side_effect = None + mock_time.return_value = 1150.0 + + result = metrics.to_dict() + + assert "epoch_times" in result + assert result["epoch_times"]["epoch_0_seconds"] == 60.0 + assert result["epoch_times"]["epoch_1_seconds"] == 80.0 + assert result["average_epoch_time_seconds"] == 70.0 + + def test_to_dict_with_gpu_memory(self, mock_time): + """Test to_dict method with GPU memory data.""" + metrics = RuntimeMetrics(start_time=1000.0) + metrics.peak_gpu_memory = { + 0: 1 * 1024 * 1024 * 1024, # 1GB + 1: 2 * 1024 * 1024 * 1024, # 2GB + } + + # Mock elapsed_time + mock_time.side_effect = [1050.0] + + result = metrics.to_dict() + + assert "gpu_memory" in result + assert result["gpu_memory"]["gpu_0_peak_memory_bytes"] == 1 * 1024 * 1024 * 1024 + assert result["gpu_memory"]["gpu_1_peak_memory_bytes"] == 2 * 1024 * 1024 * 1024 + + +class TestRuntimeMetricsTracker: + """Tests for RuntimeMetricsTracker class.""" + + # pylint: disable=unused-argument + def test_initialization(self, mock_time, mock_telemetry_manager): + """Test RuntimeMetricsTracker initialization.""" + tracker = RuntimeMetricsTracker() + + assert isinstance(tracker.metrics, RuntimeMetrics) + assert tracker.metrics.start_time == 1000.0 # First value from mock_time + + # pylint: disable=unused-argument + def test_start_epoch( + self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager + ): + """Test start_epoch method.""" + tracker = RuntimeMetricsTracker() + + # Reset mock_time to control next value + mock_time.side_effect = [1010.0] + + tracker.start_epoch(0) + + assert tracker.metrics.current_epoch == 0 + assert tracker.metrics.epoch_start_times[0] == 1010.0 + + # Verify memory metrics were updated + assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024 + assert 0 in tracker.metrics.peak_gpu_memory + assert 1 in tracker.metrics.peak_gpu_memory + + # pylint: disable=unused-argument + def test_end_epoch(self, mock_time, mock_telemetry_manager): + """Test end_epoch method.""" + tracker = RuntimeMetricsTracker() + + # Start epoch 0 + mock_time.side_effect = [1010.0] + tracker.start_epoch(0) + + # End epoch 0 + mock_time.side_effect = [1060.0] + tracker.end_epoch(0) + + assert 0 in tracker.metrics.epoch_end_times + assert tracker.metrics.epoch_end_times[0] == 1060.0 + + # pylint: disable=unused-argument + def test_update_step( + self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager + ): + """Test update_step method.""" + tracker = RuntimeMetricsTracker() + + # Update step to a non-multiple of 100 + tracker.update_step(42) + + assert tracker.metrics.current_step == 42 + assert tracker.metrics.total_steps == 1 + + # Memory metrics should not be updated for non-multiple of 100 + assert tracker.metrics.peak_cpu_memory == 0 + + # Update step to a multiple of 100 + tracker.update_step(100) + + assert tracker.metrics.current_step == 100 + assert tracker.metrics.total_steps == 2 + + # Memory metrics should be updated for multiple of 100 + assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024 + + # pylint: disable=unused-argument + def test_update_memory_metrics( + self, mock_psutil, mock_torch, mock_telemetry_manager + ): + """Test update_memory_metrics method.""" + tracker = RuntimeMetricsTracker() + + # Initial memory state + assert tracker.metrics.peak_cpu_memory == 0 + assert tracker.metrics.peak_gpu_memory == {} + + # Update memory metrics + tracker.update_memory_metrics() + + # Verify CPU memory + assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024 + + # Verify GPU memory + assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024 + + # Change mocked memory values to be lower + mock_process = mock_psutil.Process.return_value + mock_memory_info = mock_process.memory_info.return_value + mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB + + mock_torch.cuda.memory_allocated.side_effect = ( + lambda device: (device + 0.5) * 1024 * 1024 * 1024 + ) + + # Update memory metrics again + tracker.update_memory_metrics() + + # Peak values should not decrease + assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024 + + # Change mocked memory values to be higher + mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB + + mock_torch.cuda.memory_allocated.side_effect = ( + lambda device: (device + 2) * 1024 * 1024 * 1024 + ) + + # Update memory metrics again + tracker.update_memory_metrics() + + # Peak values should increase + assert tracker.metrics.peak_cpu_memory == 2 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[0] == 2 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[1] == 3 * 1024 * 1024 * 1024 + + # pylint: disable=unused-argument + def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_telemetry_manager): + """Test get_memory_metrics method.""" + tracker = RuntimeMetricsTracker() + + # Set peak memory values + tracker.metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 + tracker.metrics.peak_gpu_memory = { + 0: 3 * 1024 * 1024 * 1024, + 1: 4 * 1024 * 1024 * 1024, + } + + # Get memory metrics + result = tracker.get_memory_metrics() + + # Verify structure + assert "memory" in result + memory = result["memory"] + + # Verify CPU memory + assert ( + memory["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024 + ) # Current value from mock + assert ( + memory["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024 + ) # Peak value we set + + # Verify GPU memory + assert ( + memory["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024 + ) # Current value from mock + assert ( + memory["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024 + ) # Peak value we set + assert ( + memory["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024 + ) # Current value from mock + assert ( + memory["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024 + ) # Peak value we set