simplifying path redaction

This commit is contained in:
Dan Saunders
2025-02-24 00:06:08 +00:00
parent db3297b090
commit ef4990f304
5 changed files with 54 additions and 90 deletions

View File

View File

@@ -4,10 +4,8 @@ import atexit
import logging import logging
import os import os
import platform import platform
import re
import time import time
import uuid import uuid
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -22,7 +20,9 @@ from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
POSTHOG_HOST = "https://app.posthog.com"
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y" POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
ENABLED_WARNING_SLEEP_SECONDS = 15 ENABLED_WARNING_SLEEP_SECONDS = 15
ENABLED_WARNING = ( ENABLED_WARNING = (
"\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n" "\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
@@ -40,16 +40,7 @@ ENABLED_WARNING = (
f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..." f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..."
) )
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
@dataclass
class TelemetryConfig:
"""Configuration for telemetry manager"""
host: str = "https://app.posthog.com"
queue_size: int = 100
batch_size: int = 10
whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml")
retention_days: int = 365
class TelemetryManager: class TelemetryManager:
@@ -82,7 +73,6 @@ class TelemetryManager:
LOG.warning(ENABLED_WARNING) LOG.warning(ENABLED_WARNING)
time.sleep(ENABLED_WARNING_SLEEP_SECONDS) time.sleep(ENABLED_WARNING_SLEEP_SECONDS)
self.config = TelemetryConfig()
self.run_id = str(uuid.uuid4()) self.run_id = str(uuid.uuid4())
self.whitelist = self._load_whitelist() self.whitelist = self._load_whitelist()
self.system_info = self._get_system_info() self.system_info = self._get_system_info()
@@ -142,7 +132,7 @@ class TelemetryManager:
def _load_whitelist(self) -> dict: def _load_whitelist(self) -> dict:
"""Load HuggingFace Hub organization whitelist""" """Load HuggingFace Hub organization whitelist"""
with open(self.config.whitelist_path, encoding="utf-8") as f: with open(WHITELIST_PATH, encoding="utf-8") as f:
return yaml.safe_load(f) return yaml.safe_load(f)
def _is_whitelisted(self, base_model: str) -> bool: def _is_whitelisted(self, base_model: str) -> bool:
@@ -157,69 +147,44 @@ class TelemetryManager:
def _init_posthog(self): def _init_posthog(self):
"""Initialize PostHog client""" """Initialize PostHog client"""
posthog.host = POSTHOG_HOST
posthog.project_api_key = POSTHOG_WRITE_KEY posthog.project_api_key = POSTHOG_WRITE_KEY
posthog.host = self.config.host
def _sanitize_properties(self, properties: dict[str, Any]) -> dict[str, Any]: def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]:
""" """
Sanitize properties to remove any personally identifiable information such as: Redact properties to remove any paths, so as to avoid inadvertently collecting
- File paths private or personally identifiable information (PII).
- URLs / Links
- Cloud storage locations
Args: Args:
properties: Dictionary of properties to sanitize. properties: Dictionary of properties to redact.
Returns: Returns:
Sanitized properties dictionary. Properties dictionary with paths redacted.
""" """
if not properties: if not properties:
return {} return {}
# Define regex patterns for different types of personal information # TODO: Keep this up to date with any config schema changes
patterns = { path_indicators = {"path", "dir"}
# File paths (Unix and Windows)
"file_path": re.compile(r"(?:/|\\)(?:[^/\\]+(?:/|\\))+[^/\\]+"),
# URLs/Links
"url": re.compile(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+(?:/[^/\s]*)*"),
# Cloud storage paths (S3, GCS, Azure)
"cloud_path": re.compile(r"s3://|gs://|azure://|blob.core.windows.net"),
}
# Deep copy isn't needed; we'll create a new dict with sanitized values def redact_value(value: Any, key: str = "") -> Any:
sanitized = {} """Recursively sanitize values, redacting those with path-like keys"""
# If the key suggests this is a path, redact it
if any(indicator in key.lower() for indicator in path_indicators):
return "[REDACTED]"
def sanitize_value(value): # Handle nested structures
"""Recursively sanitize values within nested structures"""
if isinstance(value, str):
# For file paths, extract just the filename
path_match = patterns["file_path"].search(value)
if path_match:
try:
# Try to extract just the filename
path_str = path_match.group(0)
value = value.replace(path_str, Path(path_str).name)
except (ValueError, RuntimeError):
# If path extraction fails, just redact the path
value = patterns["file_path"].sub("[REDACTED_PATH]", value)
# Redact other sensitive information
value = patterns["url"].sub("[REDACTED_URL]", value)
value = patterns["cloud_path"].sub("[REDACTED_CLOUD]", value)
return value
if isinstance(value, dict): if isinstance(value, dict):
return {k: sanitize_value(v) for k, v in value.items()} return {k: redact_value(v, k) for k, v in value.items()}
if isinstance(value, list): if isinstance(value, list):
return [sanitize_value(item) for item in value] return [redact_value(item) for item in value]
return value return value
# Apply the sanitization to all properties # Create new dict with redacted values
for key, value in properties.items(): redacted = {k: redact_value(v, k) for k, v in properties.items()}
sanitized[key] = sanitize_value(value)
return sanitized return redacted
def _get_system_info(self) -> dict[str, Any]: def _get_system_info(self) -> dict[str, Any]:
"""Collect system information""" """Collect system information"""
@@ -254,7 +219,7 @@ class TelemetryManager:
properties = {} properties = {}
# Sanitize properties to remove PII # Sanitize properties to remove PII
properties = self._sanitize_properties(properties) properties = self._redact_paths(properties)
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage # Wrap PostHog errors in try / except to not raise errors during Axolotl usage
try: try:

View File

@@ -113,7 +113,7 @@ class TestTelemetryCallback:
callback.on_train_begin(training_args, trainer_state, trainer_control) callback.on_train_begin(training_args, trainer_state, trainer_control)
mock_telemetry_manager.send_event.assert_called_once_with( mock_telemetry_manager.send_event.assert_called_once_with(
event_type="train-start" event_type="train-started"
) )
def test_on_train_end( def test_on_train_end(
@@ -130,7 +130,7 @@ class TestTelemetryCallback:
mock_telemetry_manager.send_event.assert_called_once() mock_telemetry_manager.send_event.assert_called_once()
call_args = mock_telemetry_manager.send_event.call_args[1] call_args = mock_telemetry_manager.send_event.call_args[1]
assert call_args["event_type"] == "train-end" assert call_args["event_type"] == "train-ended"
assert "loss" in call_args["properties"] assert "loss" in call_args["properties"]
assert call_args["properties"]["loss"] == 2.5 assert call_args["properties"]["loss"] == 2.5
assert "learning_rate" in call_args["properties"] assert "learning_rate" in call_args["properties"]

View File

@@ -253,7 +253,7 @@ def test_send_errors_with_exception(mock_telemetry_manager):
# Check that the error info was passed correctly # Check that the error info was passed correctly
call_args = mock_telemetry_manager.send_event.call_args[1] call_args = mock_telemetry_manager.send_event.call_args[1]
assert "test_func-error" in call_args["event_type"] assert "test_func-errored" in call_args["event_type"]
assert "Test error" in call_args["properties"]["exception"] assert "Test error" in call_args["properties"]["exception"]
assert "stack_trace" in call_args["properties"] assert "stack_trace" in call_args["properties"]
@@ -336,5 +336,5 @@ def test_module_path_resolution(mock_telemetry_manager):
assert mock_telemetry_manager.send_event.called assert mock_telemetry_manager.send_event.called
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"] event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
expected_event_type = f"{current_module}.test_func-error" expected_event_type = f"{current_module}.test_func-errored"
assert expected_event_type == event_type assert expected_event_type == event_type

View File

@@ -7,7 +7,7 @@ from unittest.mock import patch
import pytest import pytest
import yaml import yaml
from axolotl.telemetry.manager import TelemetryConfig, TelemetryManager from axolotl.telemetry.manager import TelemetryManager
@pytest.fixture @pytest.fixture
@@ -38,11 +38,9 @@ def telemetry_manager_class():
@pytest.fixture @pytest.fixture
def manager(telemetry_manager_class, mock_whitelist): def manager(telemetry_manager_class, mock_whitelist):
"""Create a TelemetryManager instance with mocked dependencies""" """Create a TelemetryManager instance with mocked dependencies"""
with patch("posthog.capture"), patch("posthog.flush"), patch( with patch("posthog.capture"), patch("posthog.flush"), patch("time.sleep"), patch(
"time.sleep" "axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist
), patch.object(TelemetryConfig, "whitelist_path", mock_whitelist), patch( ), patch("axolotl.telemetry.manager.is_main_process", return_value=True):
"axolotl.telemetry.manager.is_main_process", return_value=True
):
manager = telemetry_manager_class() manager = telemetry_manager_class()
# Manually enable for most tests # Manually enable for most tests
manager.enabled = True manager.enabled = True
@@ -131,7 +129,7 @@ def test_warning_displayed_for_implicit_enable(telemetry_manager_class):
def test_is_whitelisted(manager, mock_whitelist): def test_is_whitelisted(manager, mock_whitelist):
"""Test org whitelist functionality""" """Test org whitelist functionality"""
with patch.object(TelemetryConfig, "whitelist_path", mock_whitelist): with patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist):
# Should match organizations from the mock whitelist # Should match organizations from the mock whitelist
assert manager._is_whitelisted("meta-llama/llama-7b") assert manager._is_whitelisted("meta-llama/llama-7b")
assert manager._is_whitelisted("mistralai/mistral-7b-instruct") assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
@@ -185,19 +183,21 @@ def test_send_system_info(manager):
assert mock_capture.call_args[1]["properties"] == manager.system_info assert mock_capture.call_args[1]["properties"] == manager.system_info
def test_sanitize_properties(manager): def test_redacted_properties(manager):
"""Test property sanitization in send_event method""" """Test path redaction in send_event method"""
with patch("posthog.capture") as mock_capture: with patch("posthog.capture") as mock_capture:
# Test with properties containing various PII # Test with properties containing various paths and non-paths
test_properties = { test_properties = {
"filepath": "/home/user/sensitive/data.txt", "filepath": "/home/user/sensitive/data.txt",
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py", "windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
"url": "https://example.com/private/user123", "output_dir": "/var/lib/data",
"message": "Error loading /tmp/axolotl/data.csv - check permissions", "path_to_model": "models/llama/7b",
"cloud_path": "s3://my-bucket/data/user-files/", "message": "Training started", # Should not be redacted
"metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted
"nested": { "nested": {
"deep_path": "/var/log/axolotl/training.log", "model_path": "/models/local/weights.pt",
"list_paths": ["/home/user1/file1.txt", "/home/user2/file2.txt"], "root_dir": "/home/user/projects",
"stats": {"steps": 1000, "epochs": 3}, # Should not be redacted
}, },
} }
@@ -209,20 +209,19 @@ def test_sanitize_properties(manager):
# Get the sanitized properties that were sent # Get the sanitized properties that were sent
sanitized = mock_capture.call_args[1]["properties"] sanitized = mock_capture.call_args[1]["properties"]
# Check that PII was removed/sanitized # Check that path-like keys were redacted
assert "/home/user/sensitive" not in str(sanitized) assert sanitized["filepath"] == "[REDACTED]"
assert "C:\\Users\\name" not in str(sanitized) assert sanitized["windows_path"] == "[REDACTED]"
assert "https://example.com/private" not in str(sanitized) assert sanitized["path_to_model"] == "[REDACTED]"
assert "s3://my-bucket" not in str(sanitized)
# Check that filenames were preserved # Check that non-path values were preserved
assert "data.txt" in str(sanitized) assert sanitized["message"] == "Training started"
assert "file.py" in str(sanitized) assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95}
assert "data.csv" in str(sanitized)
# Check that nested structures were properly sanitized # Check nested structure handling
assert "/var/log/axolotl" not in str(sanitized["nested"]) assert sanitized["nested"]["model_path"] == "[REDACTED]"
assert "/home/user1" not in str(sanitized["nested"]["list_paths"]) assert sanitized["nested"]["root_dir"] == "[REDACTED]"
assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3}
def test_disable_telemetry(manager): def test_disable_telemetry(manager):