diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 4f6ea8de7..cd443a197 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,13 +1,13 @@ # These are supported funding model platforms -github: [winglian, OpenAccess-AI-Collective] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] patreon: # Replace with a single Patreon username open_collective: # Replace with a single Open Collective username -ko_fi: axolotl_ai # Replace with a single Ko-fi username +ko_fi: # Replace with a single Ko-fi username tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry liberapay: # Replace with a single Liberapay username issuehunt: # Replace with a single IssueHunt username otechie: # Replace with a single Otechie username lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry -custom: ['https://quickchart.io/qr?text=bitcoin%3Abc1qxlgwlqwfea5s2cxm42xqsfmwjct0rj8w8ea5np&size=480¢erImageUrl=https%3A%2F%2Fupload.wikimedia.org%2Fwikipedia%2Fcommons%2Fthumb%2F4%2F46%2FBitcoin.svg%2F64px-Bitcoin.svg.png'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 87d6772dd..eddce1438 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -57,14 +57,14 @@ jobs: cuda_version: 12.8.1 cudnn_version: "" python_version: "3.11" - pytorch: 2.9.0 + pytorch: 2.9.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-base" - cuda: "130" cuda_version: 13.0.0 cudnn_version: "" python_version: "3.11" - pytorch: 2.9.0 + pytorch: 2.9.1 torch_cuda_arch_list: "9.0+PTX" dockerfile: "Dockerfile-base" # - cuda: "128" @@ -90,7 +90,6 @@ jobs: uses: docker/metadata-action@v5 with: images: | - winglian/axolotl-base axolotlai/axolotl-base - name: Login to Docker Hub uses: docker/login-action@v2 @@ -147,14 +146,14 @@ jobs: cuda_version: 12.8.1 cudnn_version: "" python_version: "3.11" - pytorch: 2.9.0 + pytorch: 2.9.1 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" dockerfile: "Dockerfile-uv-base" - cuda: "130" cuda_version: 13.0.0 cudnn_version: "" python_version: "3.11" - pytorch: 2.9.0 + pytorch: 2.9.1 torch_cuda_arch_list: "9.0+PTX" dockerfile: "Dockerfile-uv-base" steps: diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4040ccdc9..f34a0cf2f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -25,7 +25,6 @@ jobs: python_version: "3.11" pytorch: 2.7.1 axolotl_extras: vllm - is_latest: true - cuda: 128 cuda_version: 12.8.1 python_version: "3.11" @@ -36,6 +35,17 @@ jobs: python_version: "3.11" pytorch: 2.8.0 axolotl_extras: + is_latest: true + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.9.0 + axolotl_extras: + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.9.1 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -45,7 +55,6 @@ jobs: uses: docker/metadata-action@v5 with: images: | - winglian/axolotl axolotlai/axolotl tags: | type=ref,event=branch @@ -99,7 +108,6 @@ jobs: python_version: "3.11" pytorch: 2.7.1 axolotl_extras: vllm - is_latest: true - cuda: 128 cuda_version: 12.8.1 python_version: "3.11" @@ -110,6 +118,17 @@ jobs: python_version: "3.11" pytorch: 2.8.0 axolotl_extras: + is_latest: true + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.9.0 + axolotl_extras: + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.9.1 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -119,7 +138,6 @@ jobs: uses: docker/metadata-action@v5 with: images: | - winglian/axolotl-cloud axolotlai/axolotl-cloud tags: | type=ref,event=branch @@ -179,7 +197,6 @@ jobs: uses: docker/metadata-action@v5 with: images: | - winglian/axolotl-cloud-term axolotlai/axolotl-cloud-term tags: | type=ref,event=branch diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml index 18b036a0d..a24946ae9 100644 --- a/.github/workflows/nightlies.yml +++ b/.github/workflows/nightlies.yml @@ -31,7 +31,6 @@ jobs: uses: docker/metadata-action@v5 with: images: | - winglian/axolotl axolotlai/axolotl tags: | type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }} @@ -84,7 +83,6 @@ jobs: uses: docker/metadata-action@v5 with: images: | - winglian/axolotl-cloud axolotlai/axolotl-cloud tags: | type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7ad9d1ab4..95370ca3d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -59,6 +59,10 @@ jobs: timeout-minutes: 20 steps: + - name: cleanup node + run: | + sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL + - name: Check out repository code uses: actions/checkout@v4 @@ -91,6 +95,10 @@ jobs: python scripts/cutcrossentropy_install.py | sh pip3 install -r requirements-dev.txt -r requirements-tests.txt + - name: cleanup pip cache + run: | + find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; + - name: Make sure PyTorch version wasn't clobbered run: | python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__" @@ -118,10 +126,6 @@ jobs: flags: unittests,pytorch-${{ matrix.pytorch_version }} fail_ci_if_error: false - - name: cleanup pip cache - run: | - find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; - pytest-sdist: name: PyTest from Source Dist runs-on: ubuntu-latest @@ -134,6 +138,10 @@ jobs: timeout-minutes: 20 steps: + - name: cleanup node + run: | + sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL + - name: Check out repository code uses: actions/checkout@v4 @@ -167,6 +175,10 @@ jobs: python scripts/cutcrossentropy_install.py | sh pip3 install -r requirements-dev.txt -r requirements-tests.txt + - name: cleanup pip cache + run: | + find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; + - name: Make sure PyTorch version wasn't clobbered run: | python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__" @@ -184,10 +196,6 @@ jobs: pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml pytest -v --durations=10 tests/cli/ - - name: cleanup pip cache - run: | - find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; - gate-skip-e2e: needs: [pre-commit, pytest, pytest-sdist] runs-on: ubuntu-latest diff --git a/README.md b/README.md index 6313a73ca..1517fb874 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,10 @@ ## ๐ŸŽ‰ Latest Updates +- 2025/11: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3). +- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss). +- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion). +- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107). - 2025/07: - ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info. - Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm). @@ -36,12 +40,12 @@ - [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl! - TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl! - 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more! -- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
Expand older updates +- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning. - 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl! - 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version! - 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own! @@ -154,6 +158,13 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details. +## ๐Ÿ“ˆ Telemetry + +Axolotl has opt-out telemetry that helps us understand how the project is being used +and prioritize improvements. We collect basic system information, model types, and +error ratesโ€”never personal data or file paths. Telemetry is enabled by default. To +disable it, set AXOLOTL_DO_NOT_TRACK=1. For more details, see our [telemetry documentation](https://docs.axolotl.ai/docs/telemetry.html). + ## โค๏ธ Sponsors Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai) diff --git a/_quarto.yml b/_quarto.yml index fad3f6786..c97b9838e 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -241,6 +241,7 @@ website: - docs/installation.qmd - docs/inference.qmd - docs/cli.qmd + - docs/telemetry.qmd - docs/config-reference.qmd - text: "API Reference" href: docs/api diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index 25eae4fde..cfd30b851 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -51,7 +51,7 @@ RUN git lfs install --skip-repo && \ pip3 install -U --no-cache-dir pydantic==1.10.10 && \ pip3 cache purge -RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \ +RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \ wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index 870a2b67d..34fde45fb 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -218,6 +218,13 @@ If you have tool arguments with same name but different dtypes (like `"time": st ``` "arguments": "{\"...\": \"...\"}" ``` + +The same is applicable for tool parameters. + +``` +"parameters": "{\"...\": \"...\"}" +``` + ::: Example config for Llama4: diff --git a/docs/multi-gpu.qmd b/docs/multi-gpu.qmd index 57a941b04..1b58a108c 100644 --- a/docs/multi-gpu.qmd +++ b/docs/multi-gpu.qmd @@ -4,7 +4,7 @@ format: html: toc: true toc-depth: 3 - number-sections: true + # number-sections: true code-tools: true execute: enabled: false @@ -14,12 +14,18 @@ This guide covers advanced training configurations for multi-GPU setups using Ax ## Overview {#sec-overview} -Axolotl supports several methods for multi-GPU training: +When training on multiple GPUs, Axolotl supports 3 sharding/parallelism strategies. Additionally, you can layer specific optimization features on top of that strategy. -- DeepSpeed (recommended) -- FSDP (Fully Sharded Data Parallel) -- Sequence parallelism -- FSDP + QLoRA +You generally cannot combine these strategies; they are mutually exclusive. + +1. **DeepSpeed**: Powerful optimization library, supports ZeRO stages 1-3. +2. **FSDP (Fully Sharded Data Parallel)**: PyTorch's native sharding implementation (Recommended). +3. **DDP (Distributed Data Parallel)**: PyTorch's native parallelism implementation (Default if neither of the above are selected). + +These features can often be combined with the strategies above: + +* **Sequence Parallelism**: Splits long sequences across GPUs (Compatible with DDP, DeepSpeed, and FSDP). +* **FSDP + QLoRA**: Combines 4-bit quantization with FSDP (Specific to FSDP). ## DeepSpeed {#sec-deepspeed} @@ -65,12 +71,18 @@ Start from Stage 1 -> Stage 2 -> Stage 3. ## Fully Sharded Data Parallel (FSDP) {#sec-fsdp} +FSDP allows you to shard model parameters, gradients, and optimizer states across data parallel workers. + ::: {.callout-note} FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl. ::: +### FSDP + QLoRA {#sec-fsdp-qlora} + +For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd). + ### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2} To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and @@ -145,10 +157,6 @@ single sequence causes OOM errors during model training. See our [dedicated guide](sequence_parallelism.qmd) for more information. -### FSDP + QLoRA {#sec-fsdp-qlora} - -For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd). - ## Performance Optimization {#sec-performance} ### Liger Kernel Integration {#sec-liger} diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index 1c4e28ea7..e63a553b2 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -124,6 +124,8 @@ Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral ```yaml base_model: mistralai/Voxtral-Mini-3B-2507 + +processor_type: VoxtralProcessor ``` ### Gemma-3 {#sec-gemma-3} diff --git a/docs/telemetry.qmd b/docs/telemetry.qmd new file mode 100644 index 000000000..62d7c9bbc --- /dev/null +++ b/docs/telemetry.qmd @@ -0,0 +1,61 @@ +--- +title: Telemetry +description: A description of the telemetry implementation in Axolotl. +--- + +# Telemetry in Axolotl + +Axolotl implements anonymous telemetry to help maintainers understand how the library +is used and where users encounter issues. This data helps prioritize features, optimize +performance, and fix bugs. + +## Data Collection + +We collect: + +- System info: OS, Python version, Axolotl version, PyTorch version, Transformers +version, etc. +- Hardware info: CPU count, memory, GPU count and models +- Runtime metrics: Training progress, memory usage, timing information +- Usage patterns: Models (from a whitelist) and configurations used +- Error tracking: Stack traces and error messages (sanitized to remove personal +information) + +Personally identifiable information (PII) is not collected. + +## Implementation + +Telemetry is implemented using PostHog and consists of: + +- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the +telemetry system and provides methods for tracking events. +- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and +sends sanitized stack traces. +- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks +runtime metrics during training. +- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends +runtime metrics telemetry. + +The telemetry system will block training startup for 10 seconds to ensure users are +aware of data collection, unless telemetry is explicitly enabled or disabled. + +## Opt-Out Mechanism + +Telemetry is **enabled by default** on an opt-out basis. To disable it, set +`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1`. + +A warning message will be logged on start to clearly inform users about telemetry. +We will remove this after some period. + +To hide the warning message about telemetry that is displayed on train, etc. startup, +explicitly set: `AXOLOTL_DO_NOT_TRACK=0` (enable telemetry) or `AXOLOTL_DO_NOT_TRACK=1` +(explicitly disable telemetry). + +## Privacy + +- All path-like config information is automatically redacted from telemetry data +- Model information is only collected for whitelisted organizations + - See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations +- Each run generates a unique anonymous ID + - This allows us to link different telemetry events in a single same training run +- Telemetry is only sent from the main process to avoid duplicate events diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index cea1aeda0..57a638948 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "# This step can take ~5-10 minutes to install dependencies\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", - "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec\"" + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953\"" ] }, { diff --git a/examples/granite4/README.md b/examples/granite4/README.md new file mode 100644 index 000000000..d5efd3349 --- /dev/null +++ b/examples/granite4/README.md @@ -0,0 +1,65 @@ +# Finetune IBM's Granite 4.0 with Axolotl + +[Granite 4.0](https://huggingface.co/collections/ibm-granite/granite-40-language-models) are a family of open source models trained by IBM Research. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Granite4 is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.7.1 min) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy +python scripts/cutcrossentropy_install.py | sh +``` + +2. Run the finetuning example: + +```bash +axolotl train examples/granite4/granite-4.0-tiny-fft.yaml +``` + +This config uses about 40.8GiB VRAM. + +Let us know how it goes. Happy finetuning! ๐Ÿš€ + +### TIPS + +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +### Limitation + +Adapter finetuning does not work at the moment. It would error with + +```bash +RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x3072 and 1x1179648) +``` + +In addition, if adapter training works, `lora_target_linear: true` will not work due to: +```bash +ValueError: Target module GraniteMoeHybridParallelExperts() is not supported. +``` + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Related Resources + +- [Granite Docs](https://www.ibm.com/granite/docs/models/granite) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/granite4/granite-4.0-tiny-fft.yaml b/examples/granite4/granite-4.0-tiny-fft.yaml new file mode 100644 index 000000000..7ff8207ae --- /dev/null +++ b/examples/granite4/granite-4.0-tiny-fft.yaml @@ -0,0 +1,45 @@ +base_model: ibm-granite/granite-4.0-tiny-preview + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/model-out + +sequence_len: 2048 +sample_packing: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/olmo3/README.md b/examples/olmo3/README.md new file mode 100644 index 000000000..d4dbe05a9 --- /dev/null +++ b/examples/olmo3/README.md @@ -0,0 +1,46 @@ +# Finetune Allenai's Olmo 3 with Axolotl + +[Olmo 3](https://huggingface.co/collections/allenai/olmo-3) are a family of 7B and 32B models open source models trained by The Allen Institute for Artificial Intelligence. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + + Here is an example of how to install from pip: + ```bash + # Ensure you have a compatible version of Pytorch installed + pip3 install packaging setuptools wheel ninja + pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' + + # Install Cut Cross Entropy + python scripts/cutcrossentropy_install.py | sh + ``` + +2. Run the finetuning example: + +```bash +axolotl train examples/olmo3/olmo3-7b-qlora.yaml +``` + +Let us know how it goes. Happy finetuning! ๐Ÿš€ + +### TIPS + +- The example config can be re-used for Olmo and Olmo 2. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). + +## Related Resources + +- [Olmo 3 Blog](https://allenai.org/blog/olmo3) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/olmo3/olmo3-7b-qlora.yaml b/examples/olmo3/olmo3-7b-qlora.yaml new file mode 100644 index 000000000..c8878d79f --- /dev/null +++ b/examples/olmo3/olmo3-7b-qlora.yaml @@ -0,0 +1,64 @@ +base_model: allenai/Olmo-3-7B-Instruct-SFT + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/seed-oss/README.md b/examples/seed-oss/README.md index 5610c1316..aeb8635e3 100644 --- a/examples/seed-oss/README.md +++ b/examples/seed-oss/README.md @@ -6,21 +6,17 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations ## Getting started -1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). - Here is an example of how to install from main for pip: + Here is an example of how to install from pip: + ```bash + # Ensure you have a compatible version of Pytorch installed + pip3 install packaging setuptools wheel ninja + pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' -```bash -# Ensure you have Pytorch installed (Pytorch 2.6.0 min) -git clone https://github.com/axolotl-ai-cloud/axolotl.git -cd axolotl - -pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja -pip3 install --no-build-isolation -e '.[flash-attn]' - -# Install Cut Cross Entropy -python scripts/cutcrossentropy_install.py | sh -``` + # Install Cut Cross Entropy + python scripts/cutcrossentropy_install.py | sh + ``` 2. Run the finetuning example: @@ -41,9 +37,7 @@ Let us know how it goes. Happy finetuning! ๐Ÿš€ ## Optimization Guides -- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) -- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) -- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). ## Related Resources diff --git a/examples/smolvlm2/README.md b/examples/smolvlm2/README.md index 9c0ae4836..74c1a1c0f 100644 --- a/examples/smolvlm2/README.md +++ b/examples/smolvlm2/README.md @@ -37,9 +37,7 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl. ## Optimization Guides -- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) -- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) -- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). ## Related Resources diff --git a/examples/voxtral/voxtral-mini-audio-qlora.yml b/examples/voxtral/voxtral-mini-audio-qlora.yml index 8fe6adbff..59150c4ca 100644 --- a/examples/voxtral/voxtral-mini-audio-qlora.yml +++ b/examples/voxtral/voxtral-mini-audio-qlora.yml @@ -1,5 +1,5 @@ base_model: mistralai/Voxtral-Mini-3B-2507 -processor_type: AutoProcessor +processor_type: VoxtralProcessor # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name diff --git a/requirements.txt b/requirements.txt index a12a3941b..08759279d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ peft>=0.17.1 tokenizers>=0.22.1 transformers==4.57.1 accelerate==1.11.0 -datasets==4.3.0 +datasets==4.4.1 deepspeed>=0.17.0 trl==0.25.0 hf_xet==1.2.0 @@ -42,7 +42,6 @@ numpy>=2.2.6 # qlora things evaluate==0.4.1 scipy -scikit-learn==1.4.2 nvidia-ml-py==12.560.30 art tensorboard @@ -70,4 +69,7 @@ schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.7 axolotl-contribs-mit==0.0.5 +# telemetry +posthog==6.7.11 + mistral-common==1.8.5 diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index cb498c002..91d0f45d6 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"' ) diff --git a/setup.py b/setup.py index 9c1161642..a1bdd6bdf 100644 --- a/setup.py +++ b/setup.py @@ -130,7 +130,7 @@ extras_require = { "ring-flash-attn>=0.1.7", ], "deepspeed": [ - "deepspeed==0.17.5", + "deepspeed==0.18.2", "deepspeed-kernels", ], "mamba-ssm": [ diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 93ac6147d..3c4ace7b0 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -14,6 +14,8 @@ import yaml from transformers.utils import is_torch_bf16_gpu_available from axolotl.integrations.base import PluginManager +from axolotl.telemetry.errors import send_errors +from axolotl.telemetry.manager import TelemetryManager from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.config import ( normalize_cfg_datasets, @@ -31,6 +33,8 @@ LOG = get_logger(__name__) API_KEY_FIELDS = {"comet_api_key"} +TELEMETRY_MANAGER = TelemetryManager.get_instance() + def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: """ @@ -164,6 +168,7 @@ def plugin_set_cfg(cfg: DictDefault): plugin_manager.cfg = cfg +@send_errors def load_cfg( config: str | Path | DictDefault = Path("examples/"), **kwargs ) -> DictDefault: @@ -197,6 +202,8 @@ def load_cfg( temp_file.close() cfg.axolotl_config_path = temp_file.name + TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg) + # If there are any options passed in the cli, if it is something that seems valid # from the yaml, then overwrite the value cfg_keys = cfg.keys() @@ -240,6 +247,7 @@ def load_cfg( setup_comet_env_vars(cfg) plugin_set_cfg(cfg) + TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg) cfg_to_log = { k: "[REDACTED]" if k in API_KEY_FIELDS else v for k, v in cfg.items() diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index 3e1c01520..640be3696 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -19,7 +19,10 @@ from axolotl.cli.utils.diffusion import ( launch_diffusion_gradio_ui, ) from axolotl.integrations.base import PluginManager -from axolotl.utils.chat_templates import get_chat_template_from_config +from axolotl.telemetry.errors import send_errors +from axolotl.utils.chat_templates import ( + get_chat_template_from_config, +) from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -43,6 +46,7 @@ def get_multi_line_input() -> str: return instruction +@send_errors def do_inference( *, cfg: DictDefault, @@ -160,6 +164,7 @@ def do_inference( print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) +@send_errors def do_inference_gradio( *, cfg: DictDefault, diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 657ddcfe4..482767b12 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -7,12 +7,14 @@ import fire from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +@send_errors def do_merge_lora(*, cfg: DictDefault) -> None: """ Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index 43142d79e..1d9736b9d 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -23,6 +23,7 @@ from safetensors.torch import save_file as safe_save_file from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner from axolotl.cli.config import load_cfg +from axolotl.telemetry.errors import send_errors from axolotl.utils.logging import get_logger from axolotl.utils.train import determine_last_checkpoint @@ -118,6 +119,7 @@ def _distributed_checkpoint_to_merged_weights( return save_path_ +@send_errors def merge_fsdp_weights( checkpoint_dir: str, output_path: str, diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 6c05a55f1..af35dd801 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -17,6 +17,7 @@ from axolotl.cli.config import load_cfg from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.integrations.base import PluginManager +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.trainer import disable_datasets_caching @@ -24,6 +25,7 @@ from axolotl.utils.trainer import disable_datasets_caching LOG = get_logger(__name__) +@send_errors def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: """ Preprocesses dataset specified in axolotl config. diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index 8d7758e66..c95ddb80e 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -9,6 +9,7 @@ from datasets import Dataset import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401 from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.loaders import load_processor, load_tokenizer +from axolotl.telemetry.errors import send_errors from axolotl.utils.data import prepare_datasets, prepare_preference_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -34,6 +35,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: ) +@send_errors def load_datasets( *, cfg: DictDefault, @@ -96,6 +98,7 @@ def load_datasets( ) +@send_errors def load_preference_datasets( *, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None ) -> TrainDatasetMeta: diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 2c949f8e7..0d19b369f 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -29,6 +29,8 @@ from transformers.trainer_pt_utils import AcceleratorConfig from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr +from axolotl.telemetry.callbacks import TelemetryCallback +from axolotl.telemetry.manager import TelemetryManager from axolotl.utils import ( is_comet_available, is_mlflow_available, @@ -118,6 +120,13 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.gc_steps: callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) + if self.cfg.dynamic_checkpoint and self.cfg.dynamic_checkpoint.enabled: + from axolotl.utils.callbacks.dynamic_checkpoint import ( + DynamicCheckpointCallback, + ) + + callbacks.append(DynamicCheckpointCallback(self.cfg)) + if self.cfg.use_wandb: callbacks.append( SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) @@ -155,6 +164,10 @@ class TrainerBuilderBase(abc.ABC): ) ) + telemetry_manager = TelemetryManager.get_instance() + if telemetry_manager.enabled: + callbacks.append(TelemetryCallback()) + return callbacks def get_post_trainer_create_callbacks(self, trainer): @@ -196,9 +209,9 @@ class TrainerBuilderBase(abc.ABC): ): warmup_steps = 0 warmup_ratio = 0.0 - if self.cfg.warmup_steps: + if self.cfg.warmup_steps is not None: warmup_steps = self.cfg.warmup_steps - elif self.cfg.warmup_ratio: + elif self.cfg.warmup_ratio is not None: if total_num_steps: warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) else: diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index e4496bee6..db6fb3f16 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -10,6 +10,7 @@ import torch from datasets import Dataset from transformers.trainer import Trainer +from axolotl.telemetry.errors import send_errors from axolotl.train import ( TrainDatasetMeta, setup_model_and_tokenizer, @@ -63,6 +64,7 @@ def evaluate_dataset( return metrics +@send_errors def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: """ Evaluate a model on training and validation datasets. diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 81dd6a3a3..1c793137c 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953" ``` ## Usage @@ -66,6 +66,9 @@ plugins: - mistral3 - mixtral - mllama +- olmo +- olmo2 +- olmo3 - phi - phi3 - phi4_multimodal diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index bd0124b93..b8f7e9da3 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -35,7 +35,7 @@ LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install Axolotl's fork of cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"`' ) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index bcde4bf96..8e8177b62 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -20,6 +20,7 @@ from peft import ( from transformers import PreTrainedModel from axolotl.loaders.utils import get_linear_embedding_layers +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -172,6 +173,7 @@ def load_lora( return model, lora_config +@send_errors def load_adapter( model: PreTrainedModel, cfg: DictDefault, diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index aeec46584..1eeed3565 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -49,6 +49,7 @@ from axolotl.loaders.utils import ( load_model_config, ) from axolotl.models.mamba import fix_mamba_attn_for_loss +from axolotl.telemetry.errors import send_errors from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import ( @@ -158,6 +159,7 @@ class ModelLoader: """Property that determines if FSDP with QLoRA is enabled.""" return self.is_fsdp_enabled and self.cfg.adapter == "qlora" + @send_errors def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]: """Load and prepare the model with all configurations and patches. diff --git a/src/axolotl/loaders/processor.py b/src/axolotl/loaders/processor.py index 7580b2008..827b4be35 100644 --- a/src/axolotl/loaders/processor.py +++ b/src/axolotl/loaders/processor.py @@ -1,27 +1,47 @@ """Processor loading functionality for multi-modal models""" -from typing import Any - import transformers from transformers import ( AutoProcessor, PreTrainedTokenizerBase, ) +from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +@send_errors def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): - processor_kwargs: dict[str, Any] = {} # Do we actually need this? - processor_cls = AutoProcessor if cfg.processor_type: processor_cls = getattr(transformers, cfg.processor_type) if cfg.tokenizer_use_mistral_common: + + def _patch_mistralcommontokenizer(): + """ + Transformers v5 stops reading the sub-processor. + + We need to patch this, so both processors use this. + """ + import transformers.tokenization_mistral_common as tokenization_mistral_common + + from axolotl.utils.mistral import HFMistralTokenizer + + tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer + + _patch_mistralcommontokenizer() + + from transformers import VoxtralProcessor + + if processor_cls == VoxtralProcessor: + return VoxtralProcessor.from_pretrained( + cfg.processor_config, + ) + from axolotl.utils.mistral import Mistral3Processor return Mistral3Processor( @@ -32,7 +52,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): cfg.processor_config, trust_remote_code=cfg.trust_remote_code or False, tokenizer=tokenizer, - **processor_kwargs, ) # Attempt to load image size from processor if available diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 69455dd77..48856116c 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -13,6 +13,7 @@ from transformers import ( from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN +from axolotl.telemetry.errors import send_errors from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import ( @@ -119,6 +120,7 @@ def modify_tokenizer_files( return tokenizer_dir +@send_errors def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: """Load and configure the tokenizer based on the provided config.""" diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 9e5c4b324..ad6b6f4ef 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -50,6 +50,9 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "seed_oss", "lfm2", "lfm2_moe", + "olmo", + "olmo2", + "olmo3", ] diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index f4dcbd7cd..28155810f 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -823,6 +823,23 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return None if isinstance(tools, list): + # Process each tool to handle JSON string parameters + for tool in tools: + if isinstance(tool, dict) and "function" in tool: + function = tool["function"] + if "parameters" in function: + params = function["parameters"] + if isinstance(params, str): + try: + function["parameters"] = json.loads(params) + except json.JSONDecodeError as e: + LOG.error( + f"Error parsing tool parameters as JSON. " + f"Function: {function.get('name', 'unknown')}, " + f"Parameters string: {params!r}, " + f"Error: {e}" + ) + raise return tools raise ValueError( diff --git a/tests/e2e/integrations/__init__.py b/src/axolotl/telemetry/__init__.py similarity index 100% rename from tests/e2e/integrations/__init__.py rename to src/axolotl/telemetry/__init__.py diff --git a/src/axolotl/telemetry/callbacks.py b/src/axolotl/telemetry/callbacks.py new file mode 100644 index 000000000..0ce52ffa4 --- /dev/null +++ b/src/axolotl/telemetry/callbacks.py @@ -0,0 +1,165 @@ +"""Trainer callbacks for reporting runtime metrics at regular intervals.""" + +import logging +import time + +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +from axolotl.telemetry.manager import TelemetryManager +from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker + +LOG = logging.getLogger(__name__) + +TIME_SINCE_LAST = 60 + + +class TelemetryCallback(TrainerCallback): + """ + Trainer callback for tracking and reporting runtime metrics. + + This callback tracks training progress, runtime, and memory usage, + sending telemetry at configurable intervals. + """ + + report_interval_steps: int = 100 + + def __init__(self): + """Initialize the metrics callback.""" + self.tracker = RuntimeMetricsTracker() + self.telemetry_manager = TelemetryManager.get_instance() + self.current_epoch = -1 + self.start_time = time.time() + self.last_report_time = None + self.last_report_step = 0 + + # pylint: disable=unused-argument + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Handle training start.""" + self.telemetry_manager.send_event(event_type="train-start") + + # pylint: disable=unused-argument + def on_train_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Handle training end.""" + # Send training completion event + self.telemetry_manager.send_event( + event_type="train-end", + properties=self._extract_last_metrics(state) + | self.tracker.metrics.to_dict(), + ) + + # pylint: disable=unused-argument + def on_epoch_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Handle epoch start.""" + self.current_epoch += 1 + self.tracker.start_epoch(self.current_epoch) + + # pylint: disable=unused-argument + def on_epoch_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Handle epoch end.""" + self.tracker.end_epoch(self.current_epoch) + + # pylint: disable=unused-argument + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + """Handle step end.""" + step = state.global_step + self.tracker.update_step(step) + + # Check if we should report metrics + should_report = ( + step % self.report_interval_steps == 0 + or step == 1 # Always report first step + or step - self.last_report_step >= self.report_interval_steps + ) + + if should_report: + current_time = time.time() + if self.last_report_time is not None: + time_since_last_report = current_time - self.last_report_time + else: + time_since_last_report = current_time - self.start_time + steps_since_last_report = step - self.last_report_step + + # Only report if enough time has passed + if ( + step == 1 + or time_since_last_report >= TIME_SINCE_LAST + or steps_since_last_report >= self.report_interval_steps + ): + # Calculate steps per second for this interval + if time_since_last_report > 0 and steps_since_last_report > 0: + steps_per_second = steps_since_last_report / time_since_last_report + else: + steps_per_second = 0 + + # Update memory metrics + self.tracker.update_memory_metrics() + + # Prepare metrics to report + metrics = self._extract_last_metrics(state) | { + "step": step, + "epoch": self.current_epoch, + "progress": state.epoch, # Fractional epoch progress + "steps_per_second": steps_per_second, + "elapsed_time": current_time - self.start_time, + "time_since_last_report": time_since_last_report, + } + + # Add memory metrics + memory_metrics = self.tracker.get_memory_metrics() + metrics.update({"memory": memory_metrics}) + + # Send telemetry + self.telemetry_manager.send_event( + event_type="train-progress", properties=metrics + ) + + # Update last report time and step + self.last_report_time = current_time + self.last_report_step = step + + def _extract_last_metrics(self, state: TrainerState) -> dict: + """Extract last loss, learning_rate, and grad_norm from log history.""" + if not state.log_history: + return {"loss": 0, "learning_rate": 0, "grad_norm": 0} + + last_log = state.log_history[-1] + return { + "loss": last_log.get("loss", 0), + "learning_rate": last_log.get("learning_rate", 0), + "grad_norm": last_log.get("grad_norm", 0), + } diff --git a/src/axolotl/telemetry/errors.py b/src/axolotl/telemetry/errors.py new file mode 100644 index 000000000..27f2d2192 --- /dev/null +++ b/src/axolotl/telemetry/errors.py @@ -0,0 +1,160 @@ +"""Telemetry utilities for exception and traceback information.""" + +import logging +import os +import re +import traceback +from functools import wraps +from inspect import getmodule +from typing import Any, Callable + +from axolotl.telemetry.manager import TelemetryManager + +LOG = logging.getLogger(__name__) + +ERROR_HANDLED = False + + +def sanitize_stack_trace(stack_trace: str) -> str: + """ + Remove personal information from stack trace messages while keeping Python package codepaths. + + This function identifies Python packages by looking for common patterns in virtual environment + and site-packages directories, preserving the package path while removing user-specific paths. + + Args: + stack_trace: The original stack trace string. + + Returns: + A sanitized version of the stack trace with Python package paths preserved. + """ + # Split the stack trace into lines to process each file path separately + lines = stack_trace.split("\n") + sanitized_lines = [] + + # Regular expression to find file paths in the stack trace + path_pattern = re.compile(r'(?:File ")(.*?)(?:")') + + # Regular expression to identify paths in site-packages or dist-packages + # This matches path segments like "site-packages/package_name" or "dist-packages/package_name" + site_packages_pattern = re.compile( + r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)" + ) + + # Additional common virtual environment patterns + venv_lib_pattern = re.compile( + r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)" + ) + + for line in lines: + # Check if this line contains a file path + path_match = path_pattern.search(line) + + if path_match: + full_path = path_match.group(1) + sanitized_path = "" + + # Try to match site-packages pattern + site_packages_match = site_packages_pattern.search(full_path) + venv_lib_match = venv_lib_pattern.search(full_path) + + if site_packages_match: + # Find the index where the matched pattern starts + idx = full_path.find("site-packages") + if idx == -1: + idx = full_path.find("dist-packages") + + # Keep from 'site-packages' onward + if idx >= 0: + sanitized_path = full_path[idx:] + elif venv_lib_match: + # For other virtual environment patterns, find the package directory + match_idx = venv_lib_match.start(1) + if match_idx > 0: + # Keep from the package name onward + package_name = venv_lib_match.group(1) + idx = full_path.rfind( + package_name, 0, match_idx + len(package_name) + ) + if idx >= 0: + sanitized_path = full_path[idx:] + + # If we couldn't identify a package pattern but path contains 'axolotl' + elif "axolotl" in full_path: + idx = full_path.rfind("axolotl") + if idx >= 0: + sanitized_path = full_path[idx:] + + # Apply the sanitization to the line + if sanitized_path: + line = line.replace(full_path, sanitized_path) + else: + # If we couldn't identify a package pattern, just keep the filename + filename = os.path.basename(full_path) + if filename: + line = line.replace(full_path, filename) + else: + line = line.replace(full_path, "") + + sanitized_lines.append(line) + + return "\n".join(sanitized_lines) + + +def send_errors(func: Callable) -> Callable: + """ + Decorator to send exception info in a function. If an exception is raised, we send + telemetry containing the stack trace and error message. + + If an error occurs in a decorated function that is called by another decorated + function, we'll only send telemetry corresponding to the lower-level function. + + Args: + func: Function to decorate. + + Returns: + Decorated function. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + telemetry_manager = TelemetryManager.get_instance() + + if not telemetry_manager.enabled: + return func(*args, **kwargs) + + try: + return func(*args, **kwargs) + except Exception as exception: + # Only track if we're not already handling an error. This prevents us from + # capturing an error more than once in nested decorated function calls. + global ERROR_HANDLED # pylint: disable=global-statement + if not ERROR_HANDLED: + ERROR_HANDLED = True + + # Get function module path + module = getmodule(func) + module_path = ( + f"{module.__name__}.{func.__name__}" if module else func.__name__ + ) + + # Get stack trace + stack_trace = "".join( + traceback.format_exception( + type(exception), exception, exception.__traceback__ + ) + ) + stack_trace = sanitize_stack_trace(stack_trace) + + # Send error telemetry + telemetry_manager.send_event( + event_type=f"{module_path}-error", + properties={ + "exception": str(exception), + "stack_trace": stack_trace, + }, + ) + + raise + + return wrapper diff --git a/src/axolotl/telemetry/manager.py b/src/axolotl/telemetry/manager.py new file mode 100644 index 000000000..82d310cdc --- /dev/null +++ b/src/axolotl/telemetry/manager.py @@ -0,0 +1,416 @@ +"""Telemetry manager and associated utilities.""" + +import atexit +import importlib +import logging +import os +import platform +import time +import uuid +from pathlib import Path +from typing import Any + +import posthog +import psutil +import torch +import yaml + +LOG = logging.getLogger(__name__) + +POSTHOG_HOST = "https://app.posthog.com" +POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y" + +OPT_OUT_WARNING_SLEEP_SECONDS = 10 +OPT_OUT_WARNING = ( + "\nTelemetry is now enabled by default to help improve Axolotl. " + "If you'd like to disable it, set AXOLOTL_DO_NOT_TRACK=1 in your environment.\n\n" + "Telemetry data helps us understand:\n" + "- Which features are most used\n" + "- What hardware configurations to prioritize\n" + "- Where users encounter errors\n\n" + "Personally identifiable information (PII) is not collected.\n\n" + "To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) " + "or AXOLOTL_DO_NOT_TRACK=1 (disable telemetry).\n\n" + "For details, see: https://docs.axolotl.ai/docs/telemetry.html\n\n" + f"Sleeping for {OPT_OUT_WARNING_SLEEP_SECONDS}s..." +) + +WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml") + +# NOTE: Need to keep these up to date with any config schema changes +FIELDS_TO_REDACT = { + "base_model", + "tokenizer_config", + "base_model_config", + "pretraining_dataset", # NOTE: this field may be a string or a dictionary + "resume_from_checkpoint", + "hub_model_id", +} +PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"} +PATH_INDICATORS = {"path", "dir"} + +# pylint: disable=duplicate-code +RELEVANT_PACKAGES = { + "torch", + "transformers", + "trl", + "datasets", + "peft", + "bitsandbytes", + "accelerate", + "optimum", + "deepspeed", + "ray", + "axolotl", + "triton", + "mamba-ssm", + "flash-attn", + "xformers", + "autoawq", + "tokenizers", + "sentencepiece", + "torchao", + "lm_eval", +} + + +def is_main_process() -> bool: + """ + Check whether we're running in the main process. + + Note: + We're using this function instead of `torch.utils.distributed.is_main_process` + causes issues with DeepSpeed world_size since. This function avoids that issue + by checking env vars that are set by various launchers. + + Returns: + Whether we're running in the main process. + """ + # If PyTorch distributed is already initialized, use it + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() == 0 + + # Otherwise check environment variables for global rank + # NOTE: need to verify this in SLURM / OpenMPI environments + global_rank = int( + os.environ.get( + "RANK", + os.environ.get( + "GLOBAL_RANK", + os.environ.get( + "SLURM_PROCID", + os.environ.get( + "OMPI_COMM_WORLD_RANK", + "0", + ), + ), + ), + ) + ) + + return global_rank == 0 + + +class TelemetryManager: + """Manages telemetry collection and transmission""" + + _instance = None + _initialized = False + + def __new__(cls): + """ + Telemetry manager constructor. Creates the singleton instance of this class if + it doesn't already exist. + """ + if cls._instance is None: + cls._instance = super(TelemetryManager, cls).__new__(cls) + cls._instance._initialized = False + + return cls._instance + + def __init__(self): + """Telemetry manager initializer""" + if self._initialized: + return + + self.enabled = self._check_telemetry_enabled() + + if self.enabled: + self.run_id = str(uuid.uuid4()) + self.whitelist = self._load_whitelist() + + try: + self.system_info = self._get_system_info() + except Exception as e: # pylint: disable=broad-exception-caught + LOG.warning(f"Error during system info collection: {e}") + self.system_info = None + + self._init_posthog() + + # Register shutdown method to flush posthog telemetry + atexit.register(self.shutdown) + + self._initialized = True + + @classmethod + def get_instance(cls) -> "TelemetryManager": + if cls._instance is None: + cls._instance = TelemetryManager() + + return cls._instance + + def _check_telemetry_enabled(self) -> bool: + """ + Check if telemetry is enabled based on environment variables. We also check + whether this is the main process (for the distributed setting and to avoid + sending duplicate PostHog events per GPU). + + Note: This is enabled by default on an opt-out basis. Set + `AXOLOTL_DO_NOT_TRACK=1` to disable telemetry. For more details, see + https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html. + + Returns: + Boolean denoting whether telemetry is enabled or not. + """ + # Parse relevant env vars + axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK") + do_not_track = os.getenv("DO_NOT_TRACK") + + # Default to enabled (opt-out model) + if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in ( + "0", + "1", + "false", + "true", + ): + # Print opt-out info message for main process only + if is_main_process(): + LOG.warning(OPT_OUT_WARNING) + time.sleep(OPT_OUT_WARNING_SLEEP_SECONDS) + + return True + + # Only rank 0 will send telemetry + if not is_main_process(): + return False + + if do_not_track is None: + do_not_track = "0" + + # Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled + enabled = axolotl_do_not_track.lower() not in ( + "1", + "true", + ) and do_not_track.lower() not in ("1", "true") + + return enabled + + def _load_whitelist(self) -> dict: + """Load HuggingFace Hub organization whitelist""" + with open(WHITELIST_PATH, encoding="utf-8") as f: + whitelist = yaml.safe_load(f) + + # Send org strings to lowercase since model names are case insensitive + whitelist["organizations"] = { + org.lower() for org in whitelist["organizations"] + } + + return whitelist + + def _is_whitelisted(self, value: str) -> bool: + """ + Check if model / dataset / etc. org is in whitelist. + + Args: + value: Value for one of `axolotl.telemetry.manager.FIELDS_WITH_ORGS` + ("base_model", etc.). + + Returns: + Boolean indicating whitelist membership. + """ + # NOTE: This membership-checking logic can be improved. + # What happens when a local model path matches a whitelisted org? + parts = value.split("/") + if len(parts) < 2: + return False + org = parts[0] + whitelisted = org.lower() in self.whitelist["organizations"] + + return whitelisted + + def _init_posthog(self): + """Initialize PostHog client""" + posthog.api_key = POSTHOG_WRITE_KEY + posthog.project_api_key = POSTHOG_WRITE_KEY + posthog.host = POSTHOG_HOST + + def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]: + """ + Redact properties to remove any paths, so as to avoid inadvertently collecting + private or personally identifiable information (PII). We also remove + information related to Wandb, MLflow, etc. configuration. + + Args: + properties: Dictionary of properties to redact. + + Returns: + Properties dictionary with redaction applied. + """ + if not properties: + return {} + + def redact_value(value: Any, key: str = "") -> Any: + """Recursively sanitize values, redacting those with path-like keys""" + if isinstance(key, str) and isinstance(value, str): + # Other redaction special cases + if ( + key in FIELDS_TO_REDACT + or any(prefix in key for prefix in PREFIXES_TO_REDACT) + or any(indicator in key.lower() for indicator in PATH_INDICATORS) + ): + # Fields with whitelisted orgs don't need to be redacted + if not self._is_whitelisted(value): + return "[REDACTED]" + + # Handle nested values + if isinstance(value, dict): + return {k: redact_value(v, k) for k, v in value.items()} + if isinstance(value, list): + return [redact_value(item) for item in value] + + return value + + # Create new dict with redacted values + redacted = {k: redact_value(v, k) for k, v in properties.items()} + + return redacted + + def _get_system_info(self) -> dict[str, Any]: + """Collect system information for various hardware accelerators""" + gpu_info = [] + accelerator_type = "none" + + # NVIDIA GPUs + if torch.cuda.is_available(): + accelerator_type = "cuda" + for i in range(torch.cuda.device_count()): + gpu_info.append( + { + "name": torch.cuda.get_device_name(i), + "memory": torch.cuda.get_device_properties(i).total_memory, + } + ) + + # AMD GPUs + elif hasattr(torch, "hip") and torch.hip.is_available(): + accelerator_type = "hip" + for i in range(torch.hip.device_count()): + gpu_info.append( + { + "name": torch.hip.get_device_name(i), + "memory": ( + torch.hip.get_device_properties(i).total_memory + if hasattr(torch.hip, "get_device_properties") + else None + ), + } + ) + + # Apple Silicon + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + accelerator_type = "mps" + gpu_info.append( + { + "name": "Apple Silicon", + # NOTE: this is memory allocated to this process, not total memory + "memory": torch.mps.driver_allocated_memory(), + } + ) + + # Intel GPUs + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + accelerator_type = "xpu" + for i in range(torch.xpu.device_count()): + memory = None + if hasattr(torch.xpu, "get_device_properties"): + memory = torch.xpu.get_device_properties(i).total_memory + + gpu_info.append( + { + "name": torch.xpu.get_device_name(i), + "memory": memory, + } + ) + + # NPUs + elif hasattr(torch, "npu") and torch.npu.is_available(): + accelerator_type = "npu" + for i in range(torch.npu.device_count()): + memory = None + if hasattr(torch.npu, "get_device_properties"): + memory = torch.npu.get_device_properties(i).total_memory + + gpu_info.append( + { + "name": torch.npu.get_device_name(i), + "memory": memory, + } + ) + + # Get relevant package versions + installed_packages = {} + for package in RELEVANT_PACKAGES: + try: + version = importlib.metadata.version(package) + installed_packages[f"{package}_version"] = version + except importlib.metadata.PackageNotFoundError: + pass + + return { + "os": platform.system(), + "python_version": platform.python_version(), + "cpu_count": psutil.cpu_count(), + "memory_total": psutil.virtual_memory().total, + "accelerator_type": accelerator_type, + "accelerator_count": len(gpu_info), + "accelerator_info": gpu_info, + **installed_packages, + } + + def send_event(self, event_type: str, properties: dict[str, Any] | None = None): + """Send a telemetry event""" + if not self.enabled: + return + + if properties is None: + properties = {} + + # Sanitize properties to remove PII + properties = self._redact_paths(properties) + + # Wrap PostHog errors in try / except to not raise errors during Axolotl usage + try: + # Send event via PostHog + posthog.capture( + distinct_id=self.run_id, + event=event_type, + properties=properties, + disable_geoip=True, + ) + except Exception as e: # pylint: disable=broad-exception-caught + LOG.warning(f"Failed to send telemetry event: {e}") + + # Additionally, send system info telemetry when loading config. + # NOTE: Is this the best place for this? + if event_type == "config-loaded": + self.send_system_info() + + def send_system_info(self): + """Helper method for sending system info""" + if self.system_info is not None: + self.send_event(event_type="system-info", properties=self.system_info) + + def shutdown(self): + """Ensure all queued events are processed before shutdown""" + if self.enabled: + posthog.shutdown() diff --git a/src/axolotl/telemetry/runtime_metrics.py b/src/axolotl/telemetry/runtime_metrics.py new file mode 100644 index 000000000..fa83c00a7 --- /dev/null +++ b/src/axolotl/telemetry/runtime_metrics.py @@ -0,0 +1,210 @@ +"""Telemetry utilities for runtime and memory metrics.""" + +import logging +import time +from dataclasses import dataclass, field +from typing import Any + +import psutil +import torch + +from axolotl.telemetry.manager import TelemetryManager + +LOG = logging.getLogger(__name__) + + +@dataclass +class RuntimeMetrics: + """Container for runtime metrics to be tracked throughout training.""" + + # Timing metrics + start_time: float + epoch_start_times: dict[int, float] = field(init=False) + epoch_end_times: dict[int, float] = field(init=False) + + # Memory metrics + peak_cpu_memory: int = 0 + peak_gpu_memory: dict[int, int] = field(init=False) + + # Progress metrics + total_steps: int = 0 + current_epoch: int = 0 + current_step: int = 0 + + def __post_init__(self): + """Initialize empty metric mappings.""" + self.epoch_start_times = {} + self.epoch_end_times = {} + self.peak_gpu_memory = {} + + @property + def elapsed_time(self) -> float: + """Calculate total elapsed time in seconds.""" + return time.time() - self.start_time + + def epoch_time(self, epoch: int) -> float | None: + """Calculate time taken for a specific epoch in seconds.""" + if epoch in self.epoch_start_times and epoch in self.epoch_end_times: + return self.epoch_end_times[epoch] - self.epoch_start_times[epoch] + + return None + + def average_epoch_time(self) -> float | None: + """Calculate average time per epoch in seconds.""" + completed_epochs = [ + epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times + ] + if not completed_epochs: + return None + + total_time = 0.0 + for epoch in completed_epochs: + epoch_time = self.epoch_time(epoch) + if epoch_time is not None: # Check to avoid mypy warning + total_time += epoch_time + + return total_time / len(completed_epochs) + + def steps_per_second(self) -> float | None: + """Calculate average steps per second across all training.""" + if self.total_steps == 0 or self.elapsed_time == 0: + return None + + return self.total_steps / self.elapsed_time + + def to_dict(self) -> dict[str, Any]: + """Convert metrics to a dictionary for telemetry reporting.""" + metrics = { + "total_time_seconds": self.elapsed_time, + "total_steps": self.total_steps, + "steps_per_second": self.steps_per_second(), + "epochs_completed": len( + [ + epoch + for epoch in self.epoch_start_times + if epoch in self.epoch_end_times + ] + ), + "peak_cpu_memory_bytes": self.peak_cpu_memory, + } + + # Add per-epoch timing if available + epoch_times: dict[str, float] = {} + for epoch in sorted(self.epoch_end_times.keys()): + time_taken = self.epoch_time(epoch) + if time_taken is not None: + epoch_times[f"epoch_{epoch}_seconds"] = time_taken + + if epoch_times: + metrics["epoch_times"] = epoch_times # type: ignore + metrics["average_epoch_time_seconds"] = self.average_epoch_time() + + # Add GPU memory metrics if available + if self.peak_gpu_memory: + gpu_metrics: dict[str, int] = {} + for gpu_id, memory in self.peak_gpu_memory.items(): + gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory + metrics["gpu_memory"] = gpu_metrics # type: ignore + + return metrics + + +class RuntimeMetricsTracker: + """Tracker for runtime metrics during training.""" + + update_interval = 100 + + def __init__(self): + """Initialize the runtime metrics tracker.""" + self.metrics = RuntimeMetrics(start_time=time.time()) + self.telemetry_manager = TelemetryManager.get_instance() + self._process = psutil.Process() + + def start_epoch(self, epoch: int): + """Record the start of a new epoch.""" + self.metrics.current_epoch = epoch + self.metrics.epoch_start_times[epoch] = time.time() + self.update_memory_metrics() + + def end_epoch(self, epoch: int): + """Record the end of an epoch.""" + self.metrics.epoch_end_times[epoch] = time.time() + + def update_step(self, step: int): + """Update the current step count.""" + self.metrics.current_step = step + self.metrics.total_steps += 1 + + # Periodically update memory metrics + if step % self.update_interval == 0: + self.update_memory_metrics() + + def _get_allocated_memory(self) -> dict[int, int]: + """ + Helper function for getting accelerator-agnostic allocated memory. + + Returns: + A dictionary mapping device IDs to allocated memory in bytes + """ + memory_used: dict[int, int] = {} + + # NVIDIA GPUs + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + memory_used[i] = torch.cuda.memory_allocated(i) + + # AMD GPUs + elif hasattr(torch, "hip") and torch.hip.is_available(): + for i in range(torch.hip.device_count()): + if hasattr(torch.hip, "memory_allocated"): + memory_used[i] = torch.hip.memory_allocated(i) + + # Apple Silicon + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + # MPS doesn't have per-device memory stats since there's only one device + if hasattr(torch.mps, "current_allocated_memory"): + memory_used[0] = torch.mps.current_allocated_memory() + + # Intel GPUs + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + for i in range(torch.xpu.device_count()): + if hasattr(torch.xpu, "memory_allocated"): + memory_used[i] = torch.xpu.memory_allocated(i) + + # NPUs + elif hasattr(torch, "npu") and torch.npu.is_available(): + for i in range(torch.npu.device_count()): + if hasattr(torch.npu, "memory_allocated"): + memory_used[i] = torch.npu.memory_allocated(i) + + return memory_used + + def update_memory_metrics(self): + """Update peak memory usage metrics.""" + # CPU memory + cpu_memory = self._process.memory_info().rss + self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory) + + # GPU memory (if available) + memory_used = self._get_allocated_memory() + for i, memory in memory_used.items(): + self.metrics.peak_gpu_memory[i] = max( + self.metrics.peak_gpu_memory.get(i, 0), memory + ) + + def get_memory_metrics(self) -> dict[str, Any]: + """Get the current memory metrics as a dictionary.""" + memory_metrics = { + "cpu_memory_bytes": self._process.memory_info().rss, + "peak_cpu_memory_bytes": self.metrics.peak_cpu_memory, + } + + # GPU memory (if available) + memory_used = self._get_allocated_memory() + for i, memory in memory_used.items(): + memory_metrics[f"gpu_{i}_memory_bytes"] = memory + memory_metrics[f"gpu_{i}_peak_memory_bytes"] = ( + self.metrics.peak_gpu_memory.get(i, 0) + ) + + return memory_metrics diff --git a/src/axolotl/telemetry/whitelist.yaml b/src/axolotl/telemetry/whitelist.yaml new file mode 100644 index 000000000..6c94d6e79 --- /dev/null +++ b/src/axolotl/telemetry/whitelist.yaml @@ -0,0 +1,33 @@ +organizations: + - "axolotl-ai-co" + - "meta-llama" + - "huggingface" + - "nvidia" + - "facebook" + - "google" + - "microsoft" + - "deepseek-ai" + - "HuggingFaceTB" + - "mistralai" + - "Qwen" + - "unsloth" + - "NousResearch" + - "allenai" + - "amd" + - "tiiuae" + - "tencent" + - "zai-org" + - "openai" + - "ibm-granite" + - "arcee-ai" + - "swiss-ai" + - "CohereForAI" + - "deepcogito" + - "THUDM" + - "ai21labs" + - "LiquidAI" + - "canopylabs" + - "state-spaces" + - "mistral-community" + - "llava-hf" + - "ByteDance-Seed" diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 441c50871..cce3b8a6a 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -31,6 +31,8 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module ) from axolotl.integrations.base import PluginManager from axolotl.loaders import ModelLoader, load_processor, load_tokenizer +from axolotl.telemetry.errors import send_errors +from axolotl.telemetry.manager import TelemetryManager from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed @@ -45,6 +47,9 @@ if typing.TYPE_CHECKING: LOG = get_logger(__name__) +TELEMETRY_MANAGER = TelemetryManager.get_instance() +PLUGIN_MANAGER = PluginManager.get_instance() + def setup_model_and_tokenizer( cfg: DictDefault, @@ -62,7 +67,10 @@ def setup_model_and_tokenizer( `None`), and processor (if multimodal, else `None`). """ # Load tokenizer - LOG.debug(f"Loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + LOG.debug( + f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", + main_process_only=True, + ) tokenizer = load_tokenizer(cfg) # Load processor for multimodal models if needed @@ -78,6 +86,14 @@ def setup_model_and_tokenizer( if model.generation_config is not None: model.generation_config.do_sample = True + TELEMETRY_MANAGER.send_event( + event_type="model-load", properties=model.config.to_dict() + ) + if peft_config: + TELEMETRY_MANAGER.send_event( + event_type="peft-config-load", properties=peft_config.to_dict() + ) + # Apply freezing if specified if cfg.unfrozen_parameters: freeze_layers_except(model, cfg.unfrozen_parameters) @@ -196,8 +212,7 @@ def execute_training( LOG.info("Starting trainer...") trainer.train(resume_from_checkpoint=resume_from_checkpoint) - plugin_manager = PluginManager.get_instance() - plugin_manager.post_train(cfg, trainer.model) + PLUGIN_MANAGER.post_train(cfg, trainer.model) def save_trained_model( @@ -521,9 +536,7 @@ def setup_model_and_trainer( model_ref=model_ref, peft_config=peft_config, ) - - plugin_manager = PluginManager.get_instance() - plugin_manager.post_trainer_create(cfg, trainer) + PLUGIN_MANAGER.post_trainer_create(cfg, trainer) if cfg.use_ray: try: @@ -545,6 +558,7 @@ def setup_model_and_trainer( ) +@send_errors def train( cfg: DictDefault, dataset_meta: TrainDatasetMeta ) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]: @@ -595,5 +609,6 @@ def train( create_model_card(cfg, trainer) if not cfg.use_ray: cleanup_distributed() + PLUGIN_MANAGER.post_train(cfg, model) return model, tokenizer, trainer diff --git a/src/axolotl/utils/callbacks/dynamic_checkpoint.py b/src/axolotl/utils/callbacks/dynamic_checkpoint.py new file mode 100644 index 000000000..632109225 --- /dev/null +++ b/src/axolotl/utils/callbacks/dynamic_checkpoint.py @@ -0,0 +1,132 @@ +from pathlib import Path + +from transformers import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +from axolotl.utils.distributed import ( + barrier, + is_distributed, + is_main_process, +) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +DEFAULT_TRIGGER_FILENAME = "axolotl_checkpoint.save" + + +class DynamicCheckpointCallback(TrainerCallback): + """ + Callback to save checkpoints on-demand during training via: + 1. File-based trigger (works everywhere, rank 0 checks file) + + Thread-safe for multi-GPU distributed training. + + Usage: + # File-based: + touch /path/to/output_dir/axolotl_checkpoint.save + """ + + def _get_config_value(self, config, key, default=None): + """Helper to get config value from dict or object.""" + if isinstance(config, dict): + return config.get(key, default) + return getattr(config, key, default) + + def __init__(self, cfg): + self.cfg = cfg + if not cfg.dynamic_checkpoint or not cfg.dynamic_checkpoint.enabled: + self.enabled = False + return + + self.enabled = True + dc_config = cfg.dynamic_checkpoint + + trigger_file_path = self._get_config_value(dc_config, "trigger_file_path") + self.trigger_filename = ( + trigger_file_path if trigger_file_path else DEFAULT_TRIGGER_FILENAME + ) + + check_interval = self._get_config_value(dc_config, "check_interval") + self.check_interval = check_interval if check_interval is not None else 100 + self.should_save_checkpoint = False + + LOG.info( + f"Dynamic checkpoint enabled. To trigger checkpoint save:\n" + f" โ€ข File: touch {cfg.output_dir}/{self.trigger_filename}\n" + f" โ€ข Check interval: every {self.check_interval} steps", + main_process_only=True, + ) + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **_kwargs, + ) -> TrainerControl: + """ + Check for checkpoint triggers at the end of each step. + ONLY rank 0 checks the file, then all ranks synchronize. + """ + if not self.enabled: + return control + + trigger_detected = False + + if state.global_step % self.check_interval == 0: + if is_main_process(): + trigger_path = Path(args.output_dir) / self.trigger_filename + + if trigger_path.exists(): + trigger_detected = True + try: + trigger_path.unlink() # Delete the trigger file + LOG.info( + f"Dynamic checkpoint triggered via file '{self.trigger_filename}' " + f"at step {state.global_step}", + main_process_only=True, + ) + except OSError as exc: + LOG.warning( + f"Failed to delete trigger file: {exc}", + main_process_only=True, + ) + + if self.should_save_checkpoint: + trigger_detected = True + self.should_save_checkpoint = False # Reset flag + + if is_distributed(): + import torch + import torch.distributed as dist + + device = getattr( + args, + "device", + torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) + + trigger_tensor = torch.tensor( + 1 if trigger_detected else 0, + dtype=torch.long, + device=device, + ) + + dist.broadcast(trigger_tensor, src=0) + + trigger_detected = bool(trigger_tensor.item()) + + barrier() + + if trigger_detected: + control.should_save = True + LOG.info( + f"Saving dynamic checkpoint at step {state.global_step}", + main_process_only=True, + ) + return control diff --git a/src/axolotl/utils/mistral/mistral3_processor.py b/src/axolotl/utils/mistral/mistral3_processor.py index 85479ca7b..01e8f9f10 100644 --- a/src/axolotl/utils/mistral/mistral3_processor.py +++ b/src/axolotl/utils/mistral/mistral3_processor.py @@ -30,6 +30,7 @@ class Mistral3Processor(ProcessorMixin): Wraps HFMistralTokenizer and adds image processing capabilities. """ + # TODO(nano): This should be removed in transformers V5 attributes = ["tokenizer"] tokenizer_class = "HFMistralTokenizer" diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 86b3aa17b..c9b087ea3 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -23,6 +23,7 @@ from axolotl.utils.schemas.datasets import ( StepwiseSupervisedDataset, ) from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters +from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType from axolotl.utils.schemas.fsdp import FSDPConfig from axolotl.utils.schemas.integrations import ( @@ -141,6 +142,13 @@ class AxolotlInputConfig( default=None, json_schema_extra={"description": "Reward modelling: `True` or `False`"}, ) + dynamic_checkpoint: DynamicCheckpointConfig | None = Field( + default=None, + json_schema_extra={ + "description": "Configuration for dynamic checkpointing (trigger by file or signal). " + "Set 'enabled: true' to activate this feature." + }, + ) process_reward_model: bool | None = Field( default=None, json_schema_extra={ @@ -1061,7 +1069,7 @@ class AxolotlInputConfig( class AxolotlConfigWCapabilities(AxolotlInputConfig): - """wrapper to valdiate GPU capabilities with the configured options""" + """Wrapper to valdiate GPU capabilities with the configured options""" capabilities: GPUCapabilities env_capabilities: EnvCapabilities diff --git a/src/axolotl/utils/schemas/dynamic_checkpoint.py b/src/axolotl/utils/schemas/dynamic_checkpoint.py new file mode 100644 index 000000000..e0e1d0c1d --- /dev/null +++ b/src/axolotl/utils/schemas/dynamic_checkpoint.py @@ -0,0 +1,31 @@ +"""Schema for dynamic checkpoint configuration.""" + +from pydantic import BaseModel, Field + + +class DynamicCheckpointConfig(BaseModel): + """Configuration for dynamic checkpoint triggering during training.""" + + enabled: bool = Field( + default=False, + json_schema_extra={ + "description": "Enable dynamic checkpoint triggering during training. " + "Create a file 'axolotl_checkpoint.save' in the configured `output_dir` to trigger. " + }, + ) + check_interval: int = Field( + default=10, + ge=1, + json_schema_extra={ + "description": "Check for trigger file every N steps (reduces I/O overhead). " + "Default: 100" + }, + ) + trigger_file_path: str = Field( + default="", + json_schema_extra={ + "description": "Custom trigger filename (optional). " + "If not specified, defaults to 'axolotl_checkpoint.save'. " + "Specify a filename (not a full path) to override the default." + }, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 98847ebad..d3b9407ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,4 @@ -""" -shared pytest fixtures -""" +"""Shared pytest fixtures""" import functools import importlib @@ -582,3 +580,9 @@ def test_load_fixtures( download_llama2_model_fixture, ): pass + + +@pytest.fixture(autouse=True) +def disable_telemetry(monkeypatch): + monkeypatch.setenv("AXOLOTL_DO_NOT_TRACK", "1") + yield diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 55317151e..e50483e6c 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -3,6 +3,7 @@ Simple end-to-end test for Liger integration """ import pytest + from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins, validate_config diff --git a/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py b/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py index 7de21b940..5866cc367 100644 --- a/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py +++ b/tests/prompt_strategies/test_chat_templates_tool_call_string_arguments.py @@ -69,7 +69,7 @@ class TestQwen3IdenticalConversationArgs: { "function": { "name": function_name, - "arguments": arguments_dict, # dictๆ ผๅผ + "arguments": arguments_dict, # dict } } ], @@ -100,7 +100,7 @@ class TestQwen3IdenticalConversationArgs: { "function": { "name": function_name, - "arguments": arguments_str, # strๆ ผๅผ + "arguments": arguments_str, # str } } ], @@ -212,3 +212,294 @@ class TestQwen3IdenticalConversationArgs: decoded = qwen3_tokenizer.decode(processed[0]["input_ids"]) assert "2025-08-01" in decoded, "String time value should be present" assert "1690876800" in decoded, "Number time value should be present" + + +class TestQwen3IdenticalToolsParameters: + """ + Test Qwen3 tools parameters handling is identical between JSON string and dict + """ + + @pytest.fixture(name="tools_dict_params_dataset") + def fixture_tools_dict_params_dataset(self): + """ + Provides a dataset with tools where parameters is a dict. + """ + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + }, + } + ] + + data = [ + { + "tools": tools, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_weather", + "arguments": {"location": "Boston, MA"}, + }, + } + ], + }, + { + "role": "tool", + "name": "get_weather", + "content": "72ยฐF and sunny", + }, + ], + } + ] + return Dataset.from_list(data) + + @pytest.fixture(name="tools_str_params_dataset") + def fixture_tools_str_params_dataset(self): + """ + Provides a dataset with tools where parameters is a JSON string. + """ + parameters_dict = { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + } + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information", + "parameters": json.dumps(parameters_dict), + }, + } + ] + + data = [ + { + "tools": tools, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_weather", + "arguments": {"location": "Boston, MA"}, + }, + } + ], + }, + { + "role": "tool", + "name": "get_weather", + "content": "72ยฐF and sunny", + }, + ], + } + ] + return Dataset.from_list(data) + + @pytest.fixture(name="tools_mixed_type_params_dataset") + def fixture_tools_mixed_type_params_dataset(self): + """ + Provides a dataset where different tools have the same parameter name with different types. + This tests that JSON string format prevents casting issues. + """ + tools = [ + { + "type": "function", + "function": { + "name": "tool_with_string_arg", + "description": "Tool expecting string argument", + "parameters": json.dumps( + { + "type": "object", + "properties": { + "arg1": { + "type": "string", + "description": "A string parameter", + } + }, + "required": ["arg1"], + } + ), + }, + }, + { + "type": "function", + "function": { + "name": "tool_with_number_arg", + "description": "Tool expecting number argument", + "parameters": json.dumps( + { + "type": "object", + "properties": { + "arg1": { + "type": "number", + "description": "A numeric parameter", + } + }, + "required": ["arg1"], + } + ), + }, + }, + ] + + data = [ + { + "tools": tools, + "messages": [ + {"role": "user", "content": "Use both tools"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "tool_with_string_arg", + "arguments": json.dumps({"arg1": "hello"}), + }, + }, + { + "type": "function", + "function": { + "name": "tool_with_number_arg", + "arguments": json.dumps({"arg1": 42}), + }, + }, + ], + }, + ], + } + ] + return Dataset.from_list(data) + + def test_dict_and_str_params_produce_equivalent_output( + self, + tools_dict_params_dataset, + tools_str_params_dataset, + qwen3_instruct_prompt_strategy, + qwen3_tokenizer, + ): + """ + Tests that after tokenization and decoding, the outputs for both + dict and string `parameters` in tools are semantically equivalent. + """ + import re + + processed_dict_params = tools_dict_params_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages", "tools"], + ) + + processed_str_params = tools_str_params_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages", "tools"], + ) + + decoded_dict = qwen3_tokenizer.decode(processed_dict_params[0]["input_ids"]) + decoded_str = qwen3_tokenizer.decode(processed_str_params[0]["input_ids"]) + + # Extract the tool JSON from both outputs + tools_pattern = r"\n(.*?)\n" + + dict_tools_match = re.search(tools_pattern, decoded_dict, re.DOTALL) + str_tools_match = re.search(tools_pattern, decoded_str, re.DOTALL) + + assert dict_tools_match and str_tools_match, ( + "Could not find tools section in output" + ) + + # Parse the JSON and compare as objects (order-independent) + dict_tools_json = json.loads(dict_tools_match.group(1)) + str_tools_json = json.loads(str_tools_match.group(1)) + + # Deep comparison of the tool definitions + assert dict_tools_json == str_tools_json, ( + f"Tool definitions are not equivalent:\n" + f"Dict format: {json.dumps(dict_tools_json, indent=2)}\n" + f"String format: {json.dumps(str_tools_json, indent=2)}" + ) + + # Verify the rest of the structure is the same (excluding the tools JSON part) + # The tools JSON can have different order, so we remove it here. + dict_normalized = re.sub( + r".*?", + "TOOLS_PLACEHOLDER", + decoded_dict, + flags=re.DOTALL, + ) + str_normalized = re.sub( + r".*?", + "TOOLS_PLACEHOLDER", + decoded_str, + flags=re.DOTALL, + ) + + assert dict_normalized == str_normalized, ( + "The overall structure differs between dict and string parameter formats" + ) + + def test_str_params_with_mixed_types_no_error( + self, + tools_mixed_type_params_dataset, + qwen3_instruct_prompt_strategy, + qwen3_tokenizer, + ): + """ + Tests that when different tools have the same parameter name with different types, + JSON string format for parameters doesn't cause casting errors. + """ + processed = tools_mixed_type_params_dataset.map( + qwen3_instruct_prompt_strategy.tokenize_prompt, + batched=True, + remove_columns=["messages", "tools"], + ) + + assert len(processed) == 1 + assert "input_ids" in processed[0] + assert len(processed[0]["input_ids"]) > 0 + + decoded = qwen3_tokenizer.decode(processed[0]["input_ids"]) + + # Check that both tools are present + assert "tool_with_string_arg" in decoded + assert "tool_with_number_arg" in decoded + + # Check that both argument values are present + assert "hello" in decoded + assert "42" in decoded diff --git a/tests/telemetry/__init__.py b/tests/telemetry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/telemetry/conftest.py b/tests/telemetry/conftest.py new file mode 100644 index 000000000..47776ce90 --- /dev/null +++ b/tests/telemetry/conftest.py @@ -0,0 +1,9 @@ +"""Shared pytest fixtures for telemetry tests.""" + +import pytest + + +@pytest.fixture(autouse=True) +def del_track_env(monkeypatch): + monkeypatch.delenv("AXOLOTL_DO_NOT_TRACK", raising=False) + yield diff --git a/tests/telemetry/test_callbacks.py b/tests/telemetry/test_callbacks.py new file mode 100644 index 000000000..97d56a9c6 --- /dev/null +++ b/tests/telemetry/test_callbacks.py @@ -0,0 +1,373 @@ +"""Tests for telemetry callback module.""" + +# pylint: disable=redefined-outer-name + +import time +from unittest.mock import MagicMock, patch + +import pytest +from transformers import TrainerControl, TrainerState, TrainingArguments + +from axolotl.telemetry.callbacks import TIME_SINCE_LAST, TelemetryCallback + + +def calc_expected_metrics(step, last_step, current_time, last_time, start_time=900.0): + """Calculate expected metrics values for tests""" + time_diff = current_time - last_time + step_diff = step - last_step + return { + "steps_per_second": ( + step_diff / time_diff if time_diff > 0 and step_diff > 0 else 0 + ), + "time_since_last_report": time_diff, + "elapsed_time": current_time - start_time, + } + + +@pytest.fixture +def mock_time(): + """Mock time.time() to have predictable values in tests""" + with patch("axolotl.telemetry.callbacks.time") as mock_time: + mock_time.time.return_value = 1000.0 + yield mock_time + + +@pytest.fixture +def mock_telemetry_manager(): + """Create a mock TelemetryManager""" + with patch("axolotl.telemetry.callbacks.TelemetryManager") as mock_manager_class: + mock_manager = MagicMock() + mock_manager_class.get_instance.return_value = mock_manager + yield mock_manager + + +@pytest.fixture +def mock_runtime_metrics_tracker(): + """Create a mock RuntimeMetricsTracker""" + with patch( + "axolotl.telemetry.callbacks.RuntimeMetricsTracker" + ) as mock_tracker_class: + mock_tracker = MagicMock() + # Set up metrics property on the tracker + mock_metrics = MagicMock() + mock_metrics.to_dict.return_value = { + "total_steps": 100, + "peak_cpu_memory_bytes": 1024, + } + mock_tracker.metrics = mock_metrics + + # Make the constructor return our mock + mock_tracker_class.return_value = mock_tracker + yield mock_tracker + + +@pytest.fixture +def training_args(): + """Create a minimal TrainingArguments instance""" + return TrainingArguments(output_dir="./output") + + +@pytest.fixture +def trainer_state(): + """Create a mock TrainerState""" + state = MagicMock(spec=TrainerState) + state.global_step = 10 + state.epoch = 0.5 # halfway through first epoch + state.log_history = [{"loss": 2.5, "learning_rate": 5e-5}] + return state + + +@pytest.fixture +def trainer_control(): + """Create a mock TrainerControl""" + return MagicMock(spec=TrainerControl) + + +# pylint: disable=unused-argument +@pytest.fixture +def callback(mock_telemetry_manager, mock_runtime_metrics_tracker): + """Create a TelemetryCallback instance with mocked dependencies""" + return TelemetryCallback() + + +class TestTelemetryCallback: + """Tests for the TelemetryCallback class.""" + + def test_initialization(self, callback, mock_runtime_metrics_tracker): + """Test callback initialization.""" + assert callback.current_epoch == -1 + assert callback.tracker == mock_runtime_metrics_tracker + assert callback.last_report_step == 0 + assert hasattr(callback, "start_time") + assert hasattr(callback, "last_report_time") + assert callback.report_interval_steps == 100 + + def test_on_train_begin( + self, + callback, + mock_telemetry_manager, + training_args, + trainer_state, + trainer_control, + ): + """Test on_train_begin sends expected event.""" + callback.on_train_begin(training_args, trainer_state, trainer_control) + + mock_telemetry_manager.send_event.assert_called_once_with( + event_type="train-start" + ) + + def test_on_train_end( + self, + callback, + mock_telemetry_manager, + training_args, + trainer_state, + trainer_control, + ): + """Test on_train_end sends expected event with metrics.""" + callback.on_train_end(training_args, trainer_state, trainer_control) + + mock_telemetry_manager.send_event.assert_called_once() + call_args = mock_telemetry_manager.send_event.call_args[1] + + assert call_args["event_type"] == "train-end" + assert "loss" in call_args["properties"] + assert call_args["properties"]["loss"] == 2.5 + assert "learning_rate" in call_args["properties"] + assert call_args["properties"]["learning_rate"] == 5e-5 + + # Check that metrics from RuntimeMetricsTracker are included + assert "total_steps" in call_args["properties"] + assert call_args["properties"]["total_steps"] == 100 + assert "peak_cpu_memory_bytes" in call_args["properties"] + assert call_args["properties"]["peak_cpu_memory_bytes"] == 1024 + + def test_on_epoch_begin( + self, + callback, + mock_runtime_metrics_tracker, + training_args, + trainer_state, + trainer_control, + ): + """Test on_epoch_begin updates epoch counter and calls tracker.""" + initial_epoch = callback.current_epoch + + callback.on_epoch_begin(training_args, trainer_state, trainer_control) + + assert callback.current_epoch == initial_epoch + 1 + mock_runtime_metrics_tracker.start_epoch.assert_called_once_with( + initial_epoch + 1 + ) + + def test_on_epoch_end( + self, + callback, + mock_runtime_metrics_tracker, + training_args, + trainer_state, + trainer_control, + ): + """Test on_epoch_end calls tracker.""" + # Set current epoch + callback.current_epoch = 2 + + callback.on_epoch_end(training_args, trainer_state, trainer_control) + + mock_runtime_metrics_tracker.end_epoch.assert_called_once_with(2) + + def test_on_step_end_no_report( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, + training_args, + trainer_state, + trainer_control, + ): + """Test on_step_end updates tracker but doesn't report if criteria not met.""" + # Set up state to avoid reporting + trainer_state.global_step = 42 # Not divisible by report_interval_steps + callback.last_report_step = 41 # Just 1 step since last report + callback.last_report_time = time.time() # Just now + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should update tracker + mock_runtime_metrics_tracker.update_step.assert_called_once_with(42) + + # Should not send telemetry + mock_telemetry_manager.send_event.assert_not_called() + + # Should not update last report time/step + assert callback.last_report_step == 41 + + def test_on_step_end_report_interval_steps( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, + mock_time, + training_args, + trainer_state, + trainer_control, + ): + """Test on_step_end reports when step interval is reached.""" + # Set up state with clear values + current_step = 100 # Exactly matches report_interval_steps + last_step = 0 + start_time = 900.0 + current_time = 1000.0 + time_diff = current_time - start_time # 100 seconds + + # Configure state and callback + trainer_state.global_step = current_step + callback.report_interval_steps = 100 + callback.last_report_step = last_step + callback.start_time = start_time + callback.last_report_time = start_time + + # Mock time.time() to return consistent values + mock_time.time.return_value = current_time + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should update tracker + mock_runtime_metrics_tracker.update_step.assert_called_once_with(current_step) + mock_runtime_metrics_tracker.update_memory_metrics.assert_called_once() + + # Should send telemetry + mock_telemetry_manager.send_event.assert_called_once() + call_args = mock_telemetry_manager.send_event.call_args[1] + assert call_args["event_type"] == "train-progress" + + # Properties should include expected values + props = call_args["properties"] + assert props["step"] == current_step + assert props["elapsed_time"] == time_diff # 1000 - 900 = 100 + assert props["time_since_last_report"] == time_diff # 1000 - 900 = 100 + assert props["steps_per_second"] == 1.0 # 100 steps / 100 seconds + + # Should update last report time/step + assert callback.last_report_step == current_step + assert callback.last_report_time == current_time + + def test_on_step_end_report_time_elapsed( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, # pylint: disable=unused-argument + mock_time, + training_args, + trainer_state, + trainer_control, + ): + """Test on_step_end reports when enough time has elapsed.""" + # Set up state with clear values + current_step = 120 + last_step = 10 + start_time = 900.0 + current_time = 1000.0 + time_diff = TIME_SINCE_LAST + 1 # Just over the threshold + + # Configure state and callback + trainer_state.global_step = current_step + callback.report_interval_steps = 100 + callback.last_report_step = last_step + callback.start_time = start_time + callback.last_report_time = current_time - time_diff + + # Mock time.time() to return consistent values + mock_time.time.return_value = current_time + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should send telemetry + mock_telemetry_manager.send_event.assert_called_once() + + # Properties should include expected values + props = mock_telemetry_manager.send_event.call_args[1]["properties"] + expected_metrics = calc_expected_metrics( + current_step, last_step, current_time, current_time - time_diff, start_time + ) + assert props["steps_per_second"] == expected_metrics["steps_per_second"] + assert ( + props["time_since_last_report"] + == expected_metrics["time_since_last_report"] + ) + + def test_on_step_end_first_step( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, # pylint: disable=unused-argument + mock_time, + training_args, + trainer_state, + trainer_control, + ): + """Test on_step_end always reports on first step.""" + # Set up state with clear values + current_step = 1 # First step + last_step = 0 + start_time = 900.0 + current_time = 1000.0 + last_report_time = 999.0 # Just 1 second ago + + # Configure state and callback + trainer_state.global_step = current_step + callback.report_interval_steps = 100 + callback.last_report_step = last_step + callback.start_time = start_time + callback.last_report_time = last_report_time + + # Mock time.time() to return consistent values + mock_time.time.return_value = current_time + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should send telemetry even though not much time has passed + mock_telemetry_manager.send_event.assert_called_once() + + # Properties should include expected values for first step + props = mock_telemetry_manager.send_event.call_args[1]["properties"] + assert props["step"] == current_step + expected_metrics = calc_expected_metrics( + current_step, last_step, current_time, last_report_time, start_time + ) + assert props["steps_per_second"] == expected_metrics["steps_per_second"] + + def test_log_history_empty( + self, + callback, + mock_telemetry_manager, + mock_runtime_metrics_tracker, # pylint: disable=unused-argument + mock_time, + training_args, + trainer_state, + trainer_control, + ): + """Test handling of empty log history.""" + # Set up state with clear values + current_step = 1 + start_time = 900.0 + current_time = 1000.0 + + # Configure state and callback + trainer_state.global_step = current_step + trainer_state.log_history = [] + callback.start_time = start_time + + # Mock time.time() to return consistent values + mock_time.time.return_value = current_time + + callback.on_step_end(training_args, trainer_state, trainer_control) + + # Should still send telemetry + mock_telemetry_manager.send_event.assert_called_once() + + # Properties should have default values for missing log data + props = mock_telemetry_manager.send_event.call_args[1]["properties"] + assert props["loss"] == 0 + assert props["learning_rate"] == 0 diff --git a/tests/telemetry/test_errors.py b/tests/telemetry/test_errors.py new file mode 100644 index 000000000..2f0510b21 --- /dev/null +++ b/tests/telemetry/test_errors.py @@ -0,0 +1,341 @@ +"""Tests for telemetry error utilities""" + +# pylint: disable=redefined-outer-name + +from unittest.mock import MagicMock, patch + +import pytest + +from axolotl.telemetry.errors import sanitize_stack_trace, send_errors + + +@pytest.fixture(autouse=True) +def reset_error_flag(monkeypatch): + """Reset ERROR_HANDLED flag using monkeypatch""" + import axolotl.telemetry.errors + + monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False) + yield + monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False) + + +@pytest.fixture +def example_stack_trace(): + """Provide a sample stack trace with mixed paths""" + return """Traceback (most recent call last): + File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main + trainer = get_trainer(cfg) + File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer + model = get_model(cfg, tokenizer) + File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model + raise ValueError("Model path not found") +ValueError: Model path not found +""" + + +@pytest.fixture +def windows_stack_trace(): + """Provide a sample stack trace with Windows paths""" + return """Traceback (most recent call last): + File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main + trainer = get_trainer(cfg) + File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer + model = get_model(cfg, tokenizer) + File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained + raise ValueError(f"Unrecognized configuration class {config.__class__}") +ValueError: Unrecognized configuration class +""" + + +@pytest.fixture +def mixed_stack_trace(): + """Provide a sample stack trace with both axolotl and non-axolotl paths""" + return """Traceback (most recent call last): + File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main + trainer = get_trainer(cfg) + File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train + self._inner_training_loop() + File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop + super()._inner_training_loop() + File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__ + data = self._next_data() +RuntimeError: CUDA out of memory +""" + + +@pytest.fixture +def venv_stack_trace(): + """Provide a sample stack trace with virtual environment paths""" + return """Traceback (most recent call last): + File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train + self._inner_training_loop() + File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop + self.accelerator.backward(loss) + File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward + self.scaler.scale(loss).backward(**kwargs) + File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward + torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) +RuntimeError: CUDA out of memory +""" + + +@pytest.fixture +def dist_packages_stack_trace(): + """Provide a sample stack trace with dist-packages paths""" + return """Traceback (most recent call last): + File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__ + data = self._next_data() + File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data + data = self._dataset_fetcher.fetch(index) + File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch + data = [self.dataset[idx] for idx in possibly_batched_index] + File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__ + raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.") +IndexError: Index 10000 out of range for dataset of length 9832. +""" + + +@pytest.fixture +def project_stack_trace(): + """Provide a sample stack trace from a project directory (not a virtual env)""" + return """Traceback (most recent call last): + File "/home/user/projects/myproject/run.py", line 25, in + main() + File "/home/user/projects/myproject/src/cli.py", line 45, in main + app.run() + File "/home/user/projects/myproject/src/app.py", line 102, in run + raise ValueError("Configuration missing") +ValueError: Configuration missing +""" + + +def test_sanitize_stack_trace(example_stack_trace): + """Test that sanitize_stack_trace properly preserves axolotl paths""" + sanitized = sanitize_stack_trace(example_stack_trace) + + # Check that personal paths are removed + assert "/home/user" not in sanitized + assert ".local/lib/python3.9" not in sanitized + + # Check that site-packages is preserved + assert "site-packages/axolotl/cli/train.py" in sanitized + assert "site-packages/axolotl/train.py" in sanitized + assert "site-packages/axolotl/utils/models.py" in sanitized + + # Check that error message is preserved + assert "ValueError: Model path not found" in sanitized + + +def test_sanitize_windows_paths(windows_stack_trace): + """Test that sanitize_stack_trace handles Windows paths""" + sanitized = sanitize_stack_trace(windows_stack_trace) + + # Check that personal paths are removed + assert "C:\\Users\\name" not in sanitized + assert "AppData\\Local\\Programs\\Python" not in sanitized + + # Check that both axolotl and transformers packages are preserved + assert ( + "site-packages\\axolotl\\cli\\train.py" in sanitized + or "site-packages/axolotl/cli/train.py" in sanitized + ) + assert ( + "site-packages\\axolotl\\train.py" in sanitized + or "site-packages/axolotl/train.py" in sanitized + ) + assert ( + "site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized + or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized + ) + + # Check that error message is preserved + assert "ValueError: Unrecognized configuration class" in sanitized + + +def test_sanitize_mixed_paths(mixed_stack_trace): + """Test that sanitize_stack_trace preserves all package paths""" + sanitized = sanitize_stack_trace(mixed_stack_trace) + + # Check that all package paths are preserved + assert "site-packages/axolotl/cli/train.py" in sanitized + assert "site-packages/transformers/trainer.py" in sanitized + assert "site-packages/axolotl/utils/trainer.py" in sanitized + assert "site-packages/torch/utils/data/dataloader.py" in sanitized + + # Check that error message is preserved + assert "RuntimeError: CUDA out of memory" in sanitized + + +def test_sanitize_venv_paths(venv_stack_trace): + """Test that sanitize_stack_trace preserves virtual environment package paths""" + sanitized = sanitize_stack_trace(venv_stack_trace) + + # Check that personal paths are removed + assert "/home/user/venv" not in sanitized + + # Check that all package paths are preserved + assert "site-packages/transformers/trainer.py" in sanitized + assert "site-packages/accelerate/accelerator.py" in sanitized + assert "site-packages/torch/_tensor.py" in sanitized + + # Check that error message is preserved + assert "RuntimeError: CUDA out of memory" in sanitized + + +def test_sanitize_dist_packages(dist_packages_stack_trace): + """Test that sanitize_stack_trace preserves dist-packages paths""" + sanitized = sanitize_stack_trace(dist_packages_stack_trace) + + # Check that system paths are removed + assert "/usr/local/lib/python3.8" not in sanitized + + # Check that all package paths are preserved + assert "dist-packages/torch/utils/data/dataloader.py" in sanitized + assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized + assert "dist-packages/datasets/arrow_dataset.py" in sanitized + + # Check that error message is preserved + assert ( + "IndexError: Index 10000 out of range for dataset of length 9832." in sanitized + ) + + +def test_sanitize_project_paths(project_stack_trace): + """Test handling of project paths (non-virtual env)""" + sanitized = sanitize_stack_trace(project_stack_trace) + + # Check that personal paths are removed + assert "/home/user/projects" not in sanitized + + # For non-package paths, we should at least preserve the filename + assert "run.py" in sanitized + assert "cli.py" in sanitized + assert "app.py" in sanitized + + # Check that error message is preserved + assert "ValueError: Configuration missing" in sanitized + + +@pytest.fixture +def mock_telemetry_manager(): + """Create a mock TelemetryManager""" + with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class: + mock_manager = MagicMock() + mock_manager.enabled = True + mock_manager_class.get_instance.return_value = mock_manager + yield mock_manager + + +def test_send_errors_successful_execution(mock_telemetry_manager): + """Test that send_errors doesn't send telemetry for successful function execution""" + + @send_errors + def test_func(): + return "success" + + result = test_func() + assert result == "success" + mock_telemetry_manager.send_event.assert_not_called() + + +def test_send_errors_with_exception(mock_telemetry_manager): + """Test that send_errors sends telemetry when an exception occurs""" + test_error = ValueError("Test error") + + @send_errors + def test_func(): + raise test_error + + with pytest.raises(ValueError) as excinfo: + test_func() + + assert excinfo.value == test_error + mock_telemetry_manager.send_event.assert_called_once() + + # Check that the error info was passed correctly + call_args = mock_telemetry_manager.send_event.call_args[1] + assert "test_func-error" in call_args["event_type"] + assert "Test error" in call_args["properties"]["exception"] + assert "stack_trace" in call_args["properties"] + + +def test_send_errors_nested_calls(mock_telemetry_manager): + """Test that send_errors only sends telemetry once for nested decorated functions""" + + @send_errors + def inner_func(): + raise ValueError("Inner error") + + @send_errors + def outer_func(): + return inner_func() + + with pytest.raises(ValueError): + outer_func() + + # Telemetry should be sent only once for the inner function + assert mock_telemetry_manager.send_event.call_count == 1 + call_args = mock_telemetry_manager.send_event.call_args[1] + assert "inner_func-error" in call_args["event_type"] + + +def test_send_errors_telemetry_disable(): + """Test that send_errors doesn't attempt to send telemetry when disabled""" + + with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class: + mock_manager = MagicMock() + mock_manager.enabled = False + mock_manager_class.get_instance.return_value = mock_manager + + @send_errors + def test_func(): + raise ValueError("Test error") + + with pytest.raises(ValueError): + test_func() + + mock_manager.send_event.assert_not_called() + + +def test_error_handled_reset(): + """Test that ERROR_HANDLED flag is properly reset""" + with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class: + # Create and configure the mock manager + mock_manager = MagicMock() + mock_manager.enabled = True + mock_manager_class.get_instance.return_value = mock_manager + + from axolotl.telemetry.errors import ERROR_HANDLED + + @send_errors + def test_func(): + raise ValueError("Test error") + + assert not ERROR_HANDLED + + with pytest.raises(ValueError): + test_func() + + from axolotl.telemetry.errors import ERROR_HANDLED + + assert ERROR_HANDLED + + +def test_module_path_resolution(mock_telemetry_manager): + """Test that the module path is correctly resolved for the event type""" + import inspect + + current_module = inspect.getmodule(test_module_path_resolution).__name__ + + @send_errors + def test_func(): + raise ValueError("Test error") + + with pytest.raises(ValueError): + test_func() + + assert mock_telemetry_manager.send_event.called + event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"] + + expected_event_type = f"{current_module}.test_func-error" + assert expected_event_type == event_type diff --git a/tests/telemetry/test_manager.py b/tests/telemetry/test_manager.py new file mode 100644 index 000000000..2eeae2f11 --- /dev/null +++ b/tests/telemetry/test_manager.py @@ -0,0 +1,275 @@ +"""Tests for TelemetryManager class and utilities""" + +# pylint: disable=redefined-outer-name,protected-access + +import os +from unittest.mock import patch + +import pytest +import yaml + +from axolotl.telemetry.manager import TelemetryManager + + +@pytest.fixture +def mock_whitelist(tmp_path): + """Create a temporary whitelist file for testing""" + whitelist_content = { + "organizations": ["meta-llama", "mistralai"], + } + whitelist_file = tmp_path / "whitelist.yaml" + with open(whitelist_file, "w", encoding="utf-8") as f: + yaml.dump(whitelist_content, f) + + return str(whitelist_file) + + +@pytest.fixture +def telemetry_manager_class(): + """Reset the TelemetryManager singleton between tests""" + original_instance = TelemetryManager._instance + original_initialized = TelemetryManager._initialized + TelemetryManager._instance = None + TelemetryManager._initialized = False + yield TelemetryManager + TelemetryManager._instance = original_instance + TelemetryManager._initialized = original_initialized + + +@pytest.fixture +def manager(telemetry_manager_class, mock_whitelist): + """Create a TelemetryManager instance with mocked dependencies""" + with ( + patch("posthog.capture"), + patch("posthog.flush"), + patch("time.sleep"), + patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist), + patch.dict(os.environ, {"RANK": "0"}), + ): + manager = telemetry_manager_class() + # Manually enable for most tests + manager.enabled = True + return manager + + +def test_singleton_instance(telemetry_manager_class): + """Test that TelemetryManager is a singleton""" + with ( + patch("posthog.capture"), + patch("time.sleep"), + patch.dict(os.environ, {"RANK": "0"}), + ): + first = telemetry_manager_class() + second = telemetry_manager_class() + assert first is second + assert telemetry_manager_class.get_instance() is first + + +def test_telemetry_enabled_by_default(telemetry_manager_class): + """Test that telemetry is enabled by default (opt-out)""" + with ( + patch.dict(os.environ, {"RANK": "0"}, clear=True), + patch("time.sleep"), + patch("logging.Logger.info"), + ): + manager = telemetry_manager_class() + assert manager.enabled + + +def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class): + """Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0""" + with ( + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}), + patch("time.sleep"), + ): + manager = telemetry_manager_class() + assert manager.enabled + + +def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class): + """Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1""" + with ( + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}), + patch("time.sleep"), + ): + manager = telemetry_manager_class() + assert not manager.enabled + + +def test_telemetry_disabled_with_do_not_track(telemetry_manager_class): + """Test that telemetry is disabled when DO_NOT_TRACK=1""" + with ( + patch.dict( + os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"} + ), + patch("time.sleep"), + ): + manager = telemetry_manager_class() + assert not manager.enabled + + +def test_telemetry_disabled_for_non_main_process(telemetry_manager_class): + """Test that telemetry is disabled for non-main processes""" + with ( + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}), + patch("time.sleep"), + ): + manager = telemetry_manager_class() + assert not manager.enabled + + +def test_opt_in_info_displayed(telemetry_manager_class): + """Test that opt-in info is displayed when telemetry is not configured""" + with ( + patch.dict(os.environ, {"RANK": "0"}, clear=True), + patch("logging.Logger.warning") as mock_warning, + patch("time.sleep"), + ): + telemetry_manager_class() + assert any( + "Telemetry is now enabled by default" in str(call) + for call in mock_warning.call_args_list + ) + + +def test_is_whitelisted(telemetry_manager_class, mock_whitelist): + """Test org whitelist functionality""" + with ( + patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist), + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), + ): + manager = telemetry_manager_class() + + # Should match organizations from the mock whitelist + assert manager._is_whitelisted("meta-llama/llama-7b") + assert manager._is_whitelisted("mistralai/mistral-7b-instruct") + # Should not match + assert not manager._is_whitelisted("unknown/model") + # Should handle case insensitively + assert manager._is_whitelisted("META-LLAMA/Llama-7B") + # Should handle empty input + assert not manager._is_whitelisted("") + + +def test_system_info_collection(manager): + """Test system information collection""" + system_info = manager._get_system_info() + + # Check essential keys + assert "os" in system_info + assert "python_version" in system_info + assert "cpu_count" in system_info + assert "memory_total" in system_info + assert "accelerator_count" in system_info + + +def test_send_event(telemetry_manager_class): + """Test basic event sending""" + with ( + patch("posthog.capture") as mock_capture, + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), + ): + manager = telemetry_manager_class() + + # Test with clean properties (no PII) + manager.send_event("test_event", {"key": "value"}) + assert mock_capture.called + assert mock_capture.call_args[1]["event"] == "test_event" + assert mock_capture.call_args[1]["properties"] == {"key": "value"} + assert mock_capture.call_args[1]["distinct_id"] == manager.run_id + + # Test with default properties (None) + mock_capture.reset_mock() + manager.send_event("simple_event") + assert mock_capture.called + assert mock_capture.call_args[1]["properties"] == {} + + +def test_send_system_info(telemetry_manager_class): + """Test sending system info""" + with ( + patch("posthog.capture") as mock_capture, + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), + ): + manager = telemetry_manager_class() + manager.send_system_info() + assert mock_capture.called + assert mock_capture.call_args[1]["event"] == "system-info" + assert mock_capture.call_args[1]["properties"] == manager.system_info + + +def test_redacted_properties(telemetry_manager_class): + """Test path redaction in send_event method""" + with ( + patch("posthog.capture") as mock_capture, + patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0"}), + ): + manager = telemetry_manager_class() + # Test with properties containing various paths and non-paths + test_properties = { + "filepath": "/home/user/sensitive/data.txt", + "windows_path": "C:\\Users\\name\\Documents\\project\\file.py", + "output_dir": "/var/lib/data", + "path_to_model": "models/llama/7b", + "message": "Training started", # Should not be redacted + "metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted + "base_model": "models/local_model", + "nested": { + "model_path": "/models/my_model", + "root_dir": "/home/user/projects", + "stats": {"steps": 1000, "epochs": 3}, # Should not be redacted + }, + } + + manager.send_event("test_event", test_properties) + + # Verify the call was made + assert mock_capture.called + + # Get the sanitized properties that were sent + sanitized = mock_capture.call_args[1]["properties"] + + # Check that path-like and base_model keys were redacted + assert sanitized["filepath"] == "[REDACTED]" + assert sanitized["windows_path"] == "[REDACTED]" + assert sanitized["path_to_model"] == "[REDACTED]" + assert sanitized["base_model"] == "[REDACTED]" + + # Check that non-path values were preserved + assert sanitized["message"] == "Training started" + assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95} + + # Check nested structure handling + assert sanitized["nested"]["model_path"] == "[REDACTED]" + assert sanitized["nested"]["root_dir"] == "[REDACTED]" + assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3} + + +def test_disable_telemetry(manager): + """Test that disabled telemetry doesn't send events""" + with patch("posthog.capture") as mock_capture: + manager.enabled = False + manager.send_event("test_event") + assert not mock_capture.called + + +def test_exception_handling_during_send(manager): + """Test that exceptions in PostHog are handled gracefully""" + with ( + patch("posthog.capture", side_effect=Exception("Test error")), + patch("logging.Logger.warning") as mock_warning, + ): + manager.send_event("test_event") + warning_logged = False + for call in mock_warning.call_args_list: + if "Failed to send telemetry event" in str(call): + warning_logged = True + break + assert warning_logged + + +def test_shutdown(manager): + """Test shutdown behavior""" + with patch("posthog.shutdown") as mock_shutdown: + manager.shutdown() + assert mock_shutdown.called diff --git a/tests/telemetry/test_runtime_metrics.py b/tests/telemetry/test_runtime_metrics.py new file mode 100644 index 000000000..c8916e072 --- /dev/null +++ b/tests/telemetry/test_runtime_metrics.py @@ -0,0 +1,357 @@ +"""Tests for runtime metrics telemetry module""" + +# pylint: disable=redefined-outer-name + +from unittest.mock import MagicMock, patch + +import pytest + +from axolotl.telemetry.runtime_metrics import RuntimeMetrics, RuntimeMetricsTracker + + +@pytest.fixture +def mock_time(): + """Mock time.time() to have predictable values in tests""" + with patch("time.time") as mock_time: + # Start with time 1000.0 and increment by 10 seconds on each call + times = [1000.0 + i * 10 for i in range(10)] + mock_time.side_effect = times + yield mock_time + + +@pytest.fixture +def mock_telemetry_manager(): + """Create a mock TelemetryManager""" + with patch( + "axolotl.telemetry.runtime_metrics.TelemetryManager" + ) as mock_manager_class: + mock_manager = MagicMock() + mock_manager.enabled = True + mock_manager_class.get_instance.return_value = mock_manager + yield mock_manager + + +@pytest.fixture +def mock_psutil(): + """Mock psutil for memory information""" + with patch("axolotl.telemetry.runtime_metrics.psutil") as mock_psutil: + mock_process = MagicMock() + mock_memory_info = MagicMock() + # Set initial memory to 1GB + mock_memory_info.rss = 1024 * 1024 * 1024 + mock_process.memory_info.return_value = mock_memory_info + mock_psutil.Process.return_value = mock_process + yield mock_psutil + + +@pytest.fixture +def mock_torch(): + """Mock torch.cuda functions""" + with patch("axolotl.telemetry.runtime_metrics.torch") as mock_torch: + mock_torch.cuda.is_available.return_value = True + mock_torch.cuda.device_count.return_value = 2 + + # Mock memory allocated per device (1GB for device 0, 2GB for device 1) + mock_torch.cuda.memory_allocated.side_effect = ( + lambda device: (device + 1) * 1024 * 1024 * 1024 + ) + + yield mock_torch + + +class TestRuntimeMetrics: + """Tests for RuntimeMetrics class.""" + + def test_initialization(self): + """Test RuntimeMetrics initialization.""" + metrics = RuntimeMetrics(start_time=1000.0) + + assert metrics.start_time == 1000.0 + assert metrics.epoch_start_times == {} + assert metrics.epoch_end_times == {} + assert metrics.peak_gpu_memory == {} + assert metrics.total_steps == 0 + assert metrics.current_epoch == 0 + assert metrics.current_step == 0 + assert metrics.peak_cpu_memory == 0 + + def test_elapsed_time(self, mock_time): + """Test elapsed_time property.""" + metrics = RuntimeMetrics(start_time=1000.0) + + # Mock time.time() to return 1050.0 + mock_time.side_effect = [1050.0] + + assert metrics.elapsed_time == 50.0 + + def test_epoch_time(self): + """Test epoch_time method.""" + metrics = RuntimeMetrics(start_time=1000.0) + + # No epoch data + assert metrics.epoch_time(0) is None + + # Add epoch start but no end + metrics.epoch_start_times[0] = 1000.0 + assert metrics.epoch_time(0) is None + + # Add epoch end + metrics.epoch_end_times[0] = 1060.0 + assert metrics.epoch_time(0) == 60.0 + + def test_average_epoch_time(self): + """Test average_epoch_time method.""" + metrics = RuntimeMetrics(start_time=1000.0) + + # No completed epochs + assert metrics.average_epoch_time() is None + + # Add one completed epoch + metrics.epoch_start_times[0] = 1000.0 + metrics.epoch_end_times[0] = 1060.0 + assert metrics.average_epoch_time() == 60.0 + + # Add second completed epoch + metrics.epoch_start_times[1] = 1060.0 + metrics.epoch_end_times[1] = 1140.0 # 80 seconds + assert metrics.average_epoch_time() == 70.0 # Average of 60 and 80 + + # Add incomplete epoch (should not affect average) + metrics.epoch_start_times[2] = 1140.0 + assert metrics.average_epoch_time() == 70.0 + + def test_steps_per_second(self, mock_time): + """Test steps_per_second method.""" + metrics = RuntimeMetrics(start_time=1000.0) + + # No steps - first call to time.time() + mock_time.side_effect = None + mock_time.return_value = 1050.0 + assert metrics.steps_per_second() is None + + # Add steps - second call to time.time() + metrics.total_steps = 100 + mock_time.return_value = 1050.0 # Keep same time for consistent result + assert metrics.steps_per_second() == 2.0 # 100 steps / 50 seconds + + def test_to_dict_basic(self, mock_time): + """Test to_dict method with basic metrics.""" + metrics = RuntimeMetrics(start_time=1000.0) + metrics.total_steps = 100 + metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 # 2GB + + # Mock elapsed_time + mock_time.side_effect = None + mock_time.return_value = 1050.0 + + result = metrics.to_dict() + + assert result["total_time_seconds"] == 50.0 + assert result["total_steps"] == 100 + assert result["steps_per_second"] == 2.0 + assert result["epochs_completed"] == 0 + assert result["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024 + assert "epoch_times" not in result + assert "gpu_memory" not in result + + def test_to_dict_with_epochs(self, mock_time): + """Test to_dict method with epoch data.""" + metrics = RuntimeMetrics(start_time=1000.0) + metrics.total_steps = 100 + + # Add epoch data + metrics.epoch_start_times[0] = 1000.0 + metrics.epoch_end_times[0] = 1060.0 + metrics.epoch_start_times[1] = 1060.0 + metrics.epoch_end_times[1] = 1140.0 + + # Mock elapsed_time + mock_time.side_effect = None + mock_time.return_value = 1150.0 + + result = metrics.to_dict() + + assert "epoch_times" in result + assert result["epoch_times"]["epoch_0_seconds"] == 60.0 + assert result["epoch_times"]["epoch_1_seconds"] == 80.0 + assert result["average_epoch_time_seconds"] == 70.0 + + def test_to_dict_with_gpu_memory(self, mock_time): + """Test to_dict method with GPU memory data.""" + metrics = RuntimeMetrics(start_time=1000.0) + metrics.peak_gpu_memory = { + 0: 1 * 1024 * 1024 * 1024, # 1GB + 1: 2 * 1024 * 1024 * 1024, # 2GB + } + + # Mock elapsed_time + mock_time.side_effect = [1050.0] + + result = metrics.to_dict() + + assert "gpu_memory" in result + assert result["gpu_memory"]["gpu_0_peak_memory_bytes"] == 1 * 1024 * 1024 * 1024 + assert result["gpu_memory"]["gpu_1_peak_memory_bytes"] == 2 * 1024 * 1024 * 1024 + + +class TestRuntimeMetricsTracker: + """Tests for RuntimeMetricsTracker class.""" + + # pylint: disable=unused-argument + def test_initialization(self, mock_time, mock_telemetry_manager): + """Test RuntimeMetricsTracker initialization.""" + tracker = RuntimeMetricsTracker() + + assert isinstance(tracker.metrics, RuntimeMetrics) + assert tracker.metrics.start_time == 1000.0 # First value from mock_time + + # pylint: disable=unused-argument + def test_start_epoch( + self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager + ): + """Test start_epoch method.""" + tracker = RuntimeMetricsTracker() + + # Reset mock_time to control next value + mock_time.side_effect = [1010.0] + + tracker.start_epoch(0) + + assert tracker.metrics.current_epoch == 0 + assert tracker.metrics.epoch_start_times[0] == 1010.0 + + # Verify memory metrics were updated + assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024 + assert 0 in tracker.metrics.peak_gpu_memory + assert 1 in tracker.metrics.peak_gpu_memory + + # pylint: disable=unused-argument + def test_end_epoch(self, mock_time, mock_telemetry_manager): + """Test end_epoch method.""" + tracker = RuntimeMetricsTracker() + + # Start epoch 0 + mock_time.side_effect = [1010.0] + tracker.start_epoch(0) + + # End epoch 0 + mock_time.side_effect = [1060.0] + tracker.end_epoch(0) + + assert 0 in tracker.metrics.epoch_end_times + assert tracker.metrics.epoch_end_times[0] == 1060.0 + + # pylint: disable=unused-argument + def test_update_step( + self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager + ): + """Test update_step method.""" + tracker = RuntimeMetricsTracker() + + # Update step to a non-multiple of 100 + tracker.update_step(42) + + assert tracker.metrics.current_step == 42 + assert tracker.metrics.total_steps == 1 + + # Memory metrics should not be updated for non-multiple of 100 + assert tracker.metrics.peak_cpu_memory == 0 + + # Update step to a multiple of 100 + tracker.update_step(100) + + assert tracker.metrics.current_step == 100 + assert tracker.metrics.total_steps == 2 + + # Memory metrics should be updated for multiple of 100 + assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024 + + # pylint: disable=unused-argument + def test_update_memory_metrics( + self, mock_psutil, mock_torch, mock_telemetry_manager + ): + """Test update_memory_metrics method.""" + tracker = RuntimeMetricsTracker() + + # Initial memory state + assert tracker.metrics.peak_cpu_memory == 0 + assert tracker.metrics.peak_gpu_memory == {} + + # Update memory metrics + tracker.update_memory_metrics() + + # Verify CPU memory + assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024 + + # Verify GPU memory + assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024 + + # Change mocked memory values to be lower + mock_process = mock_psutil.Process.return_value + mock_memory_info = mock_process.memory_info.return_value + mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB + + mock_torch.cuda.memory_allocated.side_effect = ( + lambda device: (device + 0.5) * 1024 * 1024 * 1024 + ) + + # Update memory metrics again + tracker.update_memory_metrics() + + # Peak values should not decrease + assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024 + + # Change mocked memory values to be higher + mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB + + mock_torch.cuda.memory_allocated.side_effect = ( + lambda device: (device + 2) * 1024 * 1024 * 1024 + ) + + # Update memory metrics again + tracker.update_memory_metrics() + + # Peak values should increase + assert tracker.metrics.peak_cpu_memory == 2 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[0] == 2 * 1024 * 1024 * 1024 + assert tracker.metrics.peak_gpu_memory[1] == 3 * 1024 * 1024 * 1024 + + # pylint: disable=unused-argument + def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_telemetry_manager): + """Test get_memory_metrics method.""" + tracker = RuntimeMetricsTracker() + + # Set peak memory values + tracker.metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 + tracker.metrics.peak_gpu_memory = { + 0: 3 * 1024 * 1024 * 1024, + 1: 4 * 1024 * 1024 * 1024, + } + + # Get memory metrics + memory_metrics = tracker.get_memory_metrics() + + # Verify CPU memory + assert ( + memory_metrics["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024 + ) # Current value from mock + assert ( + memory_metrics["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024 + ) # Peak value we set + + # Verify GPU memory + assert ( + memory_metrics["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024 + ) # Current value from mock + assert ( + memory_metrics["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024 + ) # Peak value we set + assert ( + memory_metrics["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024 + ) # Current value from mock + assert ( + memory_metrics["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024 + ) # Peak value we set diff --git a/tests/utils/callbacks/test_dynamic_checkpoint.py b/tests/utils/callbacks/test_dynamic_checkpoint.py new file mode 100644 index 000000000..1fd792102 --- /dev/null +++ b/tests/utils/callbacks/test_dynamic_checkpoint.py @@ -0,0 +1,389 @@ +"""Unit tests for dynamic checkpoint callback""" + +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch + +from axolotl.utils.callbacks.dynamic_checkpoint import ( + DEFAULT_TRIGGER_FILENAME, + DynamicCheckpointCallback, +) +from axolotl.utils.dict import DictDefault + + +class TestDynamicCheckpointCallbackInit: + """Test callback initialization""" + + def test_callback_disabled_by_default(self): + """Test that callback is disabled when config.enabled=False""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": False}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.enabled is False + + def test_callback_disabled_when_none(self): + """Test that callback is disabled when dynamic_checkpoint is None""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": None, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.enabled is False + + def test_callback_enabled_when_configured(self): + """Test that callback is enabled when config.enabled=True""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 10}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.enabled is True + assert callback.check_interval == 10 + + def test_default_trigger_filename(self): + """Test that default trigger filename is used""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 10}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.trigger_filename == DEFAULT_TRIGGER_FILENAME + + def test_check_interval_default(self): + """Test default check interval""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + assert callback.check_interval == 100 # Default from schema + + +class TestDynamicCheckpointFileDetection: + """Test file-based checkpoint triggering""" + + def test_trigger_file_detected_and_deleted(self): + """Test that trigger file is detected and deleted""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + assert trigger_file.exists() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + result = callback.on_step_end(args, state, control) + + assert not trigger_file.exists() + assert result.should_save is True + + def test_check_interval_honored(self): + """Test that file is only checked at check_interval steps""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 10}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + args = Mock(output_dir=tmpdir) + control = Mock(should_save=False) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + # Step 5 - shouldn't check (not divisible by 10) + state = Mock(global_step=5) + result = callback.on_step_end(args, state, control) + assert trigger_file.exists() # Still there + assert result.should_save is False + + # Step 10 - should check + state = Mock(global_step=10) + result = callback.on_step_end(args, state, control) + assert not trigger_file.exists() # Deleted + assert result.should_save is True + + def test_no_file_no_trigger(self): + """Test that no trigger occurs when file doesn't exist""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + result = callback.on_step_end(args, state, control) + + assert result.should_save is False + + def test_file_deletion_error_handling(self): + """Test that file deletion errors are handled gracefully""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + with patch.object( + Path, "unlink", side_effect=OSError("Permission denied") + ): + result = callback.on_step_end(args, state, control) + + assert result.should_save is True + + +class TestDynamicCheckpointMultiGPU: + """Test multi-GPU synchronization""" + + def test_only_rank_0_checks_file(self): + """Test that only rank 0 checks filesystem in multi-GPU setup""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + # Rank 1 (not main process) - shouldn't check file + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=False, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=True, + ): + with patch("torch.distributed.broadcast") as mock_broadcast: + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.barrier" + ): + mock_tensor = MagicMock() + mock_tensor.item.return_value = 0 + with patch("torch.tensor", return_value=mock_tensor): + callback.on_step_end(args, state, control) + + assert trigger_file.exists() + # Broadcast should have been called + assert mock_broadcast.called + + def test_broadcast_synchronization(self): + """Test that trigger decision is broadcasted to all ranks""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": True, "check_interval": 1}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + # Rank 0 detects file + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=True, + ): + with patch("torch.distributed.broadcast") as mock_broadcast: + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.barrier" + ) as mock_barrier: + mock_tensor = MagicMock() + mock_tensor.item.return_value = 1 + with patch("torch.tensor", return_value=mock_tensor): + with patch("torch.cuda.current_device", return_value=0): + result = callback.on_step_end(args, state, control) + + assert mock_broadcast.called + assert mock_barrier.called + # All ranks should trigger + assert result.should_save is True + + +class TestDynamicCheckpointSignalHandling: + """Test signal-based checkpoint triggering""" + + def test_signal_trigger_via_callback(self): + """Test that signal flag triggers checkpoint save""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": { + "enabled": True, + "check_interval": 1, + "enable_signal": True, + }, + "output_dir": tmpdir, + } + ) + + with patch("signal.signal"): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.hasattr", + return_value=True, + ): + callback = DynamicCheckpointCallback(cfg) + + callback.should_save_checkpoint = True + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_main_process", + return_value=True, + ): + with patch( + "axolotl.utils.callbacks.dynamic_checkpoint.is_distributed", + return_value=False, + ): + result = callback.on_step_end(args, state, control) + + assert result.should_save is True + assert callback.should_save_checkpoint is False + + def test_signal_not_registered_when_disabled(self): + """Test that signal handler is not registered when disabled""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": { + "enabled": True, + "check_interval": 10, + "enable_signal": False, + }, + "output_dir": tmpdir, + } + ) + + with patch("signal.signal") as mock_signal_register: + _ = DynamicCheckpointCallback(cfg) + + assert not mock_signal_register.called + + +class TestDynamicCheckpointDisabled: + """Test behavior when callback is disabled""" + + def test_disabled_callback_does_nothing(self): + """Test that disabled callback doesn't check or trigger""" + with tempfile.TemporaryDirectory() as tmpdir: + cfg = DictDefault( + { + "dynamic_checkpoint": {"enabled": False}, + "output_dir": tmpdir, + } + ) + callback = DynamicCheckpointCallback(cfg) + + trigger_file = Path(tmpdir) / DEFAULT_TRIGGER_FILENAME + trigger_file.touch() + + args = Mock(output_dir=tmpdir) + state = Mock(global_step=1) + control = Mock(should_save=False) + + result = callback.on_step_end(args, state, control) + + assert trigger_file.exists() + assert result.should_save is False