progress on telemetry: config load, process, model load, train start / end, error tracking

This commit is contained in:
Dan Saunders
2025-02-19 22:05:12 +00:00
parent 90b39ce112
commit 66c6fb56cb
12 changed files with 227 additions and 217 deletions

6
docs/telemetry.qmd Normal file
View File

@@ -0,0 +1,6 @@
---
title: Telemetry
description: A description of the opt-out telemetry implementation in Axolotl.
---
TODO.

View File

@@ -14,6 +14,8 @@ import yaml
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.telemetry import TelemetryManager
from axolotl.telemetry.manager import track_errors
from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import ( from axolotl.utils.config import (
normalize_cfg_datasets, normalize_cfg_datasets,
@@ -28,6 +30,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__, use_environ=True) LOG = get_logger(__name__, use_environ=True)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
""" """
@@ -159,6 +163,7 @@ def plugin_set_cfg(cfg: DictDefault):
plugin_manager.cfg = cfg plugin_manager.cfg = cfg
@track_errors
def load_cfg( def load_cfg(
config: str | Path | DictDefault = Path("examples/"), **kwargs config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault: ) -> DictDefault:
@@ -192,6 +197,8 @@ def load_cfg(
temp_file.close() temp_file.close()
cfg.axolotl_config_path = temp_file.name cfg.axolotl_config_path = temp_file.name
TELEMETRY_MANAGER.track_event(event_type="config-loaded", properties=cfg)
# If there are any options passed in the cli, if it is something that seems valid # If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value # from the yaml, then overwrite the value
cfg_keys = cfg.keys() cfg_keys = cfg.keys()
@@ -233,4 +240,6 @@ def load_cfg(
setup_comet_env_vars(cfg) setup_comet_env_vars(cfg)
plugin_set_cfg(cfg) plugin_set_cfg(cfg)
TELEMETRY_MANAGER.track_event(event_type="config-processed", properties=cfg)
return cfg return cfg

View File

@@ -162,6 +162,7 @@ def load_lora(
return model, lora_config return model, lora_config
@send_errors
def load_adapter( def load_adapter(
model: PreTrainedModel, model: PreTrainedModel,
cfg: DictDefault, cfg: DictDefault,

View File

@@ -145,6 +145,7 @@ class ModelLoader:
"""Property that determines if FSDP with QLoRA is enabled.""" """Property that determines if FSDP with QLoRA is enabled."""
return self.cfg.fsdp and self.cfg.adapter == "qlora" return self.cfg.fsdp and self.cfg.adapter == "qlora"
@send_errors
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]: def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches. """Load and prepare the model with all configurations and patches.

View File

@@ -14,6 +14,7 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__) LOG = get_logger(__name__)
@send_errors
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs: dict[str, Any] = {} # Do we actually need this? processor_kwargs: dict[str, Any] = {} # Do we actually need this?

View File

@@ -117,6 +117,7 @@ def modify_tokenizer_files(
return tokenizer_dir return tokenizer_dir
@send_errors
def load_tokenizer(cfg): def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config.""" """Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg) model_config = load_model_config(cfg)

View File

@@ -1,15 +1,8 @@
"""Init for axolotl.telemetry module.""" """Init for axolotl.telemetry module."""
from .manager import ( from .manager import TelemetryConfig, TelemetryManager
ModelConfig,
TelemetryConfig,
TelemetryManager,
init_telemetry_manager,
)
__all__ = [ __all__ = [
"TelemetryConfig", "TelemetryConfig",
"TelemetryManager", "TelemetryManager",
"ModelConfig",
"init_telemetry_manager",
] ]

View File

@@ -1,70 +1,46 @@
"""Telemetry manager and associated utilities.""" """Telemetry manager and associated utilities."""
import atexit
import logging import logging
import os import os
import platform import platform
import time
import traceback
import uuid import uuid
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps
from inspect import getmodule
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Callable
import posthog import posthog
import psutil import psutil
import torch import torch
import transformers
import yaml import yaml
logger = logging.getLogger(__name__) import axolotl
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger(__name__)
POSTHOG_WRITE_KEY = "phc_RbAa7Bxu6TLIN9xd8gbg1PLemrStaymi8pxQbRbIwfC" POSTHOG_WRITE_KEY = "phc_RbAa7Bxu6TLIN9xd8gbg1PLemrStaymi8pxQbRbIwfC"
ENABLED_WARNING_SLEEP_SECONDS = 10
ENABLED_WARNING = (
@dataclass "\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
class ModelConfig: "- Which models and configurations are most commonly used\n"
"""Tracked model configuration details""" "- What hardware setups need to be supported\n"
"- Where users encounter errors\n\n"
base_model: str "This data helps us prioritize features, optimize performance, and fix bugs.\n\n"
model_type: str "To disable telemetry, set either:\n"
hidden_size: int "- AXOLOTL_DO_NOT_TRACK=1 (Axolotl-specific)\n"
num_layers: int "- DO_NOT_TRACK=1 (Global standard)\n\n"
num_attention_heads: int "To remove this warning and continue with telemetry enabled,"
tokenizer_config: dict "explicitly set AXOLOTL_DO_NOT_TRACK=0 (and leave DO_NOT_TRACK unset / set to 0)\n\n"
flash_attention: bool "No personally identifiable information is collected."
quantization_config: dict | None "For details, see: https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html\n\n"
training_approach: str # 'lora', 'qlora', 'full_finetune' f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..."
)
@classmethod
def from_config(cls, config: dict) -> "ModelConfig":
"""Create from Axolotl config dict"""
return cls(
base_model=config.get("base_model", ""),
model_type=config.get("model_type", ""),
hidden_size=config.get("hidden_size", 0),
num_layers=config.get("num_layers", 0),
num_attention_heads=config.get("num_attention_heads", 0),
tokenizer_config=config.get("tokenizer", {}),
flash_attention=config.get("flash_attention", False),
quantization_config=config.get("quantization", None),
training_approach=config.get("training_approach", ""),
)
def to_dict(self) -> dict:
"""Convert to PostHog-compatible dict"""
return {
"base_model": self.base_model,
"model_type": self.model_type,
"architecture": {
"hidden_size": self.hidden_size,
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
},
"optimizations": {
"flash_attention": self.flash_attention,
"quantization": self.quantization_config is not None,
"quantization_config": self.quantization_config,
},
"training_approach": self.training_approach,
}
@dataclass @dataclass
@@ -76,43 +52,95 @@ class TelemetryConfig:
batch_size: int = 10 batch_size: int = 10
whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml") whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml")
retention_days: int = 365 retention_days: int = 365
distinct_id: str = str(uuid.uuid4())
schema_version: str = "0.1.0"
class TelemetryManager: class TelemetryManager:
"""Manages telemetry collection and transmission""" """Manages telemetry collection and transmission"""
def __init__(self, config: TelemetryConfig): _instance = None
""" _initialized = False
Telemetry manager constructor.
Args: def __new__(cls):
config: Telemetry configuration object.
""" """
self.config = config Telemetry manager constructor. Creates the singleton instance of this class if
self.run_id = str(uuid.uuid4()) it doesn't already exist.
self.enabled = self._check_telemetry_enabled() """
if cls._instance is None:
cls._instance = super(TelemetryManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""Telemetry manager initializer"""
if self._initialized:
return
self.enabled, self.explicit_enable = self._check_telemetry_enabled()
if self.enabled: if self.enabled:
# Warn about telemetry collection
if not self.explicit_enable:
LOG.warning(ENABLED_WARNING)
time.sleep(ENABLED_WARNING_SLEEP_SECONDS)
self.config = TelemetryConfig()
self.run_id = str(uuid.uuid4())
self.whitelist = self._load_whitelist() self.whitelist = self._load_whitelist()
self.system_info = self._get_system_info()
self._init_posthog() self._init_posthog()
def _check_telemetry_enabled(self) -> bool: # Register shutdown method to flush posthog telemetry
atexit.register(self.shutdown)
self._initialized = True
@classmethod
def get_instance(cls) -> "TelemetryManager":
if cls._instance is None:
cls._instance = TelemetryManager()
return cls._instance
def _check_telemetry_enabled(self) -> tuple[bool, bool]:
""" """
Check if telemetry is enabled based on environment variables. Check if telemetry is enabled based on environment variables. We also check
whether this is the main process (for the distributed setting and to avoid
sending duplicate PostHog events per GPU).
Note: This is enabled on an opt-in basis. Please consider setting Note: This is enabled by default on an opt-out basis. Set either
`AXOLOTL_TELEMETRY=1` to send us valuable data on which models and algos you're `AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1` to disable telemetry. For more
using so we can focus our engineering efforts! details, see https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
Returns:
Tuple containing:
- Boolean denoting whether telemetry is enabled or disabled.
- Boolean denoting whether telemetry is explicitly enabled or not.
""" """
# Only enable if explicitly opted in # In the distributed setting, check whether we're running on rank 0
axolotl_telemetry = os.getenv("AXOLOTL_TELEMETRY", "0").lower() in ("1", "true") if not is_main_process():
return False, False
# Respect DO_NOT_TRACK as an override even if telemetry is enabled # Parse relevant env vars and fill opt-out default values
do_not_track = os.getenv("DO_NOT_TRACK", "0").lower() in ("1", "true") axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
do_not_track = os.getenv("DO_NOT_TRACK")
return axolotl_telemetry and not do_not_track if axolotl_do_not_track is None:
axolotl_do_not_track = "0"
if do_not_track is None:
do_not_track = "0"
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
enabled = axolotl_do_not_track.lower() not in (
"1",
"true",
) and do_not_track.lower() not in ("1", "true")
# If explicitly enabled, we'll disable the telemetry warning message
explicit_enabled = axolotl_do_not_track in ["0", "false"]
return enabled, explicit_enabled
def _load_whitelist(self) -> dict: def _load_whitelist(self) -> dict:
"""Load organization/model whitelist""" """Load organization/model whitelist"""
@@ -146,9 +174,7 @@ class TelemetryManager:
for path in Path(error).parents: for path in Path(error).parents:
sanitized = sanitized.replace(str(path), "") sanitized = sanitized.replace(str(path), "")
except (ValueError, RuntimeError) as e: except (ValueError, RuntimeError) as e:
# ValueError: Invalid path format LOG.debug(f"Could not parse path in error message: {e}")
# RuntimeError: Other path parsing errors
logger.debug(f"Could not parse path in error message: {e}")
return sanitized return sanitized
@@ -167,95 +193,38 @@ class TelemetryManager:
return { return {
"os": platform.system(), "os": platform.system(),
"python_version": platform.python_version(), "python_version": platform.python_version(),
"pytorch_version": torch.__version__,
"transformers_version": transformers.__version__,
"axolotl_version": axolotl.__version__,
"cpu_count": psutil.cpu_count(), "cpu_count": psutil.cpu_count(),
"memory_total": psutil.virtual_memory().total, "memory_total": psutil.virtual_memory().total,
"gpu_count": len(gpu_info), "gpu_count": len(gpu_info),
"gpu_info": gpu_info, "gpu_info": gpu_info,
} }
def track_event(self, event_type: str, properties: dict[str, Any]): def track_event(self, event_type: str, properties: dict[str, Any] | None = None):
"""Track a telemetry event""" """Track a telemetry event"""
if not self.enabled: if not self.enabled:
return return
if properties is None:
properties = {}
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
try: try:
# Get system info first - most likely source of errors LOG.warning(f"*** Sending telemetry for {event_type} ***")
system_info = self._get_system_info()
# Send event via PostHog # Send event via PostHog
try: posthog.capture(
posthog.capture( distinct_id=self.run_id,
distinct_id=self.config.distinct_id, event=event_type,
event=event_type, properties={
properties={ "system_info": self.system_info,
"run_id": self.run_id, **properties,
"system_info": system_info, },
**properties, )
}, except Exception as e: # pylint: disable=broad-exception-caught
) LOG.warning(f"Failed to send telemetry event: {e}")
except Exception as e:
logger.warning(f"Failed to send telemetry event: {e}")
except (RuntimeError, OSError) as e:
logger.warning(f"Failed to collect system info for telemetry: {e}")
except TypeError as e:
logger.warning(f"Invalid property type in telemetry event: {e}")
except AttributeError as e:
logger.warning(f"Failed to access system attribute for telemetry: {e}")
@contextmanager
def track_training(self, config: dict[str, Any]):
"""Context manager to track training run"""
if not self.enabled:
yield
return
# Track training start
sanitized_config = {
k: v
for k, v in config.items()
if not any(p in k.lower() for p in ["path", "dir", "file"])
}
self.track_event("training_start", {"config": sanitized_config})
try:
yield
# Track successful completion
self.track_event("training_complete", {})
except Exception as e:
# Track error
self.track_event("training_error", {"error": self._sanitize_error(str(e))})
raise
def track_model_load(self, model_config: ModelConfig):
"""Track model loading and configuration"""
if not self.enabled or not self._is_whitelisted(model_config.base_model):
return
self.track_event(
"model_load",
{
"model_config": model_config.to_dict(),
"system_info": self._get_system_info(),
},
)
def track_training_metrics(self, metrics: dict):
"""Track training progress metrics"""
if not self.enabled:
return
self.track_event(
"training_metrics",
{
"duration": metrics.get("duration"),
"peak_memory": metrics.get("peak_memory"),
"steps_completed": metrics.get("steps_completed"),
"current_loss": metrics.get("loss"),
"learning_rate": metrics.get("learning_rate"),
},
)
def shutdown(self): def shutdown(self):
"""Ensure all queued events are processed before shutdown""" """Ensure all queued events are processed before shutdown"""
@@ -263,6 +232,49 @@ class TelemetryManager:
posthog.flush() posthog.flush()
def init_telemetry_manager() -> TelemetryManager: ERROR_HANDLED = False
"""Initialize telemetry system"""
return TelemetryManager(TelemetryConfig())
def track_errors(func: Callable) -> Callable:
"""Decorator to track errors in a function"""
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
telemetry_manager = TelemetryManager.get_instance()
if not telemetry_manager.enabled:
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except Exception as exception:
# Only track if we're not already handling an error. This prevents us from
# capturing an error more than once in nested decorated function calls.
global ERROR_HANDLED # pylint: disable=global-statement
if not ERROR_HANDLED:
ERROR_HANDLED = True
# Get function module path
module = getmodule(func)
module_path = (
f"{module.__name__}.{func.__name__}" if module else func.__name__
)
# Get stack trace
stack_trace = "".join(
traceback.format_exception(
type(exception), exception, exception.__traceback__
)
)
# Send error telemetry
telemetry_manager.track_event(
event_type=f"{module_path}-error",
properties={
"exception": str(exception),
"stack_trace": stack_trace,
},
)
raise
return wrapper

View File

@@ -3,6 +3,7 @@ organizations:
- "huggingface" - "huggingface"
- "nvidia" - "nvidia"
- "facebook" - "facebook"
- "mistralai"
- "briaai" - "briaai"
- "unsloth" - "unsloth"
- "NousResearch" - "NousResearch"

View File

@@ -10,6 +10,7 @@ from contextlib import ExitStack
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
from axolotl.telemetry.manager import track_errors
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
@@ -47,6 +48,7 @@ except ImportError:
LOG = get_logger(__name__) LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def setup_model_and_tokenizer( def setup_model_and_tokenizer(
cfg: DictDefault, cfg: DictDefault,
@@ -64,7 +66,10 @@ def setup_model_and_tokenizer(
`None`), and processor (if multimodal, else `None`). `None`), and processor (if multimodal, else `None`).
""" """
# Load tokenizer # Load tokenizer
LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") LOG.debug(
f"loading tokenizer... {cfg.tokenizer_cocnfig or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed # Load processor for multimodal models if needed
@@ -83,6 +88,14 @@ def setup_model_and_tokenizer(
if model.generation_config is not None: if model.generation_config is not None:
model.generation_config.do_sample = True model.generation_config.do_sample = True
TELEMETRY_MANAGER.track_event(
event_type="model-load", properties=model.config.to_dict()
)
if peft_config:
TELEMETRY_MANAGER.track_event(
event_type="peft-config-load", properties=peft_config.to_dict()
)
# Apply freezing if specified # Apply freezing if specified
if cfg.unfrozen_parameters: if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters) freeze_layers_except(model, cfg.unfrozen_parameters)
@@ -527,6 +540,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
) )
@track_errors
def train( def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]: ) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
@@ -565,10 +579,12 @@ def train(
save_initial_configs(cfg, tokenizer, model, peft_config, processor) save_initial_configs(cfg, tokenizer, model, peft_config, processor)
setup_signal_handler(cfg, model, safe_serialization) setup_signal_handler(cfg, model, safe_serialization)
setup_model_card(cfg) setup_model_card(cfg)
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Execute the training # Execute the training
resume_from_checkpoint = determine_resume_checkpoint(cfg) TELEMETRY_MANAGER.track_event(event_type="train-start")
execute_training(cfg, trainer, resume_from_checkpoint) execute_training(cfg, trainer, resume_from_checkpoint)
TELEMETRY_MANAGER.track_event(event_type="train-end")
# Save the trained model and cleanup # Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization) save_trained_model(cfg, trainer, model, safe_serialization)

View File

@@ -1259,7 +1259,7 @@ class AxolotlInputConfig(
class AxolotlConfigWCapabilities(AxolotlInputConfig): class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options""" """Wrapper to validate GPU capabilities with the config options"""
capabilities: GPUCapabilities capabilities: GPUCapabilities
env_capabilities: EnvCapabilities env_capabilities: EnvCapabilities

View File

@@ -1,71 +1,65 @@
"""Tests for TelemetryManager class and utilities"""
# pylint: disable=redefined-outer-name
import os import os
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import yaml import yaml
from axolotl.telemetry import ModelConfig, TelemetryConfig, TelemetryManager from axolotl.telemetry import TelemetryManager
@pytest.fixture @pytest.fixture
def mock_whitelist(tmp_path): def mock_whitelist(tmp_path):
"""Create a temporary whitelist file for testing""" """Create a temporary whitelist file for testing"""
whitelist_content = { whitelist_content = {
"organizations": ["meta", "mistral"], "organizations": ["meta-llama", "mistralai"],
"models": ["llama", "mistral-7b"],
} }
whitelist_file = tmp_path / "whitelist.yaml" whitelist_file = tmp_path / "whitelist.yaml"
with open(whitelist_file, "w") as f: with open(whitelist_file, "w", encoding="utf-8") as f:
yaml.dump(whitelist_content, f) yaml.dump(whitelist_content, f)
return str(whitelist_file) return str(whitelist_file)
@pytest.fixture @pytest.fixture
def config(mock_whitelist): def manager():
"""Create a TelemetryConfig with test settings"""
return TelemetryConfig(
host="https://test.posthog.com", whitelist_path=mock_whitelist
)
@pytest.fixture
def manager(config):
"""Create a TelemetryManager instance with mocked PostHog""" """Create a TelemetryManager instance with mocked PostHog"""
with patch("posthog.capture"): with patch("posthog.capture"):
return TelemetryManager(config) return TelemetryManager()
def test_telemetry_disabled_by_default(): def test_telemetry_disabled_by_default():
"""Test that telemetry is disabled by default""" """Test that telemetry is disabled by default"""
manager = TelemetryManager(TelemetryConfig()) manager = TelemetryManager()
assert not manager.enabled assert not manager.enabled
def test_telemetry_opt_in(): def test_telemetry_opt_in():
"""Test that telemetry can be enabled via environment variable""" """Test that telemetry can be enabled via environment variable"""
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}): with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
manager = TelemetryManager(TelemetryConfig()) manager = TelemetryManager()
assert manager.enabled assert manager.enabled
def test_do_not_track_override(): def test_do_not_track_override():
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY""" """Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}): with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
manager = TelemetryManager(TelemetryConfig()) manager = TelemetryManager()
assert not manager.enabled assert not manager.enabled
# pylint: disable=protected-access
def test_whitelist_checking(manager): def test_whitelist_checking(manager):
"""Test model whitelist functionality""" """Test model whitelist functionality"""
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}): # Should match organization
# Should match organization assert manager._is_whitelisted("meta-llama/llama-7b")
assert manager._is_whitelisted("meta/llama-7b") # Should match model name
# Should match model name assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
assert manager._is_whitelisted("mistral-7b-instruct") # Should not match
# Should not match assert not manager._is_whitelisted("unknown/model")
assert not manager._is_whitelisted("unknown/model") # Should handle case insensitively
# Should handle case insensitively assert manager._is_whitelisted("meta/Llama-7b")
assert manager._is_whitelisted("meta/Llama-7b")
def test_event_tracking(manager): def test_event_tracking(manager):
@@ -81,33 +75,6 @@ def test_event_tracking(manager):
assert "system_info" in mock_capture.call_args[1]["properties"] assert "system_info" in mock_capture.call_args[1]["properties"]
def test_model_tracking(manager):
"""Test model load tracking"""
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
model_config = ModelConfig(
base_model="meta/llama-7b",
model_type="decoder",
hidden_size=4096,
num_layers=32,
num_attention_heads=32,
tokenizer_config={},
flash_attention=True,
quantization_config=None,
training_approach="lora",
)
with patch("posthog.capture") as mock_capture:
manager.enabled = True
manager.track_model_load(model_config)
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "model_load"
assert (
mock_capture.call_args[1]["properties"]["model_config"]
== model_config.to_dict()
)
def test_training_context(manager): def test_training_context(manager):
"""Test training context manager""" """Test training context manager"""
config = {"model": "llama", "batch_size": 8} config = {"model": "llama", "batch_size": 8}
@@ -141,6 +108,7 @@ def test_training_error(manager):
assert "training_error" in events assert "training_error" in events
# pylint: disable=protected-access
def test_path_sanitization(manager): def test_path_sanitization(manager):
"""Test path sanitization""" """Test path sanitization"""
path = "/home/user/sensitive/data.txt" path = "/home/user/sensitive/data.txt"
@@ -149,6 +117,7 @@ def test_path_sanitization(manager):
assert "/home/user" not in sanitized assert "/home/user" not in sanitized
# pylint: disable=protected-access
def test_error_sanitization(manager): def test_error_sanitization(manager):
"""Test error message sanitization""" """Test error message sanitization"""
error = "Failed to load /home/user/sensitive/data.txt: File not found" error = "Failed to load /home/user/sensitive/data.txt: File not found"