Compare commits
33 Commits
mistral-su
...
telemetry-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
345a159796 | ||
|
|
657bffd85f | ||
|
|
f0dde8e2d5 | ||
|
|
25fa4df70f | ||
|
|
e735f4270b | ||
|
|
035e7a2f4c | ||
|
|
2d36c11264 | ||
|
|
b8ec5bdccf | ||
|
|
249405b46e | ||
|
|
d3be84fec2 | ||
|
|
1c74ab175f | ||
|
|
b2f1fc109a | ||
|
|
5a2a80cc48 | ||
|
|
4033fe74f8 | ||
|
|
e9df4444be | ||
|
|
ffd2985750 | ||
|
|
17310f9acc | ||
|
|
71ae6f9f87 | ||
|
|
9dd1092f8f | ||
|
|
2c2f2647a9 | ||
|
|
98313a6b3f | ||
|
|
8b75205d3b | ||
|
|
ef4990f304 | ||
|
|
db3297b090 | ||
|
|
86ed554bda | ||
|
|
f254d7d5a2 | ||
|
|
d8b0522ea0 | ||
|
|
1edd6b9524 | ||
|
|
66c6fb56cb | ||
|
|
90b39ce112 | ||
|
|
5afab46cc6 | ||
|
|
bd152c6115 | ||
|
|
76336743ff |
@@ -112,6 +112,13 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
|
||||
|
||||
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
|
||||
|
||||
## 📈 Telemetry
|
||||
|
||||
Axolotl has opt-in telemetry that helps us understand how the project is being used
|
||||
and prioritize improvements. We collect basic system information, model types, and
|
||||
error rates—never personal data or file paths. Telemetry is disabled by default. To
|
||||
enable it, set AXOLOTL_DO_NOT_TRACK=0. For more details, see our [telemetry documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html).
|
||||
|
||||
## Supported Models
|
||||
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|
||||
@@ -236,6 +236,7 @@ website:
|
||||
- docs/inference.qmd
|
||||
- docs/cli.qmd
|
||||
- docs/config.qmd
|
||||
- docs/telemetry.qmd
|
||||
- text: "API Reference"
|
||||
href: docs/api
|
||||
|
||||
|
||||
59
docs/telemetry.qmd
Normal file
59
docs/telemetry.qmd
Normal file
@@ -0,0 +1,59 @@
|
||||
---
|
||||
title: Telemetry
|
||||
description: A description of the opt-in telemetry implementation in Axolotl.
|
||||
---
|
||||
|
||||
# Telemetry in Axolotl
|
||||
|
||||
Axolotl implements anonymous telemetry to help maintainers understand how the library
|
||||
is used and where users encounter issues. This data helps prioritize features, optimize
|
||||
performance, and fix bugs.
|
||||
|
||||
## Data Collection
|
||||
|
||||
We collect:
|
||||
|
||||
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
|
||||
version, etc.
|
||||
- Hardware info: CPU count, memory, GPU count and models
|
||||
- Runtime metrics: Training progress, memory usage, timing information
|
||||
- Usage patterns: Models (from a whitelist) and configurations used
|
||||
- Error tracking: Stack traces and error messages (sanitized to remove personal
|
||||
information)
|
||||
|
||||
Personally identifiable information (PII) is not collected.
|
||||
|
||||
## Implementation
|
||||
|
||||
Telemetry is implemented using PostHog and consists of:
|
||||
|
||||
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
|
||||
telemetry system and provides methods for tracking events.
|
||||
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
|
||||
sends sanitized stack traces.
|
||||
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
|
||||
runtime metrics during training.
|
||||
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
|
||||
runtime metrics telemetry.
|
||||
|
||||
The telemetry system will block training startup for 15 seconds to ensure users are
|
||||
aware of data collection, unless telemetry is explicitly enabled or disabled.
|
||||
|
||||
## Opt-In Mechanism
|
||||
|
||||
Telemetry is **disabled by default** on an opt-in basis. To enable it, set `AXOLOTL_DO_NOT_TRACK=0`.
|
||||
|
||||
To remove the warning message about telemetry that is displayed on train, etc. startup,
|
||||
explicitly set: `AXOLOTL_DO_NOT_TRACK=0` (enable telemetry) or `AXOLOTL_DO_NOT_TRACK=1`
|
||||
(explicitly disable telemetry).
|
||||
|
||||
**Note**: Telemetry will move to an opt-out model in a later release.
|
||||
|
||||
## Privacy
|
||||
|
||||
- All path-like config information is automatically redacted from telemetry data
|
||||
- Model information is only collected for whitelisted organizations
|
||||
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
|
||||
- Each run generates a unique anonymous ID
|
||||
- This allows us to link different telemetry events in a single same training run
|
||||
- Telemetry is only sent from the main process to avoid duplicate events
|
||||
@@ -20,7 +20,6 @@ datasets==3.6.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.18.1
|
||||
hf_xet==1.1.2
|
||||
mistral-common[hf-hub]==1.6.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
@@ -68,3 +67,6 @@ schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
# telemetry
|
||||
posthog>=4.2.0
|
||||
|
||||
@@ -14,6 +14,8 @@ import yaml
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
@@ -28,6 +30,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||
"""
|
||||
@@ -159,6 +163,7 @@ def plugin_set_cfg(cfg: DictDefault):
|
||||
plugin_manager.cfg = cfg
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_cfg(
|
||||
config: str | Path | DictDefault = Path("examples/"), **kwargs
|
||||
) -> DictDefault:
|
||||
@@ -192,6 +197,8 @@ def load_cfg(
|
||||
temp_file.close()
|
||||
cfg.axolotl_config_path = temp_file.name
|
||||
|
||||
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
|
||||
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
@@ -233,4 +240,6 @@ def load_cfg(
|
||||
setup_comet_env_vars(cfg)
|
||||
plugin_set_cfg(cfg)
|
||||
|
||||
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
@@ -16,6 +16,7 @@ from axolotl.cli.args import InferenceCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.chat_templates import (
|
||||
get_chat_template,
|
||||
get_chat_template_from_config,
|
||||
@@ -42,6 +43,7 @@ def get_multi_line_input() -> str:
|
||||
return instruction
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_inference(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
@@ -135,6 +137,7 @@ def do_inference(
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_inference_gradio(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -9,12 +9,14 @@ from dotenv import load_dotenv
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
|
||||
|
||||
@@ -24,6 +24,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -118,6 +119,7 @@ def _distributed_checkpoint_to_merged_weights(
|
||||
return save_path_
|
||||
|
||||
|
||||
@send_errors
|
||||
def merge_fsdp_weights(
|
||||
checkpoint_dir: str,
|
||||
output_path: str,
|
||||
|
||||
@@ -18,6 +18,7 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
@@ -25,6 +26,7 @@ from axolotl.utils.trainer import disable_datasets_caching
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
"""
|
||||
Preprocesses dataset specified in axolotl config.
|
||||
|
||||
@@ -305,8 +305,8 @@ def load_model_and_tokenizer(
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the
|
||||
given `axolotl` config.
|
||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||
config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
@@ -10,6 +10,7 @@ from datasets import Dataset
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -45,6 +46,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||
)
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
@@ -112,6 +114,7 @@ def load_datasets(
|
||||
)
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_preference_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -31,6 +31,8 @@ from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
from axolotl.telemetry.callbacks import TelemetryCallback
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
@@ -145,6 +147,10 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
if telemetry_manager.enabled:
|
||||
callbacks.append(TelemetryCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
@@ -59,42 +58,6 @@ class AxolotlGRPOTrainer(
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
def get_train_dataloader(self):
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator
|
||||
if isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(
|
||||
data_collator, description="training"
|
||||
)
|
||||
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size
|
||||
* self.args.steps_per_generation, # < this is the change
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
"persistent_workers": self.args.dataloader_persistent_workers,
|
||||
}
|
||||
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_train_sampler()
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = partial(
|
||||
seed_worker,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
rank=self.args.process_index,
|
||||
)
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
||||
|
||||
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
|
||||
@@ -64,10 +64,6 @@ class TokenizedPromptDataset(Dataset):
|
||||
desc="Strategy Filtering Rows",
|
||||
)
|
||||
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
|
||||
@@ -11,6 +11,7 @@ from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.train import (
|
||||
TrainDatasetMeta,
|
||||
setup_model_and_tokenizer,
|
||||
@@ -63,6 +64,7 @@ def evaluate_dataset(
|
||||
return metrics
|
||||
|
||||
|
||||
@send_errors
|
||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate a model on training and validation datasets.
|
||||
|
||||
@@ -19,6 +19,7 @@ from peft import (
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.loaders.utils import get_linear_embedding_layers
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -162,6 +163,7 @@ def load_lora(
|
||||
return model, lora_config
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_adapter(
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -46,6 +46,7 @@ from axolotl.loaders.utils import (
|
||||
load_model_config,
|
||||
)
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import (
|
||||
@@ -145,6 +146,7 @@ class ModelLoader:
|
||||
"""Property that determines if FSDP with QLoRA is enabled."""
|
||||
return self.cfg.fsdp and self.cfg.adapter == "qlora"
|
||||
|
||||
@send_errors
|
||||
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
|
||||
"""Load and prepare the model with all configurations and patches.
|
||||
|
||||
|
||||
@@ -8,12 +8,14 @@ from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
processor_kwargs: dict[str, Any] = {} # Do we actually need this?
|
||||
|
||||
|
||||
@@ -2,16 +2,8 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from huggingface_hub import hf_hub_download
|
||||
from mistral_common.protocol.instruct.messages import SystemMessage, UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer,
|
||||
)
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AutoTokenizer,
|
||||
@@ -20,6 +12,7 @@ from transformers import (
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
@@ -31,622 +24,240 @@ from axolotl.utils.logging import get_logger
|
||||
LOG = get_logger(__name__)
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
# Constants
|
||||
LLAMA_TOKENIZER_CLASSES = {
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizer",
|
||||
"CodeLlamaTokenizerFast",
|
||||
}
|
||||
|
||||
FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"}
|
||||
|
||||
QWEN_DEFAULT_TOKEN = "<|endoftext|>"
|
||||
GPTNEOX_PAD_TOKEN = "[PAD]"
|
||||
CHATML_DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
|
||||
|
||||
|
||||
class MistralTokenizerWrapper:
|
||||
"""
|
||||
Wrapper to make MistralTokenizer compatible with Hugging Face tokenizer interface.
|
||||
This provides a bridge between Mistral's native tokenizer and axolotl's expectations.
|
||||
def modify_tokenizer_files(
|
||||
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
|
||||
) -> str:
|
||||
"""
|
||||
Modify tokenizer files to replace added_tokens strings, save to output directory,
|
||||
and return the path to the modified tokenizer.
|
||||
|
||||
def __init__(self, mistral_tokenizer: "MistralTokenizer", model_id: str):
|
||||
self.mistral_tokenizer = mistral_tokenizer
|
||||
self.model_id = model_id
|
||||
self._system_prompt = None
|
||||
self.padding_side = "right" # Default padding side
|
||||
self.chat_template = None
|
||||
|
||||
# Cache token IDs by inspecting the actual tokenizer
|
||||
self._token_ids = self._discover_token_ids()
|
||||
|
||||
# Try to load system prompt if available
|
||||
try:
|
||||
self._system_prompt = self._load_system_prompt(
|
||||
model_id, "SYSTEM_PROMPT.txt"
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.debug(f"Could not load system prompt: {e}")
|
||||
|
||||
def _discover_token_ids(self) -> Dict[str, int]:
|
||||
"""Discover the actual token IDs used by this Mistral tokenizer."""
|
||||
token_ids = {}
|
||||
|
||||
try:
|
||||
if hasattr(self.mistral_tokenizer, "instruct_tokenizer"):
|
||||
instruct_tokenizer = self.mistral_tokenizer.instruct_tokenizer
|
||||
|
||||
# Get BOS token ID from instruct_tokenizer
|
||||
token_ids["bos_token_id"] = getattr(instruct_tokenizer, "BOS", 1)
|
||||
|
||||
# Get token IDs from the underlying Tekkenizer
|
||||
if hasattr(instruct_tokenizer, "tokenizer"):
|
||||
tekkenizer = instruct_tokenizer.tokenizer
|
||||
|
||||
# Get BOS ID from tekkenizer (should match instruct_tokenizer.BOS)
|
||||
if hasattr(tekkenizer, "bos_id"):
|
||||
token_ids["bos_token_id"] = tekkenizer.bos_id
|
||||
|
||||
# Get vocab size to help find EOS token
|
||||
vocab_size = getattr(tekkenizer, "_vocab_size", None)
|
||||
|
||||
# Check special tokens
|
||||
if hasattr(tekkenizer, "_all_special_tokens"):
|
||||
special_tokens = tekkenizer._all_special_tokens
|
||||
keys = (
|
||||
list(special_tokens.keys())
|
||||
if hasattr(special_tokens, "keys")
|
||||
else special_tokens
|
||||
)
|
||||
LOG.debug(f"Special tokens available: {keys}")
|
||||
|
||||
# Try to find EOS token in special tokens
|
||||
if hasattr(special_tokens, "get"):
|
||||
# Common EOS token patterns
|
||||
for eos_pattern in ["</s>", "<|endoftext|>", "eos", "EOS"]:
|
||||
if eos_pattern in special_tokens:
|
||||
token_ids["eos_token_id"] = special_tokens[
|
||||
eos_pattern
|
||||
]
|
||||
break
|
||||
|
||||
# Check special tokens reverse vocab
|
||||
if hasattr(tekkenizer, "_special_tokens_reverse_vocab"):
|
||||
reverse_vocab = tekkenizer._special_tokens_reverse_vocab
|
||||
LOG.debug(f"Reverse special tokens: {reverse_vocab}")
|
||||
|
||||
# Look for common special token IDs
|
||||
for token_id, token_str in reverse_vocab.items():
|
||||
if token_str in ["</s>", "<|endoftext|>"]:
|
||||
token_ids["eos_token_id"] = token_id
|
||||
elif token_str in ["<unk>", "<UNK>"]:
|
||||
token_ids["unk_token_id"] = token_id
|
||||
|
||||
# If we have vocab_size, EOS is often vocab_size - 1 or similar
|
||||
if "eos_token_id" not in token_ids and vocab_size:
|
||||
# Common patterns: EOS could be 2, vocab_size-1, or other values
|
||||
# Let's try a safer approach by checking what tokens decode to
|
||||
for candidate_id in [2, vocab_size - 1, vocab_size - 2]:
|
||||
try:
|
||||
# Try to decode and see if it looks like EOS
|
||||
decoded = tekkenizer.decode([candidate_id])
|
||||
if decoded in ["</s>", "<|endoftext|>", ""]:
|
||||
token_ids["eos_token_id"] = candidate_id
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
LOG.debug(f"Could not discover token IDs: {e}")
|
||||
|
||||
# Set reasonable defaults for any missing token IDs
|
||||
token_ids.setdefault("bos_token_id", 1)
|
||||
token_ids.setdefault("eos_token_id", 2)
|
||||
token_ids.setdefault("unk_token_id", 0)
|
||||
token_ids.setdefault(
|
||||
"pad_token_id", token_ids["eos_token_id"]
|
||||
) # Use EOS as pad
|
||||
|
||||
LOG.info(f"Discovered Mistral token IDs: {token_ids}")
|
||||
return token_ids
|
||||
|
||||
def _load_system_prompt(self, repo_id: str, filename: str) -> str:
|
||||
"""Load system prompt from HuggingFace Hub"""
|
||||
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
with open(file_path, "r") as file:
|
||||
return file.read()
|
||||
|
||||
def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]:
|
||||
"""Encode text to token IDs"""
|
||||
if isinstance(text, str):
|
||||
# For simple string encoding, create a user message
|
||||
messages = []
|
||||
if self._system_prompt and add_special_tokens:
|
||||
messages.append(SystemMessage(content=self._system_prompt))
|
||||
messages.append(UserMessage(content=text))
|
||||
|
||||
tokenized = self.mistral_tokenizer.encode_chat_completion(
|
||||
ChatCompletionRequest(messages=messages)
|
||||
)
|
||||
return tokenized.tokens
|
||||
else:
|
||||
raise ValueError("MistralTokenizer wrapper only supports string input")
|
||||
|
||||
def decode(
|
||||
self,
|
||||
token_ids: Union[List[int], torch.Tensor],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> str:
|
||||
"""Decode token IDs to text"""
|
||||
if isinstance(token_ids, torch.Tensor):
|
||||
token_ids = token_ids.tolist()
|
||||
return self.mistral_tokenizer.decode(token_ids)
|
||||
|
||||
def __call__(self, text: str, **kwargs):
|
||||
"""Make the tokenizer callable like HF tokenizers"""
|
||||
tokens = self.encode(text, **kwargs)
|
||||
return {"input_ids": torch.tensor([tokens])}
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self._token_ids["eos_token_id"]
|
||||
|
||||
@property
|
||||
def bos_token_id(self):
|
||||
return self._token_ids["bos_token_id"]
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
return self._token_ids["pad_token_id"]
|
||||
|
||||
@property
|
||||
def unk_token_id(self):
|
||||
return self._token_ids["unk_token_id"]
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return "</s>" # Standard Mistral EOS token
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return "<s>" # Standard Mistral BOS token
|
||||
|
||||
@property
|
||||
def pad_token(self):
|
||||
return self.eos_token # Use EOS as pad token
|
||||
|
||||
@property
|
||||
def unk_token(self):
|
||||
return "<unk>" # Standard UNK token
|
||||
|
||||
@property
|
||||
def __class__(self):
|
||||
# Create a mock class for compatibility checks
|
||||
class MistralTokenizerWrapperClass:
|
||||
__name__ = "MistralTokenizerWrapper"
|
||||
|
||||
return MistralTokenizerWrapperClass
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict: Dict[str, str]) -> int:
|
||||
"""Placeholder for special token addition - Mistral tokenizer handles this internally"""
|
||||
LOG.warning(
|
||||
"add_special_tokens called on MistralTokenizer wrapper - this is handled internally"
|
||||
)
|
||||
return 0
|
||||
|
||||
def add_tokens(self, tokens) -> int:
|
||||
"""Placeholder for token addition - Mistral tokenizer handles this internally"""
|
||||
LOG.warning(
|
||||
"add_tokens called on MistralTokenizer wrapper - this is handled internally"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
class TokenizerFileModifier:
|
||||
"""Handles modification of tokenizer files for token overrides."""
|
||||
|
||||
def __init__(
|
||||
self, tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
|
||||
):
|
||||
self.tokenizer_path = tokenizer_path
|
||||
self.token_mappings = token_mappings
|
||||
self.output_dir = output_dir
|
||||
self.tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
||||
|
||||
def modify_and_save(self) -> str:
|
||||
"""Modify tokenizer files and return path to modified tokenizer."""
|
||||
os.makedirs(self.tokenizer_dir, exist_ok=True)
|
||||
|
||||
if is_local_main_process():
|
||||
self._perform_modifications()
|
||||
barrier()
|
||||
|
||||
return self.tokenizer_dir
|
||||
|
||||
def _perform_modifications(self):
|
||||
"""Perform the actual file modifications."""
|
||||
# Load and save tokenizer to output directory
|
||||
temp_tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.tokenizer_path, use_fast=True
|
||||
)
|
||||
temp_tokenizer.save_pretrained(self.tokenizer_dir)
|
||||
|
||||
# Convert token mappings to proper format
|
||||
token_id_mappings = {
|
||||
int(token_id): new_value
|
||||
for token_id, new_value in self.token_mappings.items()
|
||||
}
|
||||
|
||||
# Update both tokenizer files
|
||||
self._update_tokenizer_config(token_id_mappings)
|
||||
self._update_tokenizer_json(token_id_mappings)
|
||||
|
||||
def _update_tokenizer_config(self, token_id_mappings: Dict[int, str]):
|
||||
"""Update tokenizer_config.json with new token mappings."""
|
||||
config_path = os.path.join(self.tokenizer_dir, "tokenizer_config.json")
|
||||
if not os.path.exists(config_path):
|
||||
return
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
if "added_tokens_decoder" in config_data:
|
||||
self._update_added_tokens_decoder(config_data, token_id_mappings)
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
|
||||
def _update_added_tokens_decoder(
|
||||
self, config_data: Dict, token_id_mappings: Dict[int, str]
|
||||
):
|
||||
"""Update the added_tokens_decoder section."""
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
token_id_str = str(token_id)
|
||||
if token_id_str in config_data["added_tokens_decoder"]:
|
||||
config_data["added_tokens_decoder"][token_id_str]["content"] = new_value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
||||
)
|
||||
|
||||
def _update_tokenizer_json(self, token_id_mappings: Dict[int, str]):
|
||||
"""Update tokenizer.json with new token mappings."""
|
||||
tokenizer_json_path = os.path.join(self.tokenizer_dir, "tokenizer.json")
|
||||
if not os.path.exists(tokenizer_json_path):
|
||||
return
|
||||
|
||||
with open(tokenizer_json_path, "r", encoding="utf-8") as f:
|
||||
tokenizer_data = json.load(f)
|
||||
|
||||
self._update_added_tokens_list(tokenizer_data, token_id_mappings)
|
||||
self._update_vocab_mappings(tokenizer_data, token_id_mappings)
|
||||
|
||||
with open(tokenizer_json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokenizer_data, f, indent=2)
|
||||
|
||||
def _update_added_tokens_list(
|
||||
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
|
||||
):
|
||||
"""Update the added_tokens list in tokenizer.json."""
|
||||
if "added_tokens" not in tokenizer_data:
|
||||
return
|
||||
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
||||
if token_entry["id"] == token_id:
|
||||
tokenizer_data["added_tokens"][i]["content"] = new_value
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Token ID {token_id} not found in added_tokens")
|
||||
|
||||
def _update_vocab_mappings(
|
||||
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
|
||||
):
|
||||
"""Update vocab mappings in tokenizer.json."""
|
||||
if not (tokenizer_data.get("model") and tokenizer_data["model"].get("vocab")):
|
||||
return
|
||||
|
||||
vocab = tokenizer_data["model"]["vocab"]
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
# Find and update the vocab entry
|
||||
for entry_val, entry_id in list(vocab.items()):
|
||||
if entry_id == token_id:
|
||||
del vocab[entry_val]
|
||||
vocab[new_value] = token_id
|
||||
break
|
||||
|
||||
|
||||
class TokenizerConfiguration:
|
||||
"""Handles tokenizer configuration and initialization."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.model_config = load_model_config(cfg)
|
||||
|
||||
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
|
||||
"""Load Mistral tokenizer from model configuration."""
|
||||
# Instantiate Mistral tokenizer
|
||||
model_id = self.cfg.base_model
|
||||
mistral_tokenizer = MistralTokenizer.from_hf_hub(model_id)
|
||||
|
||||
# Wrap it for compatibility
|
||||
tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
|
||||
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
|
||||
|
||||
return tokenizer
|
||||
|
||||
def get_tokenizer_class(self):
|
||||
"""Get the appropriate tokenizer class."""
|
||||
if self.cfg.tokenizer_type:
|
||||
return getattr(transformers, self.cfg.tokenizer_type)
|
||||
return AutoTokenizer
|
||||
|
||||
def get_tokenizer_kwargs(self) -> Dict[str, Any]:
|
||||
"""Build tokenizer initialization kwargs."""
|
||||
kwargs = {}
|
||||
if self.cfg.tokenizer_legacy is not None:
|
||||
kwargs["legacy"] = self.cfg.tokenizer_legacy
|
||||
return kwargs
|
||||
|
||||
def get_tokenizer_path(self) -> str:
|
||||
"""Get the tokenizer path, applying overrides if needed."""
|
||||
tokenizer_path = self.cfg.tokenizer_config
|
||||
|
||||
if self.cfg.added_tokens_overrides:
|
||||
modifier = TokenizerFileModifier(
|
||||
tokenizer_path, self.cfg.added_tokens_overrides, self.cfg.output_dir
|
||||
)
|
||||
tokenizer_path = modifier.modify_and_save()
|
||||
|
||||
return tokenizer_path
|
||||
|
||||
def should_use_fast_tokenizer(self) -> bool:
|
||||
"""Determine if fast tokenizer should be used."""
|
||||
return (
|
||||
self.cfg.tokenizer_use_fast
|
||||
if self.cfg.tokenizer_use_fast is not None
|
||||
else True
|
||||
)
|
||||
|
||||
|
||||
class TokenizerPostProcessor:
|
||||
"""Handles post-processing configuration of loaded tokenizers."""
|
||||
|
||||
def __init__(self, tokenizer, cfg):
|
||||
self.tokenizer = tokenizer
|
||||
self.cfg = cfg
|
||||
self.model_config = load_model_config(cfg)
|
||||
|
||||
def apply_all_configurations(self):
|
||||
"""Apply all post-processing configurations to the tokenizer."""
|
||||
# Skip most configurations for Mistral wrapper
|
||||
if isinstance(self.tokenizer, MistralTokenizerWrapper):
|
||||
self._configure_mistral_wrapper()
|
||||
return
|
||||
|
||||
self._configure_padding_token()
|
||||
self._configure_gptneox_settings()
|
||||
self._configure_mistral_padding()
|
||||
self._configure_qwen_tokens()
|
||||
self._add_special_tokens()
|
||||
self._add_regular_tokens()
|
||||
self._configure_chat_template()
|
||||
|
||||
def _configure_mistral_wrapper(self):
|
||||
"""Apply limited configurations for Mistral wrapper."""
|
||||
# Set padding side if needed
|
||||
if (
|
||||
self.cfg.is_mistral_derived_model
|
||||
and self.cfg.flash_attention
|
||||
and not self.cfg.sample_packing
|
||||
):
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
# Configure chat template for Mistral
|
||||
self._configure_chat_template()
|
||||
|
||||
def _configure_padding_token(self):
|
||||
"""Configure padding token for Llama-based tokenizers."""
|
||||
if (
|
||||
self.tokenizer.__class__.__name__ in LLAMA_TOKENIZER_CLASSES
|
||||
and hasattr(self.tokenizer, "pad_token")
|
||||
and not self.tokenizer.pad_token
|
||||
):
|
||||
self.tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||
|
||||
def _configure_gptneox_settings(self):
|
||||
"""Configure GPTNeoX-specific settings."""
|
||||
if self.tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
self.tokenizer.add_special_tokens({"pad_token": GPTNEOX_PAD_TOKEN})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def _configure_mistral_padding(self):
|
||||
"""Configure left padding for Mistral models with Flash Attention."""
|
||||
if (
|
||||
self.cfg.is_mistral_derived_model
|
||||
and self.cfg.flash_attention
|
||||
and not self.cfg.sample_packing
|
||||
):
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
def _configure_qwen_tokens(self):
|
||||
"""Configure special tokens for Qwen models."""
|
||||
if not self.cfg.is_qwen_derived_model:
|
||||
return
|
||||
|
||||
# Set token IDs
|
||||
token_id_attributes = [
|
||||
"bos_token_id",
|
||||
"eos_token_id",
|
||||
"pad_token_id",
|
||||
"unk_token_id",
|
||||
]
|
||||
for attr_name in token_id_attributes:
|
||||
if getattr(self.tokenizer, attr_name) is None:
|
||||
setattr(self.tokenizer, attr_name, self.tokenizer.eod_id)
|
||||
|
||||
# Set token strings
|
||||
token_name_attributes = ["bos_token", "eos_token", "pad_token", "unk_token"]
|
||||
for attr_name in token_name_attributes:
|
||||
if getattr(self.tokenizer, attr_name) is None:
|
||||
setattr(self.tokenizer, attr_name, QWEN_DEFAULT_TOKEN)
|
||||
|
||||
def _add_special_tokens(self):
|
||||
"""Add special tokens from configuration."""
|
||||
if not self.cfg.special_tokens:
|
||||
return
|
||||
|
||||
special_tokens_dict = self.cfg.special_tokens.to_dict()
|
||||
additional_special_tokens = special_tokens_dict.pop(
|
||||
"additional_special_tokens", None
|
||||
)
|
||||
|
||||
self._validate_and_add_special_tokens(special_tokens_dict)
|
||||
self._update_post_processor_if_needed(special_tokens_dict)
|
||||
self._add_additional_special_tokens_if_present(additional_special_tokens)
|
||||
|
||||
def _validate_and_add_special_tokens(self, special_tokens: Dict[str, str]):
|
||||
"""Validate special tokens for adapter training and add them."""
|
||||
lora_modules_to_save = get_linear_embedding_layers(self.model_config.model_type)
|
||||
|
||||
for key, value in special_tokens.items():
|
||||
self._validate_token_for_adapter(key, value, lora_modules_to_save)
|
||||
self.tokenizer.add_special_tokens(
|
||||
{key: AddedToken(value, rstrip=False, lstrip=False, normalized=False)}
|
||||
)
|
||||
|
||||
def _validate_token_for_adapter(
|
||||
self, key: str, value: str, lora_modules_to_save: List[str]
|
||||
):
|
||||
"""Validate a single token for adapter training requirements."""
|
||||
if not self._should_validate_token_for_adapter(
|
||||
key, value, lora_modules_to_save
|
||||
):
|
||||
return
|
||||
|
||||
modules_str = ", ".join(f"`{x}`" for x in lora_modules_to_save)
|
||||
raise ValueError(
|
||||
f"Please set lora_modules_to_save to [{modules_str}] "
|
||||
f"when using an adapter and changing the special tokens."
|
||||
)
|
||||
|
||||
def _should_validate_token_for_adapter(
|
||||
self, key: str, value: str, lora_modules_to_save: List[str]
|
||||
) -> bool:
|
||||
"""Check if token should be validated for adapter configuration."""
|
||||
if key == "pad_token" or not self.cfg.adapter:
|
||||
return False
|
||||
|
||||
current_token = getattr(self.tokenizer, key)
|
||||
token_changed = current_token is None or current_token != value
|
||||
token_is_multi_char = (
|
||||
len(self.tokenizer.encode(value, add_special_tokens=False)) > 2
|
||||
)
|
||||
lora_modules_missing = not self.cfg.lora_modules_to_save or not all(
|
||||
x in self.cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||
)
|
||||
|
||||
return token_changed and token_is_multi_char and lora_modules_missing
|
||||
|
||||
def _update_post_processor_if_needed(self, special_tokens: Dict[str, str]):
|
||||
"""Update post processor for Llama tokenizers when BOS/EOS tokens are added."""
|
||||
has_bos_and_eos = (
|
||||
"bos_token" in special_tokens and "eos_token" in special_tokens
|
||||
)
|
||||
is_fast_llama = (
|
||||
self.tokenizer.__class__.__name__ in FAST_LLAMA_TOKENIZER_CLASSES
|
||||
)
|
||||
|
||||
if is_fast_llama and has_bos_and_eos:
|
||||
self.tokenizer.update_post_processor()
|
||||
|
||||
def _add_additional_special_tokens_if_present(
|
||||
self, additional_special_tokens: Optional[List[str]]
|
||||
):
|
||||
"""Add additional special tokens if they exist."""
|
||||
if additional_special_tokens is not None:
|
||||
self.tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": additional_special_tokens}
|
||||
)
|
||||
|
||||
def _add_regular_tokens(self):
|
||||
"""Add regular (non-special) tokens from configuration."""
|
||||
if self.cfg.tokens:
|
||||
self.tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
|
||||
for token in self.cfg.tokens
|
||||
]
|
||||
)
|
||||
|
||||
def _configure_chat_template(self):
|
||||
"""Configure chat template if specified."""
|
||||
if not self.cfg.chat_template:
|
||||
LOG.info(
|
||||
"No Chat template selected. Consider adding a chat template for easier inference."
|
||||
)
|
||||
return
|
||||
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=self.cfg,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
if self._should_replace_default_system_message():
|
||||
chat_template_string = chat_template_string.replace(
|
||||
CHATML_DEFAULT_SYSTEM_MESSAGE, self.cfg.default_system_message
|
||||
)
|
||||
|
||||
self.tokenizer.chat_template = chat_template_string
|
||||
|
||||
def _should_replace_default_system_message(self) -> bool:
|
||||
"""Check if default system message should be replaced."""
|
||||
return self.cfg.default_system_message and self.cfg.chat_template == "chatml"
|
||||
|
||||
|
||||
def load_tokenizer(cfg):
|
||||
"""Load and configure the tokenizer based on the provided config.
|
||||
|
||||
This function handles the complete tokenizer loading pipeline:
|
||||
- Check if Mistral tokenizer should be used
|
||||
- Configure tokenizer parameters and get the appropriate class
|
||||
- Handle token file modifications if needed
|
||||
- Initialize the tokenizer with the correct parameters
|
||||
- Apply all post-processing configurations (padding, special tokens, etc.)
|
||||
- Set up chat templates and logging
|
||||
This only works with reserved tokens that were added to the tokenizer, not tokens
|
||||
already part of the vocab.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
tokenizer_path: Path or name of the original tokenizer
|
||||
token_mappings: Dict mapping {token_id (int): new_token_string}
|
||||
output_dir: Directory to save the modified tokenizer
|
||||
|
||||
Returns:
|
||||
Fully configured tokenizer instance.
|
||||
Path to the modified tokenizer directory
|
||||
|
||||
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
|
||||
"""
|
||||
# Configure tokenizer parameters
|
||||
config = TokenizerConfiguration(cfg)
|
||||
# Create the tokenizer directory in output_dir if it doesn't exist
|
||||
tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
|
||||
# Check if we should use Mistral tokenizer
|
||||
try:
|
||||
tokenizer = config.load_mistral_tokenizer()
|
||||
except:
|
||||
# Standard tokenizer loading
|
||||
tokenizer_cls = config.get_tokenizer_class()
|
||||
tokenizer_path = config.get_tokenizer_path()
|
||||
use_fast = config.should_use_fast_tokenizer()
|
||||
tokenizer_kwargs = config.get_tokenizer_kwargs()
|
||||
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
|
||||
# Load the tokenizer
|
||||
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
||||
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
tokenizer_path,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
# Save the tokenizer to the output directory
|
||||
temp_tokenizer.save_pretrained(tokenizer_dir)
|
||||
|
||||
# Get the token IDs and map them to their new values
|
||||
token_id_mappings = {
|
||||
int(token_id): new_value for token_id, new_value in token_mappings.items()
|
||||
}
|
||||
|
||||
# 1. Update tokenizer_config.json - added_tokens_decoder
|
||||
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
# Update added_tokens_decoder
|
||||
if "added_tokens_decoder" in config_data:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
token_id_str = str(token_id)
|
||||
if token_id_str in config_data["added_tokens_decoder"]:
|
||||
config_data["added_tokens_decoder"][token_id_str][
|
||||
"content"
|
||||
] = new_value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
||||
)
|
||||
|
||||
# Write the updated config back
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
|
||||
# 2. Update tokenizer.json - added_tokens
|
||||
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
||||
if os.path.exists(tokenizer_path):
|
||||
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
||||
tokenizer_data = json.load(f)
|
||||
|
||||
# Update added_tokens
|
||||
if "added_tokens" in tokenizer_data:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
||||
if token_entry["id"] == token_id:
|
||||
tokenizer_data["added_tokens"][i]["content"] = new_value
|
||||
break
|
||||
else:
|
||||
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
|
||||
raise ValueError(
|
||||
f"Token ID {token_id} not found in added_tokens"
|
||||
)
|
||||
if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
for entry_val, entry_id in tokenizer_data["model"]["vocab"].items():
|
||||
if entry_id == token_id:
|
||||
del tokenizer_data["model"]["vocab"][entry_val]
|
||||
tokenizer_data["model"]["vocab"][new_value] = token_id
|
||||
break
|
||||
|
||||
# Write the updated tokenizer data back
|
||||
with open(tokenizer_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokenizer_data, f, indent=2)
|
||||
|
||||
barrier()
|
||||
return tokenizer_dir
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_tokenizer(cfg):
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
model_config = load_model_config(cfg)
|
||||
tokenizer_kwargs = {}
|
||||
use_fast = True # this is the default
|
||||
|
||||
if cfg.tokenizer_use_fast is not None:
|
||||
use_fast = cfg.tokenizer_use_fast
|
||||
if cfg.tokenizer_legacy is not None:
|
||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||
|
||||
tokenizer_cls = AutoTokenizer
|
||||
if cfg.tokenizer_type:
|
||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||
|
||||
# Set base tokenizer path
|
||||
tokenizer_path = cfg.tokenizer_config
|
||||
|
||||
# Apply token string overrides if specified
|
||||
if cfg.added_tokens_overrides:
|
||||
# Modify tokenizer files and get path to modified tokenizer
|
||||
tokenizer_path = modify_tokenizer_files(
|
||||
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
|
||||
)
|
||||
|
||||
# Apply all post-processing configurations
|
||||
post_processor = TokenizerPostProcessor(tokenizer, cfg)
|
||||
post_processor.apply_all_configurations()
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
tokenizer_path,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
tokenizer.__class__.__name__
|
||||
in [
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizer",
|
||||
"CodeLlamaTokenizerFast",
|
||||
]
|
||||
and hasattr(tokenizer, "pad_token")
|
||||
and not tokenizer.pad_token
|
||||
):
|
||||
# set a pad_token, but use eos_token so we don't add a new token
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Mistral's official FA implementation requires left padding
|
||||
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Qwen base only has single token, so we need to set the special tokens
|
||||
if cfg.is_qwen_derived_model:
|
||||
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
||||
for attr_name in token_ids:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, tokenizer.eod_id)
|
||||
|
||||
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
|
||||
for attr_name in token_names:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||
|
||||
additional_special_tokens = None
|
||||
if cfg.special_tokens:
|
||||
special_tokens = cfg.special_tokens.to_dict()
|
||||
additional_special_tokens = special_tokens.pop(
|
||||
"additional_special_tokens", None
|
||||
)
|
||||
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
||||
for k, val in special_tokens.items():
|
||||
# check if new special token is not already in tokenizer and
|
||||
# is adapter training to make sure lora_modules_to_save is set
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (
|
||||
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
||||
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
|
||||
and cfg.adapter
|
||||
and (
|
||||
not cfg.lora_modules_to_save
|
||||
or not all(
|
||||
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||
)
|
||||
)
|
||||
and k != "pad_token"
|
||||
):
|
||||
lora_modules_to_save = ", ".join(
|
||||
[f"`{x}`" for x in lora_modules_to_save]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
|
||||
)
|
||||
|
||||
tokenizer.add_special_tokens(
|
||||
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||
)
|
||||
|
||||
# If we add bos_token and eos_token, we need to update the post processor to
|
||||
# handle them correctly.
|
||||
# https://github.com/huggingface/transformers/pull/24132
|
||||
bos_or_eos_in_special_tokens = (
|
||||
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
|
||||
)
|
||||
if (
|
||||
tokenizer.__class__.__name__
|
||||
in (
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizerFast",
|
||||
)
|
||||
and bos_or_eos_in_special_tokens
|
||||
):
|
||||
tokenizer.update_post_processor()
|
||||
|
||||
if cfg.tokens:
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
|
||||
for token in cfg.tokens
|
||||
]
|
||||
)
|
||||
|
||||
# Additional special tokens are a List, and need to be treated differently than regular special
|
||||
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
|
||||
# are new tokens.
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# ```py
|
||||
# special_tokens:
|
||||
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
|
||||
# ```
|
||||
if additional_special_tokens is not None:
|
||||
tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": additional_special_tokens}
|
||||
)
|
||||
|
||||
if is_main_process(use_environ=True):
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
@@ -654,4 +265,19 @@ def load_tokenizer(cfg):
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if cfg.chat_template:
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
if cfg.default_system_message and cfg.chat_template == "chatml":
|
||||
chat_template_string = chat_template_string.replace(
|
||||
"You are a helpful assistant.", cfg.default_system_message
|
||||
)
|
||||
|
||||
tokenizer.chat_template = chat_template_string
|
||||
else:
|
||||
LOG.info(
|
||||
"No Chat template selected. Consider adding a chat template for easier inference."
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
@@ -67,10 +67,6 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
LOG.warning("Empty text requested for tokenization.")
|
||||
return empty
|
||||
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
|
||||
0
src/axolotl/telemetry/__init__.py
Normal file
0
src/axolotl/telemetry/__init__.py
Normal file
164
src/axolotl/telemetry/callbacks.py
Normal file
164
src/axolotl/telemetry/callbacks.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Trainer callbacks for reporting runtime metrics at regular intervals."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
TIME_SINCE_LAST = 30
|
||||
|
||||
|
||||
class TelemetryCallback(TrainerCallback):
|
||||
"""
|
||||
Trainer callback for tracking and reporting runtime metrics.
|
||||
|
||||
This callback tracks training progress, runtime, and memory usage,
|
||||
sending telemetry at configurable intervals.
|
||||
"""
|
||||
|
||||
report_interval_steps: int = 100
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the metrics callback."""
|
||||
self.tracker = RuntimeMetricsTracker()
|
||||
self.telemetry_manager = TelemetryManager.get_instance()
|
||||
self.current_epoch = -1
|
||||
self.start_time = time.time()
|
||||
self.last_report_time = None
|
||||
self.last_report_step = 0
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle training start."""
|
||||
self.telemetry_manager.send_event(event_type="train-start")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle training end."""
|
||||
# Send training completion event
|
||||
self.telemetry_manager.send_event(
|
||||
event_type="train-end",
|
||||
properties=self._extract_last_metrics(state)
|
||||
| self.tracker.metrics.to_dict(),
|
||||
)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_epoch_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle epoch start."""
|
||||
self.current_epoch += 1
|
||||
self.tracker.start_epoch(self.current_epoch)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_epoch_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle epoch end."""
|
||||
self.tracker.end_epoch(self.current_epoch)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Handle step end."""
|
||||
step = state.global_step
|
||||
self.tracker.update_step(step)
|
||||
|
||||
# Check if we should report metrics
|
||||
should_report = (
|
||||
step % self.report_interval_steps == 0
|
||||
or step == 1 # Always report first step
|
||||
or step - self.last_report_step >= self.report_interval_steps
|
||||
)
|
||||
|
||||
if should_report:
|
||||
current_time = time.time()
|
||||
if self.last_report_time is not None:
|
||||
time_since_last_report = current_time - self.last_report_time
|
||||
else:
|
||||
time_since_last_report = current_time - self.start_time
|
||||
steps_since_last_report = step - self.last_report_step
|
||||
|
||||
# Only report if enough time has passed
|
||||
if (
|
||||
step == 1
|
||||
or time_since_last_report >= TIME_SINCE_LAST
|
||||
or steps_since_last_report >= self.report_interval_steps
|
||||
):
|
||||
# Calculate steps per second for this interval
|
||||
if time_since_last_report > 0 and steps_since_last_report > 0:
|
||||
steps_per_second = steps_since_last_report / time_since_last_report
|
||||
else:
|
||||
steps_per_second = 0
|
||||
|
||||
# Update memory metrics
|
||||
self.tracker.update_memory_metrics()
|
||||
|
||||
# Prepare metrics to report
|
||||
metrics = self._extract_last_metrics(state) | {
|
||||
"step": step,
|
||||
"epoch": self.current_epoch,
|
||||
"progress": state.epoch, # Fractional epoch progress
|
||||
"steps_per_second": steps_per_second,
|
||||
"elapsed_time": current_time - self.start_time,
|
||||
"time_since_last_report": time_since_last_report,
|
||||
}
|
||||
|
||||
# Add memory metrics
|
||||
memory_metrics = self.tracker.get_memory_metrics()
|
||||
metrics.update({"memory": memory_metrics})
|
||||
|
||||
# Send telemetry
|
||||
self.telemetry_manager.send_event(
|
||||
event_type="train-progress", properties=metrics
|
||||
)
|
||||
|
||||
# Update last report time and step
|
||||
self.last_report_time = current_time
|
||||
self.last_report_step = step
|
||||
|
||||
def _extract_last_metrics(self, state: TrainerState) -> dict:
|
||||
"""Extract last loss and learning_rate from log history."""
|
||||
if not state.log_history:
|
||||
return {"loss": 0, "learning_rate": 0}
|
||||
|
||||
last_log = state.log_history[-1]
|
||||
return {
|
||||
"loss": last_log.get("loss", 0),
|
||||
"learning_rate": last_log.get("learning_rate", 0),
|
||||
}
|
||||
160
src/axolotl/telemetry/errors.py
Normal file
160
src/axolotl/telemetry/errors.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Telemetry utilities for exception and traceback information."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from functools import wraps
|
||||
from inspect import getmodule
|
||||
from typing import Any, Callable
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
ERROR_HANDLED = False
|
||||
|
||||
|
||||
def sanitize_stack_trace(stack_trace: str) -> str:
|
||||
"""
|
||||
Remove personal information from stack trace messages while keeping Python package codepaths.
|
||||
|
||||
This function identifies Python packages by looking for common patterns in virtual environment
|
||||
and site-packages directories, preserving the package path while removing user-specific paths.
|
||||
|
||||
Args:
|
||||
stack_trace: The original stack trace string.
|
||||
|
||||
Returns:
|
||||
A sanitized version of the stack trace with Python package paths preserved.
|
||||
"""
|
||||
# Split the stack trace into lines to process each file path separately
|
||||
lines = stack_trace.split("\n")
|
||||
sanitized_lines = []
|
||||
|
||||
# Regular expression to find file paths in the stack trace
|
||||
path_pattern = re.compile(r'(?:File ")(.*?)(?:")')
|
||||
|
||||
# Regular expression to identify paths in site-packages or dist-packages
|
||||
# This matches path segments like "site-packages/package_name" or "dist-packages/package_name"
|
||||
site_packages_pattern = re.compile(
|
||||
r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
|
||||
)
|
||||
|
||||
# Additional common virtual environment patterns
|
||||
venv_lib_pattern = re.compile(
|
||||
r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
|
||||
)
|
||||
|
||||
for line in lines:
|
||||
# Check if this line contains a file path
|
||||
path_match = path_pattern.search(line)
|
||||
|
||||
if path_match:
|
||||
full_path = path_match.group(1)
|
||||
sanitized_path = ""
|
||||
|
||||
# Try to match site-packages pattern
|
||||
site_packages_match = site_packages_pattern.search(full_path)
|
||||
venv_lib_match = venv_lib_pattern.search(full_path)
|
||||
|
||||
if site_packages_match:
|
||||
# Find the index where the matched pattern starts
|
||||
idx = full_path.find("site-packages")
|
||||
if idx == -1:
|
||||
idx = full_path.find("dist-packages")
|
||||
|
||||
# Keep from 'site-packages' onward
|
||||
if idx >= 0:
|
||||
sanitized_path = full_path[idx:]
|
||||
elif venv_lib_match:
|
||||
# For other virtual environment patterns, find the package directory
|
||||
match_idx = venv_lib_match.start(1)
|
||||
if match_idx > 0:
|
||||
# Keep from the package name onward
|
||||
package_name = venv_lib_match.group(1)
|
||||
idx = full_path.rfind(
|
||||
package_name, 0, match_idx + len(package_name)
|
||||
)
|
||||
if idx >= 0:
|
||||
sanitized_path = full_path[idx:]
|
||||
|
||||
# If we couldn't identify a package pattern but path contains 'axolotl'
|
||||
elif "axolotl" in full_path:
|
||||
idx = full_path.rfind("axolotl")
|
||||
if idx >= 0:
|
||||
sanitized_path = full_path[idx:]
|
||||
|
||||
# Apply the sanitization to the line
|
||||
if sanitized_path:
|
||||
line = line.replace(full_path, sanitized_path)
|
||||
else:
|
||||
# If we couldn't identify a package pattern, just keep the filename
|
||||
filename = os.path.basename(full_path)
|
||||
if filename:
|
||||
line = line.replace(full_path, filename)
|
||||
else:
|
||||
line = line.replace(full_path, "")
|
||||
|
||||
sanitized_lines.append(line)
|
||||
|
||||
return "\n".join(sanitized_lines)
|
||||
|
||||
|
||||
def send_errors(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to send exception info in a function. If an exception is raised, we send
|
||||
telemetry containing the stack trace and error message.
|
||||
|
||||
If an error occurs in a decorated function that is called by another decorated
|
||||
function, we'll only send telemetry corresponding to the lower-level function.
|
||||
|
||||
Args:
|
||||
func: Function to decorate.
|
||||
|
||||
Returns:
|
||||
Decorated function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
|
||||
if not telemetry_manager.enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as exception:
|
||||
# Only track if we're not already handling an error. This prevents us from
|
||||
# capturing an error more than once in nested decorated function calls.
|
||||
global ERROR_HANDLED # pylint: disable=global-statement
|
||||
if not ERROR_HANDLED:
|
||||
ERROR_HANDLED = True
|
||||
|
||||
# Get function module path
|
||||
module = getmodule(func)
|
||||
module_path = (
|
||||
f"{module.__name__}.{func.__name__}" if module else func.__name__
|
||||
)
|
||||
|
||||
# Get stack trace
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(
|
||||
type(exception), exception, exception.__traceback__
|
||||
)
|
||||
)
|
||||
stack_trace = sanitize_stack_trace(stack_trace)
|
||||
|
||||
# Send error telemetry
|
||||
telemetry_manager.send_event(
|
||||
event_type=f"{module_path}-error",
|
||||
properties={
|
||||
"exception": str(exception),
|
||||
"stack_trace": stack_trace,
|
||||
},
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
417
src/axolotl/telemetry/manager.py
Normal file
417
src/axolotl/telemetry/manager.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""Telemetry manager and associated utilities."""
|
||||
|
||||
import atexit
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import posthog
|
||||
import psutil
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
POSTHOG_HOST = "https://app.posthog.com"
|
||||
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
|
||||
|
||||
OPT_IN_WARNING_SLEEP_SECONDS = 10
|
||||
OPT_IN_WARNING = (
|
||||
"\nTelemetry is currently disabled by default. If you'd like to help improve "
|
||||
"Axolotl, consider enabling it by setting AXOLOTL_DO_NOT_TRACK=0 in your environment.\n\n"
|
||||
"Telemetry data helps us understand:\n"
|
||||
"- Which features are most used\n"
|
||||
"- What hardware configurations to prioritize\n"
|
||||
"- Where users encounter errors\n\n"
|
||||
"Personally identifiable information (PII) is not collected.\n\n"
|
||||
"To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) "
|
||||
"or AXOLOTL_DO_NOT_TRACK=1 (explicitly disable telemetry).\n\n"
|
||||
"Note: Telemetry will move to an opt-out in a later release.\n\n"
|
||||
"For details, see: https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html\n\n"
|
||||
f"Sleeping for {OPT_IN_WARNING_SLEEP_SECONDS}s..."
|
||||
)
|
||||
|
||||
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
|
||||
|
||||
# NOTE: Need to keep these up to date with any config schema changes
|
||||
FIELDS_TO_REDACT = {
|
||||
"base_model",
|
||||
"tokenizer_config",
|
||||
"base_model_config",
|
||||
"pretraining_dataset", # NOTE: this field may be a string or a dictionary
|
||||
"resume_from_checkpoint",
|
||||
"hub_model_id",
|
||||
}
|
||||
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
|
||||
PATH_INDICATORS = {"path", "dir"}
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
RELEVANT_PACKAGES = {
|
||||
"torch",
|
||||
"transformers",
|
||||
"trl",
|
||||
"datasets",
|
||||
"peft",
|
||||
"bitsandbytes",
|
||||
"accelerate",
|
||||
"optimum",
|
||||
"deepspeed",
|
||||
"ray",
|
||||
"axolotl",
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"flash-attn",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"tokenizers",
|
||||
"sentencepiece",
|
||||
"torchao",
|
||||
"lm_eval",
|
||||
}
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
"""
|
||||
Check whether we're running in the main process.
|
||||
|
||||
Note:
|
||||
We're using this function instead of `torch.utils.distributed.is_main_process`
|
||||
causes issues with DeepSpeed world_size since. This function avoids that issue
|
||||
by checking env vars that are set by various launchers.
|
||||
|
||||
Returns:
|
||||
Whether we're running in the main process.
|
||||
"""
|
||||
# If PyTorch distributed is already initialized, use it
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_rank() == 0
|
||||
|
||||
# Otherwise check environment variables for global rank
|
||||
# NOTE: need to verify this in SLURM / OpenMPI environments
|
||||
global_rank = int(
|
||||
os.environ.get(
|
||||
"RANK",
|
||||
os.environ.get(
|
||||
"GLOBAL_RANK",
|
||||
os.environ.get(
|
||||
"SLURM_PROCID",
|
||||
os.environ.get(
|
||||
"OMPI_COMM_WORLD_RANK",
|
||||
"0",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return global_rank == 0
|
||||
|
||||
|
||||
class TelemetryManager:
|
||||
"""Manages telemetry collection and transmission"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
"""
|
||||
Telemetry manager constructor. Creates the singleton instance of this class if
|
||||
it doesn't already exist.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(TelemetryManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Telemetry manager initializer"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.enabled = self._check_telemetry_enabled()
|
||||
|
||||
if self.enabled:
|
||||
self.run_id = str(uuid.uuid4())
|
||||
self.whitelist = self._load_whitelist()
|
||||
|
||||
try:
|
||||
self.system_info = self._get_system_info()
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"Error during system info collection: {e}")
|
||||
self.system_info = None
|
||||
|
||||
self._init_posthog()
|
||||
|
||||
# Register shutdown method to flush posthog telemetry
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "TelemetryManager":
|
||||
if cls._instance is None:
|
||||
cls._instance = TelemetryManager()
|
||||
|
||||
return cls._instance
|
||||
|
||||
def _check_telemetry_enabled(self) -> bool:
|
||||
"""
|
||||
Check if telemetry is enabled based on environment variables. We also check
|
||||
whether this is the main process (for the distributed setting and to avoid
|
||||
sending duplicate PostHog events per GPU).
|
||||
|
||||
Note: This is disabled by default on an opt-in basis. Set
|
||||
`AXOLOTL_DO_NOT_TRACK=0` to enable telemetry. We plan to move to an opt-out
|
||||
model in a later release. For more details, see
|
||||
https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
|
||||
|
||||
Returns:
|
||||
Boolean denoting whether telemetry is enabled or not.
|
||||
"""
|
||||
# Parse relevant env vars
|
||||
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
||||
do_not_track = os.getenv("DO_NOT_TRACK")
|
||||
|
||||
# Default to disabled (opt-in model for initial release)
|
||||
if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in (
|
||||
"0",
|
||||
"1",
|
||||
"false",
|
||||
"true",
|
||||
):
|
||||
# Print opt-in info message for main process only
|
||||
if is_main_process():
|
||||
LOG.warning(OPT_IN_WARNING)
|
||||
time.sleep(OPT_IN_WARNING_SLEEP_SECONDS)
|
||||
|
||||
return False
|
||||
|
||||
# Only rank 0 will send telemetry
|
||||
if not is_main_process():
|
||||
return False
|
||||
|
||||
if do_not_track is None:
|
||||
do_not_track = "0"
|
||||
|
||||
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
|
||||
enabled = axolotl_do_not_track.lower() not in (
|
||||
"1",
|
||||
"true",
|
||||
) and do_not_track.lower() not in ("1", "true")
|
||||
|
||||
return enabled
|
||||
|
||||
def _load_whitelist(self) -> dict:
|
||||
"""Load HuggingFace Hub organization whitelist"""
|
||||
with open(WHITELIST_PATH, encoding="utf-8") as f:
|
||||
whitelist = yaml.safe_load(f)
|
||||
|
||||
# Send org strings to lowercase since model names are case insensitive
|
||||
whitelist["organizations"] = {
|
||||
org.lower() for org in whitelist["organizations"]
|
||||
}
|
||||
|
||||
return whitelist
|
||||
|
||||
def _is_whitelisted(self, value: str) -> bool:
|
||||
"""
|
||||
Check if model / dataset / etc. org is in whitelist.
|
||||
|
||||
Args:
|
||||
value: Value for one of `axolotl.telemetry.manager.FIELDS_WITH_ORGS`
|
||||
("base_model", etc.).
|
||||
|
||||
Returns:
|
||||
Boolean indicating whitelist membership.
|
||||
"""
|
||||
# NOTE: This membership-checking logic can be improved.
|
||||
# What happens when a local model path matches a whitelisted org?
|
||||
parts = value.split("/")
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
org = parts[0]
|
||||
whitelisted = org.lower() in self.whitelist["organizations"]
|
||||
|
||||
return whitelisted
|
||||
|
||||
def _init_posthog(self):
|
||||
"""Initialize PostHog client"""
|
||||
posthog.host = POSTHOG_HOST
|
||||
posthog.project_api_key = POSTHOG_WRITE_KEY
|
||||
|
||||
def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Redact properties to remove any paths, so as to avoid inadvertently collecting
|
||||
private or personally identifiable information (PII). We also remove
|
||||
information related to Wandb, MLflow, etc. configuration.
|
||||
|
||||
Args:
|
||||
properties: Dictionary of properties to redact.
|
||||
|
||||
Returns:
|
||||
Properties dictionary with redaction applied.
|
||||
"""
|
||||
if not properties:
|
||||
return {}
|
||||
|
||||
def redact_value(value: Any, key: str = "") -> Any:
|
||||
"""Recursively sanitize values, redacting those with path-like keys"""
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
# Other redaction special cases
|
||||
if (
|
||||
key in FIELDS_TO_REDACT
|
||||
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
|
||||
or any(indicator in key.lower() for indicator in PATH_INDICATORS)
|
||||
):
|
||||
# Fields with whitelisted orgs don't need to be redacted
|
||||
if not self._is_whitelisted(value):
|
||||
return "[REDACTED]"
|
||||
|
||||
# Handle nested values
|
||||
if isinstance(value, dict):
|
||||
return {k: redact_value(v, k) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [redact_value(item) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
# Create new dict with redacted values
|
||||
redacted = {k: redact_value(v, k) for k, v in properties.items()}
|
||||
|
||||
return redacted
|
||||
|
||||
def _get_system_info(self) -> dict[str, Any]:
|
||||
"""Collect system information for various hardware accelerators"""
|
||||
gpu_info = []
|
||||
accelerator_type = "none"
|
||||
|
||||
# NVIDIA GPUs
|
||||
if torch.cuda.is_available():
|
||||
accelerator_type = "cuda"
|
||||
for i in range(torch.cuda.device_count()):
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.cuda.get_device_name(i),
|
||||
"memory": torch.cuda.get_device_properties(i).total_memory,
|
||||
}
|
||||
)
|
||||
|
||||
# AMD GPUs
|
||||
elif hasattr(torch, "hip") and torch.hip.is_available():
|
||||
accelerator_type = "hip"
|
||||
for i in range(torch.hip.device_count()):
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.hip.get_device_name(i),
|
||||
"memory": (
|
||||
torch.hip.get_device_properties(i).total_memory
|
||||
if hasattr(torch.hip, "get_device_properties")
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Apple Silicon
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
accelerator_type = "mps"
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": "Apple Silicon",
|
||||
# NOTE: this is memory allocated to this process, not total memory
|
||||
"memory": torch.mps.driver_allocated_memory(),
|
||||
}
|
||||
)
|
||||
|
||||
# Intel GPUs
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
accelerator_type = "xpu"
|
||||
for i in range(torch.xpu.device_count()):
|
||||
memory = None
|
||||
if hasattr(torch.xpu, "get_device_properties"):
|
||||
memory = torch.xpu.get_device_properties(i).total_memory
|
||||
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.xpu.get_device_name(i),
|
||||
"memory": memory,
|
||||
}
|
||||
)
|
||||
|
||||
# NPUs
|
||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
||||
accelerator_type = "npu"
|
||||
for i in range(torch.npu.device_count()):
|
||||
memory = None
|
||||
if hasattr(torch.npu, "get_device_properties"):
|
||||
memory = torch.npu.get_device_properties(i).total_memory
|
||||
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.npu.get_device_name(i),
|
||||
"memory": memory,
|
||||
}
|
||||
)
|
||||
|
||||
# Get relevant package versions
|
||||
installed_packages = {}
|
||||
for package in RELEVANT_PACKAGES:
|
||||
try:
|
||||
version = importlib.metadata.version(package)
|
||||
installed_packages[f"{package}_version"] = version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
pass
|
||||
|
||||
return {
|
||||
"os": platform.system(),
|
||||
"python_version": platform.python_version(),
|
||||
"cpu_count": psutil.cpu_count(),
|
||||
"memory_total": psutil.virtual_memory().total,
|
||||
"accelerator_type": accelerator_type,
|
||||
"accelerator_count": len(gpu_info),
|
||||
"accelerator_info": gpu_info,
|
||||
**installed_packages,
|
||||
}
|
||||
|
||||
def send_event(self, event_type: str, properties: dict[str, Any] | None = None):
|
||||
"""Send a telemetry event"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if properties is None:
|
||||
properties = {}
|
||||
|
||||
# Sanitize properties to remove PII
|
||||
properties = self._redact_paths(properties)
|
||||
|
||||
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
|
||||
try:
|
||||
# Send event via PostHog
|
||||
posthog.capture(
|
||||
distinct_id=self.run_id,
|
||||
event=event_type,
|
||||
properties=properties,
|
||||
disable_geoip=True,
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"Failed to send telemetry event: {e}")
|
||||
|
||||
# Additionally, send system info telemetry when loading config.
|
||||
# NOTE: Is this the best place for this?
|
||||
if event_type == "config-loaded":
|
||||
self.send_system_info()
|
||||
|
||||
def send_system_info(self):
|
||||
"""Helper method for sending system info"""
|
||||
if self.system_info is not None:
|
||||
self.send_event(event_type="system-info", properties=self.system_info)
|
||||
|
||||
def shutdown(self):
|
||||
"""Ensure all queued events are processed before shutdown"""
|
||||
if self.enabled:
|
||||
posthog.flush()
|
||||
209
src/axolotl/telemetry/runtime_metrics.py
Normal file
209
src/axolotl/telemetry/runtime_metrics.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Telemetry utilities for runtime and memory metrics."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeMetrics:
|
||||
"""Container for runtime metrics to be tracked throughout training."""
|
||||
|
||||
# Timing metrics
|
||||
start_time: float
|
||||
epoch_start_times: dict[int, float] = field(init=False)
|
||||
epoch_end_times: dict[int, float] = field(init=False)
|
||||
|
||||
# Memory metrics
|
||||
peak_cpu_memory: int = 0
|
||||
peak_gpu_memory: dict[int, int] = field(init=False)
|
||||
|
||||
# Progress metrics
|
||||
total_steps: int = 0
|
||||
current_epoch: int = 0
|
||||
current_step: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize empty metric mappings."""
|
||||
self.epoch_start_times = {}
|
||||
self.epoch_end_times = {}
|
||||
self.peak_gpu_memory = {}
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float:
|
||||
"""Calculate total elapsed time in seconds."""
|
||||
return time.time() - self.start_time
|
||||
|
||||
def epoch_time(self, epoch: int) -> float | None:
|
||||
"""Calculate time taken for a specific epoch in seconds."""
|
||||
if epoch in self.epoch_start_times and epoch in self.epoch_end_times:
|
||||
return self.epoch_end_times[epoch] - self.epoch_start_times[epoch]
|
||||
|
||||
return None
|
||||
|
||||
def average_epoch_time(self) -> float | None:
|
||||
"""Calculate average time per epoch in seconds."""
|
||||
completed_epochs = [
|
||||
epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times
|
||||
]
|
||||
if not completed_epochs:
|
||||
return None
|
||||
|
||||
total_time = 0.0
|
||||
for epoch in completed_epochs:
|
||||
epoch_time = self.epoch_time(epoch)
|
||||
if epoch_time is not None: # Check to avoid mypy warning
|
||||
total_time += epoch_time
|
||||
|
||||
return total_time / len(completed_epochs)
|
||||
|
||||
def steps_per_second(self) -> float | None:
|
||||
"""Calculate average steps per second across all training."""
|
||||
if self.total_steps == 0 or self.elapsed_time == 0:
|
||||
return None
|
||||
|
||||
return self.total_steps / self.elapsed_time
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert metrics to a dictionary for telemetry reporting."""
|
||||
metrics = {
|
||||
"total_time_seconds": self.elapsed_time,
|
||||
"total_steps": self.total_steps,
|
||||
"steps_per_second": self.steps_per_second(),
|
||||
"epochs_completed": len(
|
||||
[
|
||||
epoch
|
||||
for epoch in self.epoch_start_times
|
||||
if epoch in self.epoch_end_times
|
||||
]
|
||||
),
|
||||
"peak_cpu_memory_bytes": self.peak_cpu_memory,
|
||||
}
|
||||
|
||||
# Add per-epoch timing if available
|
||||
epoch_times: dict[str, float] = {}
|
||||
for epoch in sorted(self.epoch_end_times.keys()):
|
||||
time_taken = self.epoch_time(epoch)
|
||||
if time_taken is not None:
|
||||
epoch_times[f"epoch_{epoch}_seconds"] = time_taken
|
||||
|
||||
if epoch_times:
|
||||
metrics["epoch_times"] = epoch_times # type: ignore
|
||||
metrics["average_epoch_time_seconds"] = self.average_epoch_time()
|
||||
|
||||
# Add GPU memory metrics if available
|
||||
if self.peak_gpu_memory:
|
||||
gpu_metrics: dict[str, int] = {}
|
||||
for gpu_id, memory in self.peak_gpu_memory.items():
|
||||
gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory
|
||||
metrics["gpu_memory"] = gpu_metrics # type: ignore
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
class RuntimeMetricsTracker:
|
||||
"""Tracker for runtime metrics during training."""
|
||||
|
||||
update_interval = 100
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the runtime metrics tracker."""
|
||||
self.metrics = RuntimeMetrics(start_time=time.time())
|
||||
self.telemetry_manager = TelemetryManager.get_instance()
|
||||
|
||||
def start_epoch(self, epoch: int):
|
||||
"""Record the start of a new epoch."""
|
||||
self.metrics.current_epoch = epoch
|
||||
self.metrics.epoch_start_times[epoch] = time.time()
|
||||
self.update_memory_metrics()
|
||||
|
||||
def end_epoch(self, epoch: int):
|
||||
"""Record the end of an epoch."""
|
||||
self.metrics.epoch_end_times[epoch] = time.time()
|
||||
|
||||
def update_step(self, step: int):
|
||||
"""Update the current step count."""
|
||||
self.metrics.current_step = step
|
||||
self.metrics.total_steps += 1
|
||||
|
||||
# Periodically update memory metrics
|
||||
if step % self.update_interval == 0:
|
||||
self.update_memory_metrics()
|
||||
|
||||
def _get_allocated_memory(self) -> dict[int, int]:
|
||||
"""
|
||||
Helper function for getting accelerator-agnostic allocated memory.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping device IDs to allocated memory in bytes
|
||||
"""
|
||||
memory_used: dict[int, int] = {}
|
||||
|
||||
# NVIDIA GPUs
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
memory_used[i] = torch.cuda.memory_allocated(i)
|
||||
|
||||
# AMD GPUs
|
||||
elif hasattr(torch, "hip") and torch.hip.is_available():
|
||||
for i in range(torch.hip.device_count()):
|
||||
if hasattr(torch.hip, "memory_allocated"):
|
||||
memory_used[i] = torch.hip.memory_allocated(i)
|
||||
|
||||
# Apple Silicon
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
# MPS doesn't have per-device memory stats since there's only one device
|
||||
if hasattr(torch.mps, "current_allocated_memory"):
|
||||
memory_used[0] = torch.mps.current_allocated_memory()
|
||||
|
||||
# Intel GPUs
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
if hasattr(torch.xpu, "memory_allocated"):
|
||||
memory_used[i] = torch.xpu.memory_allocated(i)
|
||||
|
||||
# NPUs
|
||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
||||
for i in range(torch.npu.device_count()):
|
||||
if hasattr(torch.npu, "memory_allocated"):
|
||||
memory_used[i] = torch.npu.memory_allocated(i)
|
||||
|
||||
return memory_used
|
||||
|
||||
def update_memory_metrics(self):
|
||||
"""Update peak memory usage metrics."""
|
||||
# CPU memory
|
||||
cpu_memory = psutil.Process().memory_info().rss
|
||||
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
|
||||
|
||||
# GPU memory (if available)
|
||||
memory_used = self._get_allocated_memory()
|
||||
for i, memory in memory_used.items():
|
||||
self.metrics.peak_gpu_memory[i] = max(
|
||||
self.metrics.peak_gpu_memory.get(i, 0), memory
|
||||
)
|
||||
|
||||
def get_memory_metrics(self) -> dict[str, Any]:
|
||||
"""Get the current memory metrics as a dictionary."""
|
||||
memory_metrics = {
|
||||
"cpu_memory_bytes": psutil.Process().memory_info().rss,
|
||||
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
|
||||
}
|
||||
|
||||
# GPU memory (if available)
|
||||
memory_used = self._get_allocated_memory()
|
||||
for i, memory in memory_used.items():
|
||||
memory_metrics[f"gpu_{i}_memory_bytes"] = memory
|
||||
memory_metrics[f"gpu_{i}_peak_memory_bytes"] = (
|
||||
self.metrics.peak_gpu_memory.get(i, 0)
|
||||
)
|
||||
|
||||
return memory_metrics
|
||||
17
src/axolotl/telemetry/whitelist.yaml
Normal file
17
src/axolotl/telemetry/whitelist.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
organizations:
|
||||
- "axolotl-ai-co"
|
||||
- "meta-llama"
|
||||
- "huggingface"
|
||||
- "nvidia"
|
||||
- "facebook"
|
||||
- "google"
|
||||
- "microsoft"
|
||||
- "deepseek-ai"
|
||||
- "HuggingFaceTB"
|
||||
- "mistralai"
|
||||
- "Qwen"
|
||||
- "unsloth"
|
||||
- "NousResearch"
|
||||
- "allenai"
|
||||
- "amd"
|
||||
- "tiiuae"
|
||||
@@ -32,6 +32,8 @@ from axolotl.loaders import (
|
||||
load_processor,
|
||||
load_tokenizer,
|
||||
)
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
@@ -47,6 +49,9 @@ except ImportError:
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
|
||||
def setup_model_and_tokenizer(
|
||||
cfg: DictDefault,
|
||||
@@ -64,7 +69,10 @@ def setup_model_and_tokenizer(
|
||||
`None`), and processor (if multimodal, else `None`).
|
||||
"""
|
||||
# Load tokenizer
|
||||
LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
# Load processor for multimodal models if needed
|
||||
@@ -83,6 +91,14 @@ def setup_model_and_tokenizer(
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
TELEMETRY_MANAGER.send_event(
|
||||
event_type="model-load", properties=model.config.to_dict()
|
||||
)
|
||||
if peft_config:
|
||||
TELEMETRY_MANAGER.send_event(
|
||||
event_type="peft-config-load", properties=peft_config.to_dict()
|
||||
)
|
||||
|
||||
# Apply freezing if specified
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
@@ -517,6 +533,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
||||
model_ref=model_ref,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
PLUGIN_MANAGER.post_trainer_create(cfg, trainer)
|
||||
|
||||
return (
|
||||
trainer,
|
||||
@@ -527,6 +544,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
||||
)
|
||||
|
||||
|
||||
@send_errors
|
||||
def train(
|
||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
|
||||
@@ -551,8 +569,11 @@ def train(
|
||||
processor,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_trainer_create(cfg, trainer)
|
||||
# Determine if we need to resume from a checkpoint
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
|
||||
# Configuration for saving
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
@@ -567,7 +588,6 @@ def train(
|
||||
setup_model_card(cfg)
|
||||
|
||||
# Execute the training
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||
|
||||
# Save the trained model and cleanup
|
||||
@@ -575,7 +595,6 @@ def train(
|
||||
create_model_card(cfg, trainer)
|
||||
if not cfg.use_ray:
|
||||
cleanup_distributed()
|
||||
|
||||
plugin_manager.post_train(cfg, model)
|
||||
PLUGIN_MANAGER.post_train(cfg, model)
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
@@ -486,10 +486,6 @@ def get_dataset_wrapper(
|
||||
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
||||
)
|
||||
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
if (
|
||||
isinstance(dataset, Dataset)
|
||||
and "input_ids" in dataset.features
|
||||
|
||||
@@ -1259,7 +1259,7 @@ class AxolotlInputConfig(
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
"""Wrapper to validate GPU capabilities with the config options"""
|
||||
|
||||
capabilities: GPUCapabilities
|
||||
env_capabilities: EnvCapabilities
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
shared pytest fixtures
|
||||
"""
|
||||
"""Shared pytest fixtures"""
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
@@ -559,3 +557,9 @@ def test_load_fixtures(
|
||||
download_llama2_model_fixture,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_telemetry(monkeypatch):
|
||||
monkeypatch.setenv("AXOLOTL_DO_NOT_TRACK", "1")
|
||||
yield
|
||||
|
||||
0
tests/telemetry/__init__.py
Normal file
0
tests/telemetry/__init__.py
Normal file
9
tests/telemetry/conftest.py
Normal file
9
tests/telemetry/conftest.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Shared pytest fixtures for telemetry tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_telemetry(monkeypatch):
|
||||
monkeypatch.delenv("AXOLOTL_DO_NOT_TRACK", raising=False)
|
||||
yield
|
||||
373
tests/telemetry/test_callbacks.py
Normal file
373
tests/telemetry/test_callbacks.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""Tests for telemetry callback module."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
|
||||
from axolotl.telemetry.callbacks import TIME_SINCE_LAST, TelemetryCallback
|
||||
|
||||
|
||||
def calc_expected_metrics(step, last_step, current_time, last_time, start_time=900.0):
|
||||
"""Calculate expected metrics values for tests"""
|
||||
time_diff = current_time - last_time
|
||||
step_diff = step - last_step
|
||||
return {
|
||||
"steps_per_second": (
|
||||
step_diff / time_diff if time_diff > 0 and step_diff > 0 else 0
|
||||
),
|
||||
"time_since_last_report": time_diff,
|
||||
"elapsed_time": current_time - start_time,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_time():
|
||||
"""Mock time.time() to have predictable values in tests"""
|
||||
with patch("axolotl.telemetry.callbacks.time") as mock_time:
|
||||
mock_time.time.return_value = 1000.0
|
||||
yield mock_time
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telemetry_manager():
|
||||
"""Create a mock TelemetryManager"""
|
||||
with patch("axolotl.telemetry.callbacks.TelemetryManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime_metrics_tracker():
|
||||
"""Create a mock RuntimeMetricsTracker"""
|
||||
with patch(
|
||||
"axolotl.telemetry.callbacks.RuntimeMetricsTracker"
|
||||
) as mock_tracker_class:
|
||||
mock_tracker = MagicMock()
|
||||
# Set up metrics property on the tracker
|
||||
mock_metrics = MagicMock()
|
||||
mock_metrics.to_dict.return_value = {
|
||||
"total_steps": 100,
|
||||
"peak_cpu_memory_bytes": 1024,
|
||||
}
|
||||
mock_tracker.metrics = mock_metrics
|
||||
|
||||
# Make the constructor return our mock
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
yield mock_tracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def training_args():
|
||||
"""Create a minimal TrainingArguments instance"""
|
||||
return TrainingArguments(output_dir="./output")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trainer_state():
|
||||
"""Create a mock TrainerState"""
|
||||
state = MagicMock(spec=TrainerState)
|
||||
state.global_step = 10
|
||||
state.epoch = 0.5 # halfway through first epoch
|
||||
state.log_history = [{"loss": 2.5, "learning_rate": 5e-5}]
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trainer_control():
|
||||
"""Create a mock TrainerControl"""
|
||||
return MagicMock(spec=TrainerControl)
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@pytest.fixture
|
||||
def callback(mock_telemetry_manager, mock_runtime_metrics_tracker):
|
||||
"""Create a TelemetryCallback instance with mocked dependencies"""
|
||||
return TelemetryCallback()
|
||||
|
||||
|
||||
class TestTelemetryCallback:
|
||||
"""Tests for the TelemetryCallback class."""
|
||||
|
||||
def test_initialization(self, callback, mock_runtime_metrics_tracker):
|
||||
"""Test callback initialization."""
|
||||
assert callback.current_epoch == -1
|
||||
assert callback.tracker == mock_runtime_metrics_tracker
|
||||
assert callback.last_report_step == 0
|
||||
assert hasattr(callback, "start_time")
|
||||
assert hasattr(callback, "last_report_time")
|
||||
assert callback.report_interval_steps == 100
|
||||
|
||||
def test_on_train_begin(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_train_begin sends expected event."""
|
||||
callback.on_train_begin(training_args, trainer_state, trainer_control)
|
||||
|
||||
mock_telemetry_manager.send_event.assert_called_once_with(
|
||||
event_type="train-start"
|
||||
)
|
||||
|
||||
def test_on_train_end(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_train_end sends expected event with metrics."""
|
||||
callback.on_train_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
|
||||
assert call_args["event_type"] == "train-end"
|
||||
assert "loss" in call_args["properties"]
|
||||
assert call_args["properties"]["loss"] == 2.5
|
||||
assert "learning_rate" in call_args["properties"]
|
||||
assert call_args["properties"]["learning_rate"] == 5e-5
|
||||
|
||||
# Check that metrics from RuntimeMetricsTracker are included
|
||||
assert "total_steps" in call_args["properties"]
|
||||
assert call_args["properties"]["total_steps"] == 100
|
||||
assert "peak_cpu_memory_bytes" in call_args["properties"]
|
||||
assert call_args["properties"]["peak_cpu_memory_bytes"] == 1024
|
||||
|
||||
def test_on_epoch_begin(
|
||||
self,
|
||||
callback,
|
||||
mock_runtime_metrics_tracker,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_epoch_begin updates epoch counter and calls tracker."""
|
||||
initial_epoch = callback.current_epoch
|
||||
|
||||
callback.on_epoch_begin(training_args, trainer_state, trainer_control)
|
||||
|
||||
assert callback.current_epoch == initial_epoch + 1
|
||||
mock_runtime_metrics_tracker.start_epoch.assert_called_once_with(
|
||||
initial_epoch + 1
|
||||
)
|
||||
|
||||
def test_on_epoch_end(
|
||||
self,
|
||||
callback,
|
||||
mock_runtime_metrics_tracker,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_epoch_end calls tracker."""
|
||||
# Set current epoch
|
||||
callback.current_epoch = 2
|
||||
|
||||
callback.on_epoch_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
mock_runtime_metrics_tracker.end_epoch.assert_called_once_with(2)
|
||||
|
||||
def test_on_step_end_no_report(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end updates tracker but doesn't report if criteria not met."""
|
||||
# Set up state to avoid reporting
|
||||
trainer_state.global_step = 42 # Not divisible by report_interval_steps
|
||||
callback.last_report_step = 41 # Just 1 step since last report
|
||||
callback.last_report_time = time.time() # Just now
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should update tracker
|
||||
mock_runtime_metrics_tracker.update_step.assert_called_once_with(42)
|
||||
|
||||
# Should not send telemetry
|
||||
mock_telemetry_manager.send_event.assert_not_called()
|
||||
|
||||
# Should not update last report time/step
|
||||
assert callback.last_report_step == 41
|
||||
|
||||
def test_on_step_end_report_interval_steps(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker,
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end reports when step interval is reached."""
|
||||
# Set up state with clear values
|
||||
current_step = 100 # Exactly matches report_interval_steps
|
||||
last_step = 0
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
time_diff = current_time - start_time # 100 seconds
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
callback.report_interval_steps = 100
|
||||
callback.last_report_step = last_step
|
||||
callback.start_time = start_time
|
||||
callback.last_report_time = start_time
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should update tracker
|
||||
mock_runtime_metrics_tracker.update_step.assert_called_once_with(current_step)
|
||||
mock_runtime_metrics_tracker.update_memory_metrics.assert_called_once()
|
||||
|
||||
# Should send telemetry
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
assert call_args["event_type"] == "train-progress"
|
||||
|
||||
# Properties should include expected values
|
||||
props = call_args["properties"]
|
||||
assert props["step"] == current_step
|
||||
assert props["elapsed_time"] == time_diff # 1000 - 900 = 100
|
||||
assert props["time_since_last_report"] == time_diff # 1000 - 900 = 100
|
||||
assert props["steps_per_second"] == 1.0 # 100 steps / 100 seconds
|
||||
|
||||
# Should update last report time/step
|
||||
assert callback.last_report_step == current_step
|
||||
assert callback.last_report_time == current_time
|
||||
|
||||
def test_on_step_end_report_time_elapsed(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end reports when enough time has elapsed."""
|
||||
# Set up state with clear values
|
||||
current_step = 120
|
||||
last_step = 10
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
time_diff = TIME_SINCE_LAST + 1 # Just over the threshold
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
callback.report_interval_steps = 100
|
||||
callback.last_report_step = last_step
|
||||
callback.start_time = start_time
|
||||
callback.last_report_time = current_time - time_diff
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should send telemetry
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Properties should include expected values
|
||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
||||
expected_metrics = calc_expected_metrics(
|
||||
current_step, last_step, current_time, current_time - time_diff, start_time
|
||||
)
|
||||
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
|
||||
assert (
|
||||
props["time_since_last_report"]
|
||||
== expected_metrics["time_since_last_report"]
|
||||
)
|
||||
|
||||
def test_on_step_end_first_step(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end always reports on first step."""
|
||||
# Set up state with clear values
|
||||
current_step = 1 # First step
|
||||
last_step = 0
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
last_report_time = 999.0 # Just 1 second ago
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
callback.report_interval_steps = 100
|
||||
callback.last_report_step = last_step
|
||||
callback.start_time = start_time
|
||||
callback.last_report_time = last_report_time
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should send telemetry even though not much time has passed
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Properties should include expected values for first step
|
||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
||||
assert props["step"] == current_step
|
||||
expected_metrics = calc_expected_metrics(
|
||||
current_step, last_step, current_time, last_report_time, start_time
|
||||
)
|
||||
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
|
||||
|
||||
def test_log_history_empty(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test handling of empty log history."""
|
||||
# Set up state with clear values
|
||||
current_step = 1
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
trainer_state.log_history = []
|
||||
callback.start_time = start_time
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should still send telemetry
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Properties should have default values for missing log data
|
||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
||||
assert props["loss"] == 0
|
||||
assert props["learning_rate"] == 0
|
||||
341
tests/telemetry/test_errors.py
Normal file
341
tests/telemetry/test_errors.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""Tests for telemetry error utilities"""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.telemetry.errors import sanitize_stack_trace, send_errors
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_error_flag(monkeypatch):
|
||||
"""Reset ERROR_HANDLED flag using monkeypatch"""
|
||||
import axolotl.telemetry.errors
|
||||
|
||||
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
|
||||
yield
|
||||
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_stack_trace():
|
||||
"""Provide a sample stack trace with mixed paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
|
||||
trainer = get_trainer(cfg)
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer
|
||||
model = get_model(cfg, tokenizer)
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model
|
||||
raise ValueError("Model path not found")
|
||||
ValueError: Model path not found
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def windows_stack_trace():
|
||||
"""Provide a sample stack trace with Windows paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main
|
||||
trainer = get_trainer(cfg)
|
||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer
|
||||
model = get_model(cfg, tokenizer)
|
||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained
|
||||
raise ValueError(f"Unrecognized configuration class {config.__class__}")
|
||||
ValueError: Unrecognized configuration class <class 'transformers.models.llama.configuration_llama.LlamaConfig'>
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixed_stack_trace():
|
||||
"""Provide a sample stack trace with both axolotl and non-axolotl paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
|
||||
trainer = get_trainer(cfg)
|
||||
File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train
|
||||
self._inner_training_loop()
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop
|
||||
super()._inner_training_loop()
|
||||
File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
|
||||
data = self._next_data()
|
||||
RuntimeError: CUDA out of memory
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def venv_stack_trace():
|
||||
"""Provide a sample stack trace with virtual environment paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train
|
||||
self._inner_training_loop()
|
||||
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop
|
||||
self.accelerator.backward(loss)
|
||||
File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward
|
||||
self.scaler.scale(loss).backward(**kwargs)
|
||||
File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
|
||||
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
|
||||
RuntimeError: CUDA out of memory
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dist_packages_stack_trace():
|
||||
"""Provide a sample stack trace with dist-packages paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
|
||||
data = self._next_data()
|
||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data
|
||||
data = self._dataset_fetcher.fetch(index)
|
||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
|
||||
data = [self.dataset[idx] for idx in possibly_batched_index]
|
||||
File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__
|
||||
raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.")
|
||||
IndexError: Index 10000 out of range for dataset of length 9832.
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_stack_trace():
|
||||
"""Provide a sample stack trace from a project directory (not a virtual env)"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/projects/myproject/run.py", line 25, in <module>
|
||||
main()
|
||||
File "/home/user/projects/myproject/src/cli.py", line 45, in main
|
||||
app.run()
|
||||
File "/home/user/projects/myproject/src/app.py", line 102, in run
|
||||
raise ValueError("Configuration missing")
|
||||
ValueError: Configuration missing
|
||||
"""
|
||||
|
||||
|
||||
def test_sanitize_stack_trace(example_stack_trace):
|
||||
"""Test that sanitize_stack_trace properly preserves axolotl paths"""
|
||||
sanitized = sanitize_stack_trace(example_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "/home/user" not in sanitized
|
||||
assert ".local/lib/python3.9" not in sanitized
|
||||
|
||||
# Check that site-packages is preserved
|
||||
assert "site-packages/axolotl/cli/train.py" in sanitized
|
||||
assert "site-packages/axolotl/train.py" in sanitized
|
||||
assert "site-packages/axolotl/utils/models.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "ValueError: Model path not found" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_windows_paths(windows_stack_trace):
|
||||
"""Test that sanitize_stack_trace handles Windows paths"""
|
||||
sanitized = sanitize_stack_trace(windows_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "C:\\Users\\name" not in sanitized
|
||||
assert "AppData\\Local\\Programs\\Python" not in sanitized
|
||||
|
||||
# Check that both axolotl and transformers packages are preserved
|
||||
assert (
|
||||
"site-packages\\axolotl\\cli\\train.py" in sanitized
|
||||
or "site-packages/axolotl/cli/train.py" in sanitized
|
||||
)
|
||||
assert (
|
||||
"site-packages\\axolotl\\train.py" in sanitized
|
||||
or "site-packages/axolotl/train.py" in sanitized
|
||||
)
|
||||
assert (
|
||||
"site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized
|
||||
or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized
|
||||
)
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "ValueError: Unrecognized configuration class" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_mixed_paths(mixed_stack_trace):
|
||||
"""Test that sanitize_stack_trace preserves all package paths"""
|
||||
sanitized = sanitize_stack_trace(mixed_stack_trace)
|
||||
|
||||
# Check that all package paths are preserved
|
||||
assert "site-packages/axolotl/cli/train.py" in sanitized
|
||||
assert "site-packages/transformers/trainer.py" in sanitized
|
||||
assert "site-packages/axolotl/utils/trainer.py" in sanitized
|
||||
assert "site-packages/torch/utils/data/dataloader.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "RuntimeError: CUDA out of memory" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_venv_paths(venv_stack_trace):
|
||||
"""Test that sanitize_stack_trace preserves virtual environment package paths"""
|
||||
sanitized = sanitize_stack_trace(venv_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "/home/user/venv" not in sanitized
|
||||
|
||||
# Check that all package paths are preserved
|
||||
assert "site-packages/transformers/trainer.py" in sanitized
|
||||
assert "site-packages/accelerate/accelerator.py" in sanitized
|
||||
assert "site-packages/torch/_tensor.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "RuntimeError: CUDA out of memory" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_dist_packages(dist_packages_stack_trace):
|
||||
"""Test that sanitize_stack_trace preserves dist-packages paths"""
|
||||
sanitized = sanitize_stack_trace(dist_packages_stack_trace)
|
||||
|
||||
# Check that system paths are removed
|
||||
assert "/usr/local/lib/python3.8" not in sanitized
|
||||
|
||||
# Check that all package paths are preserved
|
||||
assert "dist-packages/torch/utils/data/dataloader.py" in sanitized
|
||||
assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized
|
||||
assert "dist-packages/datasets/arrow_dataset.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert (
|
||||
"IndexError: Index 10000 out of range for dataset of length 9832." in sanitized
|
||||
)
|
||||
|
||||
|
||||
def test_sanitize_project_paths(project_stack_trace):
|
||||
"""Test handling of project paths (non-virtual env)"""
|
||||
sanitized = sanitize_stack_trace(project_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "/home/user/projects" not in sanitized
|
||||
|
||||
# For non-package paths, we should at least preserve the filename
|
||||
assert "run.py" in sanitized
|
||||
assert "cli.py" in sanitized
|
||||
assert "app.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "ValueError: Configuration missing" in sanitized
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telemetry_manager():
|
||||
"""Create a mock TelemetryManager"""
|
||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = True
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
|
||||
def test_send_errors_successful_execution(mock_telemetry_manager):
|
||||
"""Test that send_errors doesn't send telemetry for successful function execution"""
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
result = test_func()
|
||||
assert result == "success"
|
||||
mock_telemetry_manager.send_event.assert_not_called()
|
||||
|
||||
|
||||
def test_send_errors_with_exception(mock_telemetry_manager):
|
||||
"""Test that send_errors sends telemetry when an exception occurs"""
|
||||
test_error = ValueError("Test error")
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise test_error
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
test_func()
|
||||
|
||||
assert excinfo.value == test_error
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Check that the error info was passed correctly
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
assert "test_func-error" in call_args["event_type"]
|
||||
assert "Test error" in call_args["properties"]["exception"]
|
||||
assert "stack_trace" in call_args["properties"]
|
||||
|
||||
|
||||
def test_send_errors_nested_calls(mock_telemetry_manager):
|
||||
"""Test that send_errors only sends telemetry once for nested decorated functions"""
|
||||
|
||||
@send_errors
|
||||
def inner_func():
|
||||
raise ValueError("Inner error")
|
||||
|
||||
@send_errors
|
||||
def outer_func():
|
||||
return inner_func()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
outer_func()
|
||||
|
||||
# Telemetry should be sent only once for the inner function
|
||||
assert mock_telemetry_manager.send_event.call_count == 1
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
assert "inner_func-error" in call_args["event_type"]
|
||||
|
||||
|
||||
def test_send_errors_telemetry_disable():
|
||||
"""Test that send_errors doesn't attempt to send telemetry when disabled"""
|
||||
|
||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = False
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise ValueError("Test error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
mock_manager.send_event.assert_not_called()
|
||||
|
||||
|
||||
def test_error_handled_reset():
|
||||
"""Test that ERROR_HANDLED flag is properly reset"""
|
||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
||||
# Create and configure the mock manager
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = True
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
|
||||
from axolotl.telemetry.errors import ERROR_HANDLED
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise ValueError("Test error")
|
||||
|
||||
assert not ERROR_HANDLED
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
from axolotl.telemetry.errors import ERROR_HANDLED
|
||||
|
||||
assert ERROR_HANDLED
|
||||
|
||||
|
||||
def test_module_path_resolution(mock_telemetry_manager):
|
||||
"""Test that the module path is correctly resolved for the event type"""
|
||||
import inspect
|
||||
|
||||
current_module = inspect.getmodule(test_module_path_resolution).__name__
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise ValueError("Test error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
assert mock_telemetry_manager.send_event.called
|
||||
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
|
||||
|
||||
expected_event_type = f"{current_module}.test_func-error"
|
||||
assert expected_event_type == event_type
|
||||
278
tests/telemetry/test_manager.py
Normal file
278
tests/telemetry/test_manager.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Tests for TelemetryManager class and utilities"""
|
||||
|
||||
# pylint: disable=redefined-outer-name,protected-access
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_whitelist(tmp_path):
|
||||
"""Create a temporary whitelist file for testing"""
|
||||
whitelist_content = {
|
||||
"organizations": ["meta-llama", "mistralai"],
|
||||
}
|
||||
whitelist_file = tmp_path / "whitelist.yaml"
|
||||
with open(whitelist_file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(whitelist_content, f)
|
||||
|
||||
return str(whitelist_file)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def telemetry_manager_class():
|
||||
"""Reset the TelemetryManager singleton between tests"""
|
||||
original_instance = TelemetryManager._instance
|
||||
original_initialized = TelemetryManager._initialized
|
||||
TelemetryManager._instance = None
|
||||
TelemetryManager._initialized = False
|
||||
yield TelemetryManager
|
||||
TelemetryManager._instance = original_instance
|
||||
TelemetryManager._initialized = original_initialized
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(telemetry_manager_class, mock_whitelist):
|
||||
"""Create a TelemetryManager instance with mocked dependencies"""
|
||||
with (
|
||||
patch("posthog.capture"),
|
||||
patch("posthog.flush"),
|
||||
patch("time.sleep"),
|
||||
patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist),
|
||||
patch.dict(os.environ, {"RANK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
# Manually enable for most tests
|
||||
manager.enabled = True
|
||||
return manager
|
||||
|
||||
|
||||
def test_singleton_instance(telemetry_manager_class):
|
||||
"""Test that TelemetryManager is a singleton"""
|
||||
with (
|
||||
patch("posthog.capture"),
|
||||
patch("time.sleep"),
|
||||
patch.dict(os.environ, {"RANK": "0"}),
|
||||
):
|
||||
first = telemetry_manager_class()
|
||||
second = telemetry_manager_class()
|
||||
assert first is second
|
||||
assert telemetry_manager_class.get_instance() is first
|
||||
|
||||
|
||||
def test_telemetry_disabled_by_default(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled by default (opt-in)"""
|
||||
with (
|
||||
patch.dict(os.environ, {"RANK": "0"}, clear=True),
|
||||
patch("time.sleep"),
|
||||
patch("logging.Logger.info"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class):
|
||||
"""Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0"""
|
||||
with (
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
|
||||
with (
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"}
|
||||
),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled for non-main processes"""
|
||||
with (
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_opt_in_info_displayed(telemetry_manager_class):
|
||||
"""Test that opt-in info is displayed when telemetry is not configured"""
|
||||
with (
|
||||
patch.dict(os.environ, {"RANK": "0"}, clear=True),
|
||||
patch("logging.Logger.warning") as mock_warning,
|
||||
patch("time.sleep"),
|
||||
):
|
||||
telemetry_manager_class()
|
||||
info_displayed = False
|
||||
for call in mock_warning.call_args_list:
|
||||
print(f"call: {call}")
|
||||
if "Telemetry is currently disabled by default" in str(call):
|
||||
info_displayed = True
|
||||
break
|
||||
assert info_displayed
|
||||
|
||||
|
||||
def test_is_whitelisted(telemetry_manager_class, mock_whitelist):
|
||||
"""Test org whitelist functionality"""
|
||||
with (
|
||||
patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist),
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
|
||||
# Should match organizations from the mock whitelist
|
||||
assert manager._is_whitelisted("meta-llama/llama-7b")
|
||||
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
|
||||
# Should not match
|
||||
assert not manager._is_whitelisted("unknown/model")
|
||||
# Should handle case insensitively
|
||||
assert manager._is_whitelisted("META-LLAMA/Llama-7B")
|
||||
# Should handle empty input
|
||||
assert not manager._is_whitelisted("")
|
||||
|
||||
|
||||
def test_system_info_collection(manager):
|
||||
"""Test system information collection"""
|
||||
system_info = manager._get_system_info()
|
||||
|
||||
# Check essential keys
|
||||
assert "os" in system_info
|
||||
assert "python_version" in system_info
|
||||
assert "cpu_count" in system_info
|
||||
assert "memory_total" in system_info
|
||||
assert "accelerator_count" in system_info
|
||||
|
||||
|
||||
def test_send_event(telemetry_manager_class):
|
||||
"""Test basic event sending"""
|
||||
with (
|
||||
patch("posthog.capture") as mock_capture,
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
|
||||
# Test with clean properties (no PII)
|
||||
manager.send_event("test_event", {"key": "value"})
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["event"] == "test_event"
|
||||
assert mock_capture.call_args[1]["properties"] == {"key": "value"}
|
||||
assert mock_capture.call_args[1]["distinct_id"] == manager.run_id
|
||||
|
||||
# Test with default properties (None)
|
||||
mock_capture.reset_mock()
|
||||
manager.send_event("simple_event")
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["properties"] == {}
|
||||
|
||||
|
||||
def test_send_system_info(telemetry_manager_class):
|
||||
"""Test sending system info"""
|
||||
with (
|
||||
patch("posthog.capture") as mock_capture,
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
manager.send_system_info()
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["event"] == "system-info"
|
||||
assert mock_capture.call_args[1]["properties"] == manager.system_info
|
||||
|
||||
|
||||
def test_redacted_properties(telemetry_manager_class):
|
||||
"""Test path redaction in send_event method"""
|
||||
with (
|
||||
patch("posthog.capture") as mock_capture,
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
# Test with properties containing various paths and non-paths
|
||||
test_properties = {
|
||||
"filepath": "/home/user/sensitive/data.txt",
|
||||
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
|
||||
"output_dir": "/var/lib/data",
|
||||
"path_to_model": "models/llama/7b",
|
||||
"message": "Training started", # Should not be redacted
|
||||
"metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted
|
||||
"base_model": "models/local_model",
|
||||
"nested": {
|
||||
"model_path": "/models/my_model",
|
||||
"root_dir": "/home/user/projects",
|
||||
"stats": {"steps": 1000, "epochs": 3}, # Should not be redacted
|
||||
},
|
||||
}
|
||||
|
||||
manager.send_event("test_event", test_properties)
|
||||
|
||||
# Verify the call was made
|
||||
assert mock_capture.called
|
||||
|
||||
# Get the sanitized properties that were sent
|
||||
sanitized = mock_capture.call_args[1]["properties"]
|
||||
|
||||
# Check that path-like and base_model keys were redacted
|
||||
assert sanitized["filepath"] == "[REDACTED]"
|
||||
assert sanitized["windows_path"] == "[REDACTED]"
|
||||
assert sanitized["path_to_model"] == "[REDACTED]"
|
||||
assert sanitized["base_model"] == "[REDACTED]"
|
||||
|
||||
# Check that non-path values were preserved
|
||||
assert sanitized["message"] == "Training started"
|
||||
assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95}
|
||||
|
||||
# Check nested structure handling
|
||||
assert sanitized["nested"]["model_path"] == "[REDACTED]"
|
||||
assert sanitized["nested"]["root_dir"] == "[REDACTED]"
|
||||
assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3}
|
||||
|
||||
|
||||
def test_disable_telemetry(manager):
|
||||
"""Test that disabled telemetry doesn't send events"""
|
||||
with patch("posthog.capture") as mock_capture:
|
||||
manager.enabled = False
|
||||
manager.send_event("test_event")
|
||||
assert not mock_capture.called
|
||||
|
||||
|
||||
def test_exception_handling_during_send(manager):
|
||||
"""Test that exceptions in PostHog are handled gracefully"""
|
||||
with (
|
||||
patch("posthog.capture", side_effect=Exception("Test error")),
|
||||
patch("logging.Logger.warning") as mock_warning,
|
||||
):
|
||||
manager.send_event("test_event")
|
||||
warning_logged = False
|
||||
for call in mock_warning.call_args_list:
|
||||
if "Failed to send telemetry event" in str(call):
|
||||
warning_logged = True
|
||||
break
|
||||
assert warning_logged
|
||||
|
||||
|
||||
def test_shutdown(manager):
|
||||
"""Test shutdown behavior"""
|
||||
with patch("posthog.flush") as mock_flush:
|
||||
manager.shutdown()
|
||||
assert mock_flush.called
|
||||
357
tests/telemetry/test_runtime_metrics.py
Normal file
357
tests/telemetry/test_runtime_metrics.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""Tests for runtime metrics telemetry module"""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.telemetry.runtime_metrics import RuntimeMetrics, RuntimeMetricsTracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_time():
|
||||
"""Mock time.time() to have predictable values in tests"""
|
||||
with patch("time.time") as mock_time:
|
||||
# Start with time 1000.0 and increment by 10 seconds on each call
|
||||
times = [1000.0 + i * 10 for i in range(10)]
|
||||
mock_time.side_effect = times
|
||||
yield mock_time
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telemetry_manager():
|
||||
"""Create a mock TelemetryManager"""
|
||||
with patch(
|
||||
"axolotl.telemetry.runtime_metrics.TelemetryManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = True
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_psutil():
|
||||
"""Mock psutil for memory information"""
|
||||
with patch("axolotl.telemetry.runtime_metrics.psutil") as mock_psutil:
|
||||
mock_process = MagicMock()
|
||||
mock_memory_info = MagicMock()
|
||||
# Set initial memory to 1GB
|
||||
mock_memory_info.rss = 1024 * 1024 * 1024
|
||||
mock_process.memory_info.return_value = mock_memory_info
|
||||
mock_psutil.Process.return_value = mock_process
|
||||
yield mock_psutil
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_torch():
|
||||
"""Mock torch.cuda functions"""
|
||||
with patch("axolotl.telemetry.runtime_metrics.torch") as mock_torch:
|
||||
mock_torch.cuda.is_available.return_value = True
|
||||
mock_torch.cuda.device_count.return_value = 2
|
||||
|
||||
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 1) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
yield mock_torch
|
||||
|
||||
|
||||
class TestRuntimeMetrics:
|
||||
"""Tests for RuntimeMetrics class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test RuntimeMetrics initialization."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
assert metrics.start_time == 1000.0
|
||||
assert metrics.epoch_start_times == {}
|
||||
assert metrics.epoch_end_times == {}
|
||||
assert metrics.peak_gpu_memory == {}
|
||||
assert metrics.total_steps == 0
|
||||
assert metrics.current_epoch == 0
|
||||
assert metrics.current_step == 0
|
||||
assert metrics.peak_cpu_memory == 0
|
||||
|
||||
def test_elapsed_time(self, mock_time):
|
||||
"""Test elapsed_time property."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# Mock time.time() to return 1050.0
|
||||
mock_time.side_effect = [1050.0]
|
||||
|
||||
assert metrics.elapsed_time == 50.0
|
||||
|
||||
def test_epoch_time(self):
|
||||
"""Test epoch_time method."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# No epoch data
|
||||
assert metrics.epoch_time(0) is None
|
||||
|
||||
# Add epoch start but no end
|
||||
metrics.epoch_start_times[0] = 1000.0
|
||||
assert metrics.epoch_time(0) is None
|
||||
|
||||
# Add epoch end
|
||||
metrics.epoch_end_times[0] = 1060.0
|
||||
assert metrics.epoch_time(0) == 60.0
|
||||
|
||||
def test_average_epoch_time(self):
|
||||
"""Test average_epoch_time method."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# No completed epochs
|
||||
assert metrics.average_epoch_time() is None
|
||||
|
||||
# Add one completed epoch
|
||||
metrics.epoch_start_times[0] = 1000.0
|
||||
metrics.epoch_end_times[0] = 1060.0
|
||||
assert metrics.average_epoch_time() == 60.0
|
||||
|
||||
# Add second completed epoch
|
||||
metrics.epoch_start_times[1] = 1060.0
|
||||
metrics.epoch_end_times[1] = 1140.0 # 80 seconds
|
||||
assert metrics.average_epoch_time() == 70.0 # Average of 60 and 80
|
||||
|
||||
# Add incomplete epoch (should not affect average)
|
||||
metrics.epoch_start_times[2] = 1140.0
|
||||
assert metrics.average_epoch_time() == 70.0
|
||||
|
||||
def test_steps_per_second(self, mock_time):
|
||||
"""Test steps_per_second method."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# No steps - first call to time.time()
|
||||
mock_time.side_effect = None
|
||||
mock_time.return_value = 1050.0
|
||||
assert metrics.steps_per_second() is None
|
||||
|
||||
# Add steps - second call to time.time()
|
||||
metrics.total_steps = 100
|
||||
mock_time.return_value = 1050.0 # Keep same time for consistent result
|
||||
assert metrics.steps_per_second() == 2.0 # 100 steps / 50 seconds
|
||||
|
||||
def test_to_dict_basic(self, mock_time):
|
||||
"""Test to_dict method with basic metrics."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
metrics.total_steps = 100
|
||||
metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 # 2GB
|
||||
|
||||
# Mock elapsed_time
|
||||
mock_time.side_effect = None
|
||||
mock_time.return_value = 1050.0
|
||||
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert result["total_time_seconds"] == 50.0
|
||||
assert result["total_steps"] == 100
|
||||
assert result["steps_per_second"] == 2.0
|
||||
assert result["epochs_completed"] == 0
|
||||
assert result["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
assert "epoch_times" not in result
|
||||
assert "gpu_memory" not in result
|
||||
|
||||
def test_to_dict_with_epochs(self, mock_time):
|
||||
"""Test to_dict method with epoch data."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
metrics.total_steps = 100
|
||||
|
||||
# Add epoch data
|
||||
metrics.epoch_start_times[0] = 1000.0
|
||||
metrics.epoch_end_times[0] = 1060.0
|
||||
metrics.epoch_start_times[1] = 1060.0
|
||||
metrics.epoch_end_times[1] = 1140.0
|
||||
|
||||
# Mock elapsed_time
|
||||
mock_time.side_effect = None
|
||||
mock_time.return_value = 1150.0
|
||||
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert "epoch_times" in result
|
||||
assert result["epoch_times"]["epoch_0_seconds"] == 60.0
|
||||
assert result["epoch_times"]["epoch_1_seconds"] == 80.0
|
||||
assert result["average_epoch_time_seconds"] == 70.0
|
||||
|
||||
def test_to_dict_with_gpu_memory(self, mock_time):
|
||||
"""Test to_dict method with GPU memory data."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
metrics.peak_gpu_memory = {
|
||||
0: 1 * 1024 * 1024 * 1024, # 1GB
|
||||
1: 2 * 1024 * 1024 * 1024, # 2GB
|
||||
}
|
||||
|
||||
# Mock elapsed_time
|
||||
mock_time.side_effect = [1050.0]
|
||||
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert "gpu_memory" in result
|
||||
assert result["gpu_memory"]["gpu_0_peak_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||
assert result["gpu_memory"]["gpu_1_peak_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
class TestRuntimeMetricsTracker:
|
||||
"""Tests for RuntimeMetricsTracker class."""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_initialization(self, mock_time, mock_telemetry_manager):
|
||||
"""Test RuntimeMetricsTracker initialization."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
assert isinstance(tracker.metrics, RuntimeMetrics)
|
||||
assert tracker.metrics.start_time == 1000.0 # First value from mock_time
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_start_epoch(
|
||||
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
|
||||
):
|
||||
"""Test start_epoch method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Reset mock_time to control next value
|
||||
mock_time.side_effect = [1010.0]
|
||||
|
||||
tracker.start_epoch(0)
|
||||
|
||||
assert tracker.metrics.current_epoch == 0
|
||||
assert tracker.metrics.epoch_start_times[0] == 1010.0
|
||||
|
||||
# Verify memory metrics were updated
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
assert 0 in tracker.metrics.peak_gpu_memory
|
||||
assert 1 in tracker.metrics.peak_gpu_memory
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_end_epoch(self, mock_time, mock_telemetry_manager):
|
||||
"""Test end_epoch method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Start epoch 0
|
||||
mock_time.side_effect = [1010.0]
|
||||
tracker.start_epoch(0)
|
||||
|
||||
# End epoch 0
|
||||
mock_time.side_effect = [1060.0]
|
||||
tracker.end_epoch(0)
|
||||
|
||||
assert 0 in tracker.metrics.epoch_end_times
|
||||
assert tracker.metrics.epoch_end_times[0] == 1060.0
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_update_step(
|
||||
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
|
||||
):
|
||||
"""Test update_step method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Update step to a non-multiple of 100
|
||||
tracker.update_step(42)
|
||||
|
||||
assert tracker.metrics.current_step == 42
|
||||
assert tracker.metrics.total_steps == 1
|
||||
|
||||
# Memory metrics should not be updated for non-multiple of 100
|
||||
assert tracker.metrics.peak_cpu_memory == 0
|
||||
|
||||
# Update step to a multiple of 100
|
||||
tracker.update_step(100)
|
||||
|
||||
assert tracker.metrics.current_step == 100
|
||||
assert tracker.metrics.total_steps == 2
|
||||
|
||||
# Memory metrics should be updated for multiple of 100
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_update_memory_metrics(
|
||||
self, mock_psutil, mock_torch, mock_telemetry_manager
|
||||
):
|
||||
"""Test update_memory_metrics method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Initial memory state
|
||||
assert tracker.metrics.peak_cpu_memory == 0
|
||||
assert tracker.metrics.peak_gpu_memory == {}
|
||||
|
||||
# Update memory metrics
|
||||
tracker.update_memory_metrics()
|
||||
|
||||
# Verify CPU memory
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
|
||||
# Verify GPU memory
|
||||
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
|
||||
|
||||
# Change mocked memory values to be lower
|
||||
mock_process = mock_psutil.Process.return_value
|
||||
mock_memory_info = mock_process.memory_info.return_value
|
||||
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
|
||||
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 0.5) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Update memory metrics again
|
||||
tracker.update_memory_metrics()
|
||||
|
||||
# Peak values should not decrease
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
|
||||
|
||||
# Change mocked memory values to be higher
|
||||
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
|
||||
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 2) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Update memory metrics again
|
||||
tracker.update_memory_metrics()
|
||||
|
||||
# Peak values should increase
|
||||
assert tracker.metrics.peak_cpu_memory == 2 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[0] == 2 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[1] == 3 * 1024 * 1024 * 1024
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_telemetry_manager):
|
||||
"""Test get_memory_metrics method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Set peak memory values
|
||||
tracker.metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024
|
||||
tracker.metrics.peak_gpu_memory = {
|
||||
0: 3 * 1024 * 1024 * 1024,
|
||||
1: 4 * 1024 * 1024 * 1024,
|
||||
}
|
||||
|
||||
# Get memory metrics
|
||||
memory_metrics = tracker.get_memory_metrics()
|
||||
|
||||
# Verify CPU memory
|
||||
assert (
|
||||
memory_metrics["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||
) # Current value from mock
|
||||
assert (
|
||||
memory_metrics["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
) # Peak value we set
|
||||
|
||||
# Verify GPU memory
|
||||
assert (
|
||||
memory_metrics["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||
) # Current value from mock
|
||||
assert (
|
||||
memory_metrics["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024
|
||||
) # Peak value we set
|
||||
assert (
|
||||
memory_metrics["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
) # Current value from mock
|
||||
assert (
|
||||
memory_metrics["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024
|
||||
) # Peak value we set
|
||||
@@ -1,4 +1,8 @@
|
||||
"""Test cases for tokenizer loading."""
|
||||
"""
|
||||
Test cases for the tokenizer loading
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -9,7 +13,9 @@ from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
|
||||
class TestTokenizers:
|
||||
"""Test class for the load_tokenizer fn"""
|
||||
"""
|
||||
test class for the load_tokenizer fn
|
||||
"""
|
||||
|
||||
@enable_hf_offline
|
||||
def test_default_use_fast(self):
|
||||
@@ -149,50 +155,6 @@ class TestTokenizers:
|
||||
):
|
||||
load_tokenizer(cfg)
|
||||
|
||||
def test_mistral_tokenizer_auto_detection(self):
|
||||
"""Test that Mistral models are auto-detected and use MistralTokenizerWrapper"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
||||
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
|
||||
|
||||
def test_mixtral_tokenizer_auto_detection(self):
|
||||
"""Test that Mixtral models are auto-detected and use MistralTokenizerWrapper"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "model-hub/Mixtral-8x7B-v0.1",
|
||||
"tokenizer_config": "model-hub/Mixtral-8x7B-v0.1",
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
assert tokenizer.__class__.__name__ == "MistralTokenizerWrapper"
|
||||
|
||||
def test_mistral_tokenizer_basic_functionality(self):
|
||||
"""Test basic encode/decode functionality of MistralTokenizerWrapper"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
||||
"tokenizer_config": "adamo1139/Mistral-Small-24B-Instruct-2501-ungated",
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
# Test basic encoding
|
||||
text = "Hello, world!"
|
||||
tokens = tokenizer.encode(text)
|
||||
assert isinstance(tokens, list)
|
||||
assert len(tokens) > 0
|
||||
|
||||
# Test basic decoding
|
||||
decoded = tokenizer.decode(tokens)
|
||||
assert isinstance(decoded, str)
|
||||
|
||||
# Test token properties are accessible
|
||||
assert hasattr(tokenizer, "eos_token_id")
|
||||
assert hasattr(tokenizer, "bos_token_id")
|
||||
assert isinstance(tokenizer.eos_token_id, int)
|
||||
assert isinstance(tokenizer.bos_token_id, int)
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user