Compare commits

..

1 Commits

Author SHA1 Message Date
NanoCode012
83ff8bfa1a fix: change docker miniconda install to workspace 2025-11-06 18:54:56 +07:00
66 changed files with 69 additions and 3808 deletions

6
.github/FUNDING.yml vendored
View File

@@ -1,13 +1,13 @@
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
github: [winglian, OpenAccess-AI-Collective] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
ko_fi: axolotl_ai # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
custom: ['https://quickchart.io/qr?text=bitcoin%3Abc1qxlgwlqwfea5s2cxm42xqsfmwjct0rj8w8ea5np&size=480&centerImageUrl=https%3A%2F%2Fupload.wikimedia.org%2Fwikipedia%2Fcommons%2Fthumb%2F4%2F46%2FBitcoin.svg%2F64px-Bitcoin.svg.png'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -90,6 +90,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-base
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2

View File

@@ -25,6 +25,7 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -35,7 +36,6 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -45,6 +45,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=ref,event=branch
@@ -98,6 +99,7 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -108,7 +110,6 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -118,6 +119,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=ref,event=branch
@@ -177,6 +179,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud-term
axolotlai/axolotl-cloud-term
tags: |
type=ref,event=branch

View File

@@ -31,6 +31,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
@@ -83,6 +84,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}

View File

@@ -59,10 +59,6 @@ jobs:
timeout-minutes: 20
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository code
uses: actions/checkout@v4
@@ -95,10 +91,6 @@ jobs:
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
@@ -126,6 +118,10 @@ jobs:
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
@@ -138,10 +134,6 @@ jobs:
timeout-minutes: 20
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository code
uses: actions/checkout@v4
@@ -175,10 +167,6 @@ jobs:
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
@@ -196,6 +184,10 @@ jobs:
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest

View File

@@ -29,9 +29,6 @@
## 🎉 Latest Updates
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
- 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
@@ -39,12 +36,12 @@
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
<details>
<summary>Expand older updates</summary>
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
@@ -157,13 +154,6 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
## 📈 Telemetry
Axolotl has opt-out telemetry that helps us understand how the project is being used
and prioritize improvements. We collect basic system information, model types, and
error rates—never personal data or file paths. Telemetry is enabled by default. To
disable it, set AXOLOTL_DO_NOT_TRACK=1. For more details, see our [telemetry documentation](https://docs.axolotl.ai/docs/telemetry.html).
## ❤️ Sponsors
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)

View File

@@ -241,7 +241,6 @@ website:
- docs/installation.qmd
- docs/inference.qmd
- docs/cli.qmd
- docs/telemetry.qmd
- docs/config-reference.qmd
- text: "API Reference"
href: docs/api

View File

@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ENV PATH="/workspace/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.10"
ARG PYTORCH_VERSION="2.1.2"
@@ -24,14 +24,14 @@ RUN apt-get update \
&& rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& mkdir -p /workspace/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace

View File

@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ENV PATH="/workspace/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="next"
@@ -19,12 +19,12 @@ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& mkdir -p /workspace/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace

View File

@@ -5,7 +5,7 @@ ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ENV PATH="/workspace/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="nightly"
@@ -19,14 +19,14 @@ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& mkdir -p /workspace/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
ENV PATH="/workspace/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace

View File

@@ -218,13 +218,6 @@ If you have tool arguments with same name but different dtypes (like `"time": st
```
"arguments": "{\"...\": \"...\"}"
```
The same is applicable for tool parameters.
```
"parameters": "{\"...\": \"...\"}"
```
:::
Example config for Llama4:

View File

@@ -124,8 +124,6 @@ Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral
```yaml
base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: VoxtralProcessor
```
### Gemma-3 {#sec-gemma-3}

View File

@@ -597,116 +597,6 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
#### OpenEnv Rollout Functions
GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.
For example, to implement a simple math-solving environment with step-by-step verification:
```python
# math_env.py
import re
def math_solver_rollout(model, processing_class, prompts, generation_config=None):
"""
Custom rollout function that generates step-by-step math solutions.
Args:
model: The language model
processing_class: The tokenizer/processing_class
prompts: List of prompt dicts (with 'messages' key for chat format)
generation_config: Optional generation configuration
Returns:
List of completion strings
"""
completions = []
for prompt in prompts:
# Apply chat template to prompt
messages = prompt.get("messages", [])
formatted_prompt = processing_class.apply_chat_template(
messages, processing_class=False, add_generation_prompt=True
)
# Generate step-by-step solution
full_response = ""
for step in range(5): # Max 5 reasoning steps
current_input = formatted_prompt + full_response + "\nNext step:"
inputs = processing_class(current_input, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
generation_config=generation_config,
)
step_text = processing_class.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
# Check if solution is complete
if "FINAL ANSWER:" in step_text:
full_response += step_text
break
full_response += step_text + "\n"
completions.append(full_response)
return completions
def math_reward(prompts, completions, answers, **kwargs):
"""Reward function that checks mathematical correctness"""
rewards = []
for completion, correct_answer in zip(completions, answers):
# Extract predicted answer
match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
predicted = match.group(1).strip() if match else ""
# Compare with correct answer
reward = 1.0 if predicted == str(correct_answer) else 0.0
rewards.append(reward)
return rewards
def math_transform(cfg, *args, **kwargs):
"""Transform dataset to GRPO format with answer field"""
def transform_fn(example, processing_class=None):
return {
"prompt": [{"role": "user", "content": example["question"]}],
"answer": str(example["answer"]),
}
return transform_fn, {"remove_columns": ["question"]}
```
```yaml
rl: grpo
trl:
beta: 0.001
max_completion_length: 512
num_generations: 4
rollout_func: "math_env.math_solver_rollout" # Custom rollout function
reward_funcs: ["math_env.math_reward"]
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
type: math_env.math_transform
```
The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives:
- `model`: The language model
- `processing_class`: The tokenizer/processing class
- `prompts`: List of prompt dictionaries
- `generation_config` (optional): Generation configuration
And should return a list of completion strings.
For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv).
#### GRPO with DAPO/Dr. GRPO loss
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.

View File

@@ -1,61 +0,0 @@
---
title: Telemetry
description: A description of the telemetry implementation in Axolotl.
---
# Telemetry in Axolotl
Axolotl implements anonymous telemetry to help maintainers understand how the library
is used and where users encounter issues. This data helps prioritize features, optimize
performance, and fix bugs.
## Data Collection
We collect:
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
version, etc.
- Hardware info: CPU count, memory, GPU count and models
- Runtime metrics: Training progress, memory usage, timing information
- Usage patterns: Models (from a whitelist) and configurations used
- Error tracking: Stack traces and error messages (sanitized to remove personal
information)
Personally identifiable information (PII) is not collected.
## Implementation
Telemetry is implemented using PostHog and consists of:
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
telemetry system and provides methods for tracking events.
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
sends sanitized stack traces.
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
runtime metrics during training.
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
runtime metrics telemetry.
The telemetry system will block training startup for 10 seconds to ensure users are
aware of data collection, unless telemetry is explicitly enabled or disabled.
## Opt-Out Mechanism
Telemetry is **enabled by default** on an opt-out basis. To disable it, set
`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1`.
A warning message will be logged on start to clearly inform users about telemetry.
We will remove this after some period.
To hide the warning message about telemetry that is displayed on train, etc. startup,
explicitly set: `AXOLOTL_DO_NOT_TRACK=0` (enable telemetry) or `AXOLOTL_DO_NOT_TRACK=1`
(explicitly disable telemetry).
## Privacy
- All path-like config information is automatically redacted from telemetry data
- Model information is only collected for whitelisted organizations
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
- Each run generates a unique anonymous ID
- This allows us to link different telemetry events in a single same training run
- Telemetry is only sent from the main process to avoid duplicate events

View File

@@ -1,65 +0,0 @@
# Finetune IBM's Granite 4.0 with Axolotl
[Granite 4.0](https://huggingface.co/collections/ibm-granite/granite-40-language-models) are a family of open source models trained by IBM Research.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Granite4 is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.7.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:
```bash
axolotl train examples/granite4/granite-4.0-tiny-fft.yaml
```
This config uses about 40.8GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
### Limitation
Adapter finetuning does not work at the moment. It would error with
```bash
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x3072 and 1x1179648)
```
In addition, if adapter training works, `lora_target_linear: true` will not work due to:
```bash
ValueError: Target module GraniteMoeHybridParallelExperts() is not supported.
```
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Granite Docs](https://www.ibm.com/granite/docs/models/granite)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,45 +0,0 @@
base_model: ibm-granite/granite-4.0-tiny-preview
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/model-out
sequence_len: 2048
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,5 +1,5 @@
base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: VoxtralProcessor
processor_type: AutoProcessor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.48.2
bitsandbytes==0.47.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
@@ -12,12 +12,12 @@ packaging==23.2
huggingface_hub>=0.36.0
peft>=0.17.1
tokenizers>=0.22.1
tokenizers>=0.21.1
transformers==4.57.1
accelerate==1.11.0
datasets==4.4.1
accelerate==1.10.1
datasets==4.3.0
deepspeed>=0.17.0
trl==0.25.0
trl==0.24.0
hf_xet==1.2.0
kernels>=0.9.0
trackio
@@ -64,13 +64,9 @@ immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.13.0
openenv-core==0.1.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.7
axolotl-contribs-mit==0.0.5
# telemetry
posthog==6.7.11
mistral-common==1.8.5

View File

@@ -130,7 +130,7 @@ extras_require = {
"ring-flash-attn>=0.1.7",
],
"deepspeed": [
"deepspeed==0.18.2",
"deepspeed==0.17.5",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -14,8 +14,6 @@ import yaml
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -33,8 +31,6 @@ LOG = get_logger(__name__)
API_KEY_FIELDS = {"comet_api_key"}
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
"""
@@ -168,7 +164,6 @@ def plugin_set_cfg(cfg: DictDefault):
plugin_manager.cfg = cfg
@send_errors
def load_cfg(
config: str | Path | DictDefault = Path("examples/"), **kwargs
) -> DictDefault:
@@ -202,8 +197,6 @@ def load_cfg(
temp_file.close()
cfg.axolotl_config_path = temp_file.name
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
@@ -247,7 +240,6 @@ def load_cfg(
setup_comet_env_vars(cfg)
plugin_set_cfg(cfg)
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
cfg_to_log = {
k: "[REDACTED]" if k in API_KEY_FIELDS else v
for k, v in cfg.items()

View File

@@ -19,10 +19,7 @@ from axolotl.cli.utils.diffusion import (
launch_diffusion_gradio_ui,
)
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.utils.chat_templates import (
get_chat_template_from_config,
)
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -46,7 +43,6 @@ def get_multi_line_input() -> str:
return instruction
@send_errors
def do_inference(
*,
cfg: DictDefault,
@@ -164,7 +160,6 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@send_errors
def do_inference_gradio(
*,
cfg: DictDefault,

View File

@@ -7,14 +7,12 @@ import fire
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@send_errors
def do_merge_lora(*, cfg: DictDefault) -> None:
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config

View File

@@ -23,7 +23,6 @@ from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.config import load_cfg
from axolotl.telemetry.errors import send_errors
from axolotl.utils.logging import get_logger
from axolotl.utils.train import determine_last_checkpoint
@@ -119,7 +118,6 @@ def _distributed_checkpoint_to_merged_weights(
return save_path_
@send_errors
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,

View File

@@ -17,7 +17,6 @@ from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import disable_datasets_caching
@@ -25,7 +24,6 @@ from axolotl.utils.trainer import disable_datasets_caching
LOG = get_logger(__name__)
@send_errors
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
"""
Preprocesses dataset specified in axolotl config.

View File

@@ -9,7 +9,6 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -35,7 +34,6 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
)
@send_errors
def load_datasets(
*,
cfg: DictDefault,
@@ -98,7 +96,6 @@ def load_datasets(
)
@send_errors
def load_preference_datasets(
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
) -> TrainDatasetMeta:

View File

@@ -29,8 +29,6 @@ from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.telemetry.callbacks import TelemetryCallback
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils import (
is_comet_available,
is_mlflow_available,
@@ -120,13 +118,6 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.dynamic_checkpoint and self.cfg.dynamic_checkpoint.enabled:
from axolotl.utils.callbacks.dynamic_checkpoint import (
DynamicCheckpointCallback,
)
callbacks.append(DynamicCheckpointCallback(self.cfg))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -164,10 +155,6 @@ class TrainerBuilderBase(abc.ABC):
)
)
telemetry_manager = TelemetryManager.get_instance()
if telemetry_manager.enabled:
callbacks.append(TelemetryCallback())
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -209,9 +196,9 @@ class TrainerBuilderBase(abc.ABC):
):
warmup_steps = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps is not None:
if self.cfg.warmup_steps:
warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio is not None:
elif self.cfg.warmup_ratio:
if total_num_steps:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:

View File

@@ -43,7 +43,7 @@ from axolotl.core.trainers.utils import (
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_distributed, is_main_process
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -350,11 +350,6 @@ class AxolotlTrainer(
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
inputs_key = "labels" if "labels" in inputs else "input_ids"
num_tokens = (inputs[inputs_key] != -100).sum()
if is_distributed():
torch.distributed.all_reduce(
num_tokens, op=torch.distributed.ReduceOp.SUM
)
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
@@ -362,11 +357,6 @@ class AxolotlTrainer(
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if hasattr(self.state, "total_tokens"):
self.state.total_tokens += num_tokens
else:
self.state.total_tokens = num_tokens
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -631,7 +621,6 @@ class AxolotlTrainer(
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
logs["total_tokens"] = int(self.state.total_tokens.item())
del self._stored_metrics[train_eval]

View File

@@ -126,9 +126,6 @@ class GRPOStrategy:
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
return grpo_args_kwargs
@classmethod
@@ -204,32 +201,3 @@ class GRPOStrategy:
raise ValueError(
f"Reward function {reward_func_fqn} not found."
) from exc
@classmethod
def get_rollout_func(cls, rollout_func_fqn: str):
"""
Returns the rollout function from the given fully qualified name.
Args:
rollout_func_fqn (str): Fully qualified name of the rollout function
(e.g. my_module.my_rollout_func)
Returns:
Callable rollout function
"""
try:
rollout_func_module_name = rollout_func_fqn.split(".")[-1]
rollout_func_module = importlib.import_module(
".".join(rollout_func_fqn.split(".")[:-1])
)
rollout_func = getattr(rollout_func_module, rollout_func_module_name)
if not callable(rollout_func):
raise ValueError(
f"Rollout function {rollout_func_fqn} must be callable"
)
return rollout_func
except ModuleNotFoundError as exc:
raise ValueError(f"Rollout function {rollout_func_fqn} not found.") from exc

View File

@@ -10,7 +10,6 @@ import torch
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.telemetry.errors import send_errors
from axolotl.train import (
TrainDatasetMeta,
setup_model_and_tokenizer,
@@ -64,7 +63,6 @@ def evaluate_dataset(
return metrics
@send_errors
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets.

View File

@@ -18,9 +18,6 @@ liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
# FLCE-specific
liger_use_token_scaling: true
```
## Supported Models

View File

@@ -16,7 +16,7 @@
Module for handling LIGER input arguments.
"""
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
@@ -35,15 +35,6 @@ class LigerArgs(BaseModel):
liger_glu_activation: bool | None = None
liger_cross_entropy: bool | None = None
liger_fused_linear_cross_entropy: bool | None = None
liger_use_token_scaling: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Enables use_token_scaling in fused_linear_cross_entropy. "
"When True, each token's loss is multiplied by its predicted probability (detached from gradients)."
)
},
)
@model_validator(mode="before")
@classmethod
@@ -84,18 +75,6 @@ class LigerArgs(BaseModel):
)
return data
@model_validator(mode="before")
@classmethod
def check_liger_use_token_scaling_flce(cls, data):
if data.get("liger_use_token_scaling") and not data.get(
"liger_fused_linear_cross_entropy"
):
raise ValueError(
"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled."
)
return data
@model_validator(mode="after")
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
# TODO @SalmanMohammadi this is a larger fix - investigate

View File

@@ -48,33 +48,6 @@ class LigerPlugin(BasePlugin):
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)
if cfg.liger_use_token_scaling:
# Patch FLCE to set token_scaling=True for function and class API
from liger_kernel.transformers import functional
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
old_liger_fused_linear_cross_entropy = (
functional.liger_fused_linear_cross_entropy
)
def patched_liger_fused_linear_cross_entropy(*args, **kwargs):
kwargs["use_token_scaling"] = True
return old_liger_fused_linear_cross_entropy(*args, **kwargs)
functional.liger_fused_linear_cross_entropy = (
patched_liger_fused_linear_cross_entropy
)
old_init = LigerFusedLinearCrossEntropyLoss.__init__
def patched_init(self, *args, **kwargs):
kwargs["use_token_scaling"] = True
return old_init(self, *args, **kwargs)
LigerFusedLinearCrossEntropyLoss.__init__ = patched_init
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)

View File

@@ -20,7 +20,6 @@ from peft import (
from transformers import PreTrainedModel
from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -173,7 +172,6 @@ def load_lora(
return model, lora_config
@send_errors
def load_adapter(
model: PreTrainedModel,
cfg: DictDefault,

View File

@@ -49,7 +49,6 @@ from axolotl.loaders.utils import (
load_model_config,
)
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.telemetry.errors import send_errors
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
@@ -159,7 +158,6 @@ class ModelLoader:
"""Property that determines if FSDP with QLoRA is enabled."""
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
@send_errors
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches.

View File

@@ -457,7 +457,7 @@ class PatchManager:
and self.cfg.flash_attention
and not self.inference
):
# TODO(MengqingCao): split these patches separately
# TODO(MengqingCao): split these patches seperately
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,

View File

@@ -1,47 +1,27 @@
"""Processor loading functionality for multi-modal models"""
from typing import Any
import transformers
from transformers import (
AutoProcessor,
PreTrainedTokenizerBase,
)
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@send_errors
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs: dict[str, Any] = {} # Do we actually need this?
processor_cls = AutoProcessor
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
if cfg.tokenizer_use_mistral_common:
def _patch_mistralcommontokenizer():
"""
Transformers v5 stops reading the sub-processor.
We need to patch this, so both processors use this.
"""
import transformers.tokenization_mistral_common as tokenization_mistral_common
from axolotl.utils.mistral import HFMistralTokenizer
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
_patch_mistralcommontokenizer()
from transformers import VoxtralProcessor
if processor_cls == VoxtralProcessor:
return VoxtralProcessor.from_pretrained(
cfg.processor_config,
)
from axolotl.utils.mistral import Mistral3Processor
return Mistral3Processor(
@@ -52,6 +32,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
**processor_kwargs,
)
# Attempt to load image size from processor if available

View File

@@ -13,7 +13,6 @@ from transformers import (
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.telemetry.errors import send_errors
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
@@ -120,7 +119,6 @@ def modify_tokenizer_files(
return tokenizer_dir
@send_errors
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
"""Load and configure the tokenizer based on the provided config."""

View File

@@ -40,8 +40,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"smollm3",
"granite",
"granitemoe",
"granitemoeshared",
"granitemoehybrid",
"hunyuan_v1_dense",
"hunyuan_v1_moe",
"gpt_oss",

View File

@@ -823,23 +823,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return None
if isinstance(tools, list):
# Process each tool to handle JSON string parameters
for tool in tools:
if isinstance(tool, dict) and "function" in tool:
function = tool["function"]
if "parameters" in function:
params = function["parameters"]
if isinstance(params, str):
try:
function["parameters"] = json.loads(params)
except json.JSONDecodeError as e:
LOG.error(
f"Error parsing tool parameters as JSON. "
f"Function: {function.get('name', 'unknown')}, "
f"Parameters string: {params!r}, "
f"Error: {e}"
)
raise
return tools
raise ValueError(

View File

@@ -1,165 +0,0 @@
"""Trainer callbacks for reporting runtime metrics at regular intervals."""
import logging
import time
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.telemetry.manager import TelemetryManager
from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker
LOG = logging.getLogger(__name__)
TIME_SINCE_LAST = 60
class TelemetryCallback(TrainerCallback):
"""
Trainer callback for tracking and reporting runtime metrics.
This callback tracks training progress, runtime, and memory usage,
sending telemetry at configurable intervals.
"""
report_interval_steps: int = 100
def __init__(self):
"""Initialize the metrics callback."""
self.tracker = RuntimeMetricsTracker()
self.telemetry_manager = TelemetryManager.get_instance()
self.current_epoch = -1
self.start_time = time.time()
self.last_report_time = None
self.last_report_step = 0
# pylint: disable=unused-argument
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle training start."""
self.telemetry_manager.send_event(event_type="train-start")
# pylint: disable=unused-argument
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle training end."""
# Send training completion event
self.telemetry_manager.send_event(
event_type="train-end",
properties=self._extract_last_metrics(state)
| self.tracker.metrics.to_dict(),
)
# pylint: disable=unused-argument
def on_epoch_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle epoch start."""
self.current_epoch += 1
self.tracker.start_epoch(self.current_epoch)
# pylint: disable=unused-argument
def on_epoch_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle epoch end."""
self.tracker.end_epoch(self.current_epoch)
# pylint: disable=unused-argument
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""Handle step end."""
step = state.global_step
self.tracker.update_step(step)
# Check if we should report metrics
should_report = (
step % self.report_interval_steps == 0
or step == 1 # Always report first step
or step - self.last_report_step >= self.report_interval_steps
)
if should_report:
current_time = time.time()
if self.last_report_time is not None:
time_since_last_report = current_time - self.last_report_time
else:
time_since_last_report = current_time - self.start_time
steps_since_last_report = step - self.last_report_step
# Only report if enough time has passed
if (
step == 1
or time_since_last_report >= TIME_SINCE_LAST
or steps_since_last_report >= self.report_interval_steps
):
# Calculate steps per second for this interval
if time_since_last_report > 0 and steps_since_last_report > 0:
steps_per_second = steps_since_last_report / time_since_last_report
else:
steps_per_second = 0
# Update memory metrics
self.tracker.update_memory_metrics()
# Prepare metrics to report
metrics = self._extract_last_metrics(state) | {
"step": step,
"epoch": self.current_epoch,
"progress": state.epoch, # Fractional epoch progress
"steps_per_second": steps_per_second,
"elapsed_time": current_time - self.start_time,
"time_since_last_report": time_since_last_report,
}
# Add memory metrics
memory_metrics = self.tracker.get_memory_metrics()
metrics.update({"memory": memory_metrics})
# Send telemetry
self.telemetry_manager.send_event(
event_type="train-progress", properties=metrics
)
# Update last report time and step
self.last_report_time = current_time
self.last_report_step = step
def _extract_last_metrics(self, state: TrainerState) -> dict:
"""Extract last loss, learning_rate, and grad_norm from log history."""
if not state.log_history:
return {"loss": 0, "learning_rate": 0, "grad_norm": 0}
last_log = state.log_history[-1]
return {
"loss": last_log.get("loss", 0),
"learning_rate": last_log.get("learning_rate", 0),
"grad_norm": last_log.get("grad_norm", 0),
}

View File

@@ -1,160 +0,0 @@
"""Telemetry utilities for exception and traceback information."""
import logging
import os
import re
import traceback
from functools import wraps
from inspect import getmodule
from typing import Any, Callable
from axolotl.telemetry.manager import TelemetryManager
LOG = logging.getLogger(__name__)
ERROR_HANDLED = False
def sanitize_stack_trace(stack_trace: str) -> str:
"""
Remove personal information from stack trace messages while keeping Python package codepaths.
This function identifies Python packages by looking for common patterns in virtual environment
and site-packages directories, preserving the package path while removing user-specific paths.
Args:
stack_trace: The original stack trace string.
Returns:
A sanitized version of the stack trace with Python package paths preserved.
"""
# Split the stack trace into lines to process each file path separately
lines = stack_trace.split("\n")
sanitized_lines = []
# Regular expression to find file paths in the stack trace
path_pattern = re.compile(r'(?:File ")(.*?)(?:")')
# Regular expression to identify paths in site-packages or dist-packages
# This matches path segments like "site-packages/package_name" or "dist-packages/package_name"
site_packages_pattern = re.compile(
r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
)
# Additional common virtual environment patterns
venv_lib_pattern = re.compile(
r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
)
for line in lines:
# Check if this line contains a file path
path_match = path_pattern.search(line)
if path_match:
full_path = path_match.group(1)
sanitized_path = ""
# Try to match site-packages pattern
site_packages_match = site_packages_pattern.search(full_path)
venv_lib_match = venv_lib_pattern.search(full_path)
if site_packages_match:
# Find the index where the matched pattern starts
idx = full_path.find("site-packages")
if idx == -1:
idx = full_path.find("dist-packages")
# Keep from 'site-packages' onward
if idx >= 0:
sanitized_path = full_path[idx:]
elif venv_lib_match:
# For other virtual environment patterns, find the package directory
match_idx = venv_lib_match.start(1)
if match_idx > 0:
# Keep from the package name onward
package_name = venv_lib_match.group(1)
idx = full_path.rfind(
package_name, 0, match_idx + len(package_name)
)
if idx >= 0:
sanitized_path = full_path[idx:]
# If we couldn't identify a package pattern but path contains 'axolotl'
elif "axolotl" in full_path:
idx = full_path.rfind("axolotl")
if idx >= 0:
sanitized_path = full_path[idx:]
# Apply the sanitization to the line
if sanitized_path:
line = line.replace(full_path, sanitized_path)
else:
# If we couldn't identify a package pattern, just keep the filename
filename = os.path.basename(full_path)
if filename:
line = line.replace(full_path, filename)
else:
line = line.replace(full_path, "")
sanitized_lines.append(line)
return "\n".join(sanitized_lines)
def send_errors(func: Callable) -> Callable:
"""
Decorator to send exception info in a function. If an exception is raised, we send
telemetry containing the stack trace and error message.
If an error occurs in a decorated function that is called by another decorated
function, we'll only send telemetry corresponding to the lower-level function.
Args:
func: Function to decorate.
Returns:
Decorated function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
telemetry_manager = TelemetryManager.get_instance()
if not telemetry_manager.enabled:
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except Exception as exception:
# Only track if we're not already handling an error. This prevents us from
# capturing an error more than once in nested decorated function calls.
global ERROR_HANDLED # pylint: disable=global-statement
if not ERROR_HANDLED:
ERROR_HANDLED = True
# Get function module path
module = getmodule(func)
module_path = (
f"{module.__name__}.{func.__name__}" if module else func.__name__
)
# Get stack trace
stack_trace = "".join(
traceback.format_exception(
type(exception), exception, exception.__traceback__
)
)
stack_trace = sanitize_stack_trace(stack_trace)
# Send error telemetry
telemetry_manager.send_event(
event_type=f"{module_path}-error",
properties={
"exception": str(exception),
"stack_trace": stack_trace,
},
)
raise
return wrapper

View File

@@ -1,416 +0,0 @@
"""Telemetry manager and associated utilities."""
import atexit
import importlib
import logging
import os
import platform
import time
import uuid
from pathlib import Path
from typing import Any
import posthog
import psutil
import torch
import yaml
LOG = logging.getLogger(__name__)
POSTHOG_HOST = "https://app.posthog.com"
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
OPT_OUT_WARNING_SLEEP_SECONDS = 10
OPT_OUT_WARNING = (
"\nTelemetry is now enabled by default to help improve Axolotl. "
"If you'd like to disable it, set AXOLOTL_DO_NOT_TRACK=1 in your environment.\n\n"
"Telemetry data helps us understand:\n"
"- Which features are most used\n"
"- What hardware configurations to prioritize\n"
"- Where users encounter errors\n\n"
"Personally identifiable information (PII) is not collected.\n\n"
"To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) "
"or AXOLOTL_DO_NOT_TRACK=1 (disable telemetry).\n\n"
"For details, see: https://docs.axolotl.ai/docs/telemetry.html\n\n"
f"Sleeping for {OPT_OUT_WARNING_SLEEP_SECONDS}s..."
)
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
# NOTE: Need to keep these up to date with any config schema changes
FIELDS_TO_REDACT = {
"base_model",
"tokenizer_config",
"base_model_config",
"pretraining_dataset", # NOTE: this field may be a string or a dictionary
"resume_from_checkpoint",
"hub_model_id",
}
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
PATH_INDICATORS = {"path", "dir"}
# pylint: disable=duplicate-code
RELEVANT_PACKAGES = {
"torch",
"transformers",
"trl",
"datasets",
"peft",
"bitsandbytes",
"accelerate",
"optimum",
"deepspeed",
"ray",
"axolotl",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"tokenizers",
"sentencepiece",
"torchao",
"lm_eval",
}
def is_main_process() -> bool:
"""
Check whether we're running in the main process.
Note:
We're using this function instead of `torch.utils.distributed.is_main_process`
causes issues with DeepSpeed world_size since. This function avoids that issue
by checking env vars that are set by various launchers.
Returns:
Whether we're running in the main process.
"""
# If PyTorch distributed is already initialized, use it
if torch.distributed.is_initialized():
return torch.distributed.get_rank() == 0
# Otherwise check environment variables for global rank
# NOTE: need to verify this in SLURM / OpenMPI environments
global_rank = int(
os.environ.get(
"RANK",
os.environ.get(
"GLOBAL_RANK",
os.environ.get(
"SLURM_PROCID",
os.environ.get(
"OMPI_COMM_WORLD_RANK",
"0",
),
),
),
)
)
return global_rank == 0
class TelemetryManager:
"""Manages telemetry collection and transmission"""
_instance = None
_initialized = False
def __new__(cls):
"""
Telemetry manager constructor. Creates the singleton instance of this class if
it doesn't already exist.
"""
if cls._instance is None:
cls._instance = super(TelemetryManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""Telemetry manager initializer"""
if self._initialized:
return
self.enabled = self._check_telemetry_enabled()
if self.enabled:
self.run_id = str(uuid.uuid4())
self.whitelist = self._load_whitelist()
try:
self.system_info = self._get_system_info()
except Exception as e: # pylint: disable=broad-exception-caught
LOG.warning(f"Error during system info collection: {e}")
self.system_info = None
self._init_posthog()
# Register shutdown method to flush posthog telemetry
atexit.register(self.shutdown)
self._initialized = True
@classmethod
def get_instance(cls) -> "TelemetryManager":
if cls._instance is None:
cls._instance = TelemetryManager()
return cls._instance
def _check_telemetry_enabled(self) -> bool:
"""
Check if telemetry is enabled based on environment variables. We also check
whether this is the main process (for the distributed setting and to avoid
sending duplicate PostHog events per GPU).
Note: This is enabled by default on an opt-out basis. Set
`AXOLOTL_DO_NOT_TRACK=1` to disable telemetry. For more details, see
https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
Returns:
Boolean denoting whether telemetry is enabled or not.
"""
# Parse relevant env vars
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
do_not_track = os.getenv("DO_NOT_TRACK")
# Default to enabled (opt-out model)
if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in (
"0",
"1",
"false",
"true",
):
# Print opt-out info message for main process only
if is_main_process():
LOG.warning(OPT_OUT_WARNING)
time.sleep(OPT_OUT_WARNING_SLEEP_SECONDS)
return True
# Only rank 0 will send telemetry
if not is_main_process():
return False
if do_not_track is None:
do_not_track = "0"
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
enabled = axolotl_do_not_track.lower() not in (
"1",
"true",
) and do_not_track.lower() not in ("1", "true")
return enabled
def _load_whitelist(self) -> dict:
"""Load HuggingFace Hub organization whitelist"""
with open(WHITELIST_PATH, encoding="utf-8") as f:
whitelist = yaml.safe_load(f)
# Send org strings to lowercase since model names are case insensitive
whitelist["organizations"] = {
org.lower() for org in whitelist["organizations"]
}
return whitelist
def _is_whitelisted(self, value: str) -> bool:
"""
Check if model / dataset / etc. org is in whitelist.
Args:
value: Value for one of `axolotl.telemetry.manager.FIELDS_WITH_ORGS`
("base_model", etc.).
Returns:
Boolean indicating whitelist membership.
"""
# NOTE: This membership-checking logic can be improved.
# What happens when a local model path matches a whitelisted org?
parts = value.split("/")
if len(parts) < 2:
return False
org = parts[0]
whitelisted = org.lower() in self.whitelist["organizations"]
return whitelisted
def _init_posthog(self):
"""Initialize PostHog client"""
posthog.api_key = POSTHOG_WRITE_KEY
posthog.project_api_key = POSTHOG_WRITE_KEY
posthog.host = POSTHOG_HOST
def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]:
"""
Redact properties to remove any paths, so as to avoid inadvertently collecting
private or personally identifiable information (PII). We also remove
information related to Wandb, MLflow, etc. configuration.
Args:
properties: Dictionary of properties to redact.
Returns:
Properties dictionary with redaction applied.
"""
if not properties:
return {}
def redact_value(value: Any, key: str = "") -> Any:
"""Recursively sanitize values, redacting those with path-like keys"""
if isinstance(key, str) and isinstance(value, str):
# Other redaction special cases
if (
key in FIELDS_TO_REDACT
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
or any(indicator in key.lower() for indicator in PATH_INDICATORS)
):
# Fields with whitelisted orgs don't need to be redacted
if not self._is_whitelisted(value):
return "[REDACTED]"
# Handle nested values
if isinstance(value, dict):
return {k: redact_value(v, k) for k, v in value.items()}
if isinstance(value, list):
return [redact_value(item) for item in value]
return value
# Create new dict with redacted values
redacted = {k: redact_value(v, k) for k, v in properties.items()}
return redacted
def _get_system_info(self) -> dict[str, Any]:
"""Collect system information for various hardware accelerators"""
gpu_info = []
accelerator_type = "none"
# NVIDIA GPUs
if torch.cuda.is_available():
accelerator_type = "cuda"
for i in range(torch.cuda.device_count()):
gpu_info.append(
{
"name": torch.cuda.get_device_name(i),
"memory": torch.cuda.get_device_properties(i).total_memory,
}
)
# AMD GPUs
elif hasattr(torch, "hip") and torch.hip.is_available():
accelerator_type = "hip"
for i in range(torch.hip.device_count()):
gpu_info.append(
{
"name": torch.hip.get_device_name(i),
"memory": (
torch.hip.get_device_properties(i).total_memory
if hasattr(torch.hip, "get_device_properties")
else None
),
}
)
# Apple Silicon
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
accelerator_type = "mps"
gpu_info.append(
{
"name": "Apple Silicon",
# NOTE: this is memory allocated to this process, not total memory
"memory": torch.mps.driver_allocated_memory(),
}
)
# Intel GPUs
elif hasattr(torch, "xpu") and torch.xpu.is_available():
accelerator_type = "xpu"
for i in range(torch.xpu.device_count()):
memory = None
if hasattr(torch.xpu, "get_device_properties"):
memory = torch.xpu.get_device_properties(i).total_memory
gpu_info.append(
{
"name": torch.xpu.get_device_name(i),
"memory": memory,
}
)
# NPUs
elif hasattr(torch, "npu") and torch.npu.is_available():
accelerator_type = "npu"
for i in range(torch.npu.device_count()):
memory = None
if hasattr(torch.npu, "get_device_properties"):
memory = torch.npu.get_device_properties(i).total_memory
gpu_info.append(
{
"name": torch.npu.get_device_name(i),
"memory": memory,
}
)
# Get relevant package versions
installed_packages = {}
for package in RELEVANT_PACKAGES:
try:
version = importlib.metadata.version(package)
installed_packages[f"{package}_version"] = version
except importlib.metadata.PackageNotFoundError:
pass
return {
"os": platform.system(),
"python_version": platform.python_version(),
"cpu_count": psutil.cpu_count(),
"memory_total": psutil.virtual_memory().total,
"accelerator_type": accelerator_type,
"accelerator_count": len(gpu_info),
"accelerator_info": gpu_info,
**installed_packages,
}
def send_event(self, event_type: str, properties: dict[str, Any] | None = None):
"""Send a telemetry event"""
if not self.enabled:
return
if properties is None:
properties = {}
# Sanitize properties to remove PII
properties = self._redact_paths(properties)
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
try:
# Send event via PostHog
posthog.capture(
distinct_id=self.run_id,
event=event_type,
properties=properties,
disable_geoip=True,
)
except Exception as e: # pylint: disable=broad-exception-caught
LOG.warning(f"Failed to send telemetry event: {e}")
# Additionally, send system info telemetry when loading config.
# NOTE: Is this the best place for this?
if event_type == "config-loaded":
self.send_system_info()
def send_system_info(self):
"""Helper method for sending system info"""
if self.system_info is not None:
self.send_event(event_type="system-info", properties=self.system_info)
def shutdown(self):
"""Ensure all queued events are processed before shutdown"""
if self.enabled:
posthog.shutdown()

View File

@@ -1,210 +0,0 @@
"""Telemetry utilities for runtime and memory metrics."""
import logging
import time
from dataclasses import dataclass, field
from typing import Any
import psutil
import torch
from axolotl.telemetry.manager import TelemetryManager
LOG = logging.getLogger(__name__)
@dataclass
class RuntimeMetrics:
"""Container for runtime metrics to be tracked throughout training."""
# Timing metrics
start_time: float
epoch_start_times: dict[int, float] = field(init=False)
epoch_end_times: dict[int, float] = field(init=False)
# Memory metrics
peak_cpu_memory: int = 0
peak_gpu_memory: dict[int, int] = field(init=False)
# Progress metrics
total_steps: int = 0
current_epoch: int = 0
current_step: int = 0
def __post_init__(self):
"""Initialize empty metric mappings."""
self.epoch_start_times = {}
self.epoch_end_times = {}
self.peak_gpu_memory = {}
@property
def elapsed_time(self) -> float:
"""Calculate total elapsed time in seconds."""
return time.time() - self.start_time
def epoch_time(self, epoch: int) -> float | None:
"""Calculate time taken for a specific epoch in seconds."""
if epoch in self.epoch_start_times and epoch in self.epoch_end_times:
return self.epoch_end_times[epoch] - self.epoch_start_times[epoch]
return None
def average_epoch_time(self) -> float | None:
"""Calculate average time per epoch in seconds."""
completed_epochs = [
epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times
]
if not completed_epochs:
return None
total_time = 0.0
for epoch in completed_epochs:
epoch_time = self.epoch_time(epoch)
if epoch_time is not None: # Check to avoid mypy warning
total_time += epoch_time
return total_time / len(completed_epochs)
def steps_per_second(self) -> float | None:
"""Calculate average steps per second across all training."""
if self.total_steps == 0 or self.elapsed_time == 0:
return None
return self.total_steps / self.elapsed_time
def to_dict(self) -> dict[str, Any]:
"""Convert metrics to a dictionary for telemetry reporting."""
metrics = {
"total_time_seconds": self.elapsed_time,
"total_steps": self.total_steps,
"steps_per_second": self.steps_per_second(),
"epochs_completed": len(
[
epoch
for epoch in self.epoch_start_times
if epoch in self.epoch_end_times
]
),
"peak_cpu_memory_bytes": self.peak_cpu_memory,
}
# Add per-epoch timing if available
epoch_times: dict[str, float] = {}
for epoch in sorted(self.epoch_end_times.keys()):
time_taken = self.epoch_time(epoch)
if time_taken is not None:
epoch_times[f"epoch_{epoch}_seconds"] = time_taken
if epoch_times:
metrics["epoch_times"] = epoch_times # type: ignore
metrics["average_epoch_time_seconds"] = self.average_epoch_time()
# Add GPU memory metrics if available
if self.peak_gpu_memory:
gpu_metrics: dict[str, int] = {}
for gpu_id, memory in self.peak_gpu_memory.items():
gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory
metrics["gpu_memory"] = gpu_metrics # type: ignore
return metrics
class RuntimeMetricsTracker:
"""Tracker for runtime metrics during training."""
update_interval = 100
def __init__(self):
"""Initialize the runtime metrics tracker."""
self.metrics = RuntimeMetrics(start_time=time.time())
self.telemetry_manager = TelemetryManager.get_instance()
self._process = psutil.Process()
def start_epoch(self, epoch: int):
"""Record the start of a new epoch."""
self.metrics.current_epoch = epoch
self.metrics.epoch_start_times[epoch] = time.time()
self.update_memory_metrics()
def end_epoch(self, epoch: int):
"""Record the end of an epoch."""
self.metrics.epoch_end_times[epoch] = time.time()
def update_step(self, step: int):
"""Update the current step count."""
self.metrics.current_step = step
self.metrics.total_steps += 1
# Periodically update memory metrics
if step % self.update_interval == 0:
self.update_memory_metrics()
def _get_allocated_memory(self) -> dict[int, int]:
"""
Helper function for getting accelerator-agnostic allocated memory.
Returns:
A dictionary mapping device IDs to allocated memory in bytes
"""
memory_used: dict[int, int] = {}
# NVIDIA GPUs
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
memory_used[i] = torch.cuda.memory_allocated(i)
# AMD GPUs
elif hasattr(torch, "hip") and torch.hip.is_available():
for i in range(torch.hip.device_count()):
if hasattr(torch.hip, "memory_allocated"):
memory_used[i] = torch.hip.memory_allocated(i)
# Apple Silicon
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
# MPS doesn't have per-device memory stats since there's only one device
if hasattr(torch.mps, "current_allocated_memory"):
memory_used[0] = torch.mps.current_allocated_memory()
# Intel GPUs
elif hasattr(torch, "xpu") and torch.xpu.is_available():
for i in range(torch.xpu.device_count()):
if hasattr(torch.xpu, "memory_allocated"):
memory_used[i] = torch.xpu.memory_allocated(i)
# NPUs
elif hasattr(torch, "npu") and torch.npu.is_available():
for i in range(torch.npu.device_count()):
if hasattr(torch.npu, "memory_allocated"):
memory_used[i] = torch.npu.memory_allocated(i)
return memory_used
def update_memory_metrics(self):
"""Update peak memory usage metrics."""
# CPU memory
cpu_memory = self._process.memory_info().rss
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
# GPU memory (if available)
memory_used = self._get_allocated_memory()
for i, memory in memory_used.items():
self.metrics.peak_gpu_memory[i] = max(
self.metrics.peak_gpu_memory.get(i, 0), memory
)
def get_memory_metrics(self) -> dict[str, Any]:
"""Get the current memory metrics as a dictionary."""
memory_metrics = {
"cpu_memory_bytes": self._process.memory_info().rss,
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
}
# GPU memory (if available)
memory_used = self._get_allocated_memory()
for i, memory in memory_used.items():
memory_metrics[f"gpu_{i}_memory_bytes"] = memory
memory_metrics[f"gpu_{i}_peak_memory_bytes"] = (
self.metrics.peak_gpu_memory.get(i, 0)
)
return memory_metrics

View File

@@ -1,33 +0,0 @@
organizations:
- "axolotl-ai-co"
- "meta-llama"
- "huggingface"
- "nvidia"
- "facebook"
- "google"
- "microsoft"
- "deepseek-ai"
- "HuggingFaceTB"
- "mistralai"
- "Qwen"
- "unsloth"
- "NousResearch"
- "allenai"
- "amd"
- "tiiuae"
- "tencent"
- "zai-org"
- "openai"
- "ibm-granite"
- "arcee-ai"
- "swiss-ai"
- "CohereForAI"
- "deepcogito"
- "THUDM"
- "ai21labs"
- "LiquidAI"
- "canopylabs"
- "state-spaces"
- "mistral-community"
- "llava-hf"
- "ByteDance-Seed"

View File

@@ -31,8 +31,6 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders import ModelLoader, load_processor, load_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
@@ -47,9 +45,6 @@ if typing.TYPE_CHECKING:
LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
PLUGIN_MANAGER = PluginManager.get_instance()
def setup_model_and_tokenizer(
cfg: DictDefault,
@@ -67,10 +62,7 @@ def setup_model_and_tokenizer(
`None`), and processor (if multimodal, else `None`).
"""
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
LOG.debug(f"Loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed
@@ -86,14 +78,6 @@ def setup_model_and_tokenizer(
if model.generation_config is not None:
model.generation_config.do_sample = True
TELEMETRY_MANAGER.send_event(
event_type="model-load", properties=model.config.to_dict()
)
if peft_config:
TELEMETRY_MANAGER.send_event(
event_type="peft-config-load", properties=peft_config.to_dict()
)
# Apply freezing if specified
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
@@ -212,7 +196,8 @@ def execute_training(
LOG.info("Starting trainer...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
PLUGIN_MANAGER.post_train(cfg, trainer.model)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train(cfg, trainer.model)
def save_trained_model(
@@ -536,7 +521,9 @@ def setup_model_and_trainer(
model_ref=model_ref,
peft_config=peft_config,
)
PLUGIN_MANAGER.post_trainer_create(cfg, trainer)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
if cfg.use_ray:
try:
@@ -558,7 +545,6 @@ def setup_model_and_trainer(
)
@send_errors
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
@@ -609,6 +595,5 @@ def train(
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()
PLUGIN_MANAGER.post_train(cfg, model)
return model, tokenizer, trainer

View File

@@ -1,132 +0,0 @@
from pathlib import Path
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.utils.distributed import (
barrier,
is_distributed,
is_main_process,
)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
DEFAULT_TRIGGER_FILENAME = "axolotl_checkpoint.save"
class DynamicCheckpointCallback(TrainerCallback):
"""
Callback to save checkpoints on-demand during training via:
1. File-based trigger (works everywhere, rank 0 checks file)
Thread-safe for multi-GPU distributed training.
Usage:
# File-based:
touch /path/to/output_dir/axolotl_checkpoint.save
"""
def _get_config_value(self, config, key, default=None):
"""Helper to get config value from dict or object."""
if isinstance(config, dict):
return config.get(key, default)
return getattr(config, key, default)
def __init__(self, cfg):
self.cfg = cfg
if not cfg.dynamic_checkpoint or not cfg.dynamic_checkpoint.enabled:
self.enabled = False
return
self.enabled = True
dc_config = cfg.dynamic_checkpoint
trigger_file_path = self._get_config_value(dc_config, "trigger_file_path")
self.trigger_filename = (
trigger_file_path if trigger_file_path else DEFAULT_TRIGGER_FILENAME
)
check_interval = self._get_config_value(dc_config, "check_interval")
self.check_interval = check_interval if check_interval is not None else 100
self.should_save_checkpoint = False
LOG.info(
f"Dynamic checkpoint enabled. To trigger checkpoint save:\n"
f" • File: touch {cfg.output_dir}/{self.trigger_filename}\n"
f" • Check interval: every {self.check_interval} steps",
main_process_only=True,
)
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
) -> TrainerControl:
"""
Check for checkpoint triggers at the end of each step.
ONLY rank 0 checks the file, then all ranks synchronize.
"""
if not self.enabled:
return control
trigger_detected = False
if state.global_step % self.check_interval == 0:
if is_main_process():
trigger_path = Path(args.output_dir) / self.trigger_filename
if trigger_path.exists():
trigger_detected = True
try:
trigger_path.unlink() # Delete the trigger file
LOG.info(
f"Dynamic checkpoint triggered via file '{self.trigger_filename}' "
f"at step {state.global_step}",
main_process_only=True,
)
except OSError as exc:
LOG.warning(
f"Failed to delete trigger file: {exc}",
main_process_only=True,
)
if self.should_save_checkpoint:
trigger_detected = True
self.should_save_checkpoint = False # Reset flag
if is_distributed():
import torch
import torch.distributed as dist
device = getattr(
args,
"device",
torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
trigger_tensor = torch.tensor(
1 if trigger_detected else 0,
dtype=torch.long,
device=device,
)
dist.broadcast(trigger_tensor, src=0)
trigger_detected = bool(trigger_tensor.item())
barrier()
if trigger_detected:
control.should_save = True
LOG.info(
f"Saving dynamic checkpoint at step {state.global_step}",
main_process_only=True,
)
return control

View File

@@ -30,7 +30,6 @@ class Mistral3Processor(ProcessorMixin):
Wraps HFMistralTokenizer and adds image processing capabilities.
"""
# TODO(nano): This should be removed in transformers V5
attributes = ["tokenizer"]
tokenizer_class = "HFMistralTokenizer"

View File

@@ -23,7 +23,6 @@ from axolotl.utils.schemas.datasets import (
StepwiseSupervisedDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.fsdp import FSDPConfig
from axolotl.utils.schemas.integrations import (
@@ -142,13 +141,6 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "Reward modelling: `True` or `False`"},
)
dynamic_checkpoint: DynamicCheckpointConfig | None = Field(
default=None,
json_schema_extra={
"description": "Configuration for dynamic checkpointing (trigger by file or signal). "
"Set 'enabled: true' to activate this feature."
},
)
process_reward_model: bool | None = Field(
default=None,
json_schema_extra={
@@ -1069,7 +1061,7 @@ class AxolotlInputConfig(
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""Wrapper to valdiate GPU capabilities with the configured options"""
"""wrapper to valdiate GPU capabilities with the configured options"""
capabilities: GPUCapabilities
env_capabilities: EnvCapabilities

View File

@@ -1,31 +0,0 @@
"""Schema for dynamic checkpoint configuration."""
from pydantic import BaseModel, Field
class DynamicCheckpointConfig(BaseModel):
"""Configuration for dynamic checkpoint triggering during training."""
enabled: bool = Field(
default=False,
json_schema_extra={
"description": "Enable dynamic checkpoint triggering during training. "
"Create a file 'axolotl_checkpoint.save' in the configured `output_dir` to trigger. "
},
)
check_interval: int = Field(
default=10,
ge=1,
json_schema_extra={
"description": "Check for trigger file every N steps (reduces I/O overhead). "
"Default: 100"
},
)
trigger_file_path: str = Field(
default="",
json_schema_extra={
"description": "Custom trigger filename (optional). "
"If not specified, defaults to 'axolotl_checkpoint.save'. "
"Specify a filename (not a full path) to override the default."
},
)

View File

@@ -173,9 +173,3 @@ class TRLConfig(BaseModel):
"description": "Enable sleep mode for vLLM to offload VRAM when idle"
},
)
rollout_func: str | None = Field(
default=None,
json_schema_extra={
"description": "Path to custom rollout function. Must be importable from current dir."
},
)

View File

@@ -1,4 +1,6 @@
"""Shared pytest fixtures"""
"""
shared pytest fixtures
"""
import functools
import importlib
@@ -580,9 +582,3 @@ def test_load_fixtures(
download_llama2_model_fixture,
):
pass
@pytest.fixture(autouse=True)
def disable_telemetry(monkeypatch):
monkeypatch.setenv("AXOLOTL_DO_NOT_TRACK", "1")
yield

View File

@@ -396,10 +396,10 @@ def rand_reward_func(prompts, completions) -> list[float]:
),
("orpo_cfg", None), # don't use fixture for orpo to use smaller split
("kto_cfg", None), # no fixture for kto
# (
# "simpo_cfg",
# "dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
# ),
(
"simpo_cfg",
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
),
],
)
def test_custom_optimizer_cls_and_kwargs(

View File

@@ -2,8 +2,6 @@
Simple end-to-end test for Liger integration
"""
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -64,11 +62,7 @@ class LigerIntegrationTestCase:
check_model_output_exists(temp_dir, cfg)
@require_torch_2_4_1
@pytest.mark.parametrize(
"liger_use_token_scaling",
[True, False],
)
def test_llama_w_flce(self, temp_dir, liger_use_token_scaling):
def test_llama_w_flce(self, temp_dir):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -80,7 +74,6 @@ class LigerIntegrationTestCase:
"liger_glu_activation": True,
"liger_cross_entropy": False,
"liger_fused_linear_cross_entropy": True,
"liger_use_token_scaling": liger_use_token_scaling,
"sequence_len": 1024,
"val_set_size": 0.05,
"special_tokens": {

View File

@@ -144,7 +144,7 @@ def recursive_kill(process: subprocess.Popen):
@pytest.mark.skip(reason="flaky vllm tests in modal")
class TestGRPO:
"""
Test case for GRPO training using multiple GPUs
Test case for GRPO training using multilpe GPUs
"""
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):

View File

@@ -14,7 +14,7 @@ class TestPreprocess:
"""test cases for preprocess"""
def test_w_deepspeed(self, temp_dir):
"""make sure preprocess doesn't choke when using deepspeed in the config"""
"""make sure preproces doesn't choke when using deepspeed in the config"""
cfg = DictDefault(
{

View File

@@ -75,19 +75,3 @@ class TestValidation:
):
prepare_plugins(test_cfg)
validate_config(test_cfg)
def test_use_token_scaling_require_flce(self, minimal_liger_cfg):
test_cfg = DictDefault(
{
"liger_fused_linear_cross_entropy": False,
"liger_use_token_scaling": True,
}
| minimal_liger_cfg
)
with pytest.raises(
ValueError,
match=r"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled.",
):
prepare_plugins(test_cfg)
validate_config(test_cfg)

View File

@@ -69,7 +69,7 @@ class TestQwen3IdenticalConversationArgs:
{
"function": {
"name": function_name,
"arguments": arguments_dict, # dict
"arguments": arguments_dict, # dict格式
}
}
],
@@ -100,7 +100,7 @@ class TestQwen3IdenticalConversationArgs:
{
"function": {
"name": function_name,
"arguments": arguments_str, # str
"arguments": arguments_str, # str格式
}
}
],
@@ -212,294 +212,3 @@ class TestQwen3IdenticalConversationArgs:
decoded = qwen3_tokenizer.decode(processed[0]["input_ids"])
assert "2025-08-01" in decoded, "String time value should be present"
assert "1690876800" in decoded, "Number time value should be present"
class TestQwen3IdenticalToolsParameters:
"""
Test Qwen3 tools parameters handling is identical between JSON string and dict
"""
@pytest.fixture(name="tools_dict_params_dataset")
def fixture_tools_dict_params_dataset(self):
"""
Provides a dataset with tools where parameters is a dict.
"""
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
data = [
{
"tools": tools,
"messages": [
{"role": "user", "content": "What's the weather?"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "Boston, MA"},
},
}
],
},
{
"role": "tool",
"name": "get_weather",
"content": "72°F and sunny",
},
],
}
]
return Dataset.from_list(data)
@pytest.fixture(name="tools_str_params_dataset")
def fixture_tools_str_params_dataset(self):
"""
Provides a dataset with tools where parameters is a JSON string.
"""
parameters_dict = {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
}
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather information",
"parameters": json.dumps(parameters_dict),
},
}
]
data = [
{
"tools": tools,
"messages": [
{"role": "user", "content": "What's the weather?"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "Boston, MA"},
},
}
],
},
{
"role": "tool",
"name": "get_weather",
"content": "72°F and sunny",
},
],
}
]
return Dataset.from_list(data)
@pytest.fixture(name="tools_mixed_type_params_dataset")
def fixture_tools_mixed_type_params_dataset(self):
"""
Provides a dataset where different tools have the same parameter name with different types.
This tests that JSON string format prevents casting issues.
"""
tools = [
{
"type": "function",
"function": {
"name": "tool_with_string_arg",
"description": "Tool expecting string argument",
"parameters": json.dumps(
{
"type": "object",
"properties": {
"arg1": {
"type": "string",
"description": "A string parameter",
}
},
"required": ["arg1"],
}
),
},
},
{
"type": "function",
"function": {
"name": "tool_with_number_arg",
"description": "Tool expecting number argument",
"parameters": json.dumps(
{
"type": "object",
"properties": {
"arg1": {
"type": "number",
"description": "A numeric parameter",
}
},
"required": ["arg1"],
}
),
},
},
]
data = [
{
"tools": tools,
"messages": [
{"role": "user", "content": "Use both tools"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "tool_with_string_arg",
"arguments": json.dumps({"arg1": "hello"}),
},
},
{
"type": "function",
"function": {
"name": "tool_with_number_arg",
"arguments": json.dumps({"arg1": 42}),
},
},
],
},
],
}
]
return Dataset.from_list(data)
def test_dict_and_str_params_produce_equivalent_output(
self,
tools_dict_params_dataset,
tools_str_params_dataset,
qwen3_instruct_prompt_strategy,
qwen3_tokenizer,
):
"""
Tests that after tokenization and decoding, the outputs for both
dict and string `parameters` in tools are semantically equivalent.
"""
import re
processed_dict_params = tools_dict_params_dataset.map(
qwen3_instruct_prompt_strategy.tokenize_prompt,
batched=True,
remove_columns=["messages", "tools"],
)
processed_str_params = tools_str_params_dataset.map(
qwen3_instruct_prompt_strategy.tokenize_prompt,
batched=True,
remove_columns=["messages", "tools"],
)
decoded_dict = qwen3_tokenizer.decode(processed_dict_params[0]["input_ids"])
decoded_str = qwen3_tokenizer.decode(processed_str_params[0]["input_ids"])
# Extract the tool JSON from both outputs
tools_pattern = r"<tools>\n(.*?)\n</tools>"
dict_tools_match = re.search(tools_pattern, decoded_dict, re.DOTALL)
str_tools_match = re.search(tools_pattern, decoded_str, re.DOTALL)
assert dict_tools_match and str_tools_match, (
"Could not find tools section in output"
)
# Parse the JSON and compare as objects (order-independent)
dict_tools_json = json.loads(dict_tools_match.group(1))
str_tools_json = json.loads(str_tools_match.group(1))
# Deep comparison of the tool definitions
assert dict_tools_json == str_tools_json, (
f"Tool definitions are not equivalent:\n"
f"Dict format: {json.dumps(dict_tools_json, indent=2)}\n"
f"String format: {json.dumps(str_tools_json, indent=2)}"
)
# Verify the rest of the structure is the same (excluding the tools JSON part)
# The tools JSON can have different order, so we remove it here.
dict_normalized = re.sub(
r"<tools>.*?</tools>",
"<tools>TOOLS_PLACEHOLDER</tools>",
decoded_dict,
flags=re.DOTALL,
)
str_normalized = re.sub(
r"<tools>.*?</tools>",
"<tools>TOOLS_PLACEHOLDER</tools>",
decoded_str,
flags=re.DOTALL,
)
assert dict_normalized == str_normalized, (
"The overall structure differs between dict and string parameter formats"
)
def test_str_params_with_mixed_types_no_error(
self,
tools_mixed_type_params_dataset,
qwen3_instruct_prompt_strategy,
qwen3_tokenizer,
):
"""
Tests that when different tools have the same parameter name with different types,
JSON string format for parameters doesn't cause casting errors.
"""
processed = tools_mixed_type_params_dataset.map(
qwen3_instruct_prompt_strategy.tokenize_prompt,
batched=True,
remove_columns=["messages", "tools"],
)
assert len(processed) == 1
assert "input_ids" in processed[0]
assert len(processed[0]["input_ids"]) > 0
decoded = qwen3_tokenizer.decode(processed[0]["input_ids"])
# Check that both tools are present
assert "tool_with_string_arg" in decoded
assert "tool_with_number_arg" in decoded
# Check that both argument values are present
assert "hello" in decoded
assert "42" in decoded

View File

@@ -1,9 +0,0 @@
"""Shared pytest fixtures for telemetry tests."""
import pytest
@pytest.fixture(autouse=True)
def del_track_env(monkeypatch):
monkeypatch.delenv("AXOLOTL_DO_NOT_TRACK", raising=False)
yield

View File

@@ -1,373 +0,0 @@
"""Tests for telemetry callback module."""
# pylint: disable=redefined-outer-name
import time
from unittest.mock import MagicMock, patch
import pytest
from transformers import TrainerControl, TrainerState, TrainingArguments
from axolotl.telemetry.callbacks import TIME_SINCE_LAST, TelemetryCallback
def calc_expected_metrics(step, last_step, current_time, last_time, start_time=900.0):
"""Calculate expected metrics values for tests"""
time_diff = current_time - last_time
step_diff = step - last_step
return {
"steps_per_second": (
step_diff / time_diff if time_diff > 0 and step_diff > 0 else 0
),
"time_since_last_report": time_diff,
"elapsed_time": current_time - start_time,
}
@pytest.fixture
def mock_time():
"""Mock time.time() to have predictable values in tests"""
with patch("axolotl.telemetry.callbacks.time") as mock_time:
mock_time.time.return_value = 1000.0
yield mock_time
@pytest.fixture
def mock_telemetry_manager():
"""Create a mock TelemetryManager"""
with patch("axolotl.telemetry.callbacks.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager_class.get_instance.return_value = mock_manager
yield mock_manager
@pytest.fixture
def mock_runtime_metrics_tracker():
"""Create a mock RuntimeMetricsTracker"""
with patch(
"axolotl.telemetry.callbacks.RuntimeMetricsTracker"
) as mock_tracker_class:
mock_tracker = MagicMock()
# Set up metrics property on the tracker
mock_metrics = MagicMock()
mock_metrics.to_dict.return_value = {
"total_steps": 100,
"peak_cpu_memory_bytes": 1024,
}
mock_tracker.metrics = mock_metrics
# Make the constructor return our mock
mock_tracker_class.return_value = mock_tracker
yield mock_tracker
@pytest.fixture
def training_args():
"""Create a minimal TrainingArguments instance"""
return TrainingArguments(output_dir="./output")
@pytest.fixture
def trainer_state():
"""Create a mock TrainerState"""
state = MagicMock(spec=TrainerState)
state.global_step = 10
state.epoch = 0.5 # halfway through first epoch
state.log_history = [{"loss": 2.5, "learning_rate": 5e-5}]
return state
@pytest.fixture
def trainer_control():
"""Create a mock TrainerControl"""
return MagicMock(spec=TrainerControl)
# pylint: disable=unused-argument
@pytest.fixture
def callback(mock_telemetry_manager, mock_runtime_metrics_tracker):
"""Create a TelemetryCallback instance with mocked dependencies"""
return TelemetryCallback()
class TestTelemetryCallback:
"""Tests for the TelemetryCallback class."""
def test_initialization(self, callback, mock_runtime_metrics_tracker):
"""Test callback initialization."""
assert callback.current_epoch == -1
assert callback.tracker == mock_runtime_metrics_tracker
assert callback.last_report_step == 0
assert hasattr(callback, "start_time")
assert hasattr(callback, "last_report_time")
assert callback.report_interval_steps == 100
def test_on_train_begin(
self,
callback,
mock_telemetry_manager,
training_args,
trainer_state,
trainer_control,
):
"""Test on_train_begin sends expected event."""
callback.on_train_begin(training_args, trainer_state, trainer_control)
mock_telemetry_manager.send_event.assert_called_once_with(
event_type="train-start"
)
def test_on_train_end(
self,
callback,
mock_telemetry_manager,
training_args,
trainer_state,
trainer_control,
):
"""Test on_train_end sends expected event with metrics."""
callback.on_train_end(training_args, trainer_state, trainer_control)
mock_telemetry_manager.send_event.assert_called_once()
call_args = mock_telemetry_manager.send_event.call_args[1]
assert call_args["event_type"] == "train-end"
assert "loss" in call_args["properties"]
assert call_args["properties"]["loss"] == 2.5
assert "learning_rate" in call_args["properties"]
assert call_args["properties"]["learning_rate"] == 5e-5
# Check that metrics from RuntimeMetricsTracker are included
assert "total_steps" in call_args["properties"]
assert call_args["properties"]["total_steps"] == 100
assert "peak_cpu_memory_bytes" in call_args["properties"]
assert call_args["properties"]["peak_cpu_memory_bytes"] == 1024
def test_on_epoch_begin(
self,
callback,
mock_runtime_metrics_tracker,
training_args,
trainer_state,
trainer_control,
):
"""Test on_epoch_begin updates epoch counter and calls tracker."""
initial_epoch = callback.current_epoch
callback.on_epoch_begin(training_args, trainer_state, trainer_control)
assert callback.current_epoch == initial_epoch + 1
mock_runtime_metrics_tracker.start_epoch.assert_called_once_with(
initial_epoch + 1
)
def test_on_epoch_end(
self,
callback,
mock_runtime_metrics_tracker,
training_args,
trainer_state,
trainer_control,
):
"""Test on_epoch_end calls tracker."""
# Set current epoch
callback.current_epoch = 2
callback.on_epoch_end(training_args, trainer_state, trainer_control)
mock_runtime_metrics_tracker.end_epoch.assert_called_once_with(2)
def test_on_step_end_no_report(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker,
training_args,
trainer_state,
trainer_control,
):
"""Test on_step_end updates tracker but doesn't report if criteria not met."""
# Set up state to avoid reporting
trainer_state.global_step = 42 # Not divisible by report_interval_steps
callback.last_report_step = 41 # Just 1 step since last report
callback.last_report_time = time.time() # Just now
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should update tracker
mock_runtime_metrics_tracker.update_step.assert_called_once_with(42)
# Should not send telemetry
mock_telemetry_manager.send_event.assert_not_called()
# Should not update last report time/step
assert callback.last_report_step == 41
def test_on_step_end_report_interval_steps(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker,
mock_time,
training_args,
trainer_state,
trainer_control,
):
"""Test on_step_end reports when step interval is reached."""
# Set up state with clear values
current_step = 100 # Exactly matches report_interval_steps
last_step = 0
start_time = 900.0
current_time = 1000.0
time_diff = current_time - start_time # 100 seconds
# Configure state and callback
trainer_state.global_step = current_step
callback.report_interval_steps = 100
callback.last_report_step = last_step
callback.start_time = start_time
callback.last_report_time = start_time
# Mock time.time() to return consistent values
mock_time.time.return_value = current_time
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should update tracker
mock_runtime_metrics_tracker.update_step.assert_called_once_with(current_step)
mock_runtime_metrics_tracker.update_memory_metrics.assert_called_once()
# Should send telemetry
mock_telemetry_manager.send_event.assert_called_once()
call_args = mock_telemetry_manager.send_event.call_args[1]
assert call_args["event_type"] == "train-progress"
# Properties should include expected values
props = call_args["properties"]
assert props["step"] == current_step
assert props["elapsed_time"] == time_diff # 1000 - 900 = 100
assert props["time_since_last_report"] == time_diff # 1000 - 900 = 100
assert props["steps_per_second"] == 1.0 # 100 steps / 100 seconds
# Should update last report time/step
assert callback.last_report_step == current_step
assert callback.last_report_time == current_time
def test_on_step_end_report_time_elapsed(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
mock_time,
training_args,
trainer_state,
trainer_control,
):
"""Test on_step_end reports when enough time has elapsed."""
# Set up state with clear values
current_step = 120
last_step = 10
start_time = 900.0
current_time = 1000.0
time_diff = TIME_SINCE_LAST + 1 # Just over the threshold
# Configure state and callback
trainer_state.global_step = current_step
callback.report_interval_steps = 100
callback.last_report_step = last_step
callback.start_time = start_time
callback.last_report_time = current_time - time_diff
# Mock time.time() to return consistent values
mock_time.time.return_value = current_time
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should send telemetry
mock_telemetry_manager.send_event.assert_called_once()
# Properties should include expected values
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
expected_metrics = calc_expected_metrics(
current_step, last_step, current_time, current_time - time_diff, start_time
)
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
assert (
props["time_since_last_report"]
== expected_metrics["time_since_last_report"]
)
def test_on_step_end_first_step(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
mock_time,
training_args,
trainer_state,
trainer_control,
):
"""Test on_step_end always reports on first step."""
# Set up state with clear values
current_step = 1 # First step
last_step = 0
start_time = 900.0
current_time = 1000.0
last_report_time = 999.0 # Just 1 second ago
# Configure state and callback
trainer_state.global_step = current_step
callback.report_interval_steps = 100
callback.last_report_step = last_step
callback.start_time = start_time
callback.last_report_time = last_report_time
# Mock time.time() to return consistent values
mock_time.time.return_value = current_time
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should send telemetry even though not much time has passed
mock_telemetry_manager.send_event.assert_called_once()
# Properties should include expected values for first step
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
assert props["step"] == current_step
expected_metrics = calc_expected_metrics(
current_step, last_step, current_time, last_report_time, start_time
)
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
def test_log_history_empty(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
mock_time,
training_args,
trainer_state,
trainer_control,
):
"""Test handling of empty log history."""
# Set up state with clear values
current_step = 1
start_time = 900.0
current_time = 1000.0
# Configure state and callback
trainer_state.global_step = current_step
trainer_state.log_history = []
callback.start_time = start_time
# Mock time.time() to return consistent values
mock_time.time.return_value = current_time
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should still send telemetry
mock_telemetry_manager.send_event.assert_called_once()
# Properties should have default values for missing log data
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
assert props["loss"] == 0
assert props["learning_rate"] == 0

View File

@@ -1,341 +0,0 @@
"""Tests for telemetry error utilities"""
# pylint: disable=redefined-outer-name
from unittest.mock import MagicMock, patch
import pytest
from axolotl.telemetry.errors import sanitize_stack_trace, send_errors
@pytest.fixture(autouse=True)
def reset_error_flag(monkeypatch):
"""Reset ERROR_HANDLED flag using monkeypatch"""
import axolotl.telemetry.errors
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
yield
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
@pytest.fixture
def example_stack_trace():
"""Provide a sample stack trace with mixed paths"""
return """Traceback (most recent call last):
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
trainer = get_trainer(cfg)
File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer
model = get_model(cfg, tokenizer)
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model
raise ValueError("Model path not found")
ValueError: Model path not found
"""
@pytest.fixture
def windows_stack_trace():
"""Provide a sample stack trace with Windows paths"""
return """Traceback (most recent call last):
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main
trainer = get_trainer(cfg)
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer
model = get_model(cfg, tokenizer)
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained
raise ValueError(f"Unrecognized configuration class {config.__class__}")
ValueError: Unrecognized configuration class <class 'transformers.models.llama.configuration_llama.LlamaConfig'>
"""
@pytest.fixture
def mixed_stack_trace():
"""Provide a sample stack trace with both axolotl and non-axolotl paths"""
return """Traceback (most recent call last):
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
trainer = get_trainer(cfg)
File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train
self._inner_training_loop()
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop
super()._inner_training_loop()
File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
RuntimeError: CUDA out of memory
"""
@pytest.fixture
def venv_stack_trace():
"""Provide a sample stack trace with virtual environment paths"""
return """Traceback (most recent call last):
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train
self._inner_training_loop()
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop
self.accelerator.backward(loss)
File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward
self.scaler.scale(loss).backward(**kwargs)
File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
RuntimeError: CUDA out of memory
"""
@pytest.fixture
def dist_packages_stack_trace():
"""Provide a sample stack trace with dist-packages paths"""
return """Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data
data = self._dataset_fetcher.fetch(index)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__
raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.")
IndexError: Index 10000 out of range for dataset of length 9832.
"""
@pytest.fixture
def project_stack_trace():
"""Provide a sample stack trace from a project directory (not a virtual env)"""
return """Traceback (most recent call last):
File "/home/user/projects/myproject/run.py", line 25, in <module>
main()
File "/home/user/projects/myproject/src/cli.py", line 45, in main
app.run()
File "/home/user/projects/myproject/src/app.py", line 102, in run
raise ValueError("Configuration missing")
ValueError: Configuration missing
"""
def test_sanitize_stack_trace(example_stack_trace):
"""Test that sanitize_stack_trace properly preserves axolotl paths"""
sanitized = sanitize_stack_trace(example_stack_trace)
# Check that personal paths are removed
assert "/home/user" not in sanitized
assert ".local/lib/python3.9" not in sanitized
# Check that site-packages is preserved
assert "site-packages/axolotl/cli/train.py" in sanitized
assert "site-packages/axolotl/train.py" in sanitized
assert "site-packages/axolotl/utils/models.py" in sanitized
# Check that error message is preserved
assert "ValueError: Model path not found" in sanitized
def test_sanitize_windows_paths(windows_stack_trace):
"""Test that sanitize_stack_trace handles Windows paths"""
sanitized = sanitize_stack_trace(windows_stack_trace)
# Check that personal paths are removed
assert "C:\\Users\\name" not in sanitized
assert "AppData\\Local\\Programs\\Python" not in sanitized
# Check that both axolotl and transformers packages are preserved
assert (
"site-packages\\axolotl\\cli\\train.py" in sanitized
or "site-packages/axolotl/cli/train.py" in sanitized
)
assert (
"site-packages\\axolotl\\train.py" in sanitized
or "site-packages/axolotl/train.py" in sanitized
)
assert (
"site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized
or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized
)
# Check that error message is preserved
assert "ValueError: Unrecognized configuration class" in sanitized
def test_sanitize_mixed_paths(mixed_stack_trace):
"""Test that sanitize_stack_trace preserves all package paths"""
sanitized = sanitize_stack_trace(mixed_stack_trace)
# Check that all package paths are preserved
assert "site-packages/axolotl/cli/train.py" in sanitized
assert "site-packages/transformers/trainer.py" in sanitized
assert "site-packages/axolotl/utils/trainer.py" in sanitized
assert "site-packages/torch/utils/data/dataloader.py" in sanitized
# Check that error message is preserved
assert "RuntimeError: CUDA out of memory" in sanitized
def test_sanitize_venv_paths(venv_stack_trace):
"""Test that sanitize_stack_trace preserves virtual environment package paths"""
sanitized = sanitize_stack_trace(venv_stack_trace)
# Check that personal paths are removed
assert "/home/user/venv" not in sanitized
# Check that all package paths are preserved
assert "site-packages/transformers/trainer.py" in sanitized
assert "site-packages/accelerate/accelerator.py" in sanitized
assert "site-packages/torch/_tensor.py" in sanitized
# Check that error message is preserved
assert "RuntimeError: CUDA out of memory" in sanitized
def test_sanitize_dist_packages(dist_packages_stack_trace):
"""Test that sanitize_stack_trace preserves dist-packages paths"""
sanitized = sanitize_stack_trace(dist_packages_stack_trace)
# Check that system paths are removed
assert "/usr/local/lib/python3.8" not in sanitized
# Check that all package paths are preserved
assert "dist-packages/torch/utils/data/dataloader.py" in sanitized
assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized
assert "dist-packages/datasets/arrow_dataset.py" in sanitized
# Check that error message is preserved
assert (
"IndexError: Index 10000 out of range for dataset of length 9832." in sanitized
)
def test_sanitize_project_paths(project_stack_trace):
"""Test handling of project paths (non-virtual env)"""
sanitized = sanitize_stack_trace(project_stack_trace)
# Check that personal paths are removed
assert "/home/user/projects" not in sanitized
# For non-package paths, we should at least preserve the filename
assert "run.py" in sanitized
assert "cli.py" in sanitized
assert "app.py" in sanitized
# Check that error message is preserved
assert "ValueError: Configuration missing" in sanitized
@pytest.fixture
def mock_telemetry_manager():
"""Create a mock TelemetryManager"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
yield mock_manager
def test_send_errors_successful_execution(mock_telemetry_manager):
"""Test that send_errors doesn't send telemetry for successful function execution"""
@send_errors
def test_func():
return "success"
result = test_func()
assert result == "success"
mock_telemetry_manager.send_event.assert_not_called()
def test_send_errors_with_exception(mock_telemetry_manager):
"""Test that send_errors sends telemetry when an exception occurs"""
test_error = ValueError("Test error")
@send_errors
def test_func():
raise test_error
with pytest.raises(ValueError) as excinfo:
test_func()
assert excinfo.value == test_error
mock_telemetry_manager.send_event.assert_called_once()
# Check that the error info was passed correctly
call_args = mock_telemetry_manager.send_event.call_args[1]
assert "test_func-error" in call_args["event_type"]
assert "Test error" in call_args["properties"]["exception"]
assert "stack_trace" in call_args["properties"]
def test_send_errors_nested_calls(mock_telemetry_manager):
"""Test that send_errors only sends telemetry once for nested decorated functions"""
@send_errors
def inner_func():
raise ValueError("Inner error")
@send_errors
def outer_func():
return inner_func()
with pytest.raises(ValueError):
outer_func()
# Telemetry should be sent only once for the inner function
assert mock_telemetry_manager.send_event.call_count == 1
call_args = mock_telemetry_manager.send_event.call_args[1]
assert "inner_func-error" in call_args["event_type"]
def test_send_errors_telemetry_disable():
"""Test that send_errors doesn't attempt to send telemetry when disabled"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = False
mock_manager_class.get_instance.return_value = mock_manager
@send_errors
def test_func():
raise ValueError("Test error")
with pytest.raises(ValueError):
test_func()
mock_manager.send_event.assert_not_called()
def test_error_handled_reset():
"""Test that ERROR_HANDLED flag is properly reset"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
# Create and configure the mock manager
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
from axolotl.telemetry.errors import ERROR_HANDLED
@send_errors
def test_func():
raise ValueError("Test error")
assert not ERROR_HANDLED
with pytest.raises(ValueError):
test_func()
from axolotl.telemetry.errors import ERROR_HANDLED
assert ERROR_HANDLED
def test_module_path_resolution(mock_telemetry_manager):
"""Test that the module path is correctly resolved for the event type"""
import inspect
current_module = inspect.getmodule(test_module_path_resolution).__name__
@send_errors
def test_func():
raise ValueError("Test error")
with pytest.raises(ValueError):
test_func()
assert mock_telemetry_manager.send_event.called
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
expected_event_type = f"{current_module}.test_func-error"
assert expected_event_type == event_type

View File

@@ -1,275 +0,0 @@
"""Tests for TelemetryManager class and utilities"""
# pylint: disable=redefined-outer-name,protected-access
import os
from unittest.mock import patch
import pytest
import yaml
from axolotl.telemetry.manager import TelemetryManager
@pytest.fixture
def mock_whitelist(tmp_path):
"""Create a temporary whitelist file for testing"""
whitelist_content = {
"organizations": ["meta-llama", "mistralai"],
}
whitelist_file = tmp_path / "whitelist.yaml"
with open(whitelist_file, "w", encoding="utf-8") as f:
yaml.dump(whitelist_content, f)
return str(whitelist_file)
@pytest.fixture
def telemetry_manager_class():
"""Reset the TelemetryManager singleton between tests"""
original_instance = TelemetryManager._instance
original_initialized = TelemetryManager._initialized
TelemetryManager._instance = None
TelemetryManager._initialized = False
yield TelemetryManager
TelemetryManager._instance = original_instance
TelemetryManager._initialized = original_initialized
@pytest.fixture
def manager(telemetry_manager_class, mock_whitelist):
"""Create a TelemetryManager instance with mocked dependencies"""
with (
patch("posthog.capture"),
patch("posthog.flush"),
patch("time.sleep"),
patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist),
patch.dict(os.environ, {"RANK": "0"}),
):
manager = telemetry_manager_class()
# Manually enable for most tests
manager.enabled = True
return manager
def test_singleton_instance(telemetry_manager_class):
"""Test that TelemetryManager is a singleton"""
with (
patch("posthog.capture"),
patch("time.sleep"),
patch.dict(os.environ, {"RANK": "0"}),
):
first = telemetry_manager_class()
second = telemetry_manager_class()
assert first is second
assert telemetry_manager_class.get_instance() is first
def test_telemetry_enabled_by_default(telemetry_manager_class):
"""Test that telemetry is enabled by default (opt-out)"""
with (
patch.dict(os.environ, {"RANK": "0"}, clear=True),
patch("time.sleep"),
patch("logging.Logger.info"),
):
manager = telemetry_manager_class()
assert manager.enabled
def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class):
"""Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0"""
with (
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}),
patch("time.sleep"),
):
manager = telemetry_manager_class()
assert manager.enabled
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
with (
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}),
patch("time.sleep"),
):
manager = telemetry_manager_class()
assert not manager.enabled
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
with (
patch.dict(
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"}
),
patch("time.sleep"),
):
manager = telemetry_manager_class()
assert not manager.enabled
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
"""Test that telemetry is disabled for non-main processes"""
with (
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}),
patch("time.sleep"),
):
manager = telemetry_manager_class()
assert not manager.enabled
def test_opt_in_info_displayed(telemetry_manager_class):
"""Test that opt-in info is displayed when telemetry is not configured"""
with (
patch.dict(os.environ, {"RANK": "0"}, clear=True),
patch("logging.Logger.warning") as mock_warning,
patch("time.sleep"),
):
telemetry_manager_class()
assert any(
"Telemetry is now enabled by default" in str(call)
for call in mock_warning.call_args_list
)
def test_is_whitelisted(telemetry_manager_class, mock_whitelist):
"""Test org whitelist functionality"""
with (
patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist),
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
):
manager = telemetry_manager_class()
# Should match organizations from the mock whitelist
assert manager._is_whitelisted("meta-llama/llama-7b")
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
# Should not match
assert not manager._is_whitelisted("unknown/model")
# Should handle case insensitively
assert manager._is_whitelisted("META-LLAMA/Llama-7B")
# Should handle empty input
assert not manager._is_whitelisted("")
def test_system_info_collection(manager):
"""Test system information collection"""
system_info = manager._get_system_info()
# Check essential keys
assert "os" in system_info
assert "python_version" in system_info
assert "cpu_count" in system_info
assert "memory_total" in system_info
assert "accelerator_count" in system_info
def test_send_event(telemetry_manager_class):
"""Test basic event sending"""
with (
patch("posthog.capture") as mock_capture,
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
):
manager = telemetry_manager_class()
# Test with clean properties (no PII)
manager.send_event("test_event", {"key": "value"})
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "test_event"
assert mock_capture.call_args[1]["properties"] == {"key": "value"}
assert mock_capture.call_args[1]["distinct_id"] == manager.run_id
# Test with default properties (None)
mock_capture.reset_mock()
manager.send_event("simple_event")
assert mock_capture.called
assert mock_capture.call_args[1]["properties"] == {}
def test_send_system_info(telemetry_manager_class):
"""Test sending system info"""
with (
patch("posthog.capture") as mock_capture,
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
):
manager = telemetry_manager_class()
manager.send_system_info()
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "system-info"
assert mock_capture.call_args[1]["properties"] == manager.system_info
def test_redacted_properties(telemetry_manager_class):
"""Test path redaction in send_event method"""
with (
patch("posthog.capture") as mock_capture,
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}),
):
manager = telemetry_manager_class()
# Test with properties containing various paths and non-paths
test_properties = {
"filepath": "/home/user/sensitive/data.txt",
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
"output_dir": "/var/lib/data",
"path_to_model": "models/llama/7b",
"message": "Training started", # Should not be redacted
"metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted
"base_model": "models/local_model",
"nested": {
"model_path": "/models/my_model",
"root_dir": "/home/user/projects",
"stats": {"steps": 1000, "epochs": 3}, # Should not be redacted
},
}
manager.send_event("test_event", test_properties)
# Verify the call was made
assert mock_capture.called
# Get the sanitized properties that were sent
sanitized = mock_capture.call_args[1]["properties"]
# Check that path-like and base_model keys were redacted
assert sanitized["filepath"] == "[REDACTED]"
assert sanitized["windows_path"] == "[REDACTED]"
assert sanitized["path_to_model"] == "[REDACTED]"
assert sanitized["base_model"] == "[REDACTED]"
# Check that non-path values were preserved
assert sanitized["message"] == "Training started"
assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95}
# Check nested structure handling
assert sanitized["nested"]["model_path"] == "[REDACTED]"
assert sanitized["nested"]["root_dir"] == "[REDACTED]"
assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3}
def test_disable_telemetry(manager):
"""Test that disabled telemetry doesn't send events"""
with patch("posthog.capture") as mock_capture:
manager.enabled = False
manager.send_event("test_event")
assert not mock_capture.called
def test_exception_handling_during_send(manager):
"""Test that exceptions in PostHog are handled gracefully"""
with (
patch("posthog.capture", side_effect=Exception("Test error")),
patch("logging.Logger.warning") as mock_warning,
):
manager.send_event("test_event")
warning_logged = False
for call in mock_warning.call_args_list:
if "Failed to send telemetry event" in str(call):
warning_logged = True
break
assert warning_logged
def test_shutdown(manager):
"""Test shutdown behavior"""
with patch("posthog.shutdown") as mock_shutdown:
manager.shutdown()
assert mock_shutdown.called

View File

@@ -1,357 +0,0 @@
"""Tests for runtime metrics telemetry module"""
# pylint: disable=redefined-outer-name
from unittest.mock import MagicMock, patch
import pytest
from axolotl.telemetry.runtime_metrics import RuntimeMetrics, RuntimeMetricsTracker
@pytest.fixture
def mock_time():
"""Mock time.time() to have predictable values in tests"""
with patch("time.time") as mock_time:
# Start with time 1000.0 and increment by 10 seconds on each call
times = [1000.0 + i * 10 for i in range(10)]
mock_time.side_effect = times
yield mock_time
@pytest.fixture
def mock_telemetry_manager():
"""Create a mock TelemetryManager"""
with patch(
"axolotl.telemetry.runtime_metrics.TelemetryManager"
) as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
yield mock_manager
@pytest.fixture
def mock_psutil():
"""Mock psutil for memory information"""
with patch("axolotl.telemetry.runtime_metrics.psutil") as mock_psutil:
mock_process = MagicMock()
mock_memory_info = MagicMock()
# Set initial memory to 1GB
mock_memory_info.rss = 1024 * 1024 * 1024
mock_process.memory_info.return_value = mock_memory_info
mock_psutil.Process.return_value = mock_process
yield mock_psutil
@pytest.fixture
def mock_torch():
"""Mock torch.cuda functions"""
with patch("axolotl.telemetry.runtime_metrics.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = True
mock_torch.cuda.device_count.return_value = 2
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 1) * 1024 * 1024 * 1024
)
yield mock_torch
class TestRuntimeMetrics:
"""Tests for RuntimeMetrics class."""
def test_initialization(self):
"""Test RuntimeMetrics initialization."""
metrics = RuntimeMetrics(start_time=1000.0)
assert metrics.start_time == 1000.0
assert metrics.epoch_start_times == {}
assert metrics.epoch_end_times == {}
assert metrics.peak_gpu_memory == {}
assert metrics.total_steps == 0
assert metrics.current_epoch == 0
assert metrics.current_step == 0
assert metrics.peak_cpu_memory == 0
def test_elapsed_time(self, mock_time):
"""Test elapsed_time property."""
metrics = RuntimeMetrics(start_time=1000.0)
# Mock time.time() to return 1050.0
mock_time.side_effect = [1050.0]
assert metrics.elapsed_time == 50.0
def test_epoch_time(self):
"""Test epoch_time method."""
metrics = RuntimeMetrics(start_time=1000.0)
# No epoch data
assert metrics.epoch_time(0) is None
# Add epoch start but no end
metrics.epoch_start_times[0] = 1000.0
assert metrics.epoch_time(0) is None
# Add epoch end
metrics.epoch_end_times[0] = 1060.0
assert metrics.epoch_time(0) == 60.0
def test_average_epoch_time(self):
"""Test average_epoch_time method."""
metrics = RuntimeMetrics(start_time=1000.0)
# No completed epochs
assert metrics.average_epoch_time() is None
# Add one completed epoch
metrics.epoch_start_times[0] = 1000.0
metrics.epoch_end_times[0] = 1060.0
assert metrics.average_epoch_time() == 60.0
# Add second completed epoch
metrics.epoch_start_times[1] = 1060.0
metrics.epoch_end_times[1] = 1140.0 # 80 seconds
assert metrics.average_epoch_time() == 70.0 # Average of 60 and 80
# Add incomplete epoch (should not affect average)
metrics.epoch_start_times[2] = 1140.0
assert metrics.average_epoch_time() == 70.0
def test_steps_per_second(self, mock_time):
"""Test steps_per_second method."""
metrics = RuntimeMetrics(start_time=1000.0)
# No steps - first call to time.time()
mock_time.side_effect = None
mock_time.return_value = 1050.0
assert metrics.steps_per_second() is None
# Add steps - second call to time.time()
metrics.total_steps = 100
mock_time.return_value = 1050.0 # Keep same time for consistent result
assert metrics.steps_per_second() == 2.0 # 100 steps / 50 seconds
def test_to_dict_basic(self, mock_time):
"""Test to_dict method with basic metrics."""
metrics = RuntimeMetrics(start_time=1000.0)
metrics.total_steps = 100
metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 # 2GB
# Mock elapsed_time
mock_time.side_effect = None
mock_time.return_value = 1050.0
result = metrics.to_dict()
assert result["total_time_seconds"] == 50.0
assert result["total_steps"] == 100
assert result["steps_per_second"] == 2.0
assert result["epochs_completed"] == 0
assert result["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
assert "epoch_times" not in result
assert "gpu_memory" not in result
def test_to_dict_with_epochs(self, mock_time):
"""Test to_dict method with epoch data."""
metrics = RuntimeMetrics(start_time=1000.0)
metrics.total_steps = 100
# Add epoch data
metrics.epoch_start_times[0] = 1000.0
metrics.epoch_end_times[0] = 1060.0
metrics.epoch_start_times[1] = 1060.0
metrics.epoch_end_times[1] = 1140.0
# Mock elapsed_time
mock_time.side_effect = None
mock_time.return_value = 1150.0
result = metrics.to_dict()
assert "epoch_times" in result
assert result["epoch_times"]["epoch_0_seconds"] == 60.0
assert result["epoch_times"]["epoch_1_seconds"] == 80.0
assert result["average_epoch_time_seconds"] == 70.0
def test_to_dict_with_gpu_memory(self, mock_time):
"""Test to_dict method with GPU memory data."""
metrics = RuntimeMetrics(start_time=1000.0)
metrics.peak_gpu_memory = {
0: 1 * 1024 * 1024 * 1024, # 1GB
1: 2 * 1024 * 1024 * 1024, # 2GB
}
# Mock elapsed_time
mock_time.side_effect = [1050.0]
result = metrics.to_dict()
assert "gpu_memory" in result
assert result["gpu_memory"]["gpu_0_peak_memory_bytes"] == 1 * 1024 * 1024 * 1024
assert result["gpu_memory"]["gpu_1_peak_memory_bytes"] == 2 * 1024 * 1024 * 1024
class TestRuntimeMetricsTracker:
"""Tests for RuntimeMetricsTracker class."""
# pylint: disable=unused-argument
def test_initialization(self, mock_time, mock_telemetry_manager):
"""Test RuntimeMetricsTracker initialization."""
tracker = RuntimeMetricsTracker()
assert isinstance(tracker.metrics, RuntimeMetrics)
assert tracker.metrics.start_time == 1000.0 # First value from mock_time
# pylint: disable=unused-argument
def test_start_epoch(
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
):
"""Test start_epoch method."""
tracker = RuntimeMetricsTracker()
# Reset mock_time to control next value
mock_time.side_effect = [1010.0]
tracker.start_epoch(0)
assert tracker.metrics.current_epoch == 0
assert tracker.metrics.epoch_start_times[0] == 1010.0
# Verify memory metrics were updated
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
assert 0 in tracker.metrics.peak_gpu_memory
assert 1 in tracker.metrics.peak_gpu_memory
# pylint: disable=unused-argument
def test_end_epoch(self, mock_time, mock_telemetry_manager):
"""Test end_epoch method."""
tracker = RuntimeMetricsTracker()
# Start epoch 0
mock_time.side_effect = [1010.0]
tracker.start_epoch(0)
# End epoch 0
mock_time.side_effect = [1060.0]
tracker.end_epoch(0)
assert 0 in tracker.metrics.epoch_end_times
assert tracker.metrics.epoch_end_times[0] == 1060.0
# pylint: disable=unused-argument
def test_update_step(
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
):
"""Test update_step method."""
tracker = RuntimeMetricsTracker()
# Update step to a non-multiple of 100
tracker.update_step(42)
assert tracker.metrics.current_step == 42
assert tracker.metrics.total_steps == 1
# Memory metrics should not be updated for non-multiple of 100
assert tracker.metrics.peak_cpu_memory == 0
# Update step to a multiple of 100
tracker.update_step(100)
assert tracker.metrics.current_step == 100
assert tracker.metrics.total_steps == 2
# Memory metrics should be updated for multiple of 100
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
# pylint: disable=unused-argument
def test_update_memory_metrics(
self, mock_psutil, mock_torch, mock_telemetry_manager
):
"""Test update_memory_metrics method."""
tracker = RuntimeMetricsTracker()
# Initial memory state
assert tracker.metrics.peak_cpu_memory == 0
assert tracker.metrics.peak_gpu_memory == {}
# Update memory metrics
tracker.update_memory_metrics()
# Verify CPU memory
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
# Verify GPU memory
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
# Change mocked memory values to be lower
mock_process = mock_psutil.Process.return_value
mock_memory_info = mock_process.memory_info.return_value
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 0.5) * 1024 * 1024 * 1024
)
# Update memory metrics again
tracker.update_memory_metrics()
# Peak values should not decrease
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
# Change mocked memory values to be higher
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 2) * 1024 * 1024 * 1024
)
# Update memory metrics again
tracker.update_memory_metrics()
# Peak values should increase
assert tracker.metrics.peak_cpu_memory == 2 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[0] == 2 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[1] == 3 * 1024 * 1024 * 1024
# pylint: disable=unused-argument
def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_telemetry_manager):
"""Test get_memory_metrics method."""
tracker = RuntimeMetricsTracker()
# Set peak memory values
tracker.metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024
tracker.metrics.peak_gpu_memory = {
0: 3 * 1024 * 1024 * 1024,
1: 4 * 1024 * 1024 * 1024,
}
# Get memory metrics
memory_metrics = tracker.get_memory_metrics()
# Verify CPU memory
assert (
memory_metrics["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024
) # Current value from mock
assert (
memory_metrics["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
) # Peak value we set
# Verify GPU memory
assert (
memory_metrics["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024
) # Current value from mock
assert (
memory_metrics["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024
) # Peak value we set
assert (
memory_metrics["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024
) # Current value from mock
assert (
memory_metrics["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024
) # Peak value we set

View File

@@ -1,389 +0,0 @@
"""Unit tests for dynamic checkpoint callback"""
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, Mock, patch
from axolotl.utils.callbacks.dynamic_checkpoint import (
DEFAULT_TRIGGER_FILENAME,
DynamicCheckpointCallback,
)
from axolotl.utils.dict import DictDefault
class TestDynamicCheckpointCallbackInit:
"""Test callback initialization"""
def test_callback_disabled_by_default(self):
"""Test that callback is disabled when config.enabled=False"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {"enabled": False},
"output_dir": tmpdir,
}
)
callback = DynamicCheckpointCallback(cfg)
assert callback.enabled is False
def test_callback_disabled_when_none(self):
"""Test that callback is disabled when dynamic_checkpoint is None"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": None,
"output_dir": tmpdir,
}
)
callback = DynamicCheckpointCallback(cfg)
assert callback.enabled is False
def test_callback_enabled_when_configured(self):
"""Test that callback is enabled when config.enabled=True"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {"enabled": True, "check_interval": 10},
"output_dir": tmpdir,
}
)
callback = DynamicCheckpointCallback(cfg)
assert callback.enabled is True
assert callback.check_interval == 10
def test_default_trigger_filename(self):
"""Test that default trigger filename is used"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {"enabled": True, "check_interval": 10},
"output_dir": tmpdir,
}
)
callback = DynamicCheckpointCallback(cfg)
assert callback.trigger_filename == DEFAULT_TRIGGER_FILENAME
def test_check_interval_default(self):
"""Test default check interval"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {"enabled": True},
"output_dir": tmpdir,
}
)
callback = DynamicCheckpointCallback(cfg)
assert callback.check_interval == 100 # Default from schema
class TestDynamicCheckpointFileDetection:
"""Test file-based checkpoint triggering"""
def test_trigger_file_detected_and_deleted(self):
"""Test that trigger file is detected and deleted"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
"output_dir": tmpdir,
}
)
callback = DynamicCheckpointCallback(cfg)
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
trigger_file.touch()
assert trigger_file.exists()
args = Mock(output_dir=tmpdir)
state = Mock(global_step=1)
control = Mock(should_save=False)
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
return_value=True,
):
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
return_value=False,
):
result = callback.on_step_end(args, state, control)
assert not trigger_file.exists()
assert result.should_save is True
def test_check_interval_honored(self):
"""Test that file is only checked at check_interval steps"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {"enabled": True, "check_interval": 10},
"output_dir": tmpdir,
}
)
callback = DynamicCheckpointCallback(cfg)
args = Mock(output_dir=tmpdir)
control = Mock(should_save=False)
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
trigger_file.touch()
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
return_value=True,
):
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
return_value=False,
):
# Step 5 - shouldn't check (not divisible by 10)
state = Mock(global_step=5)
result = callback.on_step_end(args, state, control)
assert trigger_file.exists() # Still there
assert result.should_save is False
# Step 10 - should check
state = Mock(global_step=10)
result = callback.on_step_end(args, state, control)
assert not trigger_file.exists() # Deleted
assert result.should_save is True
def test_no_file_no_trigger(self):
"""Test that no trigger occurs when file doesn't exist"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
"output_dir": tmpdir,
}
)
callback = DynamicCheckpointCallback(cfg)
args = Mock(output_dir=tmpdir)
state = Mock(global_step=1)
control = Mock(should_save=False)
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
return_value=True,
):
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
return_value=False,
):
result = callback.on_step_end(args, state, control)
assert result.should_save is False
def test_file_deletion_error_handling(self):
"""Test that file deletion errors are handled gracefully"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
"output_dir": tmpdir,
}
)
callback = DynamicCheckpointCallback(cfg)
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
trigger_file.touch()
args = Mock(output_dir=tmpdir)
state = Mock(global_step=1)
control = Mock(should_save=False)
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
return_value=True,
):
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
return_value=False,
):
with patch.object(
Path, "unlink", side_effect=OSError("Permission denied")
):
result = callback.on_step_end(args, state, control)
assert result.should_save is True
class TestDynamicCheckpointMultiGPU:
"""Test multi-GPU synchronization"""
def test_only_rank_0_checks_file(self):
"""Test that only rank 0 checks filesystem in multi-GPU setup"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
"output_dir": tmpdir,
}
)
callback = DynamicCheckpointCallback(cfg)
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
trigger_file.touch()
args = Mock(output_dir=tmpdir)
state = Mock(global_step=1)
control = Mock(should_save=False)
# Rank 1 (not main process) - shouldn't check file
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
return_value=False,
):
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
return_value=True,
):
with patch("torch.distributed.broadcast") as mock_broadcast:
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.barrier"
):
mock_tensor = MagicMock()
mock_tensor.item.return_value = 0
with patch("torch.tensor", return_value=mock_tensor):
callback.on_step_end(args, state, control)
assert trigger_file.exists()
# Broadcast should have been called
assert mock_broadcast.called
def test_broadcast_synchronization(self):
"""Test that trigger decision is broadcasted to all ranks"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {"enabled": True, "check_interval": 1},
"output_dir": tmpdir,
}
)
callback = DynamicCheckpointCallback(cfg)
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
trigger_file.touch()
args = Mock(output_dir=tmpdir)
state = Mock(global_step=1)
control = Mock(should_save=False)
# Rank 0 detects file
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
return_value=True,
):
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
return_value=True,
):
with patch("torch.distributed.broadcast") as mock_broadcast:
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.barrier"
) as mock_barrier:
mock_tensor = MagicMock()
mock_tensor.item.return_value = 1
with patch("torch.tensor", return_value=mock_tensor):
with patch("torch.cuda.current_device", return_value=0):
result = callback.on_step_end(args, state, control)
assert mock_broadcast.called
assert mock_barrier.called
# All ranks should trigger
assert result.should_save is True
class TestDynamicCheckpointSignalHandling:
"""Test signal-based checkpoint triggering"""
def test_signal_trigger_via_callback(self):
"""Test that signal flag triggers checkpoint save"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {
"enabled": True,
"check_interval": 1,
"enable_signal": True,
},
"output_dir": tmpdir,
}
)
with patch("signal.signal"):
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
return_value=True,
):
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.hasattr",
return_value=True,
):
callback = DynamicCheckpointCallback(cfg)
callback.should_save_checkpoint = True
args = Mock(output_dir=tmpdir)
state = Mock(global_step=1)
control = Mock(should_save=False)
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_main_process",
return_value=True,
):
with patch(
"axolotl.utils.callbacks.dynamic_checkpoint.is_distributed",
return_value=False,
):
result = callback.on_step_end(args, state, control)
assert result.should_save is True
assert callback.should_save_checkpoint is False
def test_signal_not_registered_when_disabled(self):
"""Test that signal handler is not registered when disabled"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {
"enabled": True,
"check_interval": 10,
"enable_signal": False,
},
"output_dir": tmpdir,
}
)
with patch("signal.signal") as mock_signal_register:
_ = DynamicCheckpointCallback(cfg)
assert not mock_signal_register.called
class TestDynamicCheckpointDisabled:
"""Test behavior when callback is disabled"""
def test_disabled_callback_does_nothing(self):
"""Test that disabled callback doesn't check or trigger"""
with tempfile.TemporaryDirectory() as tmpdir:
cfg = DictDefault(
{
"dynamic_checkpoint": {"enabled": False},
"output_dir": tmpdir,
}
)
callback = DynamicCheckpointCallback(cfg)
trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME
trigger_file.touch()
args = Mock(output_dir=tmpdir)
state = Mock(global_step=1)
control = Mock(should_save=False)
result = callback.on_step_end(args, state, control)
assert trigger_file.exists()
assert result.should_save is False

View File

@@ -1,18 +0,0 @@
import os
import pytest
from axolotl.core.trainers.grpo import GRPOStrategy
def test_get_rollout_func_loads_successfully():
"""Test that a valid rollout function can be loaded"""
rollout_func = GRPOStrategy.get_rollout_func("os.path.join")
assert callable(rollout_func)
assert rollout_func == os.path.join
def test_get_rollout_func_invalid_module_raises_error():
"""Test that invalid module path raises clear ValueError"""
with pytest.raises(ValueError, match="Rollout function .* not found"):
GRPOStrategy.get_rollout_func("nonexistent_module.my_func")