updated sanitization logic, tests

This commit is contained in:
Dan Saunders
2025-02-24 20:05:55 +00:00
parent 1edd6b9524
commit d8b0522ea0
4 changed files with 666 additions and 97 deletions

View File

@@ -0,0 +1,340 @@
"""Tests for telemetry error utilities."""
# pylint: disable=redefined-outer-name
from unittest.mock import MagicMock, patch
import pytest
from axolotl.telemetry.errors import sanitize_stack_trace, send_errors
@pytest.fixture(autouse=True)
def reset_error_flag(monkeypatch):
"""Reset ERROR_HANDLED flag using monkeypatch"""
import axolotl.telemetry.errors
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
yield
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
@pytest.fixture
def example_stack_trace():
"""Provide a sample stack trace with mixed paths"""
return """Traceback (most recent call last):
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
trainer = get_trainer(cfg)
File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer
model = get_model(cfg, tokenizer)
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model
raise ValueError("Model path not found")
ValueError: Model path not found
"""
@pytest.fixture
def windows_stack_trace():
"""Provide a sample stack trace with Windows paths"""
return """Traceback (most recent call last):
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main
trainer = get_trainer(cfg)
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer
model = get_model(cfg, tokenizer)
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained
raise ValueError(f"Unrecognized configuration class {config.__class__}")
ValueError: Unrecognized configuration class <class 'transformers.models.llama.configuration_llama.LlamaConfig'>
"""
@pytest.fixture
def mixed_stack_trace():
"""Provide a sample stack trace with both axolotl and non-axolotl paths"""
return """Traceback (most recent call last):
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
trainer = get_trainer(cfg)
File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train
self._inner_training_loop()
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop
super()._inner_training_loop()
File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
RuntimeError: CUDA out of memory
"""
@pytest.fixture
def venv_stack_trace():
"""Provide a sample stack trace with virtual environment paths"""
return """Traceback (most recent call last):
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train
self._inner_training_loop()
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop
self.accelerator.backward(loss)
File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward
self.scaler.scale(loss).backward(**kwargs)
File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
RuntimeError: CUDA out of memory
"""
@pytest.fixture
def dist_packages_stack_trace():
"""Provide a sample stack trace with dist-packages paths"""
return """Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data
data = self._dataset_fetcher.fetch(index)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__
raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.")
IndexError: Index 10000 out of range for dataset of length 9832.
"""
@pytest.fixture
def project_stack_trace():
"""Provide a sample stack trace from a project directory (not a virtual env)"""
return """Traceback (most recent call last):
File "/home/user/projects/myproject/run.py", line 25, in <module>
main()
File "/home/user/projects/myproject/src/cli.py", line 45, in main
app.run()
File "/home/user/projects/myproject/src/app.py", line 102, in run
raise ValueError("Configuration missing")
ValueError: Configuration missing
"""
def test_sanitize_stack_trace(example_stack_trace):
"""Test that sanitize_stack_trace properly preserves axolotl paths"""
sanitized = sanitize_stack_trace(example_stack_trace)
# Check that personal paths are removed
assert "/home/user" not in sanitized
assert ".local/lib/python3.9" not in sanitized
# Check that site-packages is preserved
assert "site-packages/axolotl/cli/train.py" in sanitized
assert "site-packages/axolotl/train.py" in sanitized
assert "site-packages/axolotl/utils/models.py" in sanitized
# Check that error message is preserved
assert "ValueError: Model path not found" in sanitized
def test_sanitize_windows_paths(windows_stack_trace):
"""Test that sanitize_stack_trace handles Windows paths"""
sanitized = sanitize_stack_trace(windows_stack_trace)
# Check that personal paths are removed
assert "C:\\Users\\name" not in sanitized
assert "AppData\\Local\\Programs\\Python" not in sanitized
# Check that both axolotl and transformers packages are preserved
assert (
"site-packages\\axolotl\\cli\\train.py" in sanitized
or "site-packages/axolotl/cli/train.py" in sanitized
)
assert (
"site-packages\\axolotl\\train.py" in sanitized
or "site-packages/axolotl/train.py" in sanitized
)
assert (
"site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized
or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized
)
# Check that error message is preserved
assert "ValueError: Unrecognized configuration class" in sanitized
def test_sanitize_mixed_paths(mixed_stack_trace):
"""Test that sanitize_stack_trace preserves all package paths"""
sanitized = sanitize_stack_trace(mixed_stack_trace)
# Check that all package paths are preserved
assert "site-packages/axolotl/cli/train.py" in sanitized
assert "site-packages/transformers/trainer.py" in sanitized
assert "site-packages/axolotl/utils/trainer.py" in sanitized
assert "site-packages/torch/utils/data/dataloader.py" in sanitized
# Check that error message is preserved
assert "RuntimeError: CUDA out of memory" in sanitized
def test_sanitize_venv_paths(venv_stack_trace):
"""Test that sanitize_stack_trace preserves virtual environment package paths"""
sanitized = sanitize_stack_trace(venv_stack_trace)
# Check that personal paths are removed
assert "/home/user/venv" not in sanitized
# Check that all package paths are preserved
assert "site-packages/transformers/trainer.py" in sanitized
assert "site-packages/accelerate/accelerator.py" in sanitized
assert "site-packages/torch/_tensor.py" in sanitized
# Check that error message is preserved
assert "RuntimeError: CUDA out of memory" in sanitized
def test_sanitize_dist_packages(dist_packages_stack_trace):
"""Test that sanitize_stack_trace preserves dist-packages paths"""
sanitized = sanitize_stack_trace(dist_packages_stack_trace)
# Check that system paths are removed
assert "/usr/local/lib/python3.8" not in sanitized
# Check that all package paths are preserved
assert "dist-packages/torch/utils/data/dataloader.py" in sanitized
assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized
assert "dist-packages/datasets/arrow_dataset.py" in sanitized
# Check that error message is preserved
assert (
"IndexError: Index 10000 out of range for dataset of length 9832." in sanitized
)
def test_sanitize_project_paths(project_stack_trace):
"""Test handling of project paths (non-virtual env)"""
sanitized = sanitize_stack_trace(project_stack_trace)
# Check that personal paths are removed
assert "/home/user/projects" not in sanitized
# For non-package paths, we should at least preserve the filename
assert "run.py" in sanitized
assert "cli.py" in sanitized
assert "app.py" in sanitized
# Check that error message is preserved
assert "ValueError: Configuration missing" in sanitized
@pytest.fixture
def mock_telemetry_manager():
"""Create a mock TelemetryManager"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
yield mock_manager
def test_send_errors_successful_execution(mock_telemetry_manager):
"""Test that send_errors doesn't send telemetry for successful function execution"""
@send_errors
def test_func():
return "success"
result = test_func()
assert result == "success"
mock_telemetry_manager.send_event.assert_not_called()
def test_send_errors_with_exception(mock_telemetry_manager):
"""Test that send_errors sends telemetry when an exception occurs"""
test_error = ValueError("Test error")
@send_errors
def test_func():
raise test_error
with pytest.raises(ValueError) as excinfo:
test_func()
assert excinfo.value == test_error
mock_telemetry_manager.send_event.assert_called_once()
# Check that the error info was passed correctly
call_args = mock_telemetry_manager.send_event.call_args[1]
assert "test_func-error" in call_args["event_type"]
assert "Test error" in call_args["properties"]["exception"]
assert "stack_trace" in call_args["properties"]
def test_send_errors_nested_calls(mock_telemetry_manager):
"""Test that send_errors only sends telemetry once for nested decorated functions"""
@send_errors
def inner_func():
raise ValueError("Inner error")
@send_errors
def outer_func():
return inner_func()
with pytest.raises(ValueError):
outer_func()
# Telemetry should be sent only once for the inner function
assert mock_telemetry_manager.send_event.call_count == 1
call_args = mock_telemetry_manager.send_event.call_args[1]
assert "inner_func-error" in call_args["event_type"]
def test_send_errors_telemetry_disable():
"""Test that send_errors doesn't attempt to send telemetry when disabled"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = False
mock_manager_class.get_instance.return_value = mock_manager
@send_errors
def test_func():
raise ValueError("Test error")
with pytest.raises(ValueError):
test_func()
mock_manager.send_event.assert_not_called()
def test_error_handled_reset():
"""Test that ERROR_HANDLED flag is properly reset"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
# Create and configure the mock manager
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
from axolotl.telemetry.errors import ERROR_HANDLED
@send_errors
def test_func():
raise ValueError("Test error")
assert not ERROR_HANDLED
with pytest.raises(ValueError):
test_func()
from axolotl.telemetry.errors import ERROR_HANDLED
assert ERROR_HANDLED
def test_module_path_resolution(mock_telemetry_manager):
"""Test that the module path is correctly resolved for the event type"""
import inspect
current_module = inspect.getmodule(test_module_path_resolution).__name__
@send_errors
def test_func():
raise ValueError("Test error")
with pytest.raises(ValueError):
test_func()
assert mock_telemetry_manager.send_event.called
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
expected_event_type = f"{current_module}.test_func-error"
assert expected_event_type == event_type

View File

@@ -1,5 +1,5 @@
"""Tests for TelemetryManager class and utilities"""
# pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name,protected-access
import os
from unittest.mock import patch
@@ -7,7 +7,7 @@ from unittest.mock import patch
import pytest
import yaml
from axolotl.telemetry import TelemetryManager
from axolotl.telemetry.manager import TelemetryConfig, TelemetryManager
@pytest.fixture
@@ -19,116 +19,236 @@ def mock_whitelist(tmp_path):
whitelist_file = tmp_path / "whitelist.yaml"
with open(whitelist_file, "w", encoding="utf-8") as f:
yaml.dump(whitelist_content, f)
return str(whitelist_file)
@pytest.fixture
def manager():
"""Create a TelemetryManager instance with mocked PostHog"""
with patch("posthog.capture"):
return TelemetryManager()
def telemetry_manager_class():
"""Reset the TelemetryManager singleton between tests"""
original_instance = TelemetryManager._instance
original_initialized = TelemetryManager._initialized
TelemetryManager._instance = None
TelemetryManager._initialized = False
yield TelemetryManager
TelemetryManager._instance = original_instance
TelemetryManager._initialized = original_initialized
def test_telemetry_disabled_by_default():
"""Test that telemetry is disabled by default"""
manager = TelemetryManager()
assert not manager.enabled
@pytest.fixture
def manager(telemetry_manager_class, mock_whitelist):
"""Create a TelemetryManager instance with mocked dependencies"""
with patch("posthog.capture"), patch("posthog.flush"), patch(
"time.sleep"
), patch.object(TelemetryConfig, "whitelist_path", mock_whitelist), patch(
"axolotl.telemetry.manager.is_main_process", return_value=True
):
manager = telemetry_manager_class()
# Manually enable for most tests
manager.enabled = True
return manager
def test_telemetry_opt_in():
"""Test that telemetry can be enabled via environment variable"""
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
manager = TelemetryManager()
assert manager.enabled
def test_singleton_instance(telemetry_manager_class):
"""Test that TelemetryManager is a singleton"""
with patch("posthog.capture"), patch("time.sleep"):
first = telemetry_manager_class()
second = telemetry_manager_class()
assert first is second
assert telemetry_manager_class.get_instance() is first
def test_do_not_track_override():
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
manager = TelemetryManager()
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1"}), patch(
"axolotl.telemetry.manager.is_main_process", return_value=True
):
manager = telemetry_manager_class()
assert not manager.enabled
# pylint: disable=protected-access
def test_whitelist_checking(manager):
"""Test model whitelist functionality"""
# Should match organization
assert manager._is_whitelisted("meta-llama/llama-7b")
# Should match model name
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
# Should not match
assert not manager._is_whitelisted("unknown/model")
# Should handle case insensitively
assert manager._is_whitelisted("meta/Llama-7b")
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
with patch.dict(os.environ, {"DO_NOT_TRACK": "1"}), patch(
"axolotl.telemetry.manager.is_main_process", return_value=True
):
manager = telemetry_manager_class()
assert not manager.enabled
def test_event_tracking(manager):
"""Test basic event tracking"""
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
"""Test that telemetry is disabled for non-main processes"""
with patch("axolotl.telemetry.manager.is_main_process", return_value=False):
manager = telemetry_manager_class()
assert not manager.enabled
def test_telemetry_enabled_by_default(telemetry_manager_class):
"""Test that telemetry is enabled by default"""
with patch.dict(os.environ, {}, clear=True), patch(
"axolotl.telemetry.manager.is_main_process", return_value=True
), patch("time.sleep"), patch("logging.Logger.warning"):
manager = telemetry_manager_class()
assert manager.enabled
assert not manager.explicit_enable
def test_explicit_enable_disables_warning(telemetry_manager_class):
"""Test that explicit enabling prevents warning"""
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), patch(
"logging.Logger.warning"
) as mock_warning, patch(
"axolotl.telemetry.manager.is_main_process", return_value=True
), patch(
"time.sleep"
):
manager = telemetry_manager_class()
assert manager.enabled
assert manager.explicit_enable
for call in mock_warning.call_args_list:
assert "Telemetry is enabled" not in str(call)
def test_warning_displayed_for_implicit_enable(telemetry_manager_class):
"""Test that warning is displayed when telemetry is implicitly enabled"""
with patch.dict(os.environ, {}, clear=True), patch(
"logging.Logger.warning"
) as mock_warning, patch(
"axolotl.telemetry.manager.is_main_process", return_value=True
), patch(
"time.sleep"
):
manager = telemetry_manager_class()
assert manager.enabled
assert not manager.explicit_enable
warning_displayed = False
for call in mock_warning.call_args_list:
if "Telemetry is enabled" in str(call):
warning_displayed = True
break
assert warning_displayed
def test_is_whitelisted(manager, mock_whitelist):
"""Test org whitelist functionality"""
with patch.object(TelemetryConfig, "whitelist_path", mock_whitelist):
# Should match organizations from the mock whitelist
assert manager._is_whitelisted("meta-llama/llama-7b")
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
# Should not match
assert not manager._is_whitelisted("unknown/model")
# Should handle case insensitively
assert manager._is_whitelisted("META-LLAMA/Llama-7B")
# Should handle empty input
assert not manager._is_whitelisted("")
assert not manager._is_whitelisted(None)
def test_system_info_collection(manager):
"""Test system information collection"""
system_info = manager.system_info
# Check essential keys
assert "os" in system_info
assert "python_version" in system_info
assert "pytorch_version" in system_info
assert "transformers_version" in system_info
assert "axolotl_version" in system_info
assert "cpu_count" in system_info
assert "memory_total" in system_info
assert "gpu_count" in system_info
def test_send_event(manager):
"""Test basic event sending"""
with patch("posthog.capture") as mock_capture:
manager.enabled = True
manager.track_event("test_event", {"key": "value"})
# Test with clean properties (no PII)
manager.send_event("test_event", {"key": "value"})
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "test_event"
assert mock_capture.call_args[1]["properties"]["key"] == "value"
assert "run_id" in mock_capture.call_args[1]["properties"]
assert "system_info" in mock_capture.call_args[1]["properties"]
assert mock_capture.call_args[1]["properties"] == {"key": "value"}
assert mock_capture.call_args[1]["distinct_id"] == manager.run_id
# Test with default properties (None)
mock_capture.reset_mock()
manager.send_event("simple_event")
assert mock_capture.called
assert mock_capture.call_args[1]["properties"] == {}
def test_training_context(manager):
"""Test training context manager"""
config = {"model": "llama", "batch_size": 8}
def test_send_system_info(manager):
"""Test sending system info"""
with patch("posthog.capture") as mock_capture:
manager.enabled = True
with manager.track_training(config):
pass # Simulate successful training
# Should have captured training_start and training_complete
events = [call[1]["event"] for call in mock_capture.call_args_list]
assert "training_start" in events
assert "training_complete" in events
manager.send_system_info()
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "system-info"
assert mock_capture.call_args[1]["properties"] == manager.system_info
def test_training_error(manager):
"""Test training context manager with error"""
config = {"model": "llama", "batch_size": 8}
def test_sanitize_properties(manager):
"""Test property sanitization in send_event method"""
with patch("posthog.capture") as mock_capture:
manager.enabled = True
# Test with properties containing various PII
test_properties = {
"filepath": "/home/user/sensitive/data.txt",
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
"url": "https://example.com/private/user123",
"message": "Error loading /tmp/axolotl/data.csv - check permissions",
"cloud_path": "s3://my-bucket/data/user-files/",
"nested": {
"deep_path": "/var/log/axolotl/training.log",
"list_paths": ["/home/user1/file1.txt", "/home/user2/file2.txt"],
},
}
with pytest.raises(ValueError):
with manager.track_training(config):
raise ValueError("Test error")
manager.send_event("test_event", test_properties)
# Should have captured training_start and training_error
events = [call[1]["event"] for call in mock_capture.call_args_list]
assert "training_start" in events
assert "training_error" in events
# Verify the call was made
assert mock_capture.called
# Get the sanitized properties that were sent
sanitized = mock_capture.call_args[1]["properties"]
# Check that PII was removed/sanitized
assert "/home/user/sensitive" not in str(sanitized)
assert "C:\\Users\\name" not in str(sanitized)
assert "https://example.com/private" not in str(sanitized)
assert "s3://my-bucket" not in str(sanitized)
# Check that filenames were preserved
assert "data.txt" in str(sanitized)
assert "file.py" in str(sanitized)
assert "data.csv" in str(sanitized)
# Check that nested structures were properly sanitized
assert "/var/log/axolotl" not in str(sanitized["nested"])
assert "/home/user1" not in str(sanitized["nested"]["list_paths"])
# pylint: disable=protected-access
def test_path_sanitization(manager):
"""Test path sanitization"""
path = "/home/user/sensitive/data.txt"
sanitized = manager._sanitize_path(path)
assert sanitized == "data.txt"
assert "/home/user" not in sanitized
def test_disable_telemetry(manager):
"""Test that disabled telemetry doesn't send events"""
with patch("posthog.capture") as mock_capture:
manager.enabled = False
manager.send_event("test_event")
assert not mock_capture.called
# pylint: disable=protected-access
def test_error_sanitization(manager):
"""Test error message sanitization"""
error = "Failed to load /home/user/sensitive/data.txt: File not found"
sanitized = manager._sanitize_error(error)
assert "sensitive" not in sanitized
assert "/home/user" not in sanitized
def test_exception_handling_during_send(manager):
"""Test that exceptions in PostHog are handled gracefully"""
with patch("posthog.capture", side_effect=Exception("Test error")), patch(
"logging.Logger.warning"
) as mock_warning:
manager.send_event("test_event")
warning_logged = False
for call in mock_warning.call_args_list:
if "Failed to send telemetry event" in str(call):
warning_logged = True
break
assert warning_logged
def test_shutdown(manager):
"""Test shutdown behavior"""
with patch("posthog.flush") as mock_flush:
manager.enabled = True
manager.shutdown()
assert mock_flush.called