From d8b0522ea0f7edbb942f03cf4a9d34c6f9137b48 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 24 Feb 2025 20:05:55 +0000 Subject: [PATCH] updated sanitization logic, tests --- src/axolotl/telemetry/errors.py | 72 +++++-- src/axolotl/telemetry/manager.py | 75 ++++++- tests/telemetry/test_errors.py | 340 +++++++++++++++++++++++++++++++ tests/telemetry/test_manager.py | 276 ++++++++++++++++++------- 4 files changed, 666 insertions(+), 97 deletions(-) create mode 100644 tests/telemetry/test_errors.py diff --git a/src/axolotl/telemetry/errors.py b/src/axolotl/telemetry/errors.py index 4357ded9b..98acd6a2c 100644 --- a/src/axolotl/telemetry/errors.py +++ b/src/axolotl/telemetry/errors.py @@ -1,6 +1,7 @@ """Telemetry utilities for exception and traceback information.""" import logging +import os import re import traceback from functools import wraps @@ -16,13 +17,16 @@ ERROR_HANDLED = False def sanitize_stack_trace(stack_trace: str) -> str: """ - Remove personal information from stack trace messages while keeping Axolotl codepaths. + Remove personal information from stack trace messages while keeping Python package codepaths. + + This function identifies Python packages by looking for common patterns in virtual environment + and site-packages directories, preserving the package path while removing user-specific paths. Args: stack_trace: The original stack trace string. Returns: - A sanitized version of the stack trace with only axolotl paths preserved. + A sanitized version of the stack trace with Python package paths preserved. """ # Split the stack trace into lines to process each file path separately lines = stack_trace.split("\n") @@ -31,23 +35,66 @@ def sanitize_stack_trace(stack_trace: str) -> str: # Regular expression to find file paths in the stack trace path_pattern = re.compile(r'(?:File ")(.*?)(?:")') + # Regular expression to identify paths in site-packages or dist-packages + # This matches path segments like "site-packages/package_name" or "dist-packages/package_name" + site_packages_pattern = re.compile( + r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)" + ) + + # Additional common virtual environment patterns + venv_lib_pattern = re.compile( + r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)" + ) + for line in lines: # Check if this line contains a file path path_match = path_pattern.search(line) if path_match: full_path = path_match.group(1) + sanitized_path = "" - if "axolotl/" in full_path: - # Keep only the 'axolotl' part and onward - axolotl_idx = full_path.rfind("axolotl/") - if axolotl_idx >= 0: - # Replace the original path with the sanitized one - sanitized_path = full_path[axolotl_idx:] - line = line.replace(full_path, sanitized_path) + # Try to match site-packages pattern + site_packages_match = site_packages_pattern.search(full_path) + venv_lib_match = venv_lib_pattern.search(full_path) + + if site_packages_match: + # Find the index where the matched pattern starts + idx = full_path.find("site-packages") + if idx == -1: + idx = full_path.find("dist-packages") + + # Keep from 'site-packages' onward + if idx >= 0: + sanitized_path = full_path[idx:] + elif venv_lib_match: + # For other virtual environment patterns, find the package directory + match_idx = venv_lib_match.start(1) + if match_idx > 0: + # Keep from the package name onward + package_name = venv_lib_match.group(1) + idx = full_path.rfind( + package_name, 0, match_idx + len(package_name) + ) + if idx >= 0: + sanitized_path = full_path[idx:] + + # If we couldn't identify a package pattern but path contains 'axolotl' + elif "axolotl" in full_path: + idx = full_path.rfind("axolotl") + if idx >= 0: + sanitized_path = full_path[idx:] + + # Apply the sanitization to the line + if sanitized_path: + line = line.replace(full_path, sanitized_path) else: - # For non-axolotl paths, replace with an empty string or a placeholder - line = line.replace(full_path, "") + # If we couldn't identify a package pattern, just keep the filename + filename = os.path.basename(full_path) + if filename: + line = line.replace(full_path, filename) + else: + line = line.replace(full_path, "") sanitized_lines.append(line) @@ -72,6 +119,7 @@ def send_errors(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs) -> Any: telemetry_manager = TelemetryManager.get_instance() + if not telemetry_manager.enabled: return func(*args, **kwargs) @@ -79,7 +127,7 @@ def send_errors(func: Callable) -> Callable: return func(*args, **kwargs) except Exception as exception: # Only track if we're not already handling an error. This prevents us from - # capturing an error more than once in nested decorated function calls. + # capturing an error more than once in nested decorated function calls.= global ERROR_HANDLED # pylint: disable=global-statement if not ERROR_HANDLED: ERROR_HANDLED = True diff --git a/src/axolotl/telemetry/manager.py b/src/axolotl/telemetry/manager.py index 56f150c2e..0b5e2933e 100644 --- a/src/axolotl/telemetry/manager.py +++ b/src/axolotl/telemetry/manager.py @@ -4,6 +4,7 @@ import atexit import logging import os import platform +import re import time import uuid from dataclasses import dataclass @@ -122,6 +123,9 @@ class TelemetryManager: axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK") do_not_track = os.getenv("DO_NOT_TRACK") + # If explicitly enabled, we'll disable the telemetry warning message + explicit_enabled = axolotl_do_not_track in ["0", "false"] + if axolotl_do_not_track is None: axolotl_do_not_track = "0" @@ -134,9 +138,6 @@ class TelemetryManager: "true", ) and do_not_track.lower() not in ("1", "true") - # If explicitly enabled, we'll disable the telemetry warning message - explicit_enabled = axolotl_do_not_track in ["0", "false"] - return enabled, explicit_enabled def _load_whitelist(self) -> dict: @@ -145,7 +146,7 @@ class TelemetryManager: return yaml.safe_load(f) def _is_whitelisted(self, base_model: str) -> bool: - """Check if model/org is in whitelist""" + """Check if model org is in whitelist""" if not base_model: return False @@ -159,9 +160,66 @@ class TelemetryManager: posthog.project_api_key = POSTHOG_WRITE_KEY posthog.host = self.config.host - def _sanitize_path(self, path: str) -> str: - """Remove personal information from file paths""" - return Path(path).name + def _sanitize_properties(self, properties: dict[str, Any]) -> dict[str, Any]: + """ + Sanitize properties to remove any personally identifiable information such as: + - File paths + - URLs / Links + - Cloud storage locations + + Args: + properties: Dictionary of properties to sanitize. + + Returns: + Sanitized properties dictionary. + """ + if not properties: + return {} + + # Define regex patterns for different types of personal information + patterns = { + # File paths (Unix and Windows) + "file_path": re.compile(r"(?:/|\\)(?:[^/\\]+(?:/|\\))+[^/\\]+"), + # URLs/Links + "url": re.compile(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+(?:/[^/\s]*)*"), + # Cloud storage paths (S3, GCS, Azure) + "cloud_path": re.compile(r"s3://|gs://|azure://|blob.core.windows.net"), + } + + # Deep copy isn't needed; we'll create a new dict with sanitized values + sanitized = {} + + def sanitize_value(value): + """Recursively sanitize values within nested structures""" + if isinstance(value, str): + # For file paths, extract just the filename + path_match = patterns["file_path"].search(value) + if path_match: + try: + # Try to extract just the filename + path_str = path_match.group(0) + value = value.replace(path_str, Path(path_str).name) + except (ValueError, RuntimeError): + # If path extraction fails, just redact the path + value = patterns["file_path"].sub("[REDACTED_PATH]", value) + + # Redact other sensitive information + value = patterns["url"].sub("[REDACTED_URL]", value) + value = patterns["cloud_path"].sub("[REDACTED_CLOUD]", value) + + return value + if isinstance(value, dict): + return {k: sanitize_value(v) for k, v in value.items()} + if isinstance(value, list): + return [sanitize_value(item) for item in value] + + return value + + # Apply the sanitization to all properties + for key, value in properties.items(): + sanitized[key] = sanitize_value(value) + + return sanitized def _get_system_info(self) -> dict[str, Any]: """Collect system information""" @@ -195,6 +253,9 @@ class TelemetryManager: if properties is None: properties = {} + # Sanitize properties to remove PII + properties = self._sanitize_properties(properties) + # Wrap PostHog errors in try / except to not raise errors during Axolotl usage try: LOG.warning(f"*** Sending telemetry for {event_type} ***") diff --git a/tests/telemetry/test_errors.py b/tests/telemetry/test_errors.py new file mode 100644 index 000000000..a021fc96f --- /dev/null +++ b/tests/telemetry/test_errors.py @@ -0,0 +1,340 @@ +"""Tests for telemetry error utilities.""" +# pylint: disable=redefined-outer-name + +from unittest.mock import MagicMock, patch + +import pytest + +from axolotl.telemetry.errors import sanitize_stack_trace, send_errors + + +@pytest.fixture(autouse=True) +def reset_error_flag(monkeypatch): + """Reset ERROR_HANDLED flag using monkeypatch""" + import axolotl.telemetry.errors + + monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False) + yield + monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False) + + +@pytest.fixture +def example_stack_trace(): + """Provide a sample stack trace with mixed paths""" + return """Traceback (most recent call last): + File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main + trainer = get_trainer(cfg) + File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer + model = get_model(cfg, tokenizer) + File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model + raise ValueError("Model path not found") +ValueError: Model path not found +""" + + +@pytest.fixture +def windows_stack_trace(): + """Provide a sample stack trace with Windows paths""" + return """Traceback (most recent call last): + File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main + trainer = get_trainer(cfg) + File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer + model = get_model(cfg, tokenizer) + File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained + raise ValueError(f"Unrecognized configuration class {config.__class__}") +ValueError: Unrecognized configuration class +""" + + +@pytest.fixture +def mixed_stack_trace(): + """Provide a sample stack trace with both axolotl and non-axolotl paths""" + return """Traceback (most recent call last): + File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main + trainer = get_trainer(cfg) + File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train + self._inner_training_loop() + File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop + super()._inner_training_loop() + File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__ + data = self._next_data() +RuntimeError: CUDA out of memory +""" + + +@pytest.fixture +def venv_stack_trace(): + """Provide a sample stack trace with virtual environment paths""" + return """Traceback (most recent call last): + File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train + self._inner_training_loop() + File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop + self.accelerator.backward(loss) + File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward + self.scaler.scale(loss).backward(**kwargs) + File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward + torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) +RuntimeError: CUDA out of memory +""" + + +@pytest.fixture +def dist_packages_stack_trace(): + """Provide a sample stack trace with dist-packages paths""" + return """Traceback (most recent call last): + File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__ + data = self._next_data() + File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data + data = self._dataset_fetcher.fetch(index) + File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__ + raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.") +IndexError: Index 10000 out of range for dataset of length 9832. +""" + + +@pytest.fixture +def project_stack_trace(): + """Provide a sample stack trace from a project directory (not a virtual env)""" + return """Traceback (most recent call last): + File "/home/user/projects/myproject/run.py", line 25, in + main() + File "/home/user/projects/myproject/src/cli.py", line 45, in main + app.run() + File "/home/user/projects/myproject/src/app.py", line 102, in run + raise ValueError("Configuration missing") +ValueError: Configuration missing +""" + + +def test_sanitize_stack_trace(example_stack_trace): + """Test that sanitize_stack_trace properly preserves axolotl paths""" + sanitized = sanitize_stack_trace(example_stack_trace) + + # Check that personal paths are removed + assert "/home/user" not in sanitized + assert ".local/lib/python3.9" not in sanitized + + # Check that site-packages is preserved + assert "site-packages/axolotl/cli/train.py" in sanitized + assert "site-packages/axolotl/train.py" in sanitized + assert "site-packages/axolotl/utils/models.py" in sanitized + + # Check that error message is preserved + assert "ValueError: Model path not found" in sanitized + + +def test_sanitize_windows_paths(windows_stack_trace): + """Test that sanitize_stack_trace handles Windows paths""" + sanitized = sanitize_stack_trace(windows_stack_trace) + + # Check that personal paths are removed + assert "C:\\Users\\name" not in sanitized + assert "AppData\\Local\\Programs\\Python" not in sanitized + + # Check that both axolotl and transformers packages are preserved + assert ( + "site-packages\\axolotl\\cli\\train.py" in sanitized + or "site-packages/axolotl/cli/train.py" in sanitized + ) + assert ( + "site-packages\\axolotl\\train.py" in sanitized + or "site-packages/axolotl/train.py" in sanitized + ) + assert ( + "site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized + or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized + ) + + # Check that error message is preserved + assert "ValueError: Unrecognized configuration class" in sanitized + + +def test_sanitize_mixed_paths(mixed_stack_trace): + """Test that sanitize_stack_trace preserves all package paths""" + sanitized = sanitize_stack_trace(mixed_stack_trace) + + # Check that all package paths are preserved + assert "site-packages/axolotl/cli/train.py" in sanitized + assert "site-packages/transformers/trainer.py" in sanitized + assert "site-packages/axolotl/utils/trainer.py" in sanitized + assert "site-packages/torch/utils/data/dataloader.py" in sanitized + + # Check that error message is preserved + assert "RuntimeError: CUDA out of memory" in sanitized + + +def test_sanitize_venv_paths(venv_stack_trace): + """Test that sanitize_stack_trace preserves virtual environment package paths""" + sanitized = sanitize_stack_trace(venv_stack_trace) + + # Check that personal paths are removed + assert "/home/user/venv" not in sanitized + + # Check that all package paths are preserved + assert "site-packages/transformers/trainer.py" in sanitized + assert "site-packages/accelerate/accelerator.py" in sanitized + assert "site-packages/torch/_tensor.py" in sanitized + + # Check that error message is preserved + assert "RuntimeError: CUDA out of memory" in sanitized + + +def test_sanitize_dist_packages(dist_packages_stack_trace): + """Test that sanitize_stack_trace preserves dist-packages paths""" + sanitized = sanitize_stack_trace(dist_packages_stack_trace) + + # Check that system paths are removed + assert "/usr/local/lib/python3.8" not in sanitized + + # Check that all package paths are preserved + assert "dist-packages/torch/utils/data/dataloader.py" in sanitized + assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized + assert "dist-packages/datasets/arrow_dataset.py" in sanitized + + # Check that error message is preserved + assert ( + "IndexError: Index 10000 out of range for dataset of length 9832." in sanitized + ) + + +def test_sanitize_project_paths(project_stack_trace): + """Test handling of project paths (non-virtual env)""" + sanitized = sanitize_stack_trace(project_stack_trace) + + # Check that personal paths are removed + assert "/home/user/projects" not in sanitized + + # For non-package paths, we should at least preserve the filename + assert "run.py" in sanitized + assert "cli.py" in sanitized + assert "app.py" in sanitized + + # Check that error message is preserved + assert "ValueError: Configuration missing" in sanitized + + +@pytest.fixture +def mock_telemetry_manager(): + """Create a mock TelemetryManager""" + with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class: + mock_manager = MagicMock() + mock_manager.enabled = True + mock_manager_class.get_instance.return_value = mock_manager + yield mock_manager + + +def test_send_errors_successful_execution(mock_telemetry_manager): + """Test that send_errors doesn't send telemetry for successful function execution""" + + @send_errors + def test_func(): + return "success" + + result = test_func() + assert result == "success" + mock_telemetry_manager.send_event.assert_not_called() + + +def test_send_errors_with_exception(mock_telemetry_manager): + """Test that send_errors sends telemetry when an exception occurs""" + test_error = ValueError("Test error") + + @send_errors + def test_func(): + raise test_error + + with pytest.raises(ValueError) as excinfo: + test_func() + + assert excinfo.value == test_error + mock_telemetry_manager.send_event.assert_called_once() + + # Check that the error info was passed correctly + call_args = mock_telemetry_manager.send_event.call_args[1] + assert "test_func-error" in call_args["event_type"] + assert "Test error" in call_args["properties"]["exception"] + assert "stack_trace" in call_args["properties"] + + +def test_send_errors_nested_calls(mock_telemetry_manager): + """Test that send_errors only sends telemetry once for nested decorated functions""" + + @send_errors + def inner_func(): + raise ValueError("Inner error") + + @send_errors + def outer_func(): + return inner_func() + + with pytest.raises(ValueError): + outer_func() + + # Telemetry should be sent only once for the inner function + assert mock_telemetry_manager.send_event.call_count == 1 + call_args = mock_telemetry_manager.send_event.call_args[1] + assert "inner_func-error" in call_args["event_type"] + + +def test_send_errors_telemetry_disable(): + """Test that send_errors doesn't attempt to send telemetry when disabled""" + + with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class: + mock_manager = MagicMock() + mock_manager.enabled = False + mock_manager_class.get_instance.return_value = mock_manager + + @send_errors + def test_func(): + raise ValueError("Test error") + + with pytest.raises(ValueError): + test_func() + + mock_manager.send_event.assert_not_called() + + +def test_error_handled_reset(): + """Test that ERROR_HANDLED flag is properly reset""" + with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class: + # Create and configure the mock manager + mock_manager = MagicMock() + mock_manager.enabled = True + mock_manager_class.get_instance.return_value = mock_manager + + from axolotl.telemetry.errors import ERROR_HANDLED + + @send_errors + def test_func(): + raise ValueError("Test error") + + assert not ERROR_HANDLED + + with pytest.raises(ValueError): + test_func() + + from axolotl.telemetry.errors import ERROR_HANDLED + + assert ERROR_HANDLED + + +def test_module_path_resolution(mock_telemetry_manager): + """Test that the module path is correctly resolved for the event type""" + import inspect + + current_module = inspect.getmodule(test_module_path_resolution).__name__ + + @send_errors + def test_func(): + raise ValueError("Test error") + + with pytest.raises(ValueError): + test_func() + + assert mock_telemetry_manager.send_event.called + event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"] + + expected_event_type = f"{current_module}.test_func-error" + assert expected_event_type == event_type diff --git a/tests/telemetry/test_manager.py b/tests/telemetry/test_manager.py index 72e554002..04a6404bc 100644 --- a/tests/telemetry/test_manager.py +++ b/tests/telemetry/test_manager.py @@ -1,5 +1,5 @@ """Tests for TelemetryManager class and utilities""" -# pylint: disable=redefined-outer-name +# pylint: disable=redefined-outer-name,protected-access import os from unittest.mock import patch @@ -7,7 +7,7 @@ from unittest.mock import patch import pytest import yaml -from axolotl.telemetry import TelemetryManager +from axolotl.telemetry.manager import TelemetryConfig, TelemetryManager @pytest.fixture @@ -19,116 +19,236 @@ def mock_whitelist(tmp_path): whitelist_file = tmp_path / "whitelist.yaml" with open(whitelist_file, "w", encoding="utf-8") as f: yaml.dump(whitelist_content, f) + return str(whitelist_file) @pytest.fixture -def manager(): - """Create a TelemetryManager instance with mocked PostHog""" - with patch("posthog.capture"): - return TelemetryManager() +def telemetry_manager_class(): + """Reset the TelemetryManager singleton between tests""" + original_instance = TelemetryManager._instance + original_initialized = TelemetryManager._initialized + TelemetryManager._instance = None + TelemetryManager._initialized = False + yield TelemetryManager + TelemetryManager._instance = original_instance + TelemetryManager._initialized = original_initialized -def test_telemetry_disabled_by_default(): - """Test that telemetry is disabled by default""" - manager = TelemetryManager() - assert not manager.enabled +@pytest.fixture +def manager(telemetry_manager_class, mock_whitelist): + """Create a TelemetryManager instance with mocked dependencies""" + with patch("posthog.capture"), patch("posthog.flush"), patch( + "time.sleep" + ), patch.object(TelemetryConfig, "whitelist_path", mock_whitelist), patch( + "axolotl.telemetry.manager.is_main_process", return_value=True + ): + manager = telemetry_manager_class() + # Manually enable for most tests + manager.enabled = True + return manager -def test_telemetry_opt_in(): - """Test that telemetry can be enabled via environment variable""" - with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}): - manager = TelemetryManager() - assert manager.enabled +def test_singleton_instance(telemetry_manager_class): + """Test that TelemetryManager is a singleton""" + with patch("posthog.capture"), patch("time.sleep"): + first = telemetry_manager_class() + second = telemetry_manager_class() + assert first is second + assert telemetry_manager_class.get_instance() is first -def test_do_not_track_override(): - """Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY""" - with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}): - manager = TelemetryManager() +def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class): + """Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1""" + with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1"}), patch( + "axolotl.telemetry.manager.is_main_process", return_value=True + ): + manager = telemetry_manager_class() assert not manager.enabled -# pylint: disable=protected-access -def test_whitelist_checking(manager): - """Test model whitelist functionality""" - # Should match organization - assert manager._is_whitelisted("meta-llama/llama-7b") - # Should match model name - assert manager._is_whitelisted("mistralai/mistral-7b-instruct") - # Should not match - assert not manager._is_whitelisted("unknown/model") - # Should handle case insensitively - assert manager._is_whitelisted("meta/Llama-7b") +def test_telemetry_disabled_with_do_not_track(telemetry_manager_class): + """Test that telemetry is disabled when DO_NOT_TRACK=1""" + with patch.dict(os.environ, {"DO_NOT_TRACK": "1"}), patch( + "axolotl.telemetry.manager.is_main_process", return_value=True + ): + manager = telemetry_manager_class() + assert not manager.enabled -def test_event_tracking(manager): - """Test basic event tracking""" +def test_telemetry_disabled_for_non_main_process(telemetry_manager_class): + """Test that telemetry is disabled for non-main processes""" + with patch("axolotl.telemetry.manager.is_main_process", return_value=False): + manager = telemetry_manager_class() + assert not manager.enabled + + +def test_telemetry_enabled_by_default(telemetry_manager_class): + """Test that telemetry is enabled by default""" + with patch.dict(os.environ, {}, clear=True), patch( + "axolotl.telemetry.manager.is_main_process", return_value=True + ), patch("time.sleep"), patch("logging.Logger.warning"): + manager = telemetry_manager_class() + assert manager.enabled + assert not manager.explicit_enable + + +def test_explicit_enable_disables_warning(telemetry_manager_class): + """Test that explicit enabling prevents warning""" + with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), patch( + "logging.Logger.warning" + ) as mock_warning, patch( + "axolotl.telemetry.manager.is_main_process", return_value=True + ), patch( + "time.sleep" + ): + manager = telemetry_manager_class() + assert manager.enabled + assert manager.explicit_enable + for call in mock_warning.call_args_list: + assert "Telemetry is enabled" not in str(call) + + +def test_warning_displayed_for_implicit_enable(telemetry_manager_class): + """Test that warning is displayed when telemetry is implicitly enabled""" + with patch.dict(os.environ, {}, clear=True), patch( + "logging.Logger.warning" + ) as mock_warning, patch( + "axolotl.telemetry.manager.is_main_process", return_value=True + ), patch( + "time.sleep" + ): + manager = telemetry_manager_class() + assert manager.enabled + assert not manager.explicit_enable + warning_displayed = False + for call in mock_warning.call_args_list: + if "Telemetry is enabled" in str(call): + warning_displayed = True + break + assert warning_displayed + + +def test_is_whitelisted(manager, mock_whitelist): + """Test org whitelist functionality""" + with patch.object(TelemetryConfig, "whitelist_path", mock_whitelist): + # Should match organizations from the mock whitelist + assert manager._is_whitelisted("meta-llama/llama-7b") + assert manager._is_whitelisted("mistralai/mistral-7b-instruct") + # Should not match + assert not manager._is_whitelisted("unknown/model") + # Should handle case insensitively + assert manager._is_whitelisted("META-LLAMA/Llama-7B") + # Should handle empty input + assert not manager._is_whitelisted("") + assert not manager._is_whitelisted(None) + + +def test_system_info_collection(manager): + """Test system information collection""" + system_info = manager.system_info + + # Check essential keys + assert "os" in system_info + assert "python_version" in system_info + assert "pytorch_version" in system_info + assert "transformers_version" in system_info + assert "axolotl_version" in system_info + assert "cpu_count" in system_info + assert "memory_total" in system_info + assert "gpu_count" in system_info + + +def test_send_event(manager): + """Test basic event sending""" with patch("posthog.capture") as mock_capture: - manager.enabled = True - manager.track_event("test_event", {"key": "value"}) - + # Test with clean properties (no PII) + manager.send_event("test_event", {"key": "value"}) assert mock_capture.called assert mock_capture.call_args[1]["event"] == "test_event" - assert mock_capture.call_args[1]["properties"]["key"] == "value" - assert "run_id" in mock_capture.call_args[1]["properties"] - assert "system_info" in mock_capture.call_args[1]["properties"] + assert mock_capture.call_args[1]["properties"] == {"key": "value"} + assert mock_capture.call_args[1]["distinct_id"] == manager.run_id + + # Test with default properties (None) + mock_capture.reset_mock() + manager.send_event("simple_event") + assert mock_capture.called + assert mock_capture.call_args[1]["properties"] == {} -def test_training_context(manager): - """Test training context manager""" - config = {"model": "llama", "batch_size": 8} - +def test_send_system_info(manager): + """Test sending system info""" with patch("posthog.capture") as mock_capture: - manager.enabled = True - - with manager.track_training(config): - pass # Simulate successful training - - # Should have captured training_start and training_complete - events = [call[1]["event"] for call in mock_capture.call_args_list] - assert "training_start" in events - assert "training_complete" in events + manager.send_system_info() + assert mock_capture.called + assert mock_capture.call_args[1]["event"] == "system-info" + assert mock_capture.call_args[1]["properties"] == manager.system_info -def test_training_error(manager): - """Test training context manager with error""" - config = {"model": "llama", "batch_size": 8} - +def test_sanitize_properties(manager): + """Test property sanitization in send_event method""" with patch("posthog.capture") as mock_capture: - manager.enabled = True + # Test with properties containing various PII + test_properties = { + "filepath": "/home/user/sensitive/data.txt", + "windows_path": "C:\\Users\\name\\Documents\\project\\file.py", + "url": "https://example.com/private/user123", + "message": "Error loading /tmp/axolotl/data.csv - check permissions", + "cloud_path": "s3://my-bucket/data/user-files/", + "nested": { + "deep_path": "/var/log/axolotl/training.log", + "list_paths": ["/home/user1/file1.txt", "/home/user2/file2.txt"], + }, + } - with pytest.raises(ValueError): - with manager.track_training(config): - raise ValueError("Test error") + manager.send_event("test_event", test_properties) - # Should have captured training_start and training_error - events = [call[1]["event"] for call in mock_capture.call_args_list] - assert "training_start" in events - assert "training_error" in events + # Verify the call was made + assert mock_capture.called + + # Get the sanitized properties that were sent + sanitized = mock_capture.call_args[1]["properties"] + + # Check that PII was removed/sanitized + assert "/home/user/sensitive" not in str(sanitized) + assert "C:\\Users\\name" not in str(sanitized) + assert "https://example.com/private" not in str(sanitized) + assert "s3://my-bucket" not in str(sanitized) + + # Check that filenames were preserved + assert "data.txt" in str(sanitized) + assert "file.py" in str(sanitized) + assert "data.csv" in str(sanitized) + + # Check that nested structures were properly sanitized + assert "/var/log/axolotl" not in str(sanitized["nested"]) + assert "/home/user1" not in str(sanitized["nested"]["list_paths"]) -# pylint: disable=protected-access -def test_path_sanitization(manager): - """Test path sanitization""" - path = "/home/user/sensitive/data.txt" - sanitized = manager._sanitize_path(path) - assert sanitized == "data.txt" - assert "/home/user" not in sanitized +def test_disable_telemetry(manager): + """Test that disabled telemetry doesn't send events""" + with patch("posthog.capture") as mock_capture: + manager.enabled = False + manager.send_event("test_event") + assert not mock_capture.called -# pylint: disable=protected-access -def test_error_sanitization(manager): - """Test error message sanitization""" - error = "Failed to load /home/user/sensitive/data.txt: File not found" - sanitized = manager._sanitize_error(error) - assert "sensitive" not in sanitized - assert "/home/user" not in sanitized +def test_exception_handling_during_send(manager): + """Test that exceptions in PostHog are handled gracefully""" + with patch("posthog.capture", side_effect=Exception("Test error")), patch( + "logging.Logger.warning" + ) as mock_warning: + manager.send_event("test_event") + warning_logged = False + for call in mock_warning.call_args_list: + if "Failed to send telemetry event" in str(call): + warning_logged = True + break + assert warning_logged def test_shutdown(manager): """Test shutdown behavior""" with patch("posthog.flush") as mock_flush: - manager.enabled = True manager.shutdown() assert mock_flush.called