diff --git a/requirements.txt b/requirements.txt index 3af94421d..e65781896 100644 --- a/requirements.txt +++ b/requirements.txt @@ -67,3 +67,5 @@ schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.6 axolotl-contribs-mit==0.0.3 + +posthog-3.13.0 diff --git a/src/axolotl/telemetry.py b/src/axolotl/telemetry.py new file mode 100644 index 000000000..c031a2c7b --- /dev/null +++ b/src/axolotl/telemetry.py @@ -0,0 +1,206 @@ +"""Telemetry manager and associated utilities.""" + +import logging +import os +import platform +import threading +import uuid +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from queue import Empty, Full, Queue +from typing import Any + +import posthog +import psutil +import torch + +logger = logging.getLogger(__name__) + + +@dataclass +class TelemetryConfig: + """Configuration for telemetry system""" + + enabled: bool + project_api_key: str + host: str = "https://app.posthog.com" + queue_size: int = 100 + batch_size: int = 10 + whitelist_path: str = "telemetry_whitelist.yaml" + + +class TelemetryManager: + """Manages telemetry collection and transmission""" + + def __init__(self, config: TelemetryConfig): + """ + Telemetry manager constructor. + + Args: + config: Telemetry configuration object. + """ + self.config = config + self.enabled = self._check_telemetry_enabled() + self.run_id = str(uuid.uuid4()) + self.event_queue: Queue = Queue(maxsize=config.queue_size) + + if self.enabled: + self._init_posthog() + self._start_worker() + + def _check_telemetry_enabled(self) -> bool: + """Check if telemetry is enabled based on environment variables""" + if not self.config.enabled: + return False + + do_not_track = os.getenv("DO_NOT_TRACK", "0").lower() in ("1", "true") + axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK", "0").lower() in ( + "1", + "true", + ) + + return not (do_not_track or axolotl_do_not_track) + + def _init_posthog(self): + """Initialize PostHog client""" + posthog.project_api_key = self.config.project_api_key + posthog.host = self.config.host + + def _start_worker(self): + """Start background worker thread for processing events""" + self.worker_thread = threading.Thread(target=self._process_queue, daemon=True) + self.worker_thread.start() + + def _process_queue(self): + """Process events from queue and send to PostHog""" + while True: + events = [] + # Always get at least one event (blocking) + events.append(self.event_queue.get()) + + # Try to get more events up to batch size (non-blocking) + remaining_batch = self.config.batch_size - 1 + for _ in range(remaining_batch): + try: + event = self.event_queue.get_nowait() + events.append(event) + except Empty: + # No more events available right now + break + + if events: + try: + posthog.capture_batch(events) + except (posthog.RequestError, posthog.RateLimitError) as e: + logger.warning(f"Failed to send telemetry batch: {e}") + except ConnectionError as e: + logger.warning(f"Network error while sending telemetry: {e}") + finally: + # Mark tasks as done even if sending failed + for _ in range(len(events)): + self.event_queue.task_done() + + def _sanitize_path(self, path: str) -> str: + """Remove personal information from file paths""" + return Path(path).name + + def _sanitize_error(self, error: str) -> str: + """Remove personal information from error messages""" + # Replace file paths with just filename + sanitized = error + try: + for path in Path(error).parents: + sanitized = sanitized.replace(str(path), "") + except (ValueError, RuntimeError) as e: + # ValueError: Invalid path format + # RuntimeError: Other path parsing errors + logger.debug(f"Could not parse path in error message: {e}") + + return sanitized + + def _get_system_info(self) -> dict[str, Any]: + """Collect system information""" + gpu_info = [] + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + gpu_info.append( + { + "name": torch.cuda.get_device_name(i), + "memory": torch.cuda.get_device_properties(i).total_memory, + } + ) + + return { + "os": platform.system(), + "python_version": platform.python_version(), + "cpu_count": psutil.cpu_count(), + "memory_total": psutil.virtual_memory().total, + "gpu_count": len(gpu_info), + "gpu_info": gpu_info, + } + + def track_event(self, event_type: str, properties: dict[str, Any]): + """Track a telemetry event""" + if not self.enabled: + return + + try: + # Get system info first - most likely source of errors + system_info = self._get_system_info() + + # Construct event dict - could raise TypeError if properties aren't serializable + event = { + "event": event_type, + "properties": { + "run_id": self.run_id, + "system_info": system_info, + **properties, + }, + } + + try: + self.event_queue.put_nowait(event) + except Full: + logger.warning("Telemetry queue full, dropping event") + except (RuntimeError, OSError) as e: + # Hardware info collection errors + logger.warning(f"Failed to collect system info for telemetry: {e}") + except TypeError as e: + # Dict construction/serialization errors + logger.warning(f"Invalid property type in telemetry event: {e}") + except AttributeError as e: + # Missing attributes when collecting system info + logger.warning(f"Failed to access system attribute for telemetry: {e}") + + @contextmanager + def track_training(self, config: dict[str, Any]): + """Context manager to track training run""" + if not self.enabled: + yield + return + + # Track training start + sanitized_config = { + k: v + for k, v in config.items() + if not any(p in k.lower() for p in ["path", "dir", "file"]) + } + + self.track_event("training_start", {"config": sanitized_config}) + + try: + yield + # Track successful completion + self.track_event("training_complete", {}) + + except Exception as e: + # Track error + self.track_event("training_error", {"error": self._sanitize_error(str(e))}) + raise + + +def init_telemetry(project_api_key: str, enabled: bool = True) -> TelemetryManager: + """Initialize telemetry system""" + config = TelemetryConfig(enabled=enabled, project_api_key=project_api_key) + return TelemetryManager(config)