simplifying path redaction
This commit is contained in:
0
src/axolotl/telemetry/__init__.py
Normal file
0
src/axolotl/telemetry/__init__.py
Normal file
@@ -4,10 +4,8 @@ import atexit
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -22,7 +20,9 @@ from axolotl.utils.distributed import is_main_process
|
|||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
POSTHOG_HOST = "https://app.posthog.com"
|
||||||
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
|
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
|
||||||
|
|
||||||
ENABLED_WARNING_SLEEP_SECONDS = 15
|
ENABLED_WARNING_SLEEP_SECONDS = 15
|
||||||
ENABLED_WARNING = (
|
ENABLED_WARNING = (
|
||||||
"\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
|
"\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
|
||||||
@@ -40,16 +40,7 @@ ENABLED_WARNING = (
|
|||||||
f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..."
|
f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
|
||||||
@dataclass
|
|
||||||
class TelemetryConfig:
|
|
||||||
"""Configuration for telemetry manager"""
|
|
||||||
|
|
||||||
host: str = "https://app.posthog.com"
|
|
||||||
queue_size: int = 100
|
|
||||||
batch_size: int = 10
|
|
||||||
whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml")
|
|
||||||
retention_days: int = 365
|
|
||||||
|
|
||||||
|
|
||||||
class TelemetryManager:
|
class TelemetryManager:
|
||||||
@@ -82,7 +73,6 @@ class TelemetryManager:
|
|||||||
LOG.warning(ENABLED_WARNING)
|
LOG.warning(ENABLED_WARNING)
|
||||||
time.sleep(ENABLED_WARNING_SLEEP_SECONDS)
|
time.sleep(ENABLED_WARNING_SLEEP_SECONDS)
|
||||||
|
|
||||||
self.config = TelemetryConfig()
|
|
||||||
self.run_id = str(uuid.uuid4())
|
self.run_id = str(uuid.uuid4())
|
||||||
self.whitelist = self._load_whitelist()
|
self.whitelist = self._load_whitelist()
|
||||||
self.system_info = self._get_system_info()
|
self.system_info = self._get_system_info()
|
||||||
@@ -142,7 +132,7 @@ class TelemetryManager:
|
|||||||
|
|
||||||
def _load_whitelist(self) -> dict:
|
def _load_whitelist(self) -> dict:
|
||||||
"""Load HuggingFace Hub organization whitelist"""
|
"""Load HuggingFace Hub organization whitelist"""
|
||||||
with open(self.config.whitelist_path, encoding="utf-8") as f:
|
with open(WHITELIST_PATH, encoding="utf-8") as f:
|
||||||
return yaml.safe_load(f)
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
def _is_whitelisted(self, base_model: str) -> bool:
|
def _is_whitelisted(self, base_model: str) -> bool:
|
||||||
@@ -157,69 +147,44 @@ class TelemetryManager:
|
|||||||
|
|
||||||
def _init_posthog(self):
|
def _init_posthog(self):
|
||||||
"""Initialize PostHog client"""
|
"""Initialize PostHog client"""
|
||||||
|
posthog.host = POSTHOG_HOST
|
||||||
posthog.project_api_key = POSTHOG_WRITE_KEY
|
posthog.project_api_key = POSTHOG_WRITE_KEY
|
||||||
posthog.host = self.config.host
|
|
||||||
|
|
||||||
def _sanitize_properties(self, properties: dict[str, Any]) -> dict[str, Any]:
|
def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Sanitize properties to remove any personally identifiable information such as:
|
Redact properties to remove any paths, so as to avoid inadvertently collecting
|
||||||
- File paths
|
private or personally identifiable information (PII).
|
||||||
- URLs / Links
|
|
||||||
- Cloud storage locations
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
properties: Dictionary of properties to sanitize.
|
properties: Dictionary of properties to redact.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Sanitized properties dictionary.
|
Properties dictionary with paths redacted.
|
||||||
"""
|
"""
|
||||||
if not properties:
|
if not properties:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Define regex patterns for different types of personal information
|
# TODO: Keep this up to date with any config schema changes
|
||||||
patterns = {
|
path_indicators = {"path", "dir"}
|
||||||
# File paths (Unix and Windows)
|
|
||||||
"file_path": re.compile(r"(?:/|\\)(?:[^/\\]+(?:/|\\))+[^/\\]+"),
|
|
||||||
# URLs/Links
|
|
||||||
"url": re.compile(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+(?:/[^/\s]*)*"),
|
|
||||||
# Cloud storage paths (S3, GCS, Azure)
|
|
||||||
"cloud_path": re.compile(r"s3://|gs://|azure://|blob.core.windows.net"),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Deep copy isn't needed; we'll create a new dict with sanitized values
|
def redact_value(value: Any, key: str = "") -> Any:
|
||||||
sanitized = {}
|
"""Recursively sanitize values, redacting those with path-like keys"""
|
||||||
|
# If the key suggests this is a path, redact it
|
||||||
|
if any(indicator in key.lower() for indicator in path_indicators):
|
||||||
|
return "[REDACTED]"
|
||||||
|
|
||||||
def sanitize_value(value):
|
# Handle nested structures
|
||||||
"""Recursively sanitize values within nested structures"""
|
|
||||||
if isinstance(value, str):
|
|
||||||
# For file paths, extract just the filename
|
|
||||||
path_match = patterns["file_path"].search(value)
|
|
||||||
if path_match:
|
|
||||||
try:
|
|
||||||
# Try to extract just the filename
|
|
||||||
path_str = path_match.group(0)
|
|
||||||
value = value.replace(path_str, Path(path_str).name)
|
|
||||||
except (ValueError, RuntimeError):
|
|
||||||
# If path extraction fails, just redact the path
|
|
||||||
value = patterns["file_path"].sub("[REDACTED_PATH]", value)
|
|
||||||
|
|
||||||
# Redact other sensitive information
|
|
||||||
value = patterns["url"].sub("[REDACTED_URL]", value)
|
|
||||||
value = patterns["cloud_path"].sub("[REDACTED_CLOUD]", value)
|
|
||||||
|
|
||||||
return value
|
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
return {k: sanitize_value(v) for k, v in value.items()}
|
return {k: redact_value(v, k) for k, v in value.items()}
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
return [sanitize_value(item) for item in value]
|
return [redact_value(item) for item in value]
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
# Apply the sanitization to all properties
|
# Create new dict with redacted values
|
||||||
for key, value in properties.items():
|
redacted = {k: redact_value(v, k) for k, v in properties.items()}
|
||||||
sanitized[key] = sanitize_value(value)
|
|
||||||
|
|
||||||
return sanitized
|
return redacted
|
||||||
|
|
||||||
def _get_system_info(self) -> dict[str, Any]:
|
def _get_system_info(self) -> dict[str, Any]:
|
||||||
"""Collect system information"""
|
"""Collect system information"""
|
||||||
@@ -254,7 +219,7 @@ class TelemetryManager:
|
|||||||
properties = {}
|
properties = {}
|
||||||
|
|
||||||
# Sanitize properties to remove PII
|
# Sanitize properties to remove PII
|
||||||
properties = self._sanitize_properties(properties)
|
properties = self._redact_paths(properties)
|
||||||
|
|
||||||
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
|
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class TestTelemetryCallback:
|
|||||||
callback.on_train_begin(training_args, trainer_state, trainer_control)
|
callback.on_train_begin(training_args, trainer_state, trainer_control)
|
||||||
|
|
||||||
mock_telemetry_manager.send_event.assert_called_once_with(
|
mock_telemetry_manager.send_event.assert_called_once_with(
|
||||||
event_type="train-start"
|
event_type="train-started"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_on_train_end(
|
def test_on_train_end(
|
||||||
@@ -130,7 +130,7 @@ class TestTelemetryCallback:
|
|||||||
mock_telemetry_manager.send_event.assert_called_once()
|
mock_telemetry_manager.send_event.assert_called_once()
|
||||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||||
|
|
||||||
assert call_args["event_type"] == "train-end"
|
assert call_args["event_type"] == "train-ended"
|
||||||
assert "loss" in call_args["properties"]
|
assert "loss" in call_args["properties"]
|
||||||
assert call_args["properties"]["loss"] == 2.5
|
assert call_args["properties"]["loss"] == 2.5
|
||||||
assert "learning_rate" in call_args["properties"]
|
assert "learning_rate" in call_args["properties"]
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ def test_send_errors_with_exception(mock_telemetry_manager):
|
|||||||
|
|
||||||
# Check that the error info was passed correctly
|
# Check that the error info was passed correctly
|
||||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||||
assert "test_func-error" in call_args["event_type"]
|
assert "test_func-errored" in call_args["event_type"]
|
||||||
assert "Test error" in call_args["properties"]["exception"]
|
assert "Test error" in call_args["properties"]["exception"]
|
||||||
assert "stack_trace" in call_args["properties"]
|
assert "stack_trace" in call_args["properties"]
|
||||||
|
|
||||||
@@ -336,5 +336,5 @@ def test_module_path_resolution(mock_telemetry_manager):
|
|||||||
assert mock_telemetry_manager.send_event.called
|
assert mock_telemetry_manager.send_event.called
|
||||||
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
|
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
|
||||||
|
|
||||||
expected_event_type = f"{current_module}.test_func-error"
|
expected_event_type = f"{current_module}.test_func-errored"
|
||||||
assert expected_event_type == event_type
|
assert expected_event_type == event_type
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.telemetry.manager import TelemetryConfig, TelemetryManager
|
from axolotl.telemetry.manager import TelemetryManager
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -38,11 +38,9 @@ def telemetry_manager_class():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def manager(telemetry_manager_class, mock_whitelist):
|
def manager(telemetry_manager_class, mock_whitelist):
|
||||||
"""Create a TelemetryManager instance with mocked dependencies"""
|
"""Create a TelemetryManager instance with mocked dependencies"""
|
||||||
with patch("posthog.capture"), patch("posthog.flush"), patch(
|
with patch("posthog.capture"), patch("posthog.flush"), patch("time.sleep"), patch(
|
||||||
"time.sleep"
|
"axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist
|
||||||
), patch.object(TelemetryConfig, "whitelist_path", mock_whitelist), patch(
|
), patch("axolotl.telemetry.manager.is_main_process", return_value=True):
|
||||||
"axolotl.telemetry.manager.is_main_process", return_value=True
|
|
||||||
):
|
|
||||||
manager = telemetry_manager_class()
|
manager = telemetry_manager_class()
|
||||||
# Manually enable for most tests
|
# Manually enable for most tests
|
||||||
manager.enabled = True
|
manager.enabled = True
|
||||||
@@ -131,7 +129,7 @@ def test_warning_displayed_for_implicit_enable(telemetry_manager_class):
|
|||||||
|
|
||||||
def test_is_whitelisted(manager, mock_whitelist):
|
def test_is_whitelisted(manager, mock_whitelist):
|
||||||
"""Test org whitelist functionality"""
|
"""Test org whitelist functionality"""
|
||||||
with patch.object(TelemetryConfig, "whitelist_path", mock_whitelist):
|
with patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist):
|
||||||
# Should match organizations from the mock whitelist
|
# Should match organizations from the mock whitelist
|
||||||
assert manager._is_whitelisted("meta-llama/llama-7b")
|
assert manager._is_whitelisted("meta-llama/llama-7b")
|
||||||
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
|
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
|
||||||
@@ -185,19 +183,21 @@ def test_send_system_info(manager):
|
|||||||
assert mock_capture.call_args[1]["properties"] == manager.system_info
|
assert mock_capture.call_args[1]["properties"] == manager.system_info
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_properties(manager):
|
def test_redacted_properties(manager):
|
||||||
"""Test property sanitization in send_event method"""
|
"""Test path redaction in send_event method"""
|
||||||
with patch("posthog.capture") as mock_capture:
|
with patch("posthog.capture") as mock_capture:
|
||||||
# Test with properties containing various PII
|
# Test with properties containing various paths and non-paths
|
||||||
test_properties = {
|
test_properties = {
|
||||||
"filepath": "/home/user/sensitive/data.txt",
|
"filepath": "/home/user/sensitive/data.txt",
|
||||||
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
|
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
|
||||||
"url": "https://example.com/private/user123",
|
"output_dir": "/var/lib/data",
|
||||||
"message": "Error loading /tmp/axolotl/data.csv - check permissions",
|
"path_to_model": "models/llama/7b",
|
||||||
"cloud_path": "s3://my-bucket/data/user-files/",
|
"message": "Training started", # Should not be redacted
|
||||||
|
"metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted
|
||||||
"nested": {
|
"nested": {
|
||||||
"deep_path": "/var/log/axolotl/training.log",
|
"model_path": "/models/local/weights.pt",
|
||||||
"list_paths": ["/home/user1/file1.txt", "/home/user2/file2.txt"],
|
"root_dir": "/home/user/projects",
|
||||||
|
"stats": {"steps": 1000, "epochs": 3}, # Should not be redacted
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,20 +209,19 @@ def test_sanitize_properties(manager):
|
|||||||
# Get the sanitized properties that were sent
|
# Get the sanitized properties that were sent
|
||||||
sanitized = mock_capture.call_args[1]["properties"]
|
sanitized = mock_capture.call_args[1]["properties"]
|
||||||
|
|
||||||
# Check that PII was removed/sanitized
|
# Check that path-like keys were redacted
|
||||||
assert "/home/user/sensitive" not in str(sanitized)
|
assert sanitized["filepath"] == "[REDACTED]"
|
||||||
assert "C:\\Users\\name" not in str(sanitized)
|
assert sanitized["windows_path"] == "[REDACTED]"
|
||||||
assert "https://example.com/private" not in str(sanitized)
|
assert sanitized["path_to_model"] == "[REDACTED]"
|
||||||
assert "s3://my-bucket" not in str(sanitized)
|
|
||||||
|
|
||||||
# Check that filenames were preserved
|
# Check that non-path values were preserved
|
||||||
assert "data.txt" in str(sanitized)
|
assert sanitized["message"] == "Training started"
|
||||||
assert "file.py" in str(sanitized)
|
assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95}
|
||||||
assert "data.csv" in str(sanitized)
|
|
||||||
|
|
||||||
# Check that nested structures were properly sanitized
|
# Check nested structure handling
|
||||||
assert "/var/log/axolotl" not in str(sanitized["nested"])
|
assert sanitized["nested"]["model_path"] == "[REDACTED]"
|
||||||
assert "/home/user1" not in str(sanitized["nested"]["list_paths"])
|
assert sanitized["nested"]["root_dir"] == "[REDACTED]"
|
||||||
|
assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3}
|
||||||
|
|
||||||
|
|
||||||
def test_disable_telemetry(manager):
|
def test_disable_telemetry(manager):
|
||||||
|
|||||||
Reference in New Issue
Block a user