updated sanitization logic, tests

This commit is contained in:
Dan Saunders
2025-02-24 20:05:55 +00:00
parent 1edd6b9524
commit d8b0522ea0
4 changed files with 666 additions and 97 deletions

View File

@@ -1,6 +1,7 @@
"""Telemetry utilities for exception and traceback information."""
import logging
import os
import re
import traceback
from functools import wraps
@@ -16,13 +17,16 @@ ERROR_HANDLED = False
def sanitize_stack_trace(stack_trace: str) -> str:
"""
Remove personal information from stack trace messages while keeping Axolotl codepaths.
Remove personal information from stack trace messages while keeping Python package codepaths.
This function identifies Python packages by looking for common patterns in virtual environment
and site-packages directories, preserving the package path while removing user-specific paths.
Args:
stack_trace: The original stack trace string.
Returns:
A sanitized version of the stack trace with only axolotl paths preserved.
A sanitized version of the stack trace with Python package paths preserved.
"""
# Split the stack trace into lines to process each file path separately
lines = stack_trace.split("\n")
@@ -31,23 +35,66 @@ def sanitize_stack_trace(stack_trace: str) -> str:
# Regular expression to find file paths in the stack trace
path_pattern = re.compile(r'(?:File ")(.*?)(?:")')
# Regular expression to identify paths in site-packages or dist-packages
# This matches path segments like "site-packages/package_name" or "dist-packages/package_name"
site_packages_pattern = re.compile(
r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
)
# Additional common virtual environment patterns
venv_lib_pattern = re.compile(
r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
)
for line in lines:
# Check if this line contains a file path
path_match = path_pattern.search(line)
if path_match:
full_path = path_match.group(1)
sanitized_path = ""
if "axolotl/" in full_path:
# Keep only the 'axolotl' part and onward
axolotl_idx = full_path.rfind("axolotl/")
if axolotl_idx >= 0:
# Replace the original path with the sanitized one
sanitized_path = full_path[axolotl_idx:]
line = line.replace(full_path, sanitized_path)
# Try to match site-packages pattern
site_packages_match = site_packages_pattern.search(full_path)
venv_lib_match = venv_lib_pattern.search(full_path)
if site_packages_match:
# Find the index where the matched pattern starts
idx = full_path.find("site-packages")
if idx == -1:
idx = full_path.find("dist-packages")
# Keep from 'site-packages' onward
if idx >= 0:
sanitized_path = full_path[idx:]
elif venv_lib_match:
# For other virtual environment patterns, find the package directory
match_idx = venv_lib_match.start(1)
if match_idx > 0:
# Keep from the package name onward
package_name = venv_lib_match.group(1)
idx = full_path.rfind(
package_name, 0, match_idx + len(package_name)
)
if idx >= 0:
sanitized_path = full_path[idx:]
# If we couldn't identify a package pattern but path contains 'axolotl'
elif "axolotl" in full_path:
idx = full_path.rfind("axolotl")
if idx >= 0:
sanitized_path = full_path[idx:]
# Apply the sanitization to the line
if sanitized_path:
line = line.replace(full_path, sanitized_path)
else:
# For non-axolotl paths, replace with an empty string or a placeholder
line = line.replace(full_path, "")
# If we couldn't identify a package pattern, just keep the filename
filename = os.path.basename(full_path)
if filename:
line = line.replace(full_path, filename)
else:
line = line.replace(full_path, "")
sanitized_lines.append(line)
@@ -72,6 +119,7 @@ def send_errors(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
telemetry_manager = TelemetryManager.get_instance()
if not telemetry_manager.enabled:
return func(*args, **kwargs)
@@ -79,7 +127,7 @@ def send_errors(func: Callable) -> Callable:
return func(*args, **kwargs)
except Exception as exception:
# Only track if we're not already handling an error. This prevents us from
# capturing an error more than once in nested decorated function calls.
# capturing an error more than once in nested decorated function calls.=
global ERROR_HANDLED # pylint: disable=global-statement
if not ERROR_HANDLED:
ERROR_HANDLED = True

View File

@@ -4,6 +4,7 @@ import atexit
import logging
import os
import platform
import re
import time
import uuid
from dataclasses import dataclass
@@ -122,6 +123,9 @@ class TelemetryManager:
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
do_not_track = os.getenv("DO_NOT_TRACK")
# If explicitly enabled, we'll disable the telemetry warning message
explicit_enabled = axolotl_do_not_track in ["0", "false"]
if axolotl_do_not_track is None:
axolotl_do_not_track = "0"
@@ -134,9 +138,6 @@ class TelemetryManager:
"true",
) and do_not_track.lower() not in ("1", "true")
# If explicitly enabled, we'll disable the telemetry warning message
explicit_enabled = axolotl_do_not_track in ["0", "false"]
return enabled, explicit_enabled
def _load_whitelist(self) -> dict:
@@ -145,7 +146,7 @@ class TelemetryManager:
return yaml.safe_load(f)
def _is_whitelisted(self, base_model: str) -> bool:
"""Check if model/org is in whitelist"""
"""Check if model org is in whitelist"""
if not base_model:
return False
@@ -159,9 +160,66 @@ class TelemetryManager:
posthog.project_api_key = POSTHOG_WRITE_KEY
posthog.host = self.config.host
def _sanitize_path(self, path: str) -> str:
"""Remove personal information from file paths"""
return Path(path).name
def _sanitize_properties(self, properties: dict[str, Any]) -> dict[str, Any]:
"""
Sanitize properties to remove any personally identifiable information such as:
- File paths
- URLs / Links
- Cloud storage locations
Args:
properties: Dictionary of properties to sanitize.
Returns:
Sanitized properties dictionary.
"""
if not properties:
return {}
# Define regex patterns for different types of personal information
patterns = {
# File paths (Unix and Windows)
"file_path": re.compile(r"(?:/|\\)(?:[^/\\]+(?:/|\\))+[^/\\]+"),
# URLs/Links
"url": re.compile(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+(?:/[^/\s]*)*"),
# Cloud storage paths (S3, GCS, Azure)
"cloud_path": re.compile(r"s3://|gs://|azure://|blob.core.windows.net"),
}
# Deep copy isn't needed; we'll create a new dict with sanitized values
sanitized = {}
def sanitize_value(value):
"""Recursively sanitize values within nested structures"""
if isinstance(value, str):
# For file paths, extract just the filename
path_match = patterns["file_path"].search(value)
if path_match:
try:
# Try to extract just the filename
path_str = path_match.group(0)
value = value.replace(path_str, Path(path_str).name)
except (ValueError, RuntimeError):
# If path extraction fails, just redact the path
value = patterns["file_path"].sub("[REDACTED_PATH]", value)
# Redact other sensitive information
value = patterns["url"].sub("[REDACTED_URL]", value)
value = patterns["cloud_path"].sub("[REDACTED_CLOUD]", value)
return value
if isinstance(value, dict):
return {k: sanitize_value(v) for k, v in value.items()}
if isinstance(value, list):
return [sanitize_value(item) for item in value]
return value
# Apply the sanitization to all properties
for key, value in properties.items():
sanitized[key] = sanitize_value(value)
return sanitized
def _get_system_info(self) -> dict[str, Any]:
"""Collect system information"""
@@ -195,6 +253,9 @@ class TelemetryManager:
if properties is None:
properties = {}
# Sanitize properties to remove PII
properties = self._sanitize_properties(properties)
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
try:
LOG.warning(f"*** Sending telemetry for {event_type} ***")

View File

@@ -0,0 +1,340 @@
"""Tests for telemetry error utilities."""
# pylint: disable=redefined-outer-name
from unittest.mock import MagicMock, patch
import pytest
from axolotl.telemetry.errors import sanitize_stack_trace, send_errors
@pytest.fixture(autouse=True)
def reset_error_flag(monkeypatch):
"""Reset ERROR_HANDLED flag using monkeypatch"""
import axolotl.telemetry.errors
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
yield
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
@pytest.fixture
def example_stack_trace():
"""Provide a sample stack trace with mixed paths"""
return """Traceback (most recent call last):
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
trainer = get_trainer(cfg)
File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer
model = get_model(cfg, tokenizer)
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model
raise ValueError("Model path not found")
ValueError: Model path not found
"""
@pytest.fixture
def windows_stack_trace():
"""Provide a sample stack trace with Windows paths"""
return """Traceback (most recent call last):
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main
trainer = get_trainer(cfg)
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer
model = get_model(cfg, tokenizer)
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained
raise ValueError(f"Unrecognized configuration class {config.__class__}")
ValueError: Unrecognized configuration class <class 'transformers.models.llama.configuration_llama.LlamaConfig'>
"""
@pytest.fixture
def mixed_stack_trace():
"""Provide a sample stack trace with both axolotl and non-axolotl paths"""
return """Traceback (most recent call last):
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
trainer = get_trainer(cfg)
File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train
self._inner_training_loop()
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop
super()._inner_training_loop()
File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
RuntimeError: CUDA out of memory
"""
@pytest.fixture
def venv_stack_trace():
"""Provide a sample stack trace with virtual environment paths"""
return """Traceback (most recent call last):
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train
self._inner_training_loop()
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop
self.accelerator.backward(loss)
File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward
self.scaler.scale(loss).backward(**kwargs)
File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
RuntimeError: CUDA out of memory
"""
@pytest.fixture
def dist_packages_stack_trace():
"""Provide a sample stack trace with dist-packages paths"""
return """Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data
data = self._dataset_fetcher.fetch(index)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__
raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.")
IndexError: Index 10000 out of range for dataset of length 9832.
"""
@pytest.fixture
def project_stack_trace():
"""Provide a sample stack trace from a project directory (not a virtual env)"""
return """Traceback (most recent call last):
File "/home/user/projects/myproject/run.py", line 25, in <module>
main()
File "/home/user/projects/myproject/src/cli.py", line 45, in main
app.run()
File "/home/user/projects/myproject/src/app.py", line 102, in run
raise ValueError("Configuration missing")
ValueError: Configuration missing
"""
def test_sanitize_stack_trace(example_stack_trace):
"""Test that sanitize_stack_trace properly preserves axolotl paths"""
sanitized = sanitize_stack_trace(example_stack_trace)
# Check that personal paths are removed
assert "/home/user" not in sanitized
assert ".local/lib/python3.9" not in sanitized
# Check that site-packages is preserved
assert "site-packages/axolotl/cli/train.py" in sanitized
assert "site-packages/axolotl/train.py" in sanitized
assert "site-packages/axolotl/utils/models.py" in sanitized
# Check that error message is preserved
assert "ValueError: Model path not found" in sanitized
def test_sanitize_windows_paths(windows_stack_trace):
"""Test that sanitize_stack_trace handles Windows paths"""
sanitized = sanitize_stack_trace(windows_stack_trace)
# Check that personal paths are removed
assert "C:\\Users\\name" not in sanitized
assert "AppData\\Local\\Programs\\Python" not in sanitized
# Check that both axolotl and transformers packages are preserved
assert (
"site-packages\\axolotl\\cli\\train.py" in sanitized
or "site-packages/axolotl/cli/train.py" in sanitized
)
assert (
"site-packages\\axolotl\\train.py" in sanitized
or "site-packages/axolotl/train.py" in sanitized
)
assert (
"site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized
or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized
)
# Check that error message is preserved
assert "ValueError: Unrecognized configuration class" in sanitized
def test_sanitize_mixed_paths(mixed_stack_trace):
"""Test that sanitize_stack_trace preserves all package paths"""
sanitized = sanitize_stack_trace(mixed_stack_trace)
# Check that all package paths are preserved
assert "site-packages/axolotl/cli/train.py" in sanitized
assert "site-packages/transformers/trainer.py" in sanitized
assert "site-packages/axolotl/utils/trainer.py" in sanitized
assert "site-packages/torch/utils/data/dataloader.py" in sanitized
# Check that error message is preserved
assert "RuntimeError: CUDA out of memory" in sanitized
def test_sanitize_venv_paths(venv_stack_trace):
"""Test that sanitize_stack_trace preserves virtual environment package paths"""
sanitized = sanitize_stack_trace(venv_stack_trace)
# Check that personal paths are removed
assert "/home/user/venv" not in sanitized
# Check that all package paths are preserved
assert "site-packages/transformers/trainer.py" in sanitized
assert "site-packages/accelerate/accelerator.py" in sanitized
assert "site-packages/torch/_tensor.py" in sanitized
# Check that error message is preserved
assert "RuntimeError: CUDA out of memory" in sanitized
def test_sanitize_dist_packages(dist_packages_stack_trace):
"""Test that sanitize_stack_trace preserves dist-packages paths"""
sanitized = sanitize_stack_trace(dist_packages_stack_trace)
# Check that system paths are removed
assert "/usr/local/lib/python3.8" not in sanitized
# Check that all package paths are preserved
assert "dist-packages/torch/utils/data/dataloader.py" in sanitized
assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized
assert "dist-packages/datasets/arrow_dataset.py" in sanitized
# Check that error message is preserved
assert (
"IndexError: Index 10000 out of range for dataset of length 9832." in sanitized
)
def test_sanitize_project_paths(project_stack_trace):
"""Test handling of project paths (non-virtual env)"""
sanitized = sanitize_stack_trace(project_stack_trace)
# Check that personal paths are removed
assert "/home/user/projects" not in sanitized
# For non-package paths, we should at least preserve the filename
assert "run.py" in sanitized
assert "cli.py" in sanitized
assert "app.py" in sanitized
# Check that error message is preserved
assert "ValueError: Configuration missing" in sanitized
@pytest.fixture
def mock_telemetry_manager():
"""Create a mock TelemetryManager"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
yield mock_manager
def test_send_errors_successful_execution(mock_telemetry_manager):
"""Test that send_errors doesn't send telemetry for successful function execution"""
@send_errors
def test_func():
return "success"
result = test_func()
assert result == "success"
mock_telemetry_manager.send_event.assert_not_called()
def test_send_errors_with_exception(mock_telemetry_manager):
"""Test that send_errors sends telemetry when an exception occurs"""
test_error = ValueError("Test error")
@send_errors
def test_func():
raise test_error
with pytest.raises(ValueError) as excinfo:
test_func()
assert excinfo.value == test_error
mock_telemetry_manager.send_event.assert_called_once()
# Check that the error info was passed correctly
call_args = mock_telemetry_manager.send_event.call_args[1]
assert "test_func-error" in call_args["event_type"]
assert "Test error" in call_args["properties"]["exception"]
assert "stack_trace" in call_args["properties"]
def test_send_errors_nested_calls(mock_telemetry_manager):
"""Test that send_errors only sends telemetry once for nested decorated functions"""
@send_errors
def inner_func():
raise ValueError("Inner error")
@send_errors
def outer_func():
return inner_func()
with pytest.raises(ValueError):
outer_func()
# Telemetry should be sent only once for the inner function
assert mock_telemetry_manager.send_event.call_count == 1
call_args = mock_telemetry_manager.send_event.call_args[1]
assert "inner_func-error" in call_args["event_type"]
def test_send_errors_telemetry_disable():
"""Test that send_errors doesn't attempt to send telemetry when disabled"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = False
mock_manager_class.get_instance.return_value = mock_manager
@send_errors
def test_func():
raise ValueError("Test error")
with pytest.raises(ValueError):
test_func()
mock_manager.send_event.assert_not_called()
def test_error_handled_reset():
"""Test that ERROR_HANDLED flag is properly reset"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
# Create and configure the mock manager
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
from axolotl.telemetry.errors import ERROR_HANDLED
@send_errors
def test_func():
raise ValueError("Test error")
assert not ERROR_HANDLED
with pytest.raises(ValueError):
test_func()
from axolotl.telemetry.errors import ERROR_HANDLED
assert ERROR_HANDLED
def test_module_path_resolution(mock_telemetry_manager):
"""Test that the module path is correctly resolved for the event type"""
import inspect
current_module = inspect.getmodule(test_module_path_resolution).__name__
@send_errors
def test_func():
raise ValueError("Test error")
with pytest.raises(ValueError):
test_func()
assert mock_telemetry_manager.send_event.called
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
expected_event_type = f"{current_module}.test_func-error"
assert expected_event_type == event_type

View File

@@ -1,5 +1,5 @@
"""Tests for TelemetryManager class and utilities"""
# pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name,protected-access
import os
from unittest.mock import patch
@@ -7,7 +7,7 @@ from unittest.mock import patch
import pytest
import yaml
from axolotl.telemetry import TelemetryManager
from axolotl.telemetry.manager import TelemetryConfig, TelemetryManager
@pytest.fixture
@@ -19,116 +19,236 @@ def mock_whitelist(tmp_path):
whitelist_file = tmp_path / "whitelist.yaml"
with open(whitelist_file, "w", encoding="utf-8") as f:
yaml.dump(whitelist_content, f)
return str(whitelist_file)
@pytest.fixture
def manager():
"""Create a TelemetryManager instance with mocked PostHog"""
with patch("posthog.capture"):
return TelemetryManager()
def telemetry_manager_class():
"""Reset the TelemetryManager singleton between tests"""
original_instance = TelemetryManager._instance
original_initialized = TelemetryManager._initialized
TelemetryManager._instance = None
TelemetryManager._initialized = False
yield TelemetryManager
TelemetryManager._instance = original_instance
TelemetryManager._initialized = original_initialized
def test_telemetry_disabled_by_default():
"""Test that telemetry is disabled by default"""
manager = TelemetryManager()
assert not manager.enabled
@pytest.fixture
def manager(telemetry_manager_class, mock_whitelist):
"""Create a TelemetryManager instance with mocked dependencies"""
with patch("posthog.capture"), patch("posthog.flush"), patch(
"time.sleep"
), patch.object(TelemetryConfig, "whitelist_path", mock_whitelist), patch(
"axolotl.telemetry.manager.is_main_process", return_value=True
):
manager = telemetry_manager_class()
# Manually enable for most tests
manager.enabled = True
return manager
def test_telemetry_opt_in():
"""Test that telemetry can be enabled via environment variable"""
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
manager = TelemetryManager()
assert manager.enabled
def test_singleton_instance(telemetry_manager_class):
"""Test that TelemetryManager is a singleton"""
with patch("posthog.capture"), patch("time.sleep"):
first = telemetry_manager_class()
second = telemetry_manager_class()
assert first is second
assert telemetry_manager_class.get_instance() is first
def test_do_not_track_override():
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
manager = TelemetryManager()
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1"}), patch(
"axolotl.telemetry.manager.is_main_process", return_value=True
):
manager = telemetry_manager_class()
assert not manager.enabled
# pylint: disable=protected-access
def test_whitelist_checking(manager):
"""Test model whitelist functionality"""
# Should match organization
assert manager._is_whitelisted("meta-llama/llama-7b")
# Should match model name
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
# Should not match
assert not manager._is_whitelisted("unknown/model")
# Should handle case insensitively
assert manager._is_whitelisted("meta/Llama-7b")
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
with patch.dict(os.environ, {"DO_NOT_TRACK": "1"}), patch(
"axolotl.telemetry.manager.is_main_process", return_value=True
):
manager = telemetry_manager_class()
assert not manager.enabled
def test_event_tracking(manager):
"""Test basic event tracking"""
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
"""Test that telemetry is disabled for non-main processes"""
with patch("axolotl.telemetry.manager.is_main_process", return_value=False):
manager = telemetry_manager_class()
assert not manager.enabled
def test_telemetry_enabled_by_default(telemetry_manager_class):
"""Test that telemetry is enabled by default"""
with patch.dict(os.environ, {}, clear=True), patch(
"axolotl.telemetry.manager.is_main_process", return_value=True
), patch("time.sleep"), patch("logging.Logger.warning"):
manager = telemetry_manager_class()
assert manager.enabled
assert not manager.explicit_enable
def test_explicit_enable_disables_warning(telemetry_manager_class):
"""Test that explicit enabling prevents warning"""
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), patch(
"logging.Logger.warning"
) as mock_warning, patch(
"axolotl.telemetry.manager.is_main_process", return_value=True
), patch(
"time.sleep"
):
manager = telemetry_manager_class()
assert manager.enabled
assert manager.explicit_enable
for call in mock_warning.call_args_list:
assert "Telemetry is enabled" not in str(call)
def test_warning_displayed_for_implicit_enable(telemetry_manager_class):
"""Test that warning is displayed when telemetry is implicitly enabled"""
with patch.dict(os.environ, {}, clear=True), patch(
"logging.Logger.warning"
) as mock_warning, patch(
"axolotl.telemetry.manager.is_main_process", return_value=True
), patch(
"time.sleep"
):
manager = telemetry_manager_class()
assert manager.enabled
assert not manager.explicit_enable
warning_displayed = False
for call in mock_warning.call_args_list:
if "Telemetry is enabled" in str(call):
warning_displayed = True
break
assert warning_displayed
def test_is_whitelisted(manager, mock_whitelist):
"""Test org whitelist functionality"""
with patch.object(TelemetryConfig, "whitelist_path", mock_whitelist):
# Should match organizations from the mock whitelist
assert manager._is_whitelisted("meta-llama/llama-7b")
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
# Should not match
assert not manager._is_whitelisted("unknown/model")
# Should handle case insensitively
assert manager._is_whitelisted("META-LLAMA/Llama-7B")
# Should handle empty input
assert not manager._is_whitelisted("")
assert not manager._is_whitelisted(None)
def test_system_info_collection(manager):
"""Test system information collection"""
system_info = manager.system_info
# Check essential keys
assert "os" in system_info
assert "python_version" in system_info
assert "pytorch_version" in system_info
assert "transformers_version" in system_info
assert "axolotl_version" in system_info
assert "cpu_count" in system_info
assert "memory_total" in system_info
assert "gpu_count" in system_info
def test_send_event(manager):
"""Test basic event sending"""
with patch("posthog.capture") as mock_capture:
manager.enabled = True
manager.track_event("test_event", {"key": "value"})
# Test with clean properties (no PII)
manager.send_event("test_event", {"key": "value"})
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "test_event"
assert mock_capture.call_args[1]["properties"]["key"] == "value"
assert "run_id" in mock_capture.call_args[1]["properties"]
assert "system_info" in mock_capture.call_args[1]["properties"]
assert mock_capture.call_args[1]["properties"] == {"key": "value"}
assert mock_capture.call_args[1]["distinct_id"] == manager.run_id
# Test with default properties (None)
mock_capture.reset_mock()
manager.send_event("simple_event")
assert mock_capture.called
assert mock_capture.call_args[1]["properties"] == {}
def test_training_context(manager):
"""Test training context manager"""
config = {"model": "llama", "batch_size": 8}
def test_send_system_info(manager):
"""Test sending system info"""
with patch("posthog.capture") as mock_capture:
manager.enabled = True
with manager.track_training(config):
pass # Simulate successful training
# Should have captured training_start and training_complete
events = [call[1]["event"] for call in mock_capture.call_args_list]
assert "training_start" in events
assert "training_complete" in events
manager.send_system_info()
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "system-info"
assert mock_capture.call_args[1]["properties"] == manager.system_info
def test_training_error(manager):
"""Test training context manager with error"""
config = {"model": "llama", "batch_size": 8}
def test_sanitize_properties(manager):
"""Test property sanitization in send_event method"""
with patch("posthog.capture") as mock_capture:
manager.enabled = True
# Test with properties containing various PII
test_properties = {
"filepath": "/home/user/sensitive/data.txt",
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
"url": "https://example.com/private/user123",
"message": "Error loading /tmp/axolotl/data.csv - check permissions",
"cloud_path": "s3://my-bucket/data/user-files/",
"nested": {
"deep_path": "/var/log/axolotl/training.log",
"list_paths": ["/home/user1/file1.txt", "/home/user2/file2.txt"],
},
}
with pytest.raises(ValueError):
with manager.track_training(config):
raise ValueError("Test error")
manager.send_event("test_event", test_properties)
# Should have captured training_start and training_error
events = [call[1]["event"] for call in mock_capture.call_args_list]
assert "training_start" in events
assert "training_error" in events
# Verify the call was made
assert mock_capture.called
# Get the sanitized properties that were sent
sanitized = mock_capture.call_args[1]["properties"]
# Check that PII was removed/sanitized
assert "/home/user/sensitive" not in str(sanitized)
assert "C:\\Users\\name" not in str(sanitized)
assert "https://example.com/private" not in str(sanitized)
assert "s3://my-bucket" not in str(sanitized)
# Check that filenames were preserved
assert "data.txt" in str(sanitized)
assert "file.py" in str(sanitized)
assert "data.csv" in str(sanitized)
# Check that nested structures were properly sanitized
assert "/var/log/axolotl" not in str(sanitized["nested"])
assert "/home/user1" not in str(sanitized["nested"]["list_paths"])
# pylint: disable=protected-access
def test_path_sanitization(manager):
"""Test path sanitization"""
path = "/home/user/sensitive/data.txt"
sanitized = manager._sanitize_path(path)
assert sanitized == "data.txt"
assert "/home/user" not in sanitized
def test_disable_telemetry(manager):
"""Test that disabled telemetry doesn't send events"""
with patch("posthog.capture") as mock_capture:
manager.enabled = False
manager.send_event("test_event")
assert not mock_capture.called
# pylint: disable=protected-access
def test_error_sanitization(manager):
"""Test error message sanitization"""
error = "Failed to load /home/user/sensitive/data.txt: File not found"
sanitized = manager._sanitize_error(error)
assert "sensitive" not in sanitized
assert "/home/user" not in sanitized
def test_exception_handling_during_send(manager):
"""Test that exceptions in PostHog are handled gracefully"""
with patch("posthog.capture", side_effect=Exception("Test error")), patch(
"logging.Logger.warning"
) as mock_warning:
manager.send_event("test_event")
warning_logged = False
for call in mock_warning.call_args_list:
if "Failed to send telemetry event" in str(call):
warning_logged = True
break
assert warning_logged
def test_shutdown(manager):
"""Test shutdown behavior"""
with patch("posthog.flush") as mock_flush:
manager.enabled = True
manager.shutdown()
assert mock_flush.called