From 5afab46cc624970d1618f9a6d9b7d032333754a6 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 19 Feb 2025 13:54:51 +0000 Subject: [PATCH] updates --- src/axolotl/telemetry.py | 206 -------------------- src/axolotl/telemetry/__init__.py | 5 + src/axolotl/telemetry/manager.py | 269 +++++++++++++++++++++++++++ src/axolotl/telemetry/whitelist.yaml | 11 ++ tests/telemetry/test_manager.py | 166 +++++++++++++++++ 5 files changed, 451 insertions(+), 206 deletions(-) delete mode 100644 src/axolotl/telemetry.py create mode 100644 src/axolotl/telemetry/__init__.py create mode 100644 src/axolotl/telemetry/manager.py create mode 100644 src/axolotl/telemetry/whitelist.yaml create mode 100644 tests/telemetry/test_manager.py diff --git a/src/axolotl/telemetry.py b/src/axolotl/telemetry.py deleted file mode 100644 index 4c96cbfe4..000000000 --- a/src/axolotl/telemetry.py +++ /dev/null @@ -1,206 +0,0 @@ -"""Telemetry manager and associated utilities.""" - -import logging -import os -import platform -import threading -import uuid -from contextlib import contextmanager -from dataclasses import dataclass -from pathlib import Path -from queue import Empty, Full, Queue -from typing import Any - -import posthog -import psutil -import torch - -logger = logging.getLogger(__name__) - - -@dataclass -class TelemetryConfig: - """Configuration for telemetry system""" - - enabled: bool - project_api_key: str - host: str = "https://app.posthog.com" # TODO: replace with self-hosted endpoint? - queue_size: int = 100 - batch_size: int = 10 - whitelist_path: str = "telemetry_whitelist.yaml" - - -class TelemetryManager: - """Manages telemetry collection and transmission""" - - def __init__(self, config: TelemetryConfig): - """ - Telemetry manager constructor. - - Args: - config: Telemetry configuration object. - """ - self.config = config - self.enabled = self._check_telemetry_enabled() - self.run_id = str(uuid.uuid4()) - self.event_queue: Queue = Queue(maxsize=config.queue_size) - - if self.enabled: - self._init_posthog() - self._start_worker() - - def _check_telemetry_enabled(self) -> bool: - """Check if telemetry is enabled based on environment variables""" - if not self.config.enabled: - return False - - do_not_track = os.getenv("DO_NOT_TRACK", "0").lower() in ("1", "true") - axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK", "0").lower() in ( - "1", - "true", - ) - - return not (do_not_track or axolotl_do_not_track) - - def _init_posthog(self): - """Initialize PostHog client""" - posthog.project_api_key = self.config.project_api_key - posthog.host = self.config.host - - def _start_worker(self): - """Start background worker thread for processing events""" - self.worker_thread = threading.Thread(target=self._process_queue, daemon=True) - self.worker_thread.start() - - def _process_queue(self): - """Process events from queue and send to PostHog""" - while True: - events = [] - # Always get at least one event (blocking) - events.append(self.event_queue.get()) - - # Try to get more events up to batch size (non-blocking) - remaining_batch = self.config.batch_size - 1 - for _ in range(remaining_batch): - try: - event = self.event_queue.get_nowait() - events.append(event) - except Empty: - # No more events available right now - break - - if events: - try: - posthog.capture_batch(events) - except (posthog.RequestError, posthog.RateLimitError) as e: - logger.warning(f"Failed to send telemetry batch: {e}") - except ConnectionError as e: - logger.warning(f"Network error while sending telemetry: {e}") - finally: - # Mark tasks as done even if sending failed - for _ in range(len(events)): - self.event_queue.task_done() - - def _sanitize_path(self, path: str) -> str: - """Remove personal information from file paths""" - return Path(path).name - - def _sanitize_error(self, error: str) -> str: - """Remove personal information from error messages""" - # Replace file paths with just filename - sanitized = error - try: - for path in Path(error).parents: - sanitized = sanitized.replace(str(path), "") - except (ValueError, RuntimeError) as e: - # ValueError: Invalid path format - # RuntimeError: Other path parsing errors - logger.debug(f"Could not parse path in error message: {e}") - - return sanitized - - def _get_system_info(self) -> dict[str, Any]: - """Collect system information""" - gpu_info = [] - if torch.cuda.is_available(): - for i in range(torch.cuda.device_count()): - gpu_info.append( - { - "name": torch.cuda.get_device_name(i), - "memory": torch.cuda.get_device_properties(i).total_memory, - } - ) - - return { - "os": platform.system(), - "python_version": platform.python_version(), - "cpu_count": psutil.cpu_count(), - "memory_total": psutil.virtual_memory().total, - "gpu_count": len(gpu_info), - "gpu_info": gpu_info, - } - - def track_event(self, event_type: str, properties: dict[str, Any]): - """Track a telemetry event""" - if not self.enabled: - return - - try: - # Get system info first - most likely source of errors - system_info = self._get_system_info() - - # Construct event dict - could raise TypeError if properties aren't serializable - event = { - "event": event_type, - "properties": { - "run_id": self.run_id, - "system_info": system_info, - **properties, - }, - } - - try: - self.event_queue.put_nowait(event) - except Full: - logger.warning("Telemetry queue full, dropping event") - except (RuntimeError, OSError) as e: - # Hardware info collection errors - logger.warning(f"Failed to collect system info for telemetry: {e}") - except TypeError as e: - # Dict construction/serialization errors - logger.warning(f"Invalid property type in telemetry event: {e}") - except AttributeError as e: - # Missing attributes when collecting system info - logger.warning(f"Failed to access system attribute for telemetry: {e}") - - @contextmanager - def track_training(self, config: dict[str, Any]): - """Context manager to track training run""" - if not self.enabled: - yield - return - - # Track training start - sanitized_config = { - k: v - for k, v in config.items() - if not any(p in k.lower() for p in ["path", "dir", "file"]) - } - - self.track_event("training_start", {"config": sanitized_config}) - - try: - yield - # Track successful completion - self.track_event("training_complete", {}) - - except Exception as e: - # Track error - self.track_event("training_error", {"error": self._sanitize_error(str(e))}) - raise - - -def init_telemetry(project_api_key: str, enabled: bool = True) -> TelemetryManager: - """Initialize telemetry system""" - config = TelemetryConfig(enabled=enabled, project_api_key=project_api_key) - return TelemetryManager(config) diff --git a/src/axolotl/telemetry/__init__.py b/src/axolotl/telemetry/__init__.py new file mode 100644 index 000000000..99edb167c --- /dev/null +++ b/src/axolotl/telemetry/__init__.py @@ -0,0 +1,5 @@ +"""Init for axolotl.telemetry module.""" + +from .manager import ModelConfig, TelemetryConfig, TelemetryManager, init_telemetry_manager + +__all__ = ["TelemetryConfig", "TelemetryManager", "ModelConfig", "init_telemetry_manager"] diff --git a/src/axolotl/telemetry/manager.py b/src/axolotl/telemetry/manager.py new file mode 100644 index 000000000..ae5f5e2aa --- /dev/null +++ b/src/axolotl/telemetry/manager.py @@ -0,0 +1,269 @@ +"""Telemetry manager and associated utilities.""" + +import logging +import os +import platform +import uuid +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import posthog +import psutil +import torch +import yaml + +logger = logging.getLogger(__name__) + +POSTHOG_WRITE_KEY = "phc_RbAa7Bxu6TLIN9xd8gbg1PLemrStaymi8pxQbRbIwfC" + + +@dataclass +class ModelConfig: + """Tracked model configuration details""" + + base_model: str + model_type: str + hidden_size: int + num_layers: int + num_attention_heads: int + tokenizer_config: dict + flash_attention: bool + quantization_config: dict | None + training_approach: str # 'lora', 'qlora', 'full_finetune' + + @classmethod + def from_config(cls, config: dict) -> "ModelConfig": + """Create from Axolotl config dict""" + return cls( + base_model=config.get("base_model", ""), + model_type=config.get("model_type", ""), + hidden_size=config.get("hidden_size", 0), + num_layers=config.get("num_layers", 0), + num_attention_heads=config.get("num_attention_heads", 0), + tokenizer_config=config.get("tokenizer", {}), + flash_attention=config.get("flash_attention", False), + quantization_config=config.get("quantization", None), + training_approach=config.get("training_approach", ""), + ) + + def to_dict(self) -> dict: + """Convert to PostHog-compatible dict""" + return { + "base_model": self.base_model, + "model_type": self.model_type, + "architecture": { + "hidden_size": self.hidden_size, + "num_layers": self.num_layers, + "num_attention_heads": self.num_attention_heads, + }, + "optimizations": { + "flash_attention": self.flash_attention, + "quantization": self.quantization_config is not None, + "quantization_config": self.quantization_config, + }, + "training_approach": self.training_approach, + } + + +@dataclass +class TelemetryConfig: + """Configuration for telemetry manager""" + + host: str = "https://app.posthog.com" + queue_size: int = 100 + batch_size: int = 10 + whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml") + retention_days: int = 365 + distinct_id: str = str(uuid.uuid4()) + schema_version: str = "0.1.0" + + +class TelemetryManager: + """Manages telemetry collection and transmission""" + + def __init__(self, config: TelemetryConfig): + """ + Telemetry manager constructor. + + Args: + config: Telemetry configuration object. + """ + self.config = config + self.run_id = str(uuid.uuid4()) + self.enabled = self._check_telemetry_enabled() + + if self.enabled: + self.whitelist = self._load_whitelist() + self._init_posthog() + + def _check_telemetry_enabled(self) -> bool: + """ + Check if telemetry is enabled based on environment variables. + + Note: This is enabled on an opt-in basis. Please consider setting + `AXOLOTL_TELEMETRY=1` to send us valuable data on which models and algos you're + using so we can focus our engineering efforts! + """ + # Only enable if explicitly opted in + axolotl_telemetry = os.getenv("AXOLOTL_TELEMETRY", "0").lower() in ("1", "true") + + # Respect DO_NOT_TRACK as an override even if telemetry is enabled + do_not_track = os.getenv("DO_NOT_TRACK", "0").lower() in ("1", "true") + + return axolotl_telemetry and not do_not_track + + def _load_whitelist(self) -> dict: + """Load organization/model whitelist""" + with open(self.config.whitelist_path, encoding="utf-8") as f: + return yaml.safe_load(f) + + def _is_whitelisted(self, base_model: str) -> bool: + """Check if model/org is in whitelist""" + if not base_model: + return False + + base_model = base_model.lower() + return any( + org.lower() in base_model + for org in self.whitelist.get("organizations", []) + ) + + def _init_posthog(self): + """Initialize PostHog client""" + posthog.project_api_key = POSTHOG_WRITE_KEY + posthog.host = self.config.host + + def _sanitize_path(self, path: str) -> str: + """Remove personal information from file paths""" + return Path(path).name + + def _sanitize_error(self, error: str) -> str: + """Remove personal information from error messages""" + # Replace file paths with just filename + sanitized = error + try: + for path in Path(error).parents: + sanitized = sanitized.replace(str(path), "") + except (ValueError, RuntimeError) as e: + # ValueError: Invalid path format + # RuntimeError: Other path parsing errors + logger.debug(f"Could not parse path in error message: {e}") + + return sanitized + + def _get_system_info(self) -> dict[str, Any]: + """Collect system information""" + gpu_info = [] + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + gpu_info.append( + { + "name": torch.cuda.get_device_name(i), + "memory": torch.cuda.get_device_properties(i).total_memory, + } + ) + + return { + "os": platform.system(), + "python_version": platform.python_version(), + "cpu_count": psutil.cpu_count(), + "memory_total": psutil.virtual_memory().total, + "gpu_count": len(gpu_info), + "gpu_info": gpu_info, + } + + def track_event(self, event_type: str, properties: dict[str, Any]): + """Track a telemetry event""" + if not self.enabled: + return + + try: + # Get system info first - most likely source of errors + system_info = self._get_system_info() + + # Send event via PostHog + try: + posthog.capture( + distinct_id=self.config.distinct_id, + event=event_type, + properties={ + "run_id": self.run_id, + "system_info": system_info, + **properties, + } + ) + except Exception as e: + logger.warning(f"Failed to send telemetry event: {e}") + except (RuntimeError, OSError) as e: + logger.warning(f"Failed to collect system info for telemetry: {e}") + except TypeError as e: + logger.warning(f"Invalid property type in telemetry event: {e}") + except AttributeError as e: + logger.warning(f"Failed to access system attribute for telemetry: {e}") + + @contextmanager + def track_training(self, config: dict[str, Any]): + """Context manager to track training run""" + if not self.enabled: + yield + return + + # Track training start + sanitized_config = { + k: v + for k, v in config.items() + if not any(p in k.lower() for p in ["path", "dir", "file"]) + } + + self.track_event("training_start", {"config": sanitized_config}) + + try: + yield + # Track successful completion + self.track_event("training_complete", {}) + + except Exception as e: + # Track error + self.track_event("training_error", {"error": self._sanitize_error(str(e))}) + raise + + def track_model_load(self, model_config: ModelConfig): + """Track model loading and configuration""" + if not self.enabled or not self._is_whitelisted(model_config.base_model): + return + + self.track_event( + "model_load", + { + "model_config": model_config.to_dict(), + "system_info": self._get_system_info(), + }, + ) + + def track_training_metrics(self, metrics: dict): + """Track training progress metrics""" + if not self.enabled: + return + + self.track_event( + "training_metrics", + { + "duration": metrics.get("duration"), + "peak_memory": metrics.get("peak_memory"), + "steps_completed": metrics.get("steps_completed"), + "current_loss": metrics.get("loss"), + "learning_rate": metrics.get("learning_rate"), + }, + ) + + def shutdown(self): + """Ensure all queued events are processed before shutdown""" + if self.enabled: + posthog.flush() + + +def init_telemetry_manager() -> TelemetryManager: + """Initialize telemetry system""" + return TelemetryManager(TelemetryConfig()) \ No newline at end of file diff --git a/src/axolotl/telemetry/whitelist.yaml b/src/axolotl/telemetry/whitelist.yaml new file mode 100644 index 000000000..a0e3a5562 --- /dev/null +++ b/src/axolotl/telemetry/whitelist.yaml @@ -0,0 +1,11 @@ +organizations: + - "meta-llama" + - "huggingface" + - "nvidia" + - "facebook" + - "briaai" + - "unsloth" + - "NousResearch" + - "allenai" + - "amd" + - "tiiuae" \ No newline at end of file diff --git a/tests/telemetry/test_manager.py b/tests/telemetry/test_manager.py new file mode 100644 index 000000000..43305f278 --- /dev/null +++ b/tests/telemetry/test_manager.py @@ -0,0 +1,166 @@ +import os +from unittest.mock import patch + +import pytest +import yaml + +from axolotl.telemetry import TelemetryConfig, TelemetryManager, ModelConfig + + +@pytest.fixture +def mock_whitelist(tmp_path): + """Create a temporary whitelist file for testing""" + whitelist_content = { + "organizations": ["meta", "mistral"], + "models": ["llama", "mistral-7b"] + } + whitelist_file = tmp_path / "whitelist.yaml" + with open(whitelist_file, "w") as f: + yaml.dump(whitelist_content, f) + return str(whitelist_file) + + +@pytest.fixture +def config(mock_whitelist): + """Create a TelemetryConfig with test settings""" + return TelemetryConfig( + host="https://test.posthog.com", + whitelist_path=mock_whitelist + ) + + +@pytest.fixture +def manager(config): + """Create a TelemetryManager instance with mocked PostHog""" + with patch("posthog.capture"): + return TelemetryManager(config) + + +def test_telemetry_disabled_by_default(): + """Test that telemetry is disabled by default""" + manager = TelemetryManager(TelemetryConfig()) + assert not manager.enabled + + +def test_telemetry_opt_in(): + """Test that telemetry can be enabled via environment variable""" + with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}): + manager = TelemetryManager(TelemetryConfig()) + assert manager.enabled + + +def test_do_not_track_override(): + """Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY""" + with patch.dict(os.environ, { + "AXOLOTL_TELEMETRY": "1", + "DO_NOT_TRACK": "1" + }): + manager = TelemetryManager(TelemetryConfig()) + assert not manager.enabled + + +def test_whitelist_checking(manager): + """Test model whitelist functionality""" + with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}): + # Should match organization + assert manager._is_whitelisted("meta/llama-7b") + # Should match model name + assert manager._is_whitelisted("mistral-7b-instruct") + # Should not match + assert not manager._is_whitelisted("unknown/model") + # Should handle case insensitively + assert manager._is_whitelisted("meta/Llama-7b") + + +def test_event_tracking(manager): + """Test basic event tracking""" + with patch("posthog.capture") as mock_capture: + manager.enabled = True + manager.track_event("test_event", {"key": "value"}) + + assert mock_capture.called + assert mock_capture.call_args[1]["event"] == "test_event" + assert mock_capture.call_args[1]["properties"]["key"] == "value" + assert "run_id" in mock_capture.call_args[1]["properties"] + assert "system_info" in mock_capture.call_args[1]["properties"] + + +def test_model_tracking(manager): + """Test model load tracking""" + with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}): + model_config = ModelConfig( + base_model="meta/llama-7b", + model_type="decoder", + hidden_size=4096, + num_layers=32, + num_attention_heads=32, + tokenizer_config={}, + flash_attention=True, + quantization_config=None, + training_approach="lora" + ) + + with patch("posthog.capture") as mock_capture: + manager.enabled = True + manager.track_model_load(model_config) + + assert mock_capture.called + assert mock_capture.call_args[1]["event"] == "model_load" + assert mock_capture.call_args[1]["properties"]["model_config"] == model_config.to_dict() + + +def test_training_context(manager): + """Test training context manager""" + config = {"model": "llama", "batch_size": 8} + + with patch("posthog.capture") as mock_capture: + manager.enabled = True + + with manager.track_training(config): + pass # Simulate successful training + + # Should have captured training_start and training_complete + events = [call[1]["event"] for call in mock_capture.call_args_list] + assert "training_start" in events + assert "training_complete" in events + + +def test_training_error(manager): + """Test training context manager with error""" + config = {"model": "llama", "batch_size": 8} + + with patch("posthog.capture") as mock_capture: + manager.enabled = True + + with pytest.raises(ValueError): + with manager.track_training(config): + raise ValueError("Test error") + + # Should have captured training_start and training_error + events = [call[1]["event"] for call in mock_capture.call_args_list] + assert "training_start" in events + assert "training_error" in events + + +def test_path_sanitization(manager): + """Test path sanitization""" + path = "/home/user/sensitive/data.txt" + sanitized = manager._sanitize_path(path) + assert sanitized == "data.txt" + assert "/home/user" not in sanitized + + +def test_error_sanitization(manager): + """Test error message sanitization""" + error = "Failed to load /home/user/sensitive/data.txt: File not found" + sanitized = manager._sanitize_error(error) + assert "sensitive" not in sanitized + assert "/home/user" not in sanitized + + +def test_shutdown(manager): + """Test shutdown behavior""" + with patch("posthog.flush") as mock_flush: + manager.enabled = True + manager.shutdown() + assert mock_flush.called