diff --git a/src/axolotl/telemetry/__init__.py b/src/axolotl/telemetry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/telemetry/manager.py b/src/axolotl/telemetry/manager.py index e4a6d0bea..2b4460d7b 100644 --- a/src/axolotl/telemetry/manager.py +++ b/src/axolotl/telemetry/manager.py @@ -4,10 +4,8 @@ import atexit import logging import os import platform -import re import time import uuid -from dataclasses import dataclass from pathlib import Path from typing import Any @@ -22,7 +20,9 @@ from axolotl.utils.distributed import is_main_process LOG = logging.getLogger(__name__) +POSTHOG_HOST = "https://app.posthog.com" POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y" + ENABLED_WARNING_SLEEP_SECONDS = 15 ENABLED_WARNING = ( "\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n" @@ -40,16 +40,7 @@ ENABLED_WARNING = ( f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..." ) - -@dataclass -class TelemetryConfig: - """Configuration for telemetry manager""" - - host: str = "https://app.posthog.com" - queue_size: int = 100 - batch_size: int = 10 - whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml") - retention_days: int = 365 +WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml") class TelemetryManager: @@ -82,7 +73,6 @@ class TelemetryManager: LOG.warning(ENABLED_WARNING) time.sleep(ENABLED_WARNING_SLEEP_SECONDS) - self.config = TelemetryConfig() self.run_id = str(uuid.uuid4()) self.whitelist = self._load_whitelist() self.system_info = self._get_system_info() @@ -142,7 +132,7 @@ class TelemetryManager: def _load_whitelist(self) -> dict: """Load HuggingFace Hub organization whitelist""" - with open(self.config.whitelist_path, encoding="utf-8") as f: + with open(WHITELIST_PATH, encoding="utf-8") as f: return yaml.safe_load(f) def _is_whitelisted(self, base_model: str) -> bool: @@ -157,69 +147,44 @@ class TelemetryManager: def _init_posthog(self): """Initialize PostHog client""" + posthog.host = POSTHOG_HOST posthog.project_api_key = POSTHOG_WRITE_KEY - posthog.host = self.config.host - def _sanitize_properties(self, properties: dict[str, Any]) -> dict[str, Any]: + def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]: """ - Sanitize properties to remove any personally identifiable information such as: - - File paths - - URLs / Links - - Cloud storage locations + Redact properties to remove any paths, so as to avoid inadvertently collecting + private or personally identifiable information (PII). Args: - properties: Dictionary of properties to sanitize. + properties: Dictionary of properties to redact. Returns: - Sanitized properties dictionary. + Properties dictionary with paths redacted. """ if not properties: return {} - # Define regex patterns for different types of personal information - patterns = { - # File paths (Unix and Windows) - "file_path": re.compile(r"(?:/|\\)(?:[^/\\]+(?:/|\\))+[^/\\]+"), - # URLs/Links - "url": re.compile(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+(?:/[^/\s]*)*"), - # Cloud storage paths (S3, GCS, Azure) - "cloud_path": re.compile(r"s3://|gs://|azure://|blob.core.windows.net"), - } + # TODO: Keep this up to date with any config schema changes + path_indicators = {"path", "dir"} - # Deep copy isn't needed; we'll create a new dict with sanitized values - sanitized = {} + def redact_value(value: Any, key: str = "") -> Any: + """Recursively sanitize values, redacting those with path-like keys""" + # If the key suggests this is a path, redact it + if any(indicator in key.lower() for indicator in path_indicators): + return "[REDACTED]" - def sanitize_value(value): - """Recursively sanitize values within nested structures""" - if isinstance(value, str): - # For file paths, extract just the filename - path_match = patterns["file_path"].search(value) - if path_match: - try: - # Try to extract just the filename - path_str = path_match.group(0) - value = value.replace(path_str, Path(path_str).name) - except (ValueError, RuntimeError): - # If path extraction fails, just redact the path - value = patterns["file_path"].sub("[REDACTED_PATH]", value) - - # Redact other sensitive information - value = patterns["url"].sub("[REDACTED_URL]", value) - value = patterns["cloud_path"].sub("[REDACTED_CLOUD]", value) - - return value + # Handle nested structures if isinstance(value, dict): - return {k: sanitize_value(v) for k, v in value.items()} + return {k: redact_value(v, k) for k, v in value.items()} if isinstance(value, list): - return [sanitize_value(item) for item in value] + return [redact_value(item) for item in value] return value - # Apply the sanitization to all properties - for key, value in properties.items(): - sanitized[key] = sanitize_value(value) + # Create new dict with redacted values + redacted = {k: redact_value(v, k) for k, v in properties.items()} - return sanitized + return redacted def _get_system_info(self) -> dict[str, Any]: """Collect system information""" @@ -254,7 +219,7 @@ class TelemetryManager: properties = {} # Sanitize properties to remove PII - properties = self._sanitize_properties(properties) + properties = self._redact_paths(properties) # Wrap PostHog errors in try / except to not raise errors during Axolotl usage try: diff --git a/tests/telemetry/test_callbacks.py b/tests/telemetry/test_callbacks.py index 4324126e7..6303812cc 100644 --- a/tests/telemetry/test_callbacks.py +++ b/tests/telemetry/test_callbacks.py @@ -113,7 +113,7 @@ class TestTelemetryCallback: callback.on_train_begin(training_args, trainer_state, trainer_control) mock_telemetry_manager.send_event.assert_called_once_with( - event_type="train-start" + event_type="train-started" ) def test_on_train_end( @@ -130,7 +130,7 @@ class TestTelemetryCallback: mock_telemetry_manager.send_event.assert_called_once() call_args = mock_telemetry_manager.send_event.call_args[1] - assert call_args["event_type"] == "train-end" + assert call_args["event_type"] == "train-ended" assert "loss" in call_args["properties"] assert call_args["properties"]["loss"] == 2.5 assert "learning_rate" in call_args["properties"] diff --git a/tests/telemetry/test_errors.py b/tests/telemetry/test_errors.py index 3d00c0f28..021d5fbd8 100644 --- a/tests/telemetry/test_errors.py +++ b/tests/telemetry/test_errors.py @@ -253,7 +253,7 @@ def test_send_errors_with_exception(mock_telemetry_manager): # Check that the error info was passed correctly call_args = mock_telemetry_manager.send_event.call_args[1] - assert "test_func-error" in call_args["event_type"] + assert "test_func-errored" in call_args["event_type"] assert "Test error" in call_args["properties"]["exception"] assert "stack_trace" in call_args["properties"] @@ -336,5 +336,5 @@ def test_module_path_resolution(mock_telemetry_manager): assert mock_telemetry_manager.send_event.called event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"] - expected_event_type = f"{current_module}.test_func-error" + expected_event_type = f"{current_module}.test_func-errored" assert expected_event_type == event_type diff --git a/tests/telemetry/test_manager.py b/tests/telemetry/test_manager.py index 04a6404bc..629016510 100644 --- a/tests/telemetry/test_manager.py +++ b/tests/telemetry/test_manager.py @@ -7,7 +7,7 @@ from unittest.mock import patch import pytest import yaml -from axolotl.telemetry.manager import TelemetryConfig, TelemetryManager +from axolotl.telemetry.manager import TelemetryManager @pytest.fixture @@ -38,11 +38,9 @@ def telemetry_manager_class(): @pytest.fixture def manager(telemetry_manager_class, mock_whitelist): """Create a TelemetryManager instance with mocked dependencies""" - with patch("posthog.capture"), patch("posthog.flush"), patch( - "time.sleep" - ), patch.object(TelemetryConfig, "whitelist_path", mock_whitelist), patch( - "axolotl.telemetry.manager.is_main_process", return_value=True - ): + with patch("posthog.capture"), patch("posthog.flush"), patch("time.sleep"), patch( + "axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist + ), patch("axolotl.telemetry.manager.is_main_process", return_value=True): manager = telemetry_manager_class() # Manually enable for most tests manager.enabled = True @@ -131,7 +129,7 @@ def test_warning_displayed_for_implicit_enable(telemetry_manager_class): def test_is_whitelisted(manager, mock_whitelist): """Test org whitelist functionality""" - with patch.object(TelemetryConfig, "whitelist_path", mock_whitelist): + with patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist): # Should match organizations from the mock whitelist assert manager._is_whitelisted("meta-llama/llama-7b") assert manager._is_whitelisted("mistralai/mistral-7b-instruct") @@ -185,19 +183,21 @@ def test_send_system_info(manager): assert mock_capture.call_args[1]["properties"] == manager.system_info -def test_sanitize_properties(manager): - """Test property sanitization in send_event method""" +def test_redacted_properties(manager): + """Test path redaction in send_event method""" with patch("posthog.capture") as mock_capture: - # Test with properties containing various PII + # Test with properties containing various paths and non-paths test_properties = { "filepath": "/home/user/sensitive/data.txt", "windows_path": "C:\\Users\\name\\Documents\\project\\file.py", - "url": "https://example.com/private/user123", - "message": "Error loading /tmp/axolotl/data.csv - check permissions", - "cloud_path": "s3://my-bucket/data/user-files/", + "output_dir": "/var/lib/data", + "path_to_model": "models/llama/7b", + "message": "Training started", # Should not be redacted + "metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted "nested": { - "deep_path": "/var/log/axolotl/training.log", - "list_paths": ["/home/user1/file1.txt", "/home/user2/file2.txt"], + "model_path": "/models/local/weights.pt", + "root_dir": "/home/user/projects", + "stats": {"steps": 1000, "epochs": 3}, # Should not be redacted }, } @@ -209,20 +209,19 @@ def test_sanitize_properties(manager): # Get the sanitized properties that were sent sanitized = mock_capture.call_args[1]["properties"] - # Check that PII was removed/sanitized - assert "/home/user/sensitive" not in str(sanitized) - assert "C:\\Users\\name" not in str(sanitized) - assert "https://example.com/private" not in str(sanitized) - assert "s3://my-bucket" not in str(sanitized) + # Check that path-like keys were redacted + assert sanitized["filepath"] == "[REDACTED]" + assert sanitized["windows_path"] == "[REDACTED]" + assert sanitized["path_to_model"] == "[REDACTED]" - # Check that filenames were preserved - assert "data.txt" in str(sanitized) - assert "file.py" in str(sanitized) - assert "data.csv" in str(sanitized) + # Check that non-path values were preserved + assert sanitized["message"] == "Training started" + assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95} - # Check that nested structures were properly sanitized - assert "/var/log/axolotl" not in str(sanitized["nested"]) - assert "/home/user1" not in str(sanitized["nested"]["list_paths"]) + # Check nested structure handling + assert sanitized["nested"]["model_path"] == "[REDACTED]" + assert sanitized["nested"]["root_dir"] == "[REDACTED]" + assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3} def test_disable_telemetry(manager):