progress on telemetry: config load, process, model load, train start / end, error tracking
This commit is contained in:
6
docs/telemetry.qmd
Normal file
6
docs/telemetry.qmd
Normal file
@@ -0,0 +1,6 @@
|
||||
---
|
||||
title: Telemetry
|
||||
description: A description of the opt-out telemetry implementation in Axolotl.
|
||||
---
|
||||
|
||||
TODO.
|
||||
@@ -14,6 +14,8 @@ import yaml
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.telemetry import TelemetryManager
|
||||
from axolotl.telemetry.manager import track_errors
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
@@ -27,6 +29,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||
"""
|
||||
@@ -152,6 +156,7 @@ def prepare_plugins(cfg: DictDefault):
|
||||
plugin_manager.register(plugin_name)
|
||||
|
||||
|
||||
@track_errors
|
||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
|
||||
"""
|
||||
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
||||
@@ -172,6 +177,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
|
||||
TELEMETRY_MANAGER.track_event(event_type="config-loaded", properties=cfg)
|
||||
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
@@ -214,4 +221,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
||||
setup_mlflow_env_vars(cfg)
|
||||
setup_comet_env_vars(cfg)
|
||||
|
||||
TELEMETRY_MANAGER.track_event(event_type="config-processed", properties=cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
@@ -1,15 +1,8 @@
|
||||
"""Init for axolotl.telemetry module."""
|
||||
|
||||
from .manager import (
|
||||
ModelConfig,
|
||||
TelemetryConfig,
|
||||
TelemetryManager,
|
||||
init_telemetry_manager,
|
||||
)
|
||||
from .manager import TelemetryConfig, TelemetryManager
|
||||
|
||||
__all__ = [
|
||||
"TelemetryConfig",
|
||||
"TelemetryManager",
|
||||
"ModelConfig",
|
||||
"init_telemetry_manager",
|
||||
]
|
||||
|
||||
@@ -1,70 +1,46 @@
|
||||
"""Telemetry manager and associated utilities."""
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from inspect import getmodule
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
import posthog
|
||||
import psutil
|
||||
import torch
|
||||
import transformers
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import axolotl
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
POSTHOG_WRITE_KEY = "phc_RbAa7Bxu6TLIN9xd8gbg1PLemrStaymi8pxQbRbIwfC"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Tracked model configuration details"""
|
||||
|
||||
base_model: str
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
num_layers: int
|
||||
num_attention_heads: int
|
||||
tokenizer_config: dict
|
||||
flash_attention: bool
|
||||
quantization_config: dict | None
|
||||
training_approach: str # 'lora', 'qlora', 'full_finetune'
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "ModelConfig":
|
||||
"""Create from Axolotl config dict"""
|
||||
return cls(
|
||||
base_model=config.get("base_model", ""),
|
||||
model_type=config.get("model_type", ""),
|
||||
hidden_size=config.get("hidden_size", 0),
|
||||
num_layers=config.get("num_layers", 0),
|
||||
num_attention_heads=config.get("num_attention_heads", 0),
|
||||
tokenizer_config=config.get("tokenizer", {}),
|
||||
flash_attention=config.get("flash_attention", False),
|
||||
quantization_config=config.get("quantization", None),
|
||||
training_approach=config.get("training_approach", ""),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to PostHog-compatible dict"""
|
||||
return {
|
||||
"base_model": self.base_model,
|
||||
"model_type": self.model_type,
|
||||
"architecture": {
|
||||
"hidden_size": self.hidden_size,
|
||||
"num_layers": self.num_layers,
|
||||
"num_attention_heads": self.num_attention_heads,
|
||||
},
|
||||
"optimizations": {
|
||||
"flash_attention": self.flash_attention,
|
||||
"quantization": self.quantization_config is not None,
|
||||
"quantization_config": self.quantization_config,
|
||||
},
|
||||
"training_approach": self.training_approach,
|
||||
}
|
||||
ENABLED_WARNING_SLEEP_SECONDS = 10
|
||||
ENABLED_WARNING = (
|
||||
"\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
|
||||
"- Which models and configurations are most commonly used\n"
|
||||
"- What hardware setups need to be supported\n"
|
||||
"- Where users encounter errors\n\n"
|
||||
"This data helps us prioritize features, optimize performance, and fix bugs.\n\n"
|
||||
"To disable telemetry, set either:\n"
|
||||
"- AXOLOTL_DO_NOT_TRACK=1 (Axolotl-specific)\n"
|
||||
"- DO_NOT_TRACK=1 (Global standard)\n\n"
|
||||
"To remove this warning and continue with telemetry enabled,"
|
||||
"explicitly set AXOLOTL_DO_NOT_TRACK=0 (and leave DO_NOT_TRACK unset / set to 0)\n\n"
|
||||
"No personally identifiable information is collected."
|
||||
"For details, see: https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html\n\n"
|
||||
f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -76,43 +52,95 @@ class TelemetryConfig:
|
||||
batch_size: int = 10
|
||||
whitelist_path: str = str(Path(__file__).parent / "whitelist.yaml")
|
||||
retention_days: int = 365
|
||||
distinct_id: str = str(uuid.uuid4())
|
||||
schema_version: str = "0.1.0"
|
||||
|
||||
|
||||
class TelemetryManager:
|
||||
"""Manages telemetry collection and transmission"""
|
||||
|
||||
def __init__(self, config: TelemetryConfig):
|
||||
"""
|
||||
Telemetry manager constructor.
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
Args:
|
||||
config: Telemetry configuration object.
|
||||
def __new__(cls):
|
||||
"""
|
||||
self.config = config
|
||||
self.run_id = str(uuid.uuid4())
|
||||
self.enabled = self._check_telemetry_enabled()
|
||||
Telemetry manager constructor. Creates the singleton instance of this class if
|
||||
it doesn't already exist.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(TelemetryManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Telemetry manager initializer"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.enabled, self.explicit_enable = self._check_telemetry_enabled()
|
||||
|
||||
if self.enabled:
|
||||
# Warn about telemetry collection
|
||||
if not self.explicit_enable:
|
||||
LOG.warning(ENABLED_WARNING)
|
||||
time.sleep(ENABLED_WARNING_SLEEP_SECONDS)
|
||||
|
||||
self.config = TelemetryConfig()
|
||||
self.run_id = str(uuid.uuid4())
|
||||
self.whitelist = self._load_whitelist()
|
||||
self.system_info = self._get_system_info()
|
||||
self._init_posthog()
|
||||
|
||||
def _check_telemetry_enabled(self) -> bool:
|
||||
# Register shutdown method to flush posthog telemetry
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "TelemetryManager":
|
||||
if cls._instance is None:
|
||||
cls._instance = TelemetryManager()
|
||||
|
||||
return cls._instance
|
||||
|
||||
def _check_telemetry_enabled(self) -> tuple[bool, bool]:
|
||||
"""
|
||||
Check if telemetry is enabled based on environment variables.
|
||||
Check if telemetry is enabled based on environment variables. We also check
|
||||
whether this is the main process (for the distributed setting and to avoid
|
||||
sending duplicate PostHog events per GPU).
|
||||
|
||||
Note: This is enabled on an opt-in basis. Please consider setting
|
||||
`AXOLOTL_TELEMETRY=1` to send us valuable data on which models and algos you're
|
||||
using so we can focus our engineering efforts!
|
||||
Note: This is enabled by default on an opt-out basis. Set either
|
||||
`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1` to disable telemetry. For more
|
||||
details, see https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Boolean denoting whether telemetry is enabled or disabled.
|
||||
- Boolean denoting whether telemetry is explicitly enabled or not.
|
||||
"""
|
||||
# Only enable if explicitly opted in
|
||||
axolotl_telemetry = os.getenv("AXOLOTL_TELEMETRY", "0").lower() in ("1", "true")
|
||||
# In the distributed setting, check whether we're running on rank 0
|
||||
if not is_main_process():
|
||||
return False, False
|
||||
|
||||
# Respect DO_NOT_TRACK as an override even if telemetry is enabled
|
||||
do_not_track = os.getenv("DO_NOT_TRACK", "0").lower() in ("1", "true")
|
||||
# Parse relevant env vars and fill opt-out default values
|
||||
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
||||
do_not_track = os.getenv("DO_NOT_TRACK")
|
||||
|
||||
return axolotl_telemetry and not do_not_track
|
||||
if axolotl_do_not_track is None:
|
||||
axolotl_do_not_track = "0"
|
||||
|
||||
if do_not_track is None:
|
||||
do_not_track = "0"
|
||||
|
||||
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
|
||||
enabled = axolotl_do_not_track.lower() not in (
|
||||
"1",
|
||||
"true",
|
||||
) and do_not_track.lower() not in ("1", "true")
|
||||
|
||||
# If explicitly enabled, we'll disable the telemetry warning message
|
||||
explicit_enabled = axolotl_do_not_track in ["0", "false"]
|
||||
|
||||
return enabled, explicit_enabled
|
||||
|
||||
def _load_whitelist(self) -> dict:
|
||||
"""Load organization/model whitelist"""
|
||||
@@ -146,9 +174,7 @@ class TelemetryManager:
|
||||
for path in Path(error).parents:
|
||||
sanitized = sanitized.replace(str(path), "")
|
||||
except (ValueError, RuntimeError) as e:
|
||||
# ValueError: Invalid path format
|
||||
# RuntimeError: Other path parsing errors
|
||||
logger.debug(f"Could not parse path in error message: {e}")
|
||||
LOG.debug(f"Could not parse path in error message: {e}")
|
||||
|
||||
return sanitized
|
||||
|
||||
@@ -167,95 +193,38 @@ class TelemetryManager:
|
||||
return {
|
||||
"os": platform.system(),
|
||||
"python_version": platform.python_version(),
|
||||
"pytorch_version": torch.__version__,
|
||||
"transformers_version": transformers.__version__,
|
||||
"axolotl_version": axolotl.__version__,
|
||||
"cpu_count": psutil.cpu_count(),
|
||||
"memory_total": psutil.virtual_memory().total,
|
||||
"gpu_count": len(gpu_info),
|
||||
"gpu_info": gpu_info,
|
||||
}
|
||||
|
||||
def track_event(self, event_type: str, properties: dict[str, Any]):
|
||||
def track_event(self, event_type: str, properties: dict[str, Any] | None = None):
|
||||
"""Track a telemetry event"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if properties is None:
|
||||
properties = {}
|
||||
|
||||
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
|
||||
try:
|
||||
# Get system info first - most likely source of errors
|
||||
system_info = self._get_system_info()
|
||||
LOG.warning(f"*** Sending telemetry for {event_type} ***")
|
||||
|
||||
# Send event via PostHog
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=self.config.distinct_id,
|
||||
event=event_type,
|
||||
properties={
|
||||
"run_id": self.run_id,
|
||||
"system_info": system_info,
|
||||
**properties,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send telemetry event: {e}")
|
||||
except (RuntimeError, OSError) as e:
|
||||
logger.warning(f"Failed to collect system info for telemetry: {e}")
|
||||
except TypeError as e:
|
||||
logger.warning(f"Invalid property type in telemetry event: {e}")
|
||||
except AttributeError as e:
|
||||
logger.warning(f"Failed to access system attribute for telemetry: {e}")
|
||||
|
||||
@contextmanager
|
||||
def track_training(self, config: dict[str, Any]):
|
||||
"""Context manager to track training run"""
|
||||
if not self.enabled:
|
||||
yield
|
||||
return
|
||||
|
||||
# Track training start
|
||||
sanitized_config = {
|
||||
k: v
|
||||
for k, v in config.items()
|
||||
if not any(p in k.lower() for p in ["path", "dir", "file"])
|
||||
}
|
||||
|
||||
self.track_event("training_start", {"config": sanitized_config})
|
||||
|
||||
try:
|
||||
yield
|
||||
# Track successful completion
|
||||
self.track_event("training_complete", {})
|
||||
|
||||
except Exception as e:
|
||||
# Track error
|
||||
self.track_event("training_error", {"error": self._sanitize_error(str(e))})
|
||||
raise
|
||||
|
||||
def track_model_load(self, model_config: ModelConfig):
|
||||
"""Track model loading and configuration"""
|
||||
if not self.enabled or not self._is_whitelisted(model_config.base_model):
|
||||
return
|
||||
|
||||
self.track_event(
|
||||
"model_load",
|
||||
{
|
||||
"model_config": model_config.to_dict(),
|
||||
"system_info": self._get_system_info(),
|
||||
},
|
||||
)
|
||||
|
||||
def track_training_metrics(self, metrics: dict):
|
||||
"""Track training progress metrics"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
self.track_event(
|
||||
"training_metrics",
|
||||
{
|
||||
"duration": metrics.get("duration"),
|
||||
"peak_memory": metrics.get("peak_memory"),
|
||||
"steps_completed": metrics.get("steps_completed"),
|
||||
"current_loss": metrics.get("loss"),
|
||||
"learning_rate": metrics.get("learning_rate"),
|
||||
},
|
||||
)
|
||||
posthog.capture(
|
||||
distinct_id=self.run_id,
|
||||
event=event_type,
|
||||
properties={
|
||||
"system_info": self.system_info,
|
||||
**properties,
|
||||
},
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"Failed to send telemetry event: {e}")
|
||||
|
||||
def shutdown(self):
|
||||
"""Ensure all queued events are processed before shutdown"""
|
||||
@@ -263,6 +232,49 @@ class TelemetryManager:
|
||||
posthog.flush()
|
||||
|
||||
|
||||
def init_telemetry_manager() -> TelemetryManager:
|
||||
"""Initialize telemetry system"""
|
||||
return TelemetryManager(TelemetryConfig())
|
||||
ERROR_HANDLED = False
|
||||
|
||||
|
||||
def track_errors(func: Callable) -> Callable:
|
||||
"""Decorator to track errors in a function"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
if not telemetry_manager.enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as exception:
|
||||
# Only track if we're not already handling an error. This prevents us from
|
||||
# capturing an error more than once in nested decorated function calls.
|
||||
global ERROR_HANDLED # pylint: disable=global-statement
|
||||
if not ERROR_HANDLED:
|
||||
ERROR_HANDLED = True
|
||||
|
||||
# Get function module path
|
||||
module = getmodule(func)
|
||||
module_path = (
|
||||
f"{module.__name__}.{func.__name__}" if module else func.__name__
|
||||
)
|
||||
|
||||
# Get stack trace
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(
|
||||
type(exception), exception, exception.__traceback__
|
||||
)
|
||||
)
|
||||
|
||||
# Send error telemetry
|
||||
telemetry_manager.track_event(
|
||||
event_type=f"{module_path}-error",
|
||||
properties={
|
||||
"exception": str(exception),
|
||||
"stack_trace": stack_trace,
|
||||
},
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -3,6 +3,7 @@ organizations:
|
||||
- "huggingface"
|
||||
- "nvidia"
|
||||
- "facebook"
|
||||
- "mistralai"
|
||||
- "briaai"
|
||||
- "unsloth"
|
||||
- "NousResearch"
|
||||
|
||||
@@ -22,6 +22,8 @@ from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-modu
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.telemetry import TelemetryManager
|
||||
from axolotl.telemetry.manager import track_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
@@ -39,13 +41,16 @@ sys.path.insert(0, src_dir)
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
|
||||
|
||||
@track_errors
|
||||
def train(
|
||||
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
f"loading tokenizer... {cfg.tokenizer_cocnfig or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
@@ -75,7 +80,7 @@ def train(
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
# Load the model and tokenizer
|
||||
# Load model
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
@@ -84,6 +89,14 @@ def train(
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
TELEMETRY_MANAGER.track_event(
|
||||
event_type="model-load", properties=model.config.to_dict()
|
||||
)
|
||||
if peft_config:
|
||||
TELEMETRY_MANAGER.track_event(
|
||||
event_type="peft-config-load", properties=peft_config.to_dict()
|
||||
)
|
||||
|
||||
model_ref = None
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
@@ -91,7 +104,7 @@ def train(
|
||||
LOG.debug("Passing model_ref: None to RL trainer")
|
||||
model_ref = None # explicit setting to None
|
||||
else:
|
||||
# load the model again for model_ref/baseline
|
||||
# load the model again for model_ref / baseline
|
||||
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
@@ -174,6 +187,8 @@ def train(
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
TELEMETRY_MANAGER.track_event(event_type="train-start")
|
||||
|
||||
pretrain_hooks(cfg, trainer)
|
||||
|
||||
if cfg.flash_optimum:
|
||||
@@ -189,6 +204,8 @@ def train(
|
||||
|
||||
post_train_hooks(cfg, trainer)
|
||||
|
||||
TELEMETRY_MANAGER.track_event(event_type="train-end")
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
# post training
|
||||
|
||||
@@ -1683,7 +1683,7 @@ class AxolotlInputConfig(
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
"""Wrapper to validate GPU capabilities with the config options"""
|
||||
|
||||
capabilities: GPUCapabilities
|
||||
env_capabilities: EnvCapabilities
|
||||
|
||||
@@ -54,6 +54,7 @@ from axolotl.monkeypatch.multipack import (
|
||||
patch_for_multipack,
|
||||
)
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.telemetry.manager import track_errors
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -165,6 +166,7 @@ def load_model_config(cfg):
|
||||
return model_config
|
||||
|
||||
|
||||
@track_errors
|
||||
def load_tokenizer(cfg):
|
||||
model_config = load_model_config(cfg)
|
||||
tokenizer_kwargs = {}
|
||||
@@ -318,6 +320,7 @@ def load_tokenizer(cfg):
|
||||
return tokenizer
|
||||
|
||||
|
||||
@track_errors
|
||||
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
|
||||
|
||||
@@ -1192,18 +1195,17 @@ class ModelLoader:
|
||||
return self.model, lora_config
|
||||
|
||||
|
||||
@track_errors
|
||||
def load_model(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
*,
|
||||
processor: ProcessorMixin = None, # pylint: disable=unused-argument
|
||||
processor: ProcessorMixin = None,
|
||||
inference: bool = False,
|
||||
reference_model: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
**kwargs,
|
||||
) -> Tuple[PreTrainedModel, PeftConfig | None]:
|
||||
"""Load a model for a given configuration and tokenizer"""
|
||||
loader = ModelLoader(
|
||||
cfg,
|
||||
tokenizer,
|
||||
@@ -1215,6 +1217,7 @@ def load_model(
|
||||
return loader.load_model()
|
||||
|
||||
|
||||
@track_errors
|
||||
def load_adapter(model, cfg, adapter, inference=False):
|
||||
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
|
||||
@@ -1,71 +1,65 @@
|
||||
"""Tests for TelemetryManager class and utilities"""
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from axolotl.telemetry import ModelConfig, TelemetryConfig, TelemetryManager
|
||||
from axolotl.telemetry import TelemetryManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_whitelist(tmp_path):
|
||||
"""Create a temporary whitelist file for testing"""
|
||||
whitelist_content = {
|
||||
"organizations": ["meta", "mistral"],
|
||||
"models": ["llama", "mistral-7b"],
|
||||
"organizations": ["meta-llama", "mistralai"],
|
||||
}
|
||||
whitelist_file = tmp_path / "whitelist.yaml"
|
||||
with open(whitelist_file, "w") as f:
|
||||
with open(whitelist_file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(whitelist_content, f)
|
||||
return str(whitelist_file)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(mock_whitelist):
|
||||
"""Create a TelemetryConfig with test settings"""
|
||||
return TelemetryConfig(
|
||||
host="https://test.posthog.com", whitelist_path=mock_whitelist
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(config):
|
||||
def manager():
|
||||
"""Create a TelemetryManager instance with mocked PostHog"""
|
||||
with patch("posthog.capture"):
|
||||
return TelemetryManager(config)
|
||||
return TelemetryManager()
|
||||
|
||||
|
||||
def test_telemetry_disabled_by_default():
|
||||
"""Test that telemetry is disabled by default"""
|
||||
manager = TelemetryManager(TelemetryConfig())
|
||||
manager = TelemetryManager()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_opt_in():
|
||||
"""Test that telemetry can be enabled via environment variable"""
|
||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
||||
manager = TelemetryManager(TelemetryConfig())
|
||||
manager = TelemetryManager()
|
||||
assert manager.enabled
|
||||
|
||||
|
||||
def test_do_not_track_override():
|
||||
"""Test that DO_NOT_TRACK overrides AXOLOTL_TELEMETRY"""
|
||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1", "DO_NOT_TRACK": "1"}):
|
||||
manager = TelemetryManager(TelemetryConfig())
|
||||
manager = TelemetryManager()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def test_whitelist_checking(manager):
|
||||
"""Test model whitelist functionality"""
|
||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
||||
# Should match organization
|
||||
assert manager._is_whitelisted("meta/llama-7b")
|
||||
# Should match model name
|
||||
assert manager._is_whitelisted("mistral-7b-instruct")
|
||||
# Should not match
|
||||
assert not manager._is_whitelisted("unknown/model")
|
||||
# Should handle case insensitively
|
||||
assert manager._is_whitelisted("meta/Llama-7b")
|
||||
# Should match organization
|
||||
assert manager._is_whitelisted("meta-llama/llama-7b")
|
||||
# Should match model name
|
||||
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
|
||||
# Should not match
|
||||
assert not manager._is_whitelisted("unknown/model")
|
||||
# Should handle case insensitively
|
||||
assert manager._is_whitelisted("meta/Llama-7b")
|
||||
|
||||
|
||||
def test_event_tracking(manager):
|
||||
@@ -81,33 +75,6 @@ def test_event_tracking(manager):
|
||||
assert "system_info" in mock_capture.call_args[1]["properties"]
|
||||
|
||||
|
||||
def test_model_tracking(manager):
|
||||
"""Test model load tracking"""
|
||||
with patch.dict(os.environ, {"AXOLOTL_TELEMETRY": "1"}):
|
||||
model_config = ModelConfig(
|
||||
base_model="meta/llama-7b",
|
||||
model_type="decoder",
|
||||
hidden_size=4096,
|
||||
num_layers=32,
|
||||
num_attention_heads=32,
|
||||
tokenizer_config={},
|
||||
flash_attention=True,
|
||||
quantization_config=None,
|
||||
training_approach="lora",
|
||||
)
|
||||
|
||||
with patch("posthog.capture") as mock_capture:
|
||||
manager.enabled = True
|
||||
manager.track_model_load(model_config)
|
||||
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["event"] == "model_load"
|
||||
assert (
|
||||
mock_capture.call_args[1]["properties"]["model_config"]
|
||||
== model_config.to_dict()
|
||||
)
|
||||
|
||||
|
||||
def test_training_context(manager):
|
||||
"""Test training context manager"""
|
||||
config = {"model": "llama", "batch_size": 8}
|
||||
@@ -141,6 +108,7 @@ def test_training_error(manager):
|
||||
assert "training_error" in events
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def test_path_sanitization(manager):
|
||||
"""Test path sanitization"""
|
||||
path = "/home/user/sensitive/data.txt"
|
||||
@@ -149,6 +117,7 @@ def test_path_sanitization(manager):
|
||||
assert "/home/user" not in sanitized
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def test_error_sanitization(manager):
|
||||
"""Test error message sanitization"""
|
||||
error = "Failed to load /home/user/sensitive/data.txt: File not found"
|
||||
|
||||
Reference in New Issue
Block a user