feat: Add opt-out Telemetry (#3237)
* initial telemetry manager impl
* adding todo
* updates
* updates
* progress on telemetry: config load, process, model load, train start / end, error tracking
* update error file path sanitization function; adding more error tracking
* updated sanitization logic, tests
* adding runtime metrics (cpu + gpu memory, steps/s, etc.)
* tests for runtime metrics telemetry and assoc. callback
* small update / fix
* simplifying path redaction
* sleep on all ranks in distributed setting
* adding back in base_model redaction w/ whitelist
* fix
* doc update
* improved redaction, send system info during model config load telemetry, etc.
* adding runtime metrics / system info additional accelerator support, etc.
* adding runtime metrics / system info additional accelerator support, etc.
* remove duplicate info
* fixes
* fix issue with tests in ci
* distributed fix
* opt-in version of telemetry
* enable / disable logic update
* docs fix
* doc update
* minor fixes
* simplifying
* slight changes
* fix
* lint
* update posthog dep
* coderabbit comments
* fix: opt-in model
* fix: increase time since last
* fix: increase whitelist orgs
* fix: posthog init and shutdown
* fix: imports
* fix: also check grad norm
* fix: duplicate plugin_manager calls
* fix: bad merge
* chore: update docs
* fix: cache process per comment
* fix: error handling
* fix: tests
* Revert "fix: error handling"
This reverts commit 22d1ea5755.
* fix: test telemetry error_handled bool
* fix: revert test
* chore: final doc fixes
---------
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
Co-authored-by: Dan Saunders <dan@axolotl.ai>
This commit is contained in:
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
shared pytest fixtures
|
||||
"""
|
||||
"""Shared pytest fixtures"""
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
@@ -582,3 +580,9 @@ def test_load_fixtures(
|
||||
download_llama2_model_fixture,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_telemetry(monkeypatch):
|
||||
monkeypatch.setenv("AXOLOTL_DO_NOT_TRACK", "1")
|
||||
yield
|
||||
|
||||
0
tests/telemetry/__init__.py
Normal file
0
tests/telemetry/__init__.py
Normal file
9
tests/telemetry/conftest.py
Normal file
9
tests/telemetry/conftest.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Shared pytest fixtures for telemetry tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def del_track_env(monkeypatch):
|
||||
monkeypatch.delenv("AXOLOTL_DO_NOT_TRACK", raising=False)
|
||||
yield
|
||||
373
tests/telemetry/test_callbacks.py
Normal file
373
tests/telemetry/test_callbacks.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""Tests for telemetry callback module."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
|
||||
from axolotl.telemetry.callbacks import TIME_SINCE_LAST, TelemetryCallback
|
||||
|
||||
|
||||
def calc_expected_metrics(step, last_step, current_time, last_time, start_time=900.0):
|
||||
"""Calculate expected metrics values for tests"""
|
||||
time_diff = current_time - last_time
|
||||
step_diff = step - last_step
|
||||
return {
|
||||
"steps_per_second": (
|
||||
step_diff / time_diff if time_diff > 0 and step_diff > 0 else 0
|
||||
),
|
||||
"time_since_last_report": time_diff,
|
||||
"elapsed_time": current_time - start_time,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_time():
|
||||
"""Mock time.time() to have predictable values in tests"""
|
||||
with patch("axolotl.telemetry.callbacks.time") as mock_time:
|
||||
mock_time.time.return_value = 1000.0
|
||||
yield mock_time
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telemetry_manager():
|
||||
"""Create a mock TelemetryManager"""
|
||||
with patch("axolotl.telemetry.callbacks.TelemetryManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime_metrics_tracker():
|
||||
"""Create a mock RuntimeMetricsTracker"""
|
||||
with patch(
|
||||
"axolotl.telemetry.callbacks.RuntimeMetricsTracker"
|
||||
) as mock_tracker_class:
|
||||
mock_tracker = MagicMock()
|
||||
# Set up metrics property on the tracker
|
||||
mock_metrics = MagicMock()
|
||||
mock_metrics.to_dict.return_value = {
|
||||
"total_steps": 100,
|
||||
"peak_cpu_memory_bytes": 1024,
|
||||
}
|
||||
mock_tracker.metrics = mock_metrics
|
||||
|
||||
# Make the constructor return our mock
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
yield mock_tracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def training_args():
|
||||
"""Create a minimal TrainingArguments instance"""
|
||||
return TrainingArguments(output_dir="./output")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trainer_state():
|
||||
"""Create a mock TrainerState"""
|
||||
state = MagicMock(spec=TrainerState)
|
||||
state.global_step = 10
|
||||
state.epoch = 0.5 # halfway through first epoch
|
||||
state.log_history = [{"loss": 2.5, "learning_rate": 5e-5}]
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trainer_control():
|
||||
"""Create a mock TrainerControl"""
|
||||
return MagicMock(spec=TrainerControl)
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@pytest.fixture
|
||||
def callback(mock_telemetry_manager, mock_runtime_metrics_tracker):
|
||||
"""Create a TelemetryCallback instance with mocked dependencies"""
|
||||
return TelemetryCallback()
|
||||
|
||||
|
||||
class TestTelemetryCallback:
|
||||
"""Tests for the TelemetryCallback class."""
|
||||
|
||||
def test_initialization(self, callback, mock_runtime_metrics_tracker):
|
||||
"""Test callback initialization."""
|
||||
assert callback.current_epoch == -1
|
||||
assert callback.tracker == mock_runtime_metrics_tracker
|
||||
assert callback.last_report_step == 0
|
||||
assert hasattr(callback, "start_time")
|
||||
assert hasattr(callback, "last_report_time")
|
||||
assert callback.report_interval_steps == 100
|
||||
|
||||
def test_on_train_begin(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_train_begin sends expected event."""
|
||||
callback.on_train_begin(training_args, trainer_state, trainer_control)
|
||||
|
||||
mock_telemetry_manager.send_event.assert_called_once_with(
|
||||
event_type="train-start"
|
||||
)
|
||||
|
||||
def test_on_train_end(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_train_end sends expected event with metrics."""
|
||||
callback.on_train_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
|
||||
assert call_args["event_type"] == "train-end"
|
||||
assert "loss" in call_args["properties"]
|
||||
assert call_args["properties"]["loss"] == 2.5
|
||||
assert "learning_rate" in call_args["properties"]
|
||||
assert call_args["properties"]["learning_rate"] == 5e-5
|
||||
|
||||
# Check that metrics from RuntimeMetricsTracker are included
|
||||
assert "total_steps" in call_args["properties"]
|
||||
assert call_args["properties"]["total_steps"] == 100
|
||||
assert "peak_cpu_memory_bytes" in call_args["properties"]
|
||||
assert call_args["properties"]["peak_cpu_memory_bytes"] == 1024
|
||||
|
||||
def test_on_epoch_begin(
|
||||
self,
|
||||
callback,
|
||||
mock_runtime_metrics_tracker,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_epoch_begin updates epoch counter and calls tracker."""
|
||||
initial_epoch = callback.current_epoch
|
||||
|
||||
callback.on_epoch_begin(training_args, trainer_state, trainer_control)
|
||||
|
||||
assert callback.current_epoch == initial_epoch + 1
|
||||
mock_runtime_metrics_tracker.start_epoch.assert_called_once_with(
|
||||
initial_epoch + 1
|
||||
)
|
||||
|
||||
def test_on_epoch_end(
|
||||
self,
|
||||
callback,
|
||||
mock_runtime_metrics_tracker,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_epoch_end calls tracker."""
|
||||
# Set current epoch
|
||||
callback.current_epoch = 2
|
||||
|
||||
callback.on_epoch_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
mock_runtime_metrics_tracker.end_epoch.assert_called_once_with(2)
|
||||
|
||||
def test_on_step_end_no_report(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end updates tracker but doesn't report if criteria not met."""
|
||||
# Set up state to avoid reporting
|
||||
trainer_state.global_step = 42 # Not divisible by report_interval_steps
|
||||
callback.last_report_step = 41 # Just 1 step since last report
|
||||
callback.last_report_time = time.time() # Just now
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should update tracker
|
||||
mock_runtime_metrics_tracker.update_step.assert_called_once_with(42)
|
||||
|
||||
# Should not send telemetry
|
||||
mock_telemetry_manager.send_event.assert_not_called()
|
||||
|
||||
# Should not update last report time/step
|
||||
assert callback.last_report_step == 41
|
||||
|
||||
def test_on_step_end_report_interval_steps(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker,
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end reports when step interval is reached."""
|
||||
# Set up state with clear values
|
||||
current_step = 100 # Exactly matches report_interval_steps
|
||||
last_step = 0
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
time_diff = current_time - start_time # 100 seconds
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
callback.report_interval_steps = 100
|
||||
callback.last_report_step = last_step
|
||||
callback.start_time = start_time
|
||||
callback.last_report_time = start_time
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should update tracker
|
||||
mock_runtime_metrics_tracker.update_step.assert_called_once_with(current_step)
|
||||
mock_runtime_metrics_tracker.update_memory_metrics.assert_called_once()
|
||||
|
||||
# Should send telemetry
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
assert call_args["event_type"] == "train-progress"
|
||||
|
||||
# Properties should include expected values
|
||||
props = call_args["properties"]
|
||||
assert props["step"] == current_step
|
||||
assert props["elapsed_time"] == time_diff # 1000 - 900 = 100
|
||||
assert props["time_since_last_report"] == time_diff # 1000 - 900 = 100
|
||||
assert props["steps_per_second"] == 1.0 # 100 steps / 100 seconds
|
||||
|
||||
# Should update last report time/step
|
||||
assert callback.last_report_step == current_step
|
||||
assert callback.last_report_time == current_time
|
||||
|
||||
def test_on_step_end_report_time_elapsed(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end reports when enough time has elapsed."""
|
||||
# Set up state with clear values
|
||||
current_step = 120
|
||||
last_step = 10
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
time_diff = TIME_SINCE_LAST + 1 # Just over the threshold
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
callback.report_interval_steps = 100
|
||||
callback.last_report_step = last_step
|
||||
callback.start_time = start_time
|
||||
callback.last_report_time = current_time - time_diff
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should send telemetry
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Properties should include expected values
|
||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
||||
expected_metrics = calc_expected_metrics(
|
||||
current_step, last_step, current_time, current_time - time_diff, start_time
|
||||
)
|
||||
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
|
||||
assert (
|
||||
props["time_since_last_report"]
|
||||
== expected_metrics["time_since_last_report"]
|
||||
)
|
||||
|
||||
def test_on_step_end_first_step(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end always reports on first step."""
|
||||
# Set up state with clear values
|
||||
current_step = 1 # First step
|
||||
last_step = 0
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
last_report_time = 999.0 # Just 1 second ago
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
callback.report_interval_steps = 100
|
||||
callback.last_report_step = last_step
|
||||
callback.start_time = start_time
|
||||
callback.last_report_time = last_report_time
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should send telemetry even though not much time has passed
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Properties should include expected values for first step
|
||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
||||
assert props["step"] == current_step
|
||||
expected_metrics = calc_expected_metrics(
|
||||
current_step, last_step, current_time, last_report_time, start_time
|
||||
)
|
||||
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
|
||||
|
||||
def test_log_history_empty(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test handling of empty log history."""
|
||||
# Set up state with clear values
|
||||
current_step = 1
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
trainer_state.log_history = []
|
||||
callback.start_time = start_time
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should still send telemetry
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Properties should have default values for missing log data
|
||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
||||
assert props["loss"] == 0
|
||||
assert props["learning_rate"] == 0
|
||||
341
tests/telemetry/test_errors.py
Normal file
341
tests/telemetry/test_errors.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""Tests for telemetry error utilities"""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.telemetry.errors import sanitize_stack_trace, send_errors
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_error_flag(monkeypatch):
|
||||
"""Reset ERROR_HANDLED flag using monkeypatch"""
|
||||
import axolotl.telemetry.errors
|
||||
|
||||
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
|
||||
yield
|
||||
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_stack_trace():
|
||||
"""Provide a sample stack trace with mixed paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
|
||||
trainer = get_trainer(cfg)
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer
|
||||
model = get_model(cfg, tokenizer)
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model
|
||||
raise ValueError("Model path not found")
|
||||
ValueError: Model path not found
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def windows_stack_trace():
|
||||
"""Provide a sample stack trace with Windows paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main
|
||||
trainer = get_trainer(cfg)
|
||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer
|
||||
model = get_model(cfg, tokenizer)
|
||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained
|
||||
raise ValueError(f"Unrecognized configuration class {config.__class__}")
|
||||
ValueError: Unrecognized configuration class <class 'transformers.models.llama.configuration_llama.LlamaConfig'>
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixed_stack_trace():
|
||||
"""Provide a sample stack trace with both axolotl and non-axolotl paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
|
||||
trainer = get_trainer(cfg)
|
||||
File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train
|
||||
self._inner_training_loop()
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop
|
||||
super()._inner_training_loop()
|
||||
File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
|
||||
data = self._next_data()
|
||||
RuntimeError: CUDA out of memory
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def venv_stack_trace():
|
||||
"""Provide a sample stack trace with virtual environment paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train
|
||||
self._inner_training_loop()
|
||||
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop
|
||||
self.accelerator.backward(loss)
|
||||
File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward
|
||||
self.scaler.scale(loss).backward(**kwargs)
|
||||
File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
|
||||
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
|
||||
RuntimeError: CUDA out of memory
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dist_packages_stack_trace():
|
||||
"""Provide a sample stack trace with dist-packages paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
|
||||
data = self._next_data()
|
||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data
|
||||
data = self._dataset_fetcher.fetch(index)
|
||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
|
||||
data = [self.dataset[idx] for idx in possibly_batched_index]
|
||||
File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__
|
||||
raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.")
|
||||
IndexError: Index 10000 out of range for dataset of length 9832.
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_stack_trace():
|
||||
"""Provide a sample stack trace from a project directory (not a virtual env)"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/projects/myproject/run.py", line 25, in <module>
|
||||
main()
|
||||
File "/home/user/projects/myproject/src/cli.py", line 45, in main
|
||||
app.run()
|
||||
File "/home/user/projects/myproject/src/app.py", line 102, in run
|
||||
raise ValueError("Configuration missing")
|
||||
ValueError: Configuration missing
|
||||
"""
|
||||
|
||||
|
||||
def test_sanitize_stack_trace(example_stack_trace):
|
||||
"""Test that sanitize_stack_trace properly preserves axolotl paths"""
|
||||
sanitized = sanitize_stack_trace(example_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "/home/user" not in sanitized
|
||||
assert ".local/lib/python3.9" not in sanitized
|
||||
|
||||
# Check that site-packages is preserved
|
||||
assert "site-packages/axolotl/cli/train.py" in sanitized
|
||||
assert "site-packages/axolotl/train.py" in sanitized
|
||||
assert "site-packages/axolotl/utils/models.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "ValueError: Model path not found" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_windows_paths(windows_stack_trace):
|
||||
"""Test that sanitize_stack_trace handles Windows paths"""
|
||||
sanitized = sanitize_stack_trace(windows_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "C:\\Users\\name" not in sanitized
|
||||
assert "AppData\\Local\\Programs\\Python" not in sanitized
|
||||
|
||||
# Check that both axolotl and transformers packages are preserved
|
||||
assert (
|
||||
"site-packages\\axolotl\\cli\\train.py" in sanitized
|
||||
or "site-packages/axolotl/cli/train.py" in sanitized
|
||||
)
|
||||
assert (
|
||||
"site-packages\\axolotl\\train.py" in sanitized
|
||||
or "site-packages/axolotl/train.py" in sanitized
|
||||
)
|
||||
assert (
|
||||
"site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized
|
||||
or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized
|
||||
)
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "ValueError: Unrecognized configuration class" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_mixed_paths(mixed_stack_trace):
|
||||
"""Test that sanitize_stack_trace preserves all package paths"""
|
||||
sanitized = sanitize_stack_trace(mixed_stack_trace)
|
||||
|
||||
# Check that all package paths are preserved
|
||||
assert "site-packages/axolotl/cli/train.py" in sanitized
|
||||
assert "site-packages/transformers/trainer.py" in sanitized
|
||||
assert "site-packages/axolotl/utils/trainer.py" in sanitized
|
||||
assert "site-packages/torch/utils/data/dataloader.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "RuntimeError: CUDA out of memory" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_venv_paths(venv_stack_trace):
|
||||
"""Test that sanitize_stack_trace preserves virtual environment package paths"""
|
||||
sanitized = sanitize_stack_trace(venv_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "/home/user/venv" not in sanitized
|
||||
|
||||
# Check that all package paths are preserved
|
||||
assert "site-packages/transformers/trainer.py" in sanitized
|
||||
assert "site-packages/accelerate/accelerator.py" in sanitized
|
||||
assert "site-packages/torch/_tensor.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "RuntimeError: CUDA out of memory" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_dist_packages(dist_packages_stack_trace):
|
||||
"""Test that sanitize_stack_trace preserves dist-packages paths"""
|
||||
sanitized = sanitize_stack_trace(dist_packages_stack_trace)
|
||||
|
||||
# Check that system paths are removed
|
||||
assert "/usr/local/lib/python3.8" not in sanitized
|
||||
|
||||
# Check that all package paths are preserved
|
||||
assert "dist-packages/torch/utils/data/dataloader.py" in sanitized
|
||||
assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized
|
||||
assert "dist-packages/datasets/arrow_dataset.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert (
|
||||
"IndexError: Index 10000 out of range for dataset of length 9832." in sanitized
|
||||
)
|
||||
|
||||
|
||||
def test_sanitize_project_paths(project_stack_trace):
|
||||
"""Test handling of project paths (non-virtual env)"""
|
||||
sanitized = sanitize_stack_trace(project_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "/home/user/projects" not in sanitized
|
||||
|
||||
# For non-package paths, we should at least preserve the filename
|
||||
assert "run.py" in sanitized
|
||||
assert "cli.py" in sanitized
|
||||
assert "app.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "ValueError: Configuration missing" in sanitized
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telemetry_manager():
|
||||
"""Create a mock TelemetryManager"""
|
||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = True
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
|
||||
def test_send_errors_successful_execution(mock_telemetry_manager):
|
||||
"""Test that send_errors doesn't send telemetry for successful function execution"""
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
result = test_func()
|
||||
assert result == "success"
|
||||
mock_telemetry_manager.send_event.assert_not_called()
|
||||
|
||||
|
||||
def test_send_errors_with_exception(mock_telemetry_manager):
|
||||
"""Test that send_errors sends telemetry when an exception occurs"""
|
||||
test_error = ValueError("Test error")
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise test_error
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
test_func()
|
||||
|
||||
assert excinfo.value == test_error
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Check that the error info was passed correctly
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
assert "test_func-error" in call_args["event_type"]
|
||||
assert "Test error" in call_args["properties"]["exception"]
|
||||
assert "stack_trace" in call_args["properties"]
|
||||
|
||||
|
||||
def test_send_errors_nested_calls(mock_telemetry_manager):
|
||||
"""Test that send_errors only sends telemetry once for nested decorated functions"""
|
||||
|
||||
@send_errors
|
||||
def inner_func():
|
||||
raise ValueError("Inner error")
|
||||
|
||||
@send_errors
|
||||
def outer_func():
|
||||
return inner_func()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
outer_func()
|
||||
|
||||
# Telemetry should be sent only once for the inner function
|
||||
assert mock_telemetry_manager.send_event.call_count == 1
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
assert "inner_func-error" in call_args["event_type"]
|
||||
|
||||
|
||||
def test_send_errors_telemetry_disable():
|
||||
"""Test that send_errors doesn't attempt to send telemetry when disabled"""
|
||||
|
||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = False
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise ValueError("Test error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
mock_manager.send_event.assert_not_called()
|
||||
|
||||
|
||||
def test_error_handled_reset():
|
||||
"""Test that ERROR_HANDLED flag is properly reset"""
|
||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
||||
# Create and configure the mock manager
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = True
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
|
||||
from axolotl.telemetry.errors import ERROR_HANDLED
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise ValueError("Test error")
|
||||
|
||||
assert not ERROR_HANDLED
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
from axolotl.telemetry.errors import ERROR_HANDLED
|
||||
|
||||
assert ERROR_HANDLED
|
||||
|
||||
|
||||
def test_module_path_resolution(mock_telemetry_manager):
|
||||
"""Test that the module path is correctly resolved for the event type"""
|
||||
import inspect
|
||||
|
||||
current_module = inspect.getmodule(test_module_path_resolution).__name__
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise ValueError("Test error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
assert mock_telemetry_manager.send_event.called
|
||||
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
|
||||
|
||||
expected_event_type = f"{current_module}.test_func-error"
|
||||
assert expected_event_type == event_type
|
||||
275
tests/telemetry/test_manager.py
Normal file
275
tests/telemetry/test_manager.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""Tests for TelemetryManager class and utilities"""
|
||||
|
||||
# pylint: disable=redefined-outer-name,protected-access
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_whitelist(tmp_path):
|
||||
"""Create a temporary whitelist file for testing"""
|
||||
whitelist_content = {
|
||||
"organizations": ["meta-llama", "mistralai"],
|
||||
}
|
||||
whitelist_file = tmp_path / "whitelist.yaml"
|
||||
with open(whitelist_file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(whitelist_content, f)
|
||||
|
||||
return str(whitelist_file)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def telemetry_manager_class():
|
||||
"""Reset the TelemetryManager singleton between tests"""
|
||||
original_instance = TelemetryManager._instance
|
||||
original_initialized = TelemetryManager._initialized
|
||||
TelemetryManager._instance = None
|
||||
TelemetryManager._initialized = False
|
||||
yield TelemetryManager
|
||||
TelemetryManager._instance = original_instance
|
||||
TelemetryManager._initialized = original_initialized
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(telemetry_manager_class, mock_whitelist):
|
||||
"""Create a TelemetryManager instance with mocked dependencies"""
|
||||
with (
|
||||
patch("posthog.capture"),
|
||||
patch("posthog.flush"),
|
||||
patch("time.sleep"),
|
||||
patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist),
|
||||
patch.dict(os.environ, {"RANK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
# Manually enable for most tests
|
||||
manager.enabled = True
|
||||
return manager
|
||||
|
||||
|
||||
def test_singleton_instance(telemetry_manager_class):
|
||||
"""Test that TelemetryManager is a singleton"""
|
||||
with (
|
||||
patch("posthog.capture"),
|
||||
patch("time.sleep"),
|
||||
patch.dict(os.environ, {"RANK": "0"}),
|
||||
):
|
||||
first = telemetry_manager_class()
|
||||
second = telemetry_manager_class()
|
||||
assert first is second
|
||||
assert telemetry_manager_class.get_instance() is first
|
||||
|
||||
|
||||
def test_telemetry_enabled_by_default(telemetry_manager_class):
|
||||
"""Test that telemetry is enabled by default (opt-out)"""
|
||||
with (
|
||||
patch.dict(os.environ, {"RANK": "0"}, clear=True),
|
||||
patch("time.sleep"),
|
||||
patch("logging.Logger.info"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class):
|
||||
"""Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0"""
|
||||
with (
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
|
||||
with (
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"}
|
||||
),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled for non-main processes"""
|
||||
with (
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_opt_in_info_displayed(telemetry_manager_class):
|
||||
"""Test that opt-in info is displayed when telemetry is not configured"""
|
||||
with (
|
||||
patch.dict(os.environ, {"RANK": "0"}, clear=True),
|
||||
patch("logging.Logger.warning") as mock_warning,
|
||||
patch("time.sleep"),
|
||||
):
|
||||
telemetry_manager_class()
|
||||
assert any(
|
||||
"Telemetry is now enabled by default" in str(call)
|
||||
for call in mock_warning.call_args_list
|
||||
)
|
||||
|
||||
|
||||
def test_is_whitelisted(telemetry_manager_class, mock_whitelist):
|
||||
"""Test org whitelist functionality"""
|
||||
with (
|
||||
patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist),
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
|
||||
# Should match organizations from the mock whitelist
|
||||
assert manager._is_whitelisted("meta-llama/llama-7b")
|
||||
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
|
||||
# Should not match
|
||||
assert not manager._is_whitelisted("unknown/model")
|
||||
# Should handle case insensitively
|
||||
assert manager._is_whitelisted("META-LLAMA/Llama-7B")
|
||||
# Should handle empty input
|
||||
assert not manager._is_whitelisted("")
|
||||
|
||||
|
||||
def test_system_info_collection(manager):
|
||||
"""Test system information collection"""
|
||||
system_info = manager._get_system_info()
|
||||
|
||||
# Check essential keys
|
||||
assert "os" in system_info
|
||||
assert "python_version" in system_info
|
||||
assert "cpu_count" in system_info
|
||||
assert "memory_total" in system_info
|
||||
assert "accelerator_count" in system_info
|
||||
|
||||
|
||||
def test_send_event(telemetry_manager_class):
|
||||
"""Test basic event sending"""
|
||||
with (
|
||||
patch("posthog.capture") as mock_capture,
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
|
||||
# Test with clean properties (no PII)
|
||||
manager.send_event("test_event", {"key": "value"})
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["event"] == "test_event"
|
||||
assert mock_capture.call_args[1]["properties"] == {"key": "value"}
|
||||
assert mock_capture.call_args[1]["distinct_id"] == manager.run_id
|
||||
|
||||
# Test with default properties (None)
|
||||
mock_capture.reset_mock()
|
||||
manager.send_event("simple_event")
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["properties"] == {}
|
||||
|
||||
|
||||
def test_send_system_info(telemetry_manager_class):
|
||||
"""Test sending system info"""
|
||||
with (
|
||||
patch("posthog.capture") as mock_capture,
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
manager.send_system_info()
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["event"] == "system-info"
|
||||
assert mock_capture.call_args[1]["properties"] == manager.system_info
|
||||
|
||||
|
||||
def test_redacted_properties(telemetry_manager_class):
|
||||
"""Test path redaction in send_event method"""
|
||||
with (
|
||||
patch("posthog.capture") as mock_capture,
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
# Test with properties containing various paths and non-paths
|
||||
test_properties = {
|
||||
"filepath": "/home/user/sensitive/data.txt",
|
||||
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
|
||||
"output_dir": "/var/lib/data",
|
||||
"path_to_model": "models/llama/7b",
|
||||
"message": "Training started", # Should not be redacted
|
||||
"metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted
|
||||
"base_model": "models/local_model",
|
||||
"nested": {
|
||||
"model_path": "/models/my_model",
|
||||
"root_dir": "/home/user/projects",
|
||||
"stats": {"steps": 1000, "epochs": 3}, # Should not be redacted
|
||||
},
|
||||
}
|
||||
|
||||
manager.send_event("test_event", test_properties)
|
||||
|
||||
# Verify the call was made
|
||||
assert mock_capture.called
|
||||
|
||||
# Get the sanitized properties that were sent
|
||||
sanitized = mock_capture.call_args[1]["properties"]
|
||||
|
||||
# Check that path-like and base_model keys were redacted
|
||||
assert sanitized["filepath"] == "[REDACTED]"
|
||||
assert sanitized["windows_path"] == "[REDACTED]"
|
||||
assert sanitized["path_to_model"] == "[REDACTED]"
|
||||
assert sanitized["base_model"] == "[REDACTED]"
|
||||
|
||||
# Check that non-path values were preserved
|
||||
assert sanitized["message"] == "Training started"
|
||||
assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95}
|
||||
|
||||
# Check nested structure handling
|
||||
assert sanitized["nested"]["model_path"] == "[REDACTED]"
|
||||
assert sanitized["nested"]["root_dir"] == "[REDACTED]"
|
||||
assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3}
|
||||
|
||||
|
||||
def test_disable_telemetry(manager):
|
||||
"""Test that disabled telemetry doesn't send events"""
|
||||
with patch("posthog.capture") as mock_capture:
|
||||
manager.enabled = False
|
||||
manager.send_event("test_event")
|
||||
assert not mock_capture.called
|
||||
|
||||
|
||||
def test_exception_handling_during_send(manager):
|
||||
"""Test that exceptions in PostHog are handled gracefully"""
|
||||
with (
|
||||
patch("posthog.capture", side_effect=Exception("Test error")),
|
||||
patch("logging.Logger.warning") as mock_warning,
|
||||
):
|
||||
manager.send_event("test_event")
|
||||
warning_logged = False
|
||||
for call in mock_warning.call_args_list:
|
||||
if "Failed to send telemetry event" in str(call):
|
||||
warning_logged = True
|
||||
break
|
||||
assert warning_logged
|
||||
|
||||
|
||||
def test_shutdown(manager):
|
||||
"""Test shutdown behavior"""
|
||||
with patch("posthog.shutdown") as mock_shutdown:
|
||||
manager.shutdown()
|
||||
assert mock_shutdown.called
|
||||
357
tests/telemetry/test_runtime_metrics.py
Normal file
357
tests/telemetry/test_runtime_metrics.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""Tests for runtime metrics telemetry module"""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.telemetry.runtime_metrics import RuntimeMetrics, RuntimeMetricsTracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_time():
|
||||
"""Mock time.time() to have predictable values in tests"""
|
||||
with patch("time.time") as mock_time:
|
||||
# Start with time 1000.0 and increment by 10 seconds on each call
|
||||
times = [1000.0 + i * 10 for i in range(10)]
|
||||
mock_time.side_effect = times
|
||||
yield mock_time
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telemetry_manager():
|
||||
"""Create a mock TelemetryManager"""
|
||||
with patch(
|
||||
"axolotl.telemetry.runtime_metrics.TelemetryManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = True
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_psutil():
|
||||
"""Mock psutil for memory information"""
|
||||
with patch("axolotl.telemetry.runtime_metrics.psutil") as mock_psutil:
|
||||
mock_process = MagicMock()
|
||||
mock_memory_info = MagicMock()
|
||||
# Set initial memory to 1GB
|
||||
mock_memory_info.rss = 1024 * 1024 * 1024
|
||||
mock_process.memory_info.return_value = mock_memory_info
|
||||
mock_psutil.Process.return_value = mock_process
|
||||
yield mock_psutil
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_torch():
|
||||
"""Mock torch.cuda functions"""
|
||||
with patch("axolotl.telemetry.runtime_metrics.torch") as mock_torch:
|
||||
mock_torch.cuda.is_available.return_value = True
|
||||
mock_torch.cuda.device_count.return_value = 2
|
||||
|
||||
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 1) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
yield mock_torch
|
||||
|
||||
|
||||
class TestRuntimeMetrics:
|
||||
"""Tests for RuntimeMetrics class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test RuntimeMetrics initialization."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
assert metrics.start_time == 1000.0
|
||||
assert metrics.epoch_start_times == {}
|
||||
assert metrics.epoch_end_times == {}
|
||||
assert metrics.peak_gpu_memory == {}
|
||||
assert metrics.total_steps == 0
|
||||
assert metrics.current_epoch == 0
|
||||
assert metrics.current_step == 0
|
||||
assert metrics.peak_cpu_memory == 0
|
||||
|
||||
def test_elapsed_time(self, mock_time):
|
||||
"""Test elapsed_time property."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# Mock time.time() to return 1050.0
|
||||
mock_time.side_effect = [1050.0]
|
||||
|
||||
assert metrics.elapsed_time == 50.0
|
||||
|
||||
def test_epoch_time(self):
|
||||
"""Test epoch_time method."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# No epoch data
|
||||
assert metrics.epoch_time(0) is None
|
||||
|
||||
# Add epoch start but no end
|
||||
metrics.epoch_start_times[0] = 1000.0
|
||||
assert metrics.epoch_time(0) is None
|
||||
|
||||
# Add epoch end
|
||||
metrics.epoch_end_times[0] = 1060.0
|
||||
assert metrics.epoch_time(0) == 60.0
|
||||
|
||||
def test_average_epoch_time(self):
|
||||
"""Test average_epoch_time method."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# No completed epochs
|
||||
assert metrics.average_epoch_time() is None
|
||||
|
||||
# Add one completed epoch
|
||||
metrics.epoch_start_times[0] = 1000.0
|
||||
metrics.epoch_end_times[0] = 1060.0
|
||||
assert metrics.average_epoch_time() == 60.0
|
||||
|
||||
# Add second completed epoch
|
||||
metrics.epoch_start_times[1] = 1060.0
|
||||
metrics.epoch_end_times[1] = 1140.0 # 80 seconds
|
||||
assert metrics.average_epoch_time() == 70.0 # Average of 60 and 80
|
||||
|
||||
# Add incomplete epoch (should not affect average)
|
||||
metrics.epoch_start_times[2] = 1140.0
|
||||
assert metrics.average_epoch_time() == 70.0
|
||||
|
||||
def test_steps_per_second(self, mock_time):
|
||||
"""Test steps_per_second method."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# No steps - first call to time.time()
|
||||
mock_time.side_effect = None
|
||||
mock_time.return_value = 1050.0
|
||||
assert metrics.steps_per_second() is None
|
||||
|
||||
# Add steps - second call to time.time()
|
||||
metrics.total_steps = 100
|
||||
mock_time.return_value = 1050.0 # Keep same time for consistent result
|
||||
assert metrics.steps_per_second() == 2.0 # 100 steps / 50 seconds
|
||||
|
||||
def test_to_dict_basic(self, mock_time):
|
||||
"""Test to_dict method with basic metrics."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
metrics.total_steps = 100
|
||||
metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 # 2GB
|
||||
|
||||
# Mock elapsed_time
|
||||
mock_time.side_effect = None
|
||||
mock_time.return_value = 1050.0
|
||||
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert result["total_time_seconds"] == 50.0
|
||||
assert result["total_steps"] == 100
|
||||
assert result["steps_per_second"] == 2.0
|
||||
assert result["epochs_completed"] == 0
|
||||
assert result["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
assert "epoch_times" not in result
|
||||
assert "gpu_memory" not in result
|
||||
|
||||
def test_to_dict_with_epochs(self, mock_time):
|
||||
"""Test to_dict method with epoch data."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
metrics.total_steps = 100
|
||||
|
||||
# Add epoch data
|
||||
metrics.epoch_start_times[0] = 1000.0
|
||||
metrics.epoch_end_times[0] = 1060.0
|
||||
metrics.epoch_start_times[1] = 1060.0
|
||||
metrics.epoch_end_times[1] = 1140.0
|
||||
|
||||
# Mock elapsed_time
|
||||
mock_time.side_effect = None
|
||||
mock_time.return_value = 1150.0
|
||||
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert "epoch_times" in result
|
||||
assert result["epoch_times"]["epoch_0_seconds"] == 60.0
|
||||
assert result["epoch_times"]["epoch_1_seconds"] == 80.0
|
||||
assert result["average_epoch_time_seconds"] == 70.0
|
||||
|
||||
def test_to_dict_with_gpu_memory(self, mock_time):
|
||||
"""Test to_dict method with GPU memory data."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
metrics.peak_gpu_memory = {
|
||||
0: 1 * 1024 * 1024 * 1024, # 1GB
|
||||
1: 2 * 1024 * 1024 * 1024, # 2GB
|
||||
}
|
||||
|
||||
# Mock elapsed_time
|
||||
mock_time.side_effect = [1050.0]
|
||||
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert "gpu_memory" in result
|
||||
assert result["gpu_memory"]["gpu_0_peak_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||
assert result["gpu_memory"]["gpu_1_peak_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
class TestRuntimeMetricsTracker:
|
||||
"""Tests for RuntimeMetricsTracker class."""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_initialization(self, mock_time, mock_telemetry_manager):
|
||||
"""Test RuntimeMetricsTracker initialization."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
assert isinstance(tracker.metrics, RuntimeMetrics)
|
||||
assert tracker.metrics.start_time == 1000.0 # First value from mock_time
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_start_epoch(
|
||||
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
|
||||
):
|
||||
"""Test start_epoch method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Reset mock_time to control next value
|
||||
mock_time.side_effect = [1010.0]
|
||||
|
||||
tracker.start_epoch(0)
|
||||
|
||||
assert tracker.metrics.current_epoch == 0
|
||||
assert tracker.metrics.epoch_start_times[0] == 1010.0
|
||||
|
||||
# Verify memory metrics were updated
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
assert 0 in tracker.metrics.peak_gpu_memory
|
||||
assert 1 in tracker.metrics.peak_gpu_memory
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_end_epoch(self, mock_time, mock_telemetry_manager):
|
||||
"""Test end_epoch method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Start epoch 0
|
||||
mock_time.side_effect = [1010.0]
|
||||
tracker.start_epoch(0)
|
||||
|
||||
# End epoch 0
|
||||
mock_time.side_effect = [1060.0]
|
||||
tracker.end_epoch(0)
|
||||
|
||||
assert 0 in tracker.metrics.epoch_end_times
|
||||
assert tracker.metrics.epoch_end_times[0] == 1060.0
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_update_step(
|
||||
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
|
||||
):
|
||||
"""Test update_step method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Update step to a non-multiple of 100
|
||||
tracker.update_step(42)
|
||||
|
||||
assert tracker.metrics.current_step == 42
|
||||
assert tracker.metrics.total_steps == 1
|
||||
|
||||
# Memory metrics should not be updated for non-multiple of 100
|
||||
assert tracker.metrics.peak_cpu_memory == 0
|
||||
|
||||
# Update step to a multiple of 100
|
||||
tracker.update_step(100)
|
||||
|
||||
assert tracker.metrics.current_step == 100
|
||||
assert tracker.metrics.total_steps == 2
|
||||
|
||||
# Memory metrics should be updated for multiple of 100
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_update_memory_metrics(
|
||||
self, mock_psutil, mock_torch, mock_telemetry_manager
|
||||
):
|
||||
"""Test update_memory_metrics method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Initial memory state
|
||||
assert tracker.metrics.peak_cpu_memory == 0
|
||||
assert tracker.metrics.peak_gpu_memory == {}
|
||||
|
||||
# Update memory metrics
|
||||
tracker.update_memory_metrics()
|
||||
|
||||
# Verify CPU memory
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
|
||||
# Verify GPU memory
|
||||
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
|
||||
|
||||
# Change mocked memory values to be lower
|
||||
mock_process = mock_psutil.Process.return_value
|
||||
mock_memory_info = mock_process.memory_info.return_value
|
||||
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
|
||||
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 0.5) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Update memory metrics again
|
||||
tracker.update_memory_metrics()
|
||||
|
||||
# Peak values should not decrease
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
|
||||
|
||||
# Change mocked memory values to be higher
|
||||
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
|
||||
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 2) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Update memory metrics again
|
||||
tracker.update_memory_metrics()
|
||||
|
||||
# Peak values should increase
|
||||
assert tracker.metrics.peak_cpu_memory == 2 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[0] == 2 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[1] == 3 * 1024 * 1024 * 1024
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_telemetry_manager):
|
||||
"""Test get_memory_metrics method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Set peak memory values
|
||||
tracker.metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024
|
||||
tracker.metrics.peak_gpu_memory = {
|
||||
0: 3 * 1024 * 1024 * 1024,
|
||||
1: 4 * 1024 * 1024 * 1024,
|
||||
}
|
||||
|
||||
# Get memory metrics
|
||||
memory_metrics = tracker.get_memory_metrics()
|
||||
|
||||
# Verify CPU memory
|
||||
assert (
|
||||
memory_metrics["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||
) # Current value from mock
|
||||
assert (
|
||||
memory_metrics["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
) # Peak value we set
|
||||
|
||||
# Verify GPU memory
|
||||
assert (
|
||||
memory_metrics["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||
) # Current value from mock
|
||||
assert (
|
||||
memory_metrics["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024
|
||||
) # Peak value we set
|
||||
assert (
|
||||
memory_metrics["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
) # Current value from mock
|
||||
assert (
|
||||
memory_metrics["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024
|
||||
) # Peak value we set
|
||||
Reference in New Issue
Block a user