diff --git a/docs/telemetry.qmd b/docs/telemetry.qmd index e8d128668..75ab2af93 100644 --- a/docs/telemetry.qmd +++ b/docs/telemetry.qmd @@ -3,4 +3,44 @@ title: Telemetry description: A description of the opt-out telemetry implementation in Axolotl. --- -TODO. +# Telemetry in Axolotl + +Axolotl implements anonymous telemetry to help maintainers understand how the library +is used and where users encounter issues. This data helps prioritize features, optimize +performance, and fix bugs. + +## Data Collection + +We collect: + +- **System info**: OS, Python version, PyTorch version, Transformers version, Axolotl version +- **Hardware info**: CPU count, memory, GPU count and models +- **Usage patterns**: Models (from a whitelist) and configurations used +- **Error tracking**: Stack traces and error messages (sanitized to remove personal information) + +No personally identifiable information (PII) is collected. + +## Implementation + +Telemetry is implemented using PostHog and consists of: + +1. `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the telemetry system and provides methods for tracking events. +2. `axolotl.telemetry.errors.track_errors`: A decorator that captures exceptions and sends sanitized stack traces. + +## Opt-Out Mechanism + +Telemetry is **enabled by default** on an opt-out basis. To disable it, set either: + +- `AXOLOTL_DO_NOT_TRACK=1` (Axolotl-specific) +- `DO_NOT_TRACK=1` (Global standard) + +To acknowledge and explicitly enable telemetry (and remove the warning message), set: +`AXOLOTL_DO_NOT_TRACK=0` + +## Privacy + +- Stack traces are sanitized to remove personal file paths, keeping only the Axolotl code paths +- Each run generates a unique anonymous ID +- Only whitelisted organization information is tracked + - See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations +- Telemetry is only sent from the main process to avoid duplicate events diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index c3297e712..e24e3fe61 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -15,7 +15,7 @@ from transformers.utils import is_torch_bf16_gpu_available from axolotl.integrations.base import PluginManager from axolotl.telemetry import TelemetryManager -from axolotl.telemetry.manager import track_errors +from axolotl.telemetry.errors import send_errors from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.config import ( normalize_cfg_datasets, @@ -156,7 +156,7 @@ def prepare_plugins(cfg: DictDefault): plugin_manager.register(plugin_name) -@track_errors +@send_errors def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault: """ Loads the `axolotl` configuration stored at `config`, validates it, and performs @@ -177,7 +177,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa with open(config, encoding="utf-8") as file: cfg: DictDefault = DictDefault(yaml.safe_load(file)) - TELEMETRY_MANAGER.track_event(event_type="config-loaded", properties=cfg) + TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg) # If there are any options passed in the cli, if it is something that seems valid # from the yaml, then overwrite the value @@ -221,6 +221,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa setup_mlflow_env_vars(cfg) setup_comet_env_vars(cfg) - TELEMETRY_MANAGER.track_event(event_type="config-processed", properties=cfg) + TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg) return cfg diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index e11a39bd6..38a4382f7 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -17,6 +17,7 @@ from axolotl.cli.args import InferenceCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.telemetry.errors import send_errors from axolotl.utils.chat_templates import ( get_chat_template, get_chat_template_from_config, @@ -42,6 +43,7 @@ def get_multi_line_input() -> str: return instruction +@send_errors def do_inference( *, cfg: DictDefault, @@ -135,6 +137,7 @@ def do_inference( print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) +@send_errors def do_inference_gradio( *, cfg: DictDefault, diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 595eb3eab..59c954c4c 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -12,11 +12,13 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) +@send_errors def do_merge_lora(*, cfg: DictDefault) -> None: """ Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index d4b36d92c..4c0175716 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -27,6 +27,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner from axolotl.cli.args import TrainerCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg +from axolotl.telemetry.errors import send_errors LOG = logging.getLogger(__name__) @@ -120,6 +121,7 @@ def _distributed_checkpoint_to_merged_weights( return save_path_ +@send_errors def merge_fsdp_weights( checkpoint_dir: str, output_path: str, diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 5585c88a7..0fbbce176 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -18,12 +18,14 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.datasets import load_datasets, load_preference_datasets +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.trainer import disable_datasets_caching LOG = logging.getLogger(__name__) +@send_errors def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: """ Preprocesses dataset specified in axolotl config. diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index db07eb43b..024a11d31 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -10,6 +10,7 @@ from datasets import Dataset import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs +from axolotl.telemetry.errors import send_errors from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.dict import DictDefault @@ -44,6 +45,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: ) +@send_errors def load_datasets( *, cfg: DictDefault, @@ -103,6 +105,7 @@ def load_datasets( ) +@send_errors def load_preference_datasets( *, cfg: DictDefault, diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index 8d9ddc6ab..db10f73ea 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -10,6 +10,7 @@ import torch from accelerate.logging import get_logger from axolotl.logging_config import configure_logging +from axolotl.telemetry.errors import send_errors from axolotl.train import TrainDatasetMeta from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.dict import DictDefault @@ -61,6 +62,7 @@ def evaluate_dataset( return metrics +@send_errors def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: """ Evaluate a model on training and validation datasets diff --git a/src/axolotl/telemetry/errors.py b/src/axolotl/telemetry/errors.py new file mode 100644 index 000000000..4357ded9b --- /dev/null +++ b/src/axolotl/telemetry/errors.py @@ -0,0 +1,112 @@ +"""Telemetry utilities for exception and traceback information.""" + +import logging +import re +import traceback +from functools import wraps +from inspect import getmodule +from typing import Any, Callable + +from axolotl.telemetry.manager import TelemetryManager + +LOG = logging.getLogger(__name__) + +ERROR_HANDLED = False + + +def sanitize_stack_trace(stack_trace: str) -> str: + """ + Remove personal information from stack trace messages while keeping Axolotl codepaths. + + Args: + stack_trace: The original stack trace string. + + Returns: + A sanitized version of the stack trace with only axolotl paths preserved. + """ + # Split the stack trace into lines to process each file path separately + lines = stack_trace.split("\n") + sanitized_lines = [] + + # Regular expression to find file paths in the stack trace + path_pattern = re.compile(r'(?:File ")(.*?)(?:")') + + for line in lines: + # Check if this line contains a file path + path_match = path_pattern.search(line) + + if path_match: + full_path = path_match.group(1) + + if "axolotl/" in full_path: + # Keep only the 'axolotl' part and onward + axolotl_idx = full_path.rfind("axolotl/") + if axolotl_idx >= 0: + # Replace the original path with the sanitized one + sanitized_path = full_path[axolotl_idx:] + line = line.replace(full_path, sanitized_path) + else: + # For non-axolotl paths, replace with an empty string or a placeholder + line = line.replace(full_path, "") + + sanitized_lines.append(line) + + return "\n".join(sanitized_lines) + + +def send_errors(func: Callable) -> Callable: + """ + Decorator to send exception info in a function. If an exception is raised, we send + telemetry containing the stack trace and error message. + + If an error occurs in a decorated function that is called by another decorated + function, we'll only send telemetry corresponding to the lower-level function. + + Args: + func: Function to decorate. + + Returns: + Decorated function. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + telemetry_manager = TelemetryManager.get_instance() + if not telemetry_manager.enabled: + return func(*args, **kwargs) + + try: + return func(*args, **kwargs) + except Exception as exception: + # Only track if we're not already handling an error. This prevents us from + # capturing an error more than once in nested decorated function calls. + global ERROR_HANDLED # pylint: disable=global-statement + if not ERROR_HANDLED: + ERROR_HANDLED = True + + # Get function module path + module = getmodule(func) + module_path = ( + f"{module.__name__}.{func.__name__}" if module else func.__name__ + ) + + # Get stack trace + stack_trace = "".join( + traceback.format_exception( + type(exception), exception, exception.__traceback__ + ) + ) + stack_trace = sanitize_stack_trace(stack_trace) + + # Send error telemetry + telemetry_manager.send_event( + event_type=f"{module_path}-error", + properties={ + "exception": str(exception), + "stack_trace": stack_trace, + }, + ) + + raise + + return wrapper diff --git a/src/axolotl/telemetry/manager.py b/src/axolotl/telemetry/manager.py index dc257f117..56f150c2e 100644 --- a/src/axolotl/telemetry/manager.py +++ b/src/axolotl/telemetry/manager.py @@ -5,13 +5,10 @@ import logging import os import platform import time -import traceback import uuid from dataclasses import dataclass -from functools import wraps -from inspect import getmodule from pathlib import Path -from typing import Any, Callable +from typing import Any import posthog import psutil @@ -24,8 +21,8 @@ from axolotl.utils.distributed import is_main_process LOG = logging.getLogger(__name__) -POSTHOG_WRITE_KEY = "phc_RbAa7Bxu6TLIN9xd8gbg1PLemrStaymi8pxQbRbIwfC" -ENABLED_WARNING_SLEEP_SECONDS = 10 +POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y" +ENABLED_WARNING_SLEEP_SECONDS = 15 ENABLED_WARNING = ( "\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n" "- Which models and configurations are most commonly used\n" @@ -166,18 +163,6 @@ class TelemetryManager: """Remove personal information from file paths""" return Path(path).name - def _sanitize_error(self, error: str) -> str: - """Remove personal information from error messages""" - # Replace file paths with just filename - sanitized = error - try: - for path in Path(error).parents: - sanitized = sanitized.replace(str(path), "") - except (ValueError, RuntimeError) as e: - LOG.debug(f"Could not parse path in error message: {e}") - - return sanitized - def _get_system_info(self) -> dict[str, Any]: """Collect system information""" gpu_info = [] @@ -202,8 +187,8 @@ class TelemetryManager: "gpu_info": gpu_info, } - def track_event(self, event_type: str, properties: dict[str, Any] | None = None): - """Track a telemetry event""" + def send_event(self, event_type: str, properties: dict[str, Any] | None = None): + """Send a telemetry event""" if not self.enabled: return @@ -218,63 +203,16 @@ class TelemetryManager: posthog.capture( distinct_id=self.run_id, event=event_type, - properties={ - "system_info": self.system_info, - **properties, - }, + properties=properties, ) except Exception as e: # pylint: disable=broad-exception-caught LOG.warning(f"Failed to send telemetry event: {e}") + def send_system_info(self): + """Helper method for sending system info""" + self.send_event(event_type="system-info", properties=self.system_info) + def shutdown(self): """Ensure all queued events are processed before shutdown""" if self.enabled: posthog.flush() - - -ERROR_HANDLED = False - - -def track_errors(func: Callable) -> Callable: - """Decorator to track errors in a function""" - - @wraps(func) - def wrapper(*args, **kwargs) -> Any: - telemetry_manager = TelemetryManager.get_instance() - if not telemetry_manager.enabled: - return func(*args, **kwargs) - - try: - return func(*args, **kwargs) - except Exception as exception: - # Only track if we're not already handling an error. This prevents us from - # capturing an error more than once in nested decorated function calls. - global ERROR_HANDLED # pylint: disable=global-statement - if not ERROR_HANDLED: - ERROR_HANDLED = True - - # Get function module path - module = getmodule(func) - module_path = ( - f"{module.__name__}.{func.__name__}" if module else func.__name__ - ) - - # Get stack trace - stack_trace = "".join( - traceback.format_exception( - type(exception), exception, exception.__traceback__ - ) - ) - - # Send error telemetry - telemetry_manager.track_event( - event_type=f"{module_path}-error", - properties={ - "exception": str(exception), - "stack_trace": stack_trace, - }, - ) - - raise - - return wrapper diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 95f40da5f..9833705de 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -23,7 +23,7 @@ from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-modu ) from axolotl.logging_config import configure_logging from axolotl.telemetry import TelemetryManager -from axolotl.telemetry.manager import track_errors +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_processor, load_tokenizer @@ -44,7 +44,7 @@ LOG = get_logger(__name__) TELEMETRY_MANAGER = TelemetryManager.get_instance() -@track_errors +@send_errors def train( *, cfg: DictDefault, dataset_meta: TrainDatasetMeta ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: @@ -89,11 +89,11 @@ def train( if model.generation_config is not None: model.generation_config.do_sample = True - TELEMETRY_MANAGER.track_event( + TELEMETRY_MANAGER.send_event( event_type="model-load", properties=model.config.to_dict() ) if peft_config: - TELEMETRY_MANAGER.track_event( + TELEMETRY_MANAGER.send_event( event_type="peft-config-load", properties=peft_config.to_dict() ) @@ -187,7 +187,7 @@ def train( if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length") - TELEMETRY_MANAGER.track_event(event_type="train-start") + TELEMETRY_MANAGER.send_event(event_type="train-start") pretrain_hooks(cfg, trainer) @@ -204,7 +204,7 @@ def train( post_train_hooks(cfg, trainer) - TELEMETRY_MANAGER.track_event(event_type="train-end") + TELEMETRY_MANAGER.send_event(event_type="train-end") LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d79722e67..00d70c54c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -54,7 +54,7 @@ from axolotl.monkeypatch.multipack import ( patch_for_multipack, ) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN -from axolotl.telemetry.manager import track_errors +from axolotl.telemetry.errors import send_errors from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.dict import DictDefault @@ -166,7 +166,7 @@ def load_model_config(cfg): return model_config -@track_errors +@send_errors def load_tokenizer(cfg): model_config = load_model_config(cfg) tokenizer_kwargs = {} @@ -320,7 +320,7 @@ def load_tokenizer(cfg): return tokenizer -@track_errors +@send_errors def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): processor_kwargs: Dict[str, Any] = {} # do we actually need this? @@ -1195,7 +1195,7 @@ class ModelLoader: return self.model, lora_config -@track_errors +@send_errors def load_model( cfg: DictDefault, tokenizer: PreTrainedTokenizerBase, @@ -1217,7 +1217,7 @@ def load_model( return loader.load_model() -@track_errors +@send_errors def load_adapter(model, cfg, adapter, inference=False): # type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]