Compare commits
11 Commits
telemetry
...
optimizers
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76bb09784d | ||
|
|
0542c7dd56 | ||
|
|
0134093acc | ||
|
|
d4de93a7bb | ||
|
|
c8191394e9 | ||
|
|
f18231c653 | ||
|
|
9ed4f6b3aa | ||
|
|
05dddfc41d | ||
|
|
8e30917440 | ||
|
|
d883b11b6f | ||
|
|
f4910dd2ea |
@@ -19,9 +19,6 @@
|
|||||||
<br/>
|
<br/>
|
||||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
|
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
|
||||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
||||||
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
|
||||||
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
|
|
||||||
</a>
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ website:
|
|||||||
|
|
||||||
- section: "Deployments"
|
- section: "Deployments"
|
||||||
contents:
|
contents:
|
||||||
|
- docs/docker.qmd
|
||||||
- docs/multi-gpu.qmd
|
- docs/multi-gpu.qmd
|
||||||
- docs/multi-node.qmd
|
- docs/multi-node.qmd
|
||||||
- docs/ray-integration.qmd
|
- docs/ray-integration.qmd
|
||||||
|
|||||||
@@ -163,6 +163,12 @@ datasets:
|
|||||||
system: ["system"]
|
system: ["system"]
|
||||||
tool: ["tool"]
|
tool: ["tool"]
|
||||||
|
|
||||||
|
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
|
||||||
|
# This does not drop the default system message from chat_template if it exists. If you wish to,
|
||||||
|
# we recommend using a custom jinja template with the default system message removed or
|
||||||
|
# adding a system turn with empty content.
|
||||||
|
drop_system_message:
|
||||||
|
|
||||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
||||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
# See examples at `docs/dataset-formats/conversation.qmd`
|
||||||
@@ -222,8 +228,8 @@ process_reward_model:
|
|||||||
chat_template: tokenizer_default
|
chat_template: tokenizer_default
|
||||||
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
||||||
chat_template_jinja: null
|
chat_template_jinja: null
|
||||||
# Changes the default system message
|
# Changes the default system message. Currently only supports chatml.
|
||||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
|
||||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
@@ -445,7 +451,7 @@ gradient_checkpointing: false
|
|||||||
early_stopping_patience: 3
|
early_stopping_patience: 3
|
||||||
|
|
||||||
# Specify a scheduler and kwargs to use with the optimizer
|
# Specify a scheduler and kwargs to use with the optimizer
|
||||||
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
|
||||||
lr_scheduler_kwargs:
|
lr_scheduler_kwargs:
|
||||||
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
||||||
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
|
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
|
||||||
@@ -528,6 +534,8 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
|||||||
sdp_attention:
|
sdp_attention:
|
||||||
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
||||||
s2_attention:
|
s2_attention:
|
||||||
|
# Optional[bool]. Whether to use low_cpu_mem_usage
|
||||||
|
low_cpu_mem_usage:
|
||||||
# Resume from a specific checkpoint dir
|
# Resume from a specific checkpoint dir
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||||
|
|||||||
@@ -129,6 +129,7 @@ You can mix and match within each approach or across approaches to train a model
|
|||||||
We suggest this approach when you want to bring your own tokenized dataset.
|
We suggest this approach when you want to bring your own tokenized dataset.
|
||||||
|
|
||||||
Axolotl expects the dataset to have three keys:
|
Axolotl expects the dataset to have three keys:
|
||||||
|
|
||||||
- `input_ids`: from tokenizing formatted prompt
|
- `input_ids`: from tokenizing formatted prompt
|
||||||
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
|
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
|
||||||
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.
|
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.
|
||||||
|
|||||||
140
docs/docker.qmd
Normal file
140
docs/docker.qmd
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
---
|
||||||
|
title: "Docker"
|
||||||
|
format:
|
||||||
|
html:
|
||||||
|
toc: true
|
||||||
|
toc-depth: 4
|
||||||
|
---
|
||||||
|
|
||||||
|
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||||
|
|
||||||
|
## Base
|
||||||
|
|
||||||
|
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||||
|
|
||||||
|
#### Image
|
||||||
|
|
||||||
|
```
|
||||||
|
axolotlai/axolotl-base
|
||||||
|
```
|
||||||
|
|
||||||
|
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
|
||||||
|
|
||||||
|
#### Tags format
|
||||||
|
|
||||||
|
```bash
|
||||||
|
main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||||
|
```
|
||||||
|
|
||||||
|
Tags examples:
|
||||||
|
|
||||||
|
- `main-base-py3.11-cu124-2.6.0`
|
||||||
|
- `main-base-py3.11-cu124-2.5.1`
|
||||||
|
- `main-base-py3.11-cu124-2.4.1`
|
||||||
|
|
||||||
|
## Main
|
||||||
|
|
||||||
|
The main image is the image that is used to run Axolotl. It is based on the `axolotlai/axolotl-base` image and includes the Axolotl codebase, dependencies, and more.
|
||||||
|
|
||||||
|
#### Image
|
||||||
|
|
||||||
|
```
|
||||||
|
axolotlai/axolotl
|
||||||
|
```
|
||||||
|
|
||||||
|
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
||||||
|
|
||||||
|
#### Tags format {#sec-main-tags}
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# on push to main
|
||||||
|
main-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||||
|
|
||||||
|
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
|
||||||
|
main-latest
|
||||||
|
|
||||||
|
# nightly build
|
||||||
|
{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||||
|
|
||||||
|
# tagged release
|
||||||
|
{version}
|
||||||
|
```
|
||||||
|
|
||||||
|
:::{.callout-tip}
|
||||||
|
|
||||||
|
There may be some extra tags appended to the image, like `-vllm` which installs those packages.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
Tags examples:
|
||||||
|
|
||||||
|
- `main-py3.11-cu124-2.6.0`
|
||||||
|
- `main-py3.11-cu124-2.5.1`
|
||||||
|
- `main-py3.11-cu124-2.4.1`
|
||||||
|
- `main-latest`
|
||||||
|
- `main-20250303-py3.11-cu124-2.6.0`
|
||||||
|
- `main-20250303-py3.11-cu124-2.5.1`
|
||||||
|
- `main-20250303-py3.11-cu124-2.4.1`
|
||||||
|
- `0.7.1`
|
||||||
|
|
||||||
|
## Cloud
|
||||||
|
|
||||||
|
The cloud image is the image that is used to run Axolotl in the cloud. It is based on the `axolotlai/axolotl` image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.
|
||||||
|
|
||||||
|
:::{.callout-tip}
|
||||||
|
|
||||||
|
Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variables to disable it.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
#### Image
|
||||||
|
|
||||||
|
```
|
||||||
|
axolotlai/axolotl-cloud
|
||||||
|
```
|
||||||
|
|
||||||
|
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
|
||||||
|
|
||||||
|
#### Tags format
|
||||||
|
|
||||||
|
This uses the same tags as the [`main` image](#sec-main-tags).
|
||||||
|
|
||||||
|
#### Environment variables
|
||||||
|
|
||||||
|
- `JUPYTER_DISABLE`: Disable Jupyter lab.
|
||||||
|
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
|
||||||
|
- `PUBLIC_KEY`: Add a public key for the SSH service.
|
||||||
|
- `SSH_KEY`: Add a private key for the SSH service.
|
||||||
|
|
||||||
|
#### Volume mounts
|
||||||
|
|
||||||
|
:::{.callout-tip}
|
||||||
|
|
||||||
|
We recommend mounting volumes to `/workspace/data` for data persistence. `/workspace/axolotl` contains the source code and is ephemeral.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
- `/workspace/data/axolotl-artifacts`: Directory to store Axolotl artifacts.
|
||||||
|
- `/workspace/data/huggingface-cache`: Directory to store HuggingFace cache.
|
||||||
|
|
||||||
|
## Cloud-no-tmux
|
||||||
|
|
||||||
|
This is the same as the [`cloud` image](#sec-cloud) but without tmux.
|
||||||
|
|
||||||
|
#### Image
|
||||||
|
|
||||||
|
```
|
||||||
|
axolotlai/axolotl-cloud-term
|
||||||
|
```
|
||||||
|
|
||||||
|
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud-term)
|
||||||
|
|
||||||
|
:::{.callout-note}
|
||||||
|
|
||||||
|
The naming may be a bit confusing as it has `-term` appended to the end.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
#### Tags format
|
||||||
|
|
||||||
|
This uses the same tags as the [`cloud` image](#sec-cloud-tags).
|
||||||
@@ -19,7 +19,9 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
|
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
|
||||||
|
|
||||||
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
**Q: ModuleNotFoundError: No module named 'mpi4py' using single GPU with deepspeed**
|
||||||
|
|
||||||
|
> A: You may be using deepspeed with single gpu. Please remove the `deepspeed:` section in the yaml file or `--deepspeed` CLI flag.
|
||||||
|
|
||||||
**Q: The codes is stuck on saving preprocessed datasets.**
|
**Q: The codes is stuck on saving preprocessed datasets.**
|
||||||
|
|
||||||
|
|||||||
@@ -65,6 +65,8 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
|||||||
```
|
```
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||||
|
|
||||||
## Cloud Environments {#sec-cloud}
|
## Cloud Environments {#sec-cloud}
|
||||||
|
|
||||||
### Cloud GPU Providers {#sec-cloud-gpu}
|
### Cloud GPU Providers {#sec-cloud-gpu}
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
---
|
|
||||||
title: Telemetry
|
|
||||||
description: A description of the opt-out telemetry implementation in Axolotl.
|
|
||||||
---
|
|
||||||
|
|
||||||
# Telemetry in Axolotl
|
|
||||||
|
|
||||||
Axolotl implements anonymous telemetry to help maintainers understand how the library
|
|
||||||
is used and where users encounter issues. This data helps prioritize features, optimize
|
|
||||||
performance, and fix bugs.
|
|
||||||
|
|
||||||
## Data Collection
|
|
||||||
|
|
||||||
We collect:
|
|
||||||
|
|
||||||
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
|
|
||||||
version, etc.
|
|
||||||
- Hardware info: CPU count, memory, GPU count and models
|
|
||||||
- Runtime metrics: Training progress, memory usage, timing information
|
|
||||||
- Usage patterns: Models (from a whitelist) and configurations used
|
|
||||||
- Error tracking: Stack traces and error messages (sanitized to remove personal
|
|
||||||
information)
|
|
||||||
|
|
||||||
No personally identifiable information (PII) is collected.
|
|
||||||
|
|
||||||
## Implementation
|
|
||||||
|
|
||||||
Telemetry is implemented using PostHog and consists of:
|
|
||||||
|
|
||||||
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
|
|
||||||
telemetry system and provides methods for tracking events.
|
|
||||||
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
|
|
||||||
sends sanitized stack traces.
|
|
||||||
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
|
|
||||||
runtime metrics during training.
|
|
||||||
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
|
|
||||||
runtime metrics telemetry.
|
|
||||||
|
|
||||||
The telemetry system will block training startup for 15 seconds to ensure users are
|
|
||||||
aware of data collection, unless telemetry is explicitly enabled or disabled.
|
|
||||||
|
|
||||||
## Opt-Out Mechanism
|
|
||||||
|
|
||||||
Telemetry is **enabled by default** on an opt-out basis. To disable it, set either:
|
|
||||||
|
|
||||||
- `AXOLOTL_DO_NOT_TRACK=1` (Axolotl-specific)
|
|
||||||
- `DO_NOT_TRACK=1` (Global standard; see https://consoledonottrack.com/)
|
|
||||||
|
|
||||||
To acknowledge and explicitly enable telemetry (and remove the warning message), set:
|
|
||||||
`AXOLOTL_DO_NOT_TRACK=0`.
|
|
||||||
|
|
||||||
## Privacy
|
|
||||||
|
|
||||||
- All path-like config information is automatically redacted from telemetry data
|
|
||||||
- Model information is only collected for whitelisted organizations
|
|
||||||
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
|
|
||||||
- Each run generates a unique anonymous ID
|
|
||||||
- This allows us to link different telemetry events in a single same training run
|
|
||||||
- Telemetry is only sent from the main process to avoid duplicate events
|
|
||||||
@@ -63,6 +63,4 @@ torchao==0.7.0
|
|||||||
schedulefree==1.3.0
|
schedulefree==1.3.0
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.3
|
axolotl-contribs-lgpl==0.0.3
|
||||||
|
axolotl-contribs-mit==0.0.3
|
||||||
# telemetry
|
|
||||||
posthog>=3.15.1
|
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ import yaml
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.telemetry.errors import send_errors
|
|
||||||
from axolotl.telemetry.manager import TelemetryManager
|
|
||||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
@@ -29,8 +27,6 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
|||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||||
"""
|
"""
|
||||||
@@ -156,7 +152,6 @@ def prepare_plugins(cfg: DictDefault):
|
|||||||
plugin_manager.register(plugin_name)
|
plugin_manager.register(plugin_name)
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
|
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
|
||||||
"""
|
"""
|
||||||
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
||||||
@@ -176,7 +171,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
|||||||
# Load the config from the yaml file
|
# Load the config from the yaml file
|
||||||
with open(config, encoding="utf-8") as file:
|
with open(config, encoding="utf-8") as file:
|
||||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||||
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
|
|
||||||
|
|
||||||
# If there are any options passed in the cli, if it is something that seems valid
|
# If there are any options passed in the cli, if it is something that seems valid
|
||||||
# from the yaml, then overwrite the value
|
# from the yaml, then overwrite the value
|
||||||
@@ -220,6 +214,4 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
|||||||
setup_mlflow_env_vars(cfg)
|
setup_mlflow_env_vars(cfg)
|
||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
|
|
||||||
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from axolotl.cli.args import InferenceCliArgs
|
|||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.telemetry.errors import send_errors
|
|
||||||
from axolotl.utils.chat_templates import (
|
from axolotl.utils.chat_templates import (
|
||||||
get_chat_template,
|
get_chat_template,
|
||||||
get_chat_template_from_config,
|
get_chat_template_from_config,
|
||||||
@@ -43,7 +42,6 @@ def get_multi_line_input() -> str:
|
|||||||
return instruction
|
return instruction
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def do_inference(
|
def do_inference(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
@@ -137,7 +135,6 @@ def do_inference(
|
|||||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def do_inference_gradio(
|
def do_inference_gradio(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
|
|||||||
@@ -12,13 +12,11 @@ from axolotl.cli.args import TrainerCliArgs
|
|||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.telemetry.errors import send_errors
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def do_merge_lora(*, cfg: DictDefault) -> None:
|
def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||||
"""
|
"""
|
||||||
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
|
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.telemetry.errors import send_errors
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -121,7 +120,6 @@ def _distributed_checkpoint_to_merged_weights(
|
|||||||
return save_path_
|
return save_path_
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def merge_fsdp_weights(
|
def merge_fsdp_weights(
|
||||||
checkpoint_dir: str,
|
checkpoint_dir: str,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
|
|||||||
@@ -18,14 +18,12 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
|||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.telemetry.errors import send_errors
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.trainer import disable_datasets_caching
|
from axolotl.utils.trainer import disable_datasets_caching
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||||
"""
|
"""
|
||||||
Preprocesses dataset specified in axolotl config.
|
Preprocesses dataset specified in axolotl config.
|
||||||
|
|||||||
@@ -41,11 +41,12 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
del model
|
del model
|
||||||
del tokenizer
|
del tokenizer
|
||||||
|
del trainer
|
||||||
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from datasets import Dataset
|
|||||||
|
|
||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||||
from axolotl.telemetry.errors import send_errors
|
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -25,8 +24,8 @@ class TrainDatasetMeta:
|
|||||||
"""Dataclass with fields for training and validation datasets and metadata."""
|
"""Dataclass with fields for training and validation datasets and metadata."""
|
||||||
|
|
||||||
train_dataset: Dataset
|
train_dataset: Dataset
|
||||||
eval_dataset: Optional[Dataset] = None
|
eval_dataset: Dataset | None = None
|
||||||
total_num_steps: Optional[int] = None
|
total_num_steps: int | None = None
|
||||||
|
|
||||||
|
|
||||||
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||||
@@ -45,7 +44,6 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def load_datasets(
|
def load_datasets(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
@@ -105,7 +103,6 @@ def load_datasets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def load_preference_datasets(
|
def load_preference_datasets(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from transformers import (
|
|||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
|
from transformers.training_args import OptimizerNames
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.base import (
|
||||||
@@ -61,8 +62,6 @@ from axolotl.core.training_args import (
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
from axolotl.telemetry.callbacks import TelemetryCallback
|
|
||||||
from axolotl.telemetry.manager import TelemetryManager
|
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
@@ -86,6 +85,7 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
||||||
from axolotl.utils.models import ensure_dtype
|
from axolotl.utils.models import ensure_dtype
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -93,13 +93,11 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""Base class for trainer builder."""
|
||||||
Base class for trainer builder
|
|
||||||
"""
|
|
||||||
|
|
||||||
_train_dataset = None
|
_train_dataset = None
|
||||||
_eval_dataset = None
|
_eval_dataset = None
|
||||||
@@ -112,9 +110,9 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
|
||||||
# in case the model supports tagging, add the axolotl tag.
|
# If the model supports tagging, add the axolotl tag.
|
||||||
# This makes sure the tag is correctly pushed even if a user calls
|
# This makes sure the tag is correctly pushed even if a user calls
|
||||||
# model.push_to_hub instad of trainer.push_to_hub.
|
# model.push_to_hub instead of trainer.push_to_hub.
|
||||||
if hasattr(model, "add_model_tags"):
|
if hasattr(model, "add_model_tags"):
|
||||||
model.add_model_tags(["axolotl"])
|
model.add_model_tags(["axolotl"])
|
||||||
|
|
||||||
@@ -178,8 +176,10 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
SaveAxolotlConfigtoMlflowCallback,
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
)
|
)
|
||||||
|
|
||||||
callbacks.append(
|
callbacks.extend(
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
[
|
||||||
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
if self.cfg.use_comet and is_comet_available():
|
if self.cfg.use_comet and is_comet_available():
|
||||||
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
||||||
@@ -188,10 +188,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
telemetry_manager = TelemetryManager.get_instance()
|
|
||||||
if telemetry_manager.enabled:
|
|
||||||
callbacks.append(TelemetryCallback())
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -231,8 +227,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||||
"""
|
"""
|
||||||
Build the HuggingFace training args/trainer for causal models
|
Build the HuggingFace training args/trainer for causal models and reward modeling
|
||||||
and reward modelling using TRL.
|
using TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
@@ -555,30 +551,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||||
else:
|
else:
|
||||||
training_arguments_kwargs["run_name"] = None
|
training_arguments_kwargs["run_name"] = None
|
||||||
training_arguments_kwargs["optim"] = (
|
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
|
||||||
)
|
|
||||||
if self.cfg.optim_args:
|
|
||||||
if isinstance(self.cfg.optim_args, dict):
|
|
||||||
optim_args = ",".join(
|
|
||||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
optim_args = self.cfg.optim_args
|
|
||||||
training_arguments_kwargs["optim_args"] = optim_args
|
|
||||||
if self.cfg.optim_target_modules:
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"optim_target_modules"
|
|
||||||
] = self.cfg.optim_target_modules
|
|
||||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"loraplus_lr_embedding"
|
|
||||||
] = self.cfg.loraplus_lr_embedding
|
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
|
||||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
|
||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"alternate_lr_scheduler_type"
|
"alternate_lr_scheduler_type"
|
||||||
@@ -662,46 +636,114 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# Handle custom optimizer
|
||||||
if self.cfg.optimizer in [
|
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
|
||||||
"optimi_adamw",
|
if self.cfg.optimizer in custom_supported_optimizers:
|
||||||
"ao_adamw_4bit",
|
# Common optimizer kwargs
|
||||||
"ao_adamw_8bit",
|
optimizer_kwargs = {
|
||||||
"ao_adamw_fp8",
|
"lr": training_arguments_kwargs.get("learning_rate"),
|
||||||
"adopt_adamw",
|
"weight_decay": training_arguments_kwargs.get("weight_decay"),
|
||||||
]:
|
}
|
||||||
# Set default so transformers doesn't throw
|
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
|
||||||
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
|
|
||||||
|
|
||||||
if self.cfg.optimizer == "lion_pytorch":
|
# Adam-specific kwargs
|
||||||
from lion_pytorch import Lion
|
adam_kwargs = {}
|
||||||
|
if training_arguments_kwargs.get(
|
||||||
|
"adam_beta1"
|
||||||
|
) and training_arguments_kwargs.get("adam_beta2"):
|
||||||
|
adam_kwargs["betas"] = (
|
||||||
|
training_arguments_kwargs.get("adam_beta1"),
|
||||||
|
training_arguments_kwargs.get("adam_beta2"),
|
||||||
|
)
|
||||||
|
if training_arguments_kwargs.get("adam_epsilon"):
|
||||||
|
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
|
||||||
|
|
||||||
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
|
if self.cfg.optimizer == "muon":
|
||||||
if "weight_decay" in training_arguments_kwargs:
|
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||||
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
|
MuonOptimizerFactory,
|
||||||
|
|
||||||
if (
|
|
||||||
"adam_beta1" in training_arguments_kwargs
|
|
||||||
and "adam_beta2" in training_arguments_kwargs
|
|
||||||
):
|
|
||||||
lion_kwargs["betas"] = (
|
|
||||||
training_arguments_kwargs["adam_beta1"],
|
|
||||||
training_arguments_kwargs["adam_beta2"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer_kwargs["optimizers"] = (
|
optimizer_cls = MuonOptimizerFactory
|
||||||
Lion(params=self.model.parameters(), **lion_kwargs),
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
None,
|
elif self.cfg.optimizer == "optimi_adamw":
|
||||||
|
from optimi import AdamW
|
||||||
|
|
||||||
|
optimizer_kwargs["foreach"] = False
|
||||||
|
optimizer_cls = AdamW
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||||
|
# TODO remove 20250401
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||||
|
|
||||||
|
optimizer_cls = AdamW4bit
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||||
|
)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||||
|
|
||||||
|
optimizer_cls = AdamW8bit
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||||
|
|
||||||
|
optimizer_cls = AdamWFp8
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "adopt_adamw":
|
||||||
|
from axolotl.utils.optimizers.adopt import ADOPT
|
||||||
|
|
||||||
|
optimizer_cls = ADOPT
|
||||||
|
adam_kwargs["decouple"] = True
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|
||||||
|
# Parse any additional optimizer args from config
|
||||||
|
if self.cfg.optim_args:
|
||||||
|
if isinstance(self.cfg.optim_args, dict):
|
||||||
|
optimizer_kwargs.update(self.cfg.optim_args)
|
||||||
|
else:
|
||||||
|
# Parse string format "key1=value1,key2=value2"
|
||||||
|
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
|
||||||
|
key, value = mapping.split("=")
|
||||||
|
optimizer_kwargs[key] = value
|
||||||
|
|
||||||
|
trainer_kwargs["optimizer_cls_and_kwargs"] = (
|
||||||
|
optimizer_cls,
|
||||||
|
optimizer_kwargs,
|
||||||
)
|
)
|
||||||
# Set default so transformers doesn't throw
|
else:
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
# Use transformers' optimizer
|
||||||
|
training_arguments_kwargs["optim"] = self.cfg.optimizer
|
||||||
|
|
||||||
|
# Parse any additional optimizer args from config
|
||||||
|
if self.cfg.optim_args:
|
||||||
|
if isinstance(self.cfg.optim_args, dict):
|
||||||
|
optim_args = ",".join(
|
||||||
|
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optim_args = self.cfg.optim_args
|
||||||
|
training_arguments_kwargs["optim_args"] = optim_args
|
||||||
|
|
||||||
if self.cfg.optimizer == "adamw_anyprecision":
|
if self.cfg.optimizer == "adamw_anyprecision":
|
||||||
if Path(self.cfg.torchdistx_path).exists():
|
if Path(self.cfg.torchdistx_path).exists():
|
||||||
sys.path.append(self.cfg.torchdistx_path)
|
sys.path.append(self.cfg.torchdistx_path)
|
||||||
importlib.import_module("torchdistx")
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
|
if self.cfg.optim_target_modules:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"optim_target_modules"
|
||||||
|
] = self.cfg.optim_target_modules
|
||||||
|
|
||||||
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
|
|
||||||
|
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"loraplus_lr_embedding"
|
||||||
|
] = self.cfg.loraplus_lr_embedding
|
||||||
|
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||||
|
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"accelerator_config"
|
"accelerator_config"
|
||||||
@@ -876,9 +918,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
|
|
||||||
class HFRLTrainerBuilder(TrainerBuilderBase):
|
class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||||
"""
|
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
|
||||||
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from typing import Dict, Literal, Optional
|
|||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
@@ -22,9 +23,11 @@ from transformers.utils import is_sagemaker_mp_enabled
|
|||||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
|
RexLR,
|
||||||
get_cosine_schedule_with_min_lr,
|
get_cosine_schedule_with_min_lr,
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
get_cosine_schedule_with_warmup_decay_constant,
|
get_cosine_schedule_with_warmup_decay_constant,
|
||||||
@@ -115,6 +118,17 @@ class SchedulerMixin(Trainer):
|
|||||||
**extra_lr_kwargs,
|
**extra_lr_kwargs,
|
||||||
**self.args.lr_scheduler_kwargs,
|
**self.args.lr_scheduler_kwargs,
|
||||||
)
|
)
|
||||||
|
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
|
||||||
|
self.lr_scheduler = RexLR(
|
||||||
|
optimizer=optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
)
|
||||||
elif use_cosine_quadratic:
|
elif use_cosine_quadratic:
|
||||||
if use_cosine_min_lr:
|
if use_cosine_min_lr:
|
||||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||||
@@ -154,47 +168,18 @@ class SchedulerMixin(Trainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
class OptimizerMixin(Trainer):
|
||||||
"""
|
"""
|
||||||
Extend the base Trainer for axolotl helpers
|
Mixin class for shared handling of building custom optimizers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
|
||||||
|
|
||||||
def __init__(
|
def create_optimizer_grouped_parameters(
|
||||||
self,
|
self, opt_model, optimizer_kwargs
|
||||||
*_args,
|
) -> list[dict]:
|
||||||
bench_data_collator=None,
|
|
||||||
eval_data_collator=None,
|
|
||||||
dataset_tags=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.bench_data_collator = bench_data_collator
|
|
||||||
self.eval_data_collator = eval_data_collator
|
|
||||||
self.dataset_tags = dataset_tags
|
|
||||||
self._signature_columns = None # workaround for pylint
|
|
||||||
super().__init__(*_args, **kwargs)
|
|
||||||
self.train_data_collator = self.data_collator
|
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
|
||||||
if self.args.orpo_alpha:
|
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
|
||||||
if self.args.torch_compile:
|
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
|
||||||
256
|
|
||||||
)
|
|
||||||
model = torch.compile(
|
|
||||||
model,
|
|
||||||
backend=self.args.torch_compile_backend,
|
|
||||||
mode=self.args.torch_compile_mode,
|
|
||||||
)
|
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
|
||||||
|
|
||||||
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
params = {
|
params: dict = {
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
"no_weight_decay": {},
|
"no_weight_decay": {},
|
||||||
@@ -281,23 +266,30 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
and self.args.embedding_lr_scale is None
|
and self.args.embedding_lr_scale is None
|
||||||
and self.args.embedding_lr is None
|
and self.args.embedding_lr is None
|
||||||
and self.args.lr_groups is None
|
and self.args.lr_groups is None
|
||||||
and self.args.alternate_optimizer
|
and self.optimizer_cls_and_kwargs is None
|
||||||
not in [
|
|
||||||
"optimi_adamw",
|
|
||||||
"ao_adamw_8bit",
|
|
||||||
"ao_adamw_4bit",
|
|
||||||
"ao_adamw_fp8",
|
|
||||||
"adopt_adamw",
|
|
||||||
]
|
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
if (
|
||||||
self.args,
|
not self.optimizer
|
||||||
opt_model,
|
and self.optimizer_cls_and_kwargs is not None
|
||||||
|
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||||
|
):
|
||||||
|
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
self.optimizer = optimizer_factory_cls()(
|
||||||
|
opt_model, self.args, **optimizer_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.optimizer:
|
||||||
|
if self.optimizer_cls_and_kwargs is not None:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
else:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args, opt_model
|
||||||
|
)
|
||||||
|
|
||||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||||
opt_model, optimizer_kwargs
|
opt_model, optimizer_kwargs
|
||||||
)
|
)
|
||||||
@@ -314,50 +306,47 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
**optimizer_kwargs,
|
**optimizer_kwargs,
|
||||||
)
|
)
|
||||||
elif (
|
else:
|
||||||
self.args.embedding_lr_scale is not None
|
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
or self.args.embedding_lr is not None
|
# e.g. for GaLore optimizer.
|
||||||
or self.args.lr_groups is not None
|
if "params" in optimizer_kwargs:
|
||||||
):
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
|
||||||
from optimi import AdamW
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
AdamW(
|
# e.g. for LOMO optimizer.
|
||||||
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
if "model" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||||
|
|
||||||
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||||
|
# to avoid arguments conflicts.
|
||||||
|
if "optimizer_dict" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||||
|
"optimizer_dict"
|
||||||
)
|
)
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "ao_adamw_4bit":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = optimizer_cls(
|
||||||
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
optimizer_grouped_parameters, **optimizer_kwargs
|
||||||
)
|
)
|
||||||
elif self.args.alternate_optimizer == "ao_adamw_8bit":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
import bitsandbytes
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "ao_adamw_fp8":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "adopt_adamw":
|
|
||||||
from axolotl.utils.optimizers.adopt import ADOPT
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
skipped = 0
|
||||||
ADOPT(
|
for module in opt_model.modules():
|
||||||
optimizer_grouped_parameters,
|
if isinstance(module, nn.Embedding):
|
||||||
decouple=True,
|
skipped += sum(
|
||||||
**optimizer_kwargs,
|
{
|
||||||
)
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
)
|
}.values()
|
||||||
|
)
|
||||||
|
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||||
|
manager.register_module_override(
|
||||||
|
module, "weight", {"optim_bits": 32}
|
||||||
|
)
|
||||||
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
|
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
@@ -366,6 +355,45 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
|
|
||||||
return self.optimizer
|
return self.optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||||
|
"""
|
||||||
|
Extend the base Trainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
tag_names = ["axolotl"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*_args,
|
||||||
|
bench_data_collator=None,
|
||||||
|
eval_data_collator=None,
|
||||||
|
dataset_tags=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.bench_data_collator = bench_data_collator
|
||||||
|
self.eval_data_collator = eval_data_collator
|
||||||
|
self.dataset_tags = dataset_tags
|
||||||
|
self._signature_columns = None # workaround for pylint
|
||||||
|
super().__init__(*_args, **kwargs)
|
||||||
|
self.train_data_collator = self.data_collator
|
||||||
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
if self.args.orpo_alpha:
|
||||||
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
|
if self.args.torch_compile:
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
|
256
|
||||||
|
)
|
||||||
|
model = torch.compile(
|
||||||
|
model,
|
||||||
|
backend=self.args.torch_compile_backend,
|
||||||
|
mode=self.args.torch_compile_mode,
|
||||||
|
)
|
||||||
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
if self.args.multipack_real_batches:
|
if self.args.multipack_real_batches:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import logging
|
|||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -31,30 +32,44 @@ class GRPOStrategy:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def set_training_args_kwargs(cls, cfg):
|
def set_training_args_kwargs(cls, cfg):
|
||||||
grpo_args_kwargs = {}
|
grpo_args_kwargs = {}
|
||||||
if cfg.trl and cfg.trl.use_vllm:
|
|
||||||
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
if not hasattr(cfg, "trl") or not cfg.trl:
|
||||||
if cfg.trl and cfg.trl.vllm_device:
|
return grpo_args_kwargs
|
||||||
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
|
||||||
else:
|
trl: TRLConfig = cfg.trl # type: ignore
|
||||||
grpo_args_kwargs["vllm_device"] = "auto"
|
|
||||||
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
if trl.use_vllm:
|
||||||
|
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||||
|
grpo_args_kwargs["vllm_device"] = (
|
||||||
|
trl.vllm_device if trl.vllm_device else "auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
if trl.vllm_gpu_memory_utilization:
|
||||||
grpo_args_kwargs[
|
grpo_args_kwargs[
|
||||||
"vllm_gpu_memory_utilization"
|
"vllm_gpu_memory_utilization"
|
||||||
] = cfg.trl.vllm_gpu_memory_utilization
|
] = trl.vllm_gpu_memory_utilization
|
||||||
if cfg.trl and cfg.trl.vllm_max_model_len:
|
|
||||||
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
if trl.vllm_max_model_len:
|
||||||
if cfg.trl and cfg.trl.num_generations:
|
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
||||||
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
|
||||||
if cfg.trl and cfg.trl.sync_ref_model:
|
if trl.num_generations:
|
||||||
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
grpo_args_kwargs["num_generations"] = trl.num_generations
|
||||||
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
|
||||||
grpo_args_kwargs[
|
if trl.sync_ref_model:
|
||||||
"ref_model_mixup_alpha"
|
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
|
||||||
] = cfg.trl.ref_model_mixup_alpha
|
|
||||||
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
if trl.ref_model_mixup_alpha:
|
||||||
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha
|
||||||
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
|
||||||
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
|
if trl.ref_model_sync_steps:
|
||||||
|
grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps
|
||||||
|
|
||||||
|
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
|
||||||
|
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||||
|
|
||||||
|
if trl.reward_weights:
|
||||||
|
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||||
|
|
||||||
return grpo_args_kwargs
|
return grpo_args_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import torch
|
|||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.telemetry.errors import send_errors
|
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -62,7 +61,6 @@ def evaluate_dataset(
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate a model on training and validation datasets
|
Evaluate a model on training and validation datasets
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import importlib
|
|||||||
import logging
|
import logging
|
||||||
from typing import OrderedDict
|
from typing import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin:
|
class BasePlugin:
|
||||||
"""
|
"""
|
||||||
@@ -469,3 +471,14 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
plugin.post_train_unload(cfg)
|
plugin.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOptimizerFactory:
|
||||||
|
"""
|
||||||
|
Base class for factories to create custom optimizers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, opt_model, training_args, **optimizer_kwargs
|
||||||
|
) -> "torch.optim.Optimizer":
|
||||||
|
pass
|
||||||
|
|||||||
@@ -4,6 +4,22 @@ Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy o
|
|||||||
|
|
||||||
See https://github.com/apple/ml-cross-entropy
|
See https://github.com/apple/ml-cross-entropy
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- PyTorch 2.4.0 or higher
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# if you are in dev environment
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|
||||||
|
# if you are not in dev environment
|
||||||
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
||||||
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -1,164 +0,0 @@
|
|||||||
"""Trainer callbacks for reporting runtime metrics at regular intervals."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
TrainerCallback,
|
|
||||||
TrainerControl,
|
|
||||||
TrainerState,
|
|
||||||
TrainingArguments,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.telemetry.manager import TelemetryManager
|
|
||||||
from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
TIME_SINCE_LAST = 30
|
|
||||||
|
|
||||||
|
|
||||||
class TelemetryCallback(TrainerCallback):
|
|
||||||
"""
|
|
||||||
Trainer callback for tracking and reporting runtime metrics.
|
|
||||||
|
|
||||||
This callback tracks training progress, runtime, and memory usage,
|
|
||||||
sending telemetry at configurable intervals.
|
|
||||||
"""
|
|
||||||
|
|
||||||
report_interval_steps: int = 100
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize the metrics callback."""
|
|
||||||
self.tracker = RuntimeMetricsTracker()
|
|
||||||
self.telemetry_manager = TelemetryManager.get_instance()
|
|
||||||
self.current_epoch = -1
|
|
||||||
self.start_time = time.time()
|
|
||||||
self.last_report_time = None
|
|
||||||
self.last_report_step = 0
|
|
||||||
|
|
||||||
def on_train_begin(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments,
|
|
||||||
state: TrainerState, # pylint: disable=unused-argument
|
|
||||||
control: TrainerControl, # pylint: disable=unused-argument
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
"""Handle training start."""
|
|
||||||
self.telemetry_manager.send_event(event_type="train-started")
|
|
||||||
|
|
||||||
def on_train_end(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments, # pylint: disable=unused-argument
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl, # pylint: disable=unused-argument
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
"""Handle training end."""
|
|
||||||
# Send training completion event
|
|
||||||
self.telemetry_manager.send_event(
|
|
||||||
event_type="train-ended",
|
|
||||||
properties={
|
|
||||||
"loss": state.log_history[-1].get("loss", 0)
|
|
||||||
if state.log_history
|
|
||||||
else None,
|
|
||||||
"learning_rate": state.log_history[-1].get("learning_rate", 0)
|
|
||||||
if state.log_history
|
|
||||||
else None,
|
|
||||||
}
|
|
||||||
| self.tracker.metrics.to_dict(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_epoch_begin(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments, # pylint: disable=unused-argument
|
|
||||||
state: TrainerState, # pylint: disable=unused-argument
|
|
||||||
control: TrainerControl, # pylint: disable=unused-argument
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
"""Handle epoch start."""
|
|
||||||
self.current_epoch += 1
|
|
||||||
self.tracker.start_epoch(self.current_epoch)
|
|
||||||
|
|
||||||
def on_epoch_end(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments, # pylint: disable=unused-argument
|
|
||||||
state: TrainerState, # pylint: disable=unused-argument
|
|
||||||
control: TrainerControl, # pylint: disable=unused-argument
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
"""Handle epoch end."""
|
|
||||||
self.tracker.end_epoch(self.current_epoch)
|
|
||||||
|
|
||||||
def on_step_end(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments, # pylint: disable=unused-argument
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl, # pylint: disable=unused-argument
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
"""Handle step end."""
|
|
||||||
step = state.global_step
|
|
||||||
self.tracker.update_step(step)
|
|
||||||
|
|
||||||
# Check if we should report metrics
|
|
||||||
should_report = (
|
|
||||||
step % self.report_interval_steps == 0
|
|
||||||
or step == 1 # Always report first step
|
|
||||||
or step - self.last_report_step >= self.report_interval_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_report:
|
|
||||||
current_time = time.time()
|
|
||||||
if self.last_report_time is not None:
|
|
||||||
time_since_last_report = current_time - self.last_report_time
|
|
||||||
else:
|
|
||||||
time_since_last_report = current_time - self.start_time
|
|
||||||
steps_since_last_report = step - self.last_report_step
|
|
||||||
|
|
||||||
# Only report if enough time has passed to avoid flooding
|
|
||||||
if (
|
|
||||||
step == 1
|
|
||||||
or time_since_last_report >= TIME_SINCE_LAST
|
|
||||||
or steps_since_last_report >= self.report_interval_steps
|
|
||||||
):
|
|
||||||
# Calculate steps per second for this interval
|
|
||||||
if time_since_last_report > 0 and steps_since_last_report > 0:
|
|
||||||
steps_per_second = steps_since_last_report / time_since_last_report
|
|
||||||
else:
|
|
||||||
steps_per_second = 0
|
|
||||||
|
|
||||||
# Update memory metrics
|
|
||||||
self.tracker.update_memory_metrics()
|
|
||||||
|
|
||||||
loss = state.log_history[-1].get("loss", 0) if state.log_history else 0
|
|
||||||
learning_rate = (
|
|
||||||
state.log_history[-1].get("learning_rate", 0)
|
|
||||||
if state.log_history
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare metrics to report
|
|
||||||
metrics = {
|
|
||||||
"step": step,
|
|
||||||
"epoch": self.current_epoch,
|
|
||||||
"progress": state.epoch, # Fractional epoch progress
|
|
||||||
"loss": loss,
|
|
||||||
"learning_rate": learning_rate,
|
|
||||||
"steps_per_second": steps_per_second,
|
|
||||||
"elapsed_time": current_time - self.start_time,
|
|
||||||
"time_since_last_report": time_since_last_report,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add memory metrics
|
|
||||||
memory_metrics = self.tracker.get_memory_metrics()
|
|
||||||
metrics.update({"memory": memory_metrics})
|
|
||||||
|
|
||||||
# Send telemetry
|
|
||||||
self.telemetry_manager.send_event(
|
|
||||||
event_type="train-progress", properties=metrics
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update last report time and step
|
|
||||||
self.last_report_time = current_time
|
|
||||||
self.last_report_step = step
|
|
||||||
@@ -1,160 +0,0 @@
|
|||||||
"""Telemetry utilities for exception and traceback information."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import traceback
|
|
||||||
from functools import wraps
|
|
||||||
from inspect import getmodule
|
|
||||||
from typing import Any, Callable
|
|
||||||
|
|
||||||
from axolotl.telemetry.manager import TelemetryManager
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
ERROR_HANDLED = False
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_stack_trace(stack_trace: str) -> str:
|
|
||||||
"""
|
|
||||||
Remove personal information from stack trace messages while keeping Python package codepaths.
|
|
||||||
|
|
||||||
This function identifies Python packages by looking for common patterns in virtual environment
|
|
||||||
and site-packages directories, preserving the package path while removing user-specific paths.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stack_trace: The original stack trace string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A sanitized version of the stack trace with Python package paths preserved.
|
|
||||||
"""
|
|
||||||
# Split the stack trace into lines to process each file path separately
|
|
||||||
lines = stack_trace.split("\n")
|
|
||||||
sanitized_lines = []
|
|
||||||
|
|
||||||
# Regular expression to find file paths in the stack trace
|
|
||||||
path_pattern = re.compile(r'(?:File ")(.*?)(?:")')
|
|
||||||
|
|
||||||
# Regular expression to identify paths in site-packages or dist-packages
|
|
||||||
# This matches path segments like "site-packages/package_name" or "dist-packages/package_name"
|
|
||||||
site_packages_pattern = re.compile(
|
|
||||||
r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Additional common virtual environment patterns
|
|
||||||
venv_lib_pattern = re.compile(
|
|
||||||
r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
|
|
||||||
)
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
# Check if this line contains a file path
|
|
||||||
path_match = path_pattern.search(line)
|
|
||||||
|
|
||||||
if path_match:
|
|
||||||
full_path = path_match.group(1)
|
|
||||||
sanitized_path = ""
|
|
||||||
|
|
||||||
# Try to match site-packages pattern
|
|
||||||
site_packages_match = site_packages_pattern.search(full_path)
|
|
||||||
venv_lib_match = venv_lib_pattern.search(full_path)
|
|
||||||
|
|
||||||
if site_packages_match:
|
|
||||||
# Find the index where the matched pattern starts
|
|
||||||
idx = full_path.find("site-packages")
|
|
||||||
if idx == -1:
|
|
||||||
idx = full_path.find("dist-packages")
|
|
||||||
|
|
||||||
# Keep from 'site-packages' onward
|
|
||||||
if idx >= 0:
|
|
||||||
sanitized_path = full_path[idx:]
|
|
||||||
elif venv_lib_match:
|
|
||||||
# For other virtual environment patterns, find the package directory
|
|
||||||
match_idx = venv_lib_match.start(1)
|
|
||||||
if match_idx > 0:
|
|
||||||
# Keep from the package name onward
|
|
||||||
package_name = venv_lib_match.group(1)
|
|
||||||
idx = full_path.rfind(
|
|
||||||
package_name, 0, match_idx + len(package_name)
|
|
||||||
)
|
|
||||||
if idx >= 0:
|
|
||||||
sanitized_path = full_path[idx:]
|
|
||||||
|
|
||||||
# If we couldn't identify a package pattern but path contains 'axolotl'
|
|
||||||
elif "axolotl" in full_path:
|
|
||||||
idx = full_path.rfind("axolotl")
|
|
||||||
if idx >= 0:
|
|
||||||
sanitized_path = full_path[idx:]
|
|
||||||
|
|
||||||
# Apply the sanitization to the line
|
|
||||||
if sanitized_path:
|
|
||||||
line = line.replace(full_path, sanitized_path)
|
|
||||||
else:
|
|
||||||
# If we couldn't identify a package pattern, just keep the filename
|
|
||||||
filename = os.path.basename(full_path)
|
|
||||||
if filename:
|
|
||||||
line = line.replace(full_path, filename)
|
|
||||||
else:
|
|
||||||
line = line.replace(full_path, "")
|
|
||||||
|
|
||||||
sanitized_lines.append(line)
|
|
||||||
|
|
||||||
return "\n".join(sanitized_lines)
|
|
||||||
|
|
||||||
|
|
||||||
def send_errors(func: Callable) -> Callable:
|
|
||||||
"""
|
|
||||||
Decorator to send exception info in a function. If an exception is raised, we send
|
|
||||||
telemetry containing the stack trace and error message.
|
|
||||||
|
|
||||||
If an error occurs in a decorated function that is called by another decorated
|
|
||||||
function, we'll only send telemetry corresponding to the lower-level function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: Function to decorate.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Decorated function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs) -> Any:
|
|
||||||
telemetry_manager = TelemetryManager.get_instance()
|
|
||||||
|
|
||||||
if not telemetry_manager.enabled:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except Exception as exception:
|
|
||||||
# Only track if we're not already handling an error. This prevents us from
|
|
||||||
# capturing an error more than once in nested decorated function calls.=
|
|
||||||
global ERROR_HANDLED # pylint: disable=global-statement
|
|
||||||
if not ERROR_HANDLED:
|
|
||||||
ERROR_HANDLED = True
|
|
||||||
|
|
||||||
# Get function module path
|
|
||||||
module = getmodule(func)
|
|
||||||
module_path = (
|
|
||||||
f"{module.__name__}.{func.__name__}" if module else func.__name__
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get stack trace
|
|
||||||
stack_trace = "".join(
|
|
||||||
traceback.format_exception(
|
|
||||||
type(exception), exception, exception.__traceback__
|
|
||||||
)
|
|
||||||
)
|
|
||||||
stack_trace = sanitize_stack_trace(stack_trace)
|
|
||||||
|
|
||||||
# Send error telemetry
|
|
||||||
telemetry_manager.send_event(
|
|
||||||
event_type=f"{module_path}-errored",
|
|
||||||
properties={
|
|
||||||
"exception": str(exception),
|
|
||||||
"stack_trace": stack_trace,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
raise
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
@@ -1,399 +0,0 @@
|
|||||||
"""Telemetry manager and associated utilities."""
|
|
||||||
|
|
||||||
import atexit
|
|
||||||
import importlib
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import posthog
|
|
||||||
import psutil
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
POSTHOG_HOST = "https://app.posthog.com"
|
|
||||||
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
|
|
||||||
|
|
||||||
ENABLED_WARNING_SLEEP_SECONDS = 15
|
|
||||||
ENABLED_WARNING = (
|
|
||||||
"\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
|
|
||||||
"- Which models and configurations are most commonly used\n"
|
|
||||||
"- What hardware setups need to be supported\n"
|
|
||||||
"- Where users encounter errors\n\n"
|
|
||||||
"This data helps us prioritize features, optimize performance, and fix bugs.\n\n"
|
|
||||||
"To disable telemetry, set either:\n"
|
|
||||||
"- AXOLOTL_DO_NOT_TRACK=1 (Axolotl-specific)\n"
|
|
||||||
"- DO_NOT_TRACK=1 (Global standard; see https://consoledonottrack.com/)\n\n"
|
|
||||||
"To remove this warning and continue with telemetry enabled,"
|
|
||||||
"explicitly set AXOLOTL_DO_NOT_TRACK=0 (and leave DO_NOT_TRACK unset / set to 0)\n\n"
|
|
||||||
"No personally identifiable information is collected."
|
|
||||||
"For details, see: https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html\n\n"
|
|
||||||
f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..."
|
|
||||||
)
|
|
||||||
|
|
||||||
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
|
|
||||||
|
|
||||||
# NOTE: Keep these up to date with any config schema changes
|
|
||||||
FIELDS_WITH_ORGS = {
|
|
||||||
"base_model",
|
|
||||||
"tokenizer_config",
|
|
||||||
"base_model_config",
|
|
||||||
"pretraining_dataset", # NOTE: this field may be a string or a dictionary
|
|
||||||
}
|
|
||||||
FIELDS_TO_REDACT = {"resume_from_checkpoint", "hub_model_id"}
|
|
||||||
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
|
|
||||||
PATH_INDICATORS = {"path", "dir"}
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
RELEVANT_PACKAGES = {
|
|
||||||
"torch",
|
|
||||||
"transformers",
|
|
||||||
"trl",
|
|
||||||
"datasets",
|
|
||||||
"peft",
|
|
||||||
"bitsandbytes",
|
|
||||||
"accelerate",
|
|
||||||
"optimum",
|
|
||||||
"deepspeed",
|
|
||||||
"ray",
|
|
||||||
"axolotl",
|
|
||||||
"triton",
|
|
||||||
"mamba-ssm",
|
|
||||||
"flash-attn",
|
|
||||||
"xformers",
|
|
||||||
"autoawq",
|
|
||||||
"tokenizers",
|
|
||||||
"sentencepiece",
|
|
||||||
"torchao",
|
|
||||||
"lm_eval",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def is_main_process() -> bool:
|
|
||||||
"""
|
|
||||||
Check whether we're running in the main process.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
We're using this function instead of `torch.utils.distributed.is_main_process`
|
|
||||||
causes issues with DeepSpeed world_size since. This function avoids that issue
|
|
||||||
by checking env vars that are set by various launchers.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Whether we're running in the main process.
|
|
||||||
"""
|
|
||||||
# If PyTorch distributed is already initialized, use it
|
|
||||||
if torch.distributed.is_initialized():
|
|
||||||
return torch.distributed.get_rank() == 0
|
|
||||||
|
|
||||||
# Otherwise check environment variables for global rank
|
|
||||||
# NOTE: need to verify this in SLURM / OpenMPI environments
|
|
||||||
global_rank = int(
|
|
||||||
os.environ.get(
|
|
||||||
"RANK",
|
|
||||||
os.environ.get(
|
|
||||||
"GLOBAL_RANK",
|
|
||||||
os.environ.get(
|
|
||||||
"SLURM_PROCID",
|
|
||||||
os.environ.get(
|
|
||||||
"OMPI_COMM_WORLD_RANK",
|
|
||||||
"0",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return global_rank == 0
|
|
||||||
|
|
||||||
|
|
||||||
class TelemetryManager:
|
|
||||||
"""Manages telemetry collection and transmission"""
|
|
||||||
|
|
||||||
_instance = None
|
|
||||||
_initialized = False
|
|
||||||
|
|
||||||
def __new__(cls):
|
|
||||||
"""
|
|
||||||
Telemetry manager constructor. Creates the singleton instance of this class if
|
|
||||||
it doesn't already exist.
|
|
||||||
"""
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super(TelemetryManager, cls).__new__(cls)
|
|
||||||
cls._instance._initialized = False
|
|
||||||
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Telemetry manager initializer"""
|
|
||||||
if self._initialized:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.enabled, self.explicit_enable = self._check_telemetry_enabled()
|
|
||||||
|
|
||||||
if self.enabled:
|
|
||||||
self.run_id = str(uuid.uuid4())
|
|
||||||
self.whitelist = self._load_whitelist()
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.system_info = self._get_system_info()
|
|
||||||
except Exception as e: # pylint: disable=broad-exception-caught
|
|
||||||
LOG.warning(f"Error during system info collection: {e}")
|
|
||||||
self.system_info = None
|
|
||||||
|
|
||||||
self._init_posthog()
|
|
||||||
|
|
||||||
# Register shutdown method to flush posthog telemetry
|
|
||||||
atexit.register(self.shutdown)
|
|
||||||
|
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls) -> "TelemetryManager":
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = TelemetryManager()
|
|
||||||
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def _check_telemetry_enabled(self) -> tuple[bool, bool]:
|
|
||||||
"""
|
|
||||||
Check if telemetry is enabled based on environment variables. We also check
|
|
||||||
whether this is the main process (for the distributed setting and to avoid
|
|
||||||
sending duplicate PostHog events per GPU).
|
|
||||||
|
|
||||||
Note: This is enabled by default on an opt-out basis. Set either
|
|
||||||
`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1` to disable telemetry. For more
|
|
||||||
details, see https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple containing:
|
|
||||||
- Boolean denoting whether telemetry is enabled or disabled.
|
|
||||||
- Boolean denoting whether telemetry is explicitly enabled or not.
|
|
||||||
"""
|
|
||||||
# Parse relevant env vars and fill opt-out default values
|
|
||||||
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
|
||||||
do_not_track = os.getenv("DO_NOT_TRACK")
|
|
||||||
|
|
||||||
# If explicitly enabled, we'll disable the telemetry warning message
|
|
||||||
explicit_enabled = axolotl_do_not_track in ["0", "false"]
|
|
||||||
|
|
||||||
if axolotl_do_not_track is None:
|
|
||||||
axolotl_do_not_track = "0"
|
|
||||||
|
|
||||||
if do_not_track is None:
|
|
||||||
do_not_track = "0"
|
|
||||||
|
|
||||||
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
|
|
||||||
enabled = axolotl_do_not_track.lower() not in (
|
|
||||||
"1",
|
|
||||||
"true",
|
|
||||||
) and do_not_track.lower() not in ("1", "true")
|
|
||||||
|
|
||||||
# Show warning (and sleep on all ranks) unless explicitly enabled
|
|
||||||
if enabled and not explicit_enabled:
|
|
||||||
if is_main_process():
|
|
||||||
LOG.warning(ENABLED_WARNING)
|
|
||||||
time.sleep(ENABLED_WARNING_SLEEP_SECONDS)
|
|
||||||
|
|
||||||
# Only rank 0 will send telemetry
|
|
||||||
if not is_main_process():
|
|
||||||
return False, False
|
|
||||||
|
|
||||||
return enabled, explicit_enabled
|
|
||||||
|
|
||||||
def _load_whitelist(self) -> dict:
|
|
||||||
"""Load HuggingFace Hub organization whitelist"""
|
|
||||||
with open(WHITELIST_PATH, encoding="utf-8") as f:
|
|
||||||
return yaml.safe_load(f)
|
|
||||||
|
|
||||||
def _is_whitelisted(self, base_model: str) -> bool:
|
|
||||||
"""Check if model org is in whitelist"""
|
|
||||||
if not base_model:
|
|
||||||
return False
|
|
||||||
|
|
||||||
base_model = base_model.lower()
|
|
||||||
return any(
|
|
||||||
org.lower() in base_model for org in self.whitelist.get("organizations", [])
|
|
||||||
)
|
|
||||||
|
|
||||||
def _init_posthog(self):
|
|
||||||
"""Initialize PostHog client"""
|
|
||||||
posthog.host = POSTHOG_HOST
|
|
||||||
posthog.project_api_key = POSTHOG_WRITE_KEY
|
|
||||||
|
|
||||||
def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Redact properties to remove any paths, so as to avoid inadvertently collecting
|
|
||||||
private or personally identifiable information (PII). We also remove
|
|
||||||
information related to Wandb, MLflow, etc. configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
properties: Dictionary of properties to redact.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Properties dictionary with redaction applied.
|
|
||||||
"""
|
|
||||||
if not properties:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def redact_value(value: Any, key: str = "") -> Any:
|
|
||||||
"""Recursively sanitize values, redacting those with path-like keys"""
|
|
||||||
if isinstance(key, str) and isinstance(value, str):
|
|
||||||
# Fields that should be redacted if org is not whitelisted
|
|
||||||
if key in FIELDS_WITH_ORGS:
|
|
||||||
org = value.split("/")[0]
|
|
||||||
if org not in self.whitelist["organizations"]:
|
|
||||||
return "[REDACTED]"
|
|
||||||
|
|
||||||
# Other redaction special cases
|
|
||||||
if (
|
|
||||||
key in FIELDS_TO_REDACT
|
|
||||||
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
|
|
||||||
or any(indicator in key.lower() for indicator in PATH_INDICATORS)
|
|
||||||
):
|
|
||||||
return "[REDACTED]"
|
|
||||||
|
|
||||||
# Handle nested structures
|
|
||||||
if isinstance(value, dict):
|
|
||||||
return {k: redact_value(v, k) for k, v in value.items()}
|
|
||||||
if isinstance(value, list):
|
|
||||||
return [redact_value(item) for item in value]
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
# Create new dict with redacted values
|
|
||||||
redacted = {k: redact_value(v, k) for k, v in properties.items()}
|
|
||||||
|
|
||||||
return redacted
|
|
||||||
|
|
||||||
def _get_system_info(self) -> dict[str, Any]:
|
|
||||||
"""Collect system information for various hardware accelerators"""
|
|
||||||
gpu_info = []
|
|
||||||
accelerator_type = "none"
|
|
||||||
|
|
||||||
# NVIDIA GPUs
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
accelerator_type = "cuda"
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
gpu_info.append(
|
|
||||||
{
|
|
||||||
"name": torch.cuda.get_device_name(i),
|
|
||||||
"memory": torch.cuda.get_device_properties(i).total_memory,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# AMD GPUs
|
|
||||||
elif hasattr(torch, "hip") and torch.hip.is_available():
|
|
||||||
accelerator_type = "hip"
|
|
||||||
for i in range(torch.hip.device_count()):
|
|
||||||
gpu_info.append(
|
|
||||||
{
|
|
||||||
"name": torch.hip.get_device_name(i),
|
|
||||||
"memory": torch.hip.get_device_properties(i).total_memory
|
|
||||||
if hasattr(torch.hip, "get_device_properties")
|
|
||||||
else None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apple Silicon
|
|
||||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
||||||
accelerator_type = "mps"
|
|
||||||
gpu_info.append(
|
|
||||||
{
|
|
||||||
"name": "Apple Silicon",
|
|
||||||
# NOTE: this is memory allocated to this process, not total memory
|
|
||||||
"memory": torch.mps.driver_allocated_memory(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Intel GPUs
|
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
accelerator_type = "xpu"
|
|
||||||
for i in range(torch.xpu.device_count()):
|
|
||||||
memory = None
|
|
||||||
if hasattr(torch.xpu, "get_device_properties"):
|
|
||||||
memory = torch.xpu.get_device_properties(i).total_memory
|
|
||||||
|
|
||||||
gpu_info.append(
|
|
||||||
{
|
|
||||||
"name": torch.xpu.get_device_name(i),
|
|
||||||
"memory": memory,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# NPUs
|
|
||||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
|
||||||
accelerator_type = "npu"
|
|
||||||
for i in range(torch.npu.device_count()):
|
|
||||||
memory = None
|
|
||||||
if hasattr(torch.npu, "get_device_properties"):
|
|
||||||
memory = torch.npu.get_device_properties(i).total_memory
|
|
||||||
|
|
||||||
gpu_info.append(
|
|
||||||
{
|
|
||||||
"name": torch.npu.get_device_name(i),
|
|
||||||
"memory": memory,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get relevant package versions
|
|
||||||
installed_packages = {}
|
|
||||||
for package in RELEVANT_PACKAGES:
|
|
||||||
try:
|
|
||||||
version = importlib.metadata.version(package)
|
|
||||||
installed_packages[f"{package}_version"] = version
|
|
||||||
except importlib.metadata.PackageNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return {
|
|
||||||
"os": platform.system(),
|
|
||||||
"python_version": platform.python_version(),
|
|
||||||
"cpu_count": psutil.cpu_count(),
|
|
||||||
"memory_total": psutil.virtual_memory().total,
|
|
||||||
"accelerator_type": accelerator_type,
|
|
||||||
"accelerator_count": len(gpu_info),
|
|
||||||
"accelerator_info": gpu_info,
|
|
||||||
**installed_packages,
|
|
||||||
}
|
|
||||||
|
|
||||||
def send_event(self, event_type: str, properties: dict[str, Any] | None = None):
|
|
||||||
"""Send a telemetry event"""
|
|
||||||
if not self.enabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
if properties is None:
|
|
||||||
properties = {}
|
|
||||||
|
|
||||||
# Sanitize properties to remove PII
|
|
||||||
properties = self._redact_paths(properties)
|
|
||||||
|
|
||||||
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
|
|
||||||
try:
|
|
||||||
# Send event via PostHog
|
|
||||||
posthog.capture(
|
|
||||||
distinct_id=self.run_id,
|
|
||||||
event=event_type,
|
|
||||||
properties=properties,
|
|
||||||
disable_geoip=True,
|
|
||||||
)
|
|
||||||
except Exception as e: # pylint: disable=broad-exception-caught
|
|
||||||
LOG.warning(f"Failed to send telemetry event: {e}")
|
|
||||||
|
|
||||||
# Additionally, send system info telemetry when loading config.
|
|
||||||
# NOTE: Is this the best place for this?
|
|
||||||
if event_type == "config-loaded":
|
|
||||||
self.send_system_info()
|
|
||||||
|
|
||||||
def send_system_info(self):
|
|
||||||
"""Helper method for sending system info"""
|
|
||||||
self.send_event(event_type="system-info", properties=self.system_info)
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
"""Ensure all queued events are processed before shutdown"""
|
|
||||||
if self.enabled:
|
|
||||||
posthog.flush()
|
|
||||||
@@ -1,209 +0,0 @@
|
|||||||
"""Telemetry utilities for runtime and memory metrics."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import psutil
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.telemetry.manager import TelemetryManager
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RuntimeMetrics:
|
|
||||||
"""Container for runtime metrics to be tracked throughout training."""
|
|
||||||
|
|
||||||
# Timing metrics
|
|
||||||
start_time: float
|
|
||||||
epoch_start_times: dict[int, float] = field(init=False)
|
|
||||||
epoch_end_times: dict[int, float] = field(init=False)
|
|
||||||
|
|
||||||
# Memory metrics
|
|
||||||
peak_cpu_memory: int = 0
|
|
||||||
peak_gpu_memory: dict[int, int] = field(init=False)
|
|
||||||
|
|
||||||
# Progress metrics
|
|
||||||
total_steps: int = 0
|
|
||||||
current_epoch: int = 0
|
|
||||||
current_step: int = 0
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
"""Initialize empty metric mappings."""
|
|
||||||
self.epoch_start_times = {}
|
|
||||||
self.epoch_end_times = {}
|
|
||||||
self.peak_gpu_memory = {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def elapsed_time(self) -> float:
|
|
||||||
"""Calculate total elapsed time in seconds."""
|
|
||||||
return time.time() - self.start_time
|
|
||||||
|
|
||||||
def epoch_time(self, epoch: int) -> float | None:
|
|
||||||
"""Calculate time taken for a specific epoch in seconds."""
|
|
||||||
if epoch in self.epoch_start_times and epoch in self.epoch_end_times:
|
|
||||||
return self.epoch_end_times[epoch] - self.epoch_start_times[epoch]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def average_epoch_time(self) -> float | None:
|
|
||||||
"""Calculate average time per epoch in seconds."""
|
|
||||||
completed_epochs = [
|
|
||||||
epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times
|
|
||||||
]
|
|
||||||
if not completed_epochs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
total_time = 0.0
|
|
||||||
for epoch in completed_epochs:
|
|
||||||
epoch_time = self.epoch_time(epoch)
|
|
||||||
if epoch_time is not None: # Check to avoid mypy warning
|
|
||||||
total_time += epoch_time
|
|
||||||
|
|
||||||
return total_time / len(completed_epochs)
|
|
||||||
|
|
||||||
def steps_per_second(self) -> float | None:
|
|
||||||
"""Calculate average steps per second across all training."""
|
|
||||||
if self.total_steps == 0 or self.elapsed_time == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return self.total_steps / self.elapsed_time
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
|
||||||
"""Convert metrics to a dictionary for telemetry reporting."""
|
|
||||||
metrics = {
|
|
||||||
"total_time_seconds": self.elapsed_time,
|
|
||||||
"total_steps": self.total_steps,
|
|
||||||
"steps_per_second": self.steps_per_second(),
|
|
||||||
"epochs_completed": len(
|
|
||||||
[
|
|
||||||
epoch
|
|
||||||
for epoch in self.epoch_start_times
|
|
||||||
if epoch in self.epoch_end_times
|
|
||||||
]
|
|
||||||
),
|
|
||||||
"peak_cpu_memory_bytes": self.peak_cpu_memory,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add per-epoch timing if available
|
|
||||||
epoch_times: dict[str, float] = {}
|
|
||||||
for epoch in sorted(self.epoch_end_times.keys()):
|
|
||||||
time_taken = self.epoch_time(epoch)
|
|
||||||
if time_taken is not None:
|
|
||||||
epoch_times[f"epoch_{epoch}_seconds"] = time_taken
|
|
||||||
|
|
||||||
if epoch_times:
|
|
||||||
metrics["epoch_times"] = epoch_times # type: ignore
|
|
||||||
metrics["average_epoch_time_seconds"] = self.average_epoch_time()
|
|
||||||
|
|
||||||
# Add GPU memory metrics if available
|
|
||||||
if self.peak_gpu_memory:
|
|
||||||
gpu_metrics: dict[str, int] = {}
|
|
||||||
for gpu_id, memory in self.peak_gpu_memory.items():
|
|
||||||
gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory
|
|
||||||
metrics["gpu_memory"] = gpu_metrics # type: ignore
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
|
|
||||||
class RuntimeMetricsTracker:
|
|
||||||
"""Tracker for runtime metrics during training."""
|
|
||||||
|
|
||||||
update_interval = 100
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize the runtime metrics tracker."""
|
|
||||||
self.metrics = RuntimeMetrics(start_time=time.time())
|
|
||||||
self.telemetry_manager = TelemetryManager.get_instance()
|
|
||||||
|
|
||||||
def start_epoch(self, epoch: int):
|
|
||||||
"""Record the start of a new epoch."""
|
|
||||||
self.metrics.current_epoch = epoch
|
|
||||||
self.metrics.epoch_start_times[epoch] = time.time()
|
|
||||||
self.update_memory_metrics()
|
|
||||||
|
|
||||||
def end_epoch(self, epoch: int):
|
|
||||||
"""Record the end of an epoch."""
|
|
||||||
self.metrics.epoch_end_times[epoch] = time.time()
|
|
||||||
|
|
||||||
def update_step(self, step: int):
|
|
||||||
"""Update the current step count."""
|
|
||||||
self.metrics.current_step = step
|
|
||||||
self.metrics.total_steps += 1
|
|
||||||
|
|
||||||
# Periodically update memory metrics
|
|
||||||
if step % self.update_interval == 0:
|
|
||||||
self.update_memory_metrics()
|
|
||||||
|
|
||||||
def _get_allocated_memory(self) -> dict[int, int]:
|
|
||||||
"""
|
|
||||||
Helper function for getting accelerator-agnostic allocated memory.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary mapping device IDs to allocated memory in bytes
|
|
||||||
"""
|
|
||||||
memory_used: dict[int, int] = {}
|
|
||||||
|
|
||||||
# NVIDIA GPUs
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
memory_used[i] = torch.cuda.memory_allocated(i)
|
|
||||||
|
|
||||||
# AMD GPUs
|
|
||||||
elif hasattr(torch, "hip") and torch.hip.is_available():
|
|
||||||
for i in range(torch.hip.device_count()):
|
|
||||||
if hasattr(torch.hip, "memory_allocated"):
|
|
||||||
memory_used[i] = torch.hip.memory_allocated(i)
|
|
||||||
|
|
||||||
# Apple Silicon
|
|
||||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
||||||
# MPS doesn't have per-device memory stats since there's only one device
|
|
||||||
if hasattr(torch.mps, "current_allocated_memory"):
|
|
||||||
memory_used[0] = torch.mps.current_allocated_memory()
|
|
||||||
|
|
||||||
# Intel GPUs
|
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
||||||
for i in range(torch.xpu.device_count()):
|
|
||||||
if hasattr(torch.xpu, "memory_allocated"):
|
|
||||||
memory_used[i] = torch.xpu.memory_allocated(i)
|
|
||||||
|
|
||||||
# NPUs
|
|
||||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
|
||||||
for i in range(torch.npu.device_count()):
|
|
||||||
if hasattr(torch.npu, "memory_allocated"):
|
|
||||||
memory_used[i] = torch.npu.memory_allocated(i)
|
|
||||||
|
|
||||||
return memory_used
|
|
||||||
|
|
||||||
def update_memory_metrics(self):
|
|
||||||
"""Update peak memory usage metrics."""
|
|
||||||
# CPU memory
|
|
||||||
cpu_memory = psutil.Process().memory_info().rss
|
|
||||||
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
|
|
||||||
|
|
||||||
# GPU memory (if available)
|
|
||||||
memory_used = self._get_allocated_memory()
|
|
||||||
for i, memory in memory_used.items():
|
|
||||||
self.metrics.peak_gpu_memory[i] = max(
|
|
||||||
self.metrics.peak_gpu_memory.get(i, 0), memory
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_memory_metrics(self) -> dict[str, Any]:
|
|
||||||
"""Get the current memory metrics as a dictionary."""
|
|
||||||
memory_metrics = {
|
|
||||||
"cpu_memory_bytes": psutil.Process().memory_info().rss,
|
|
||||||
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
|
|
||||||
}
|
|
||||||
|
|
||||||
# GPU memory (if available)
|
|
||||||
memory_used = self._get_allocated_memory()
|
|
||||||
for i, memory in memory_used.items():
|
|
||||||
memory_metrics[f"gpu_{i}_memory_bytes"] = memory
|
|
||||||
memory_metrics[
|
|
||||||
f"gpu_{i}_peak_memory_bytes"
|
|
||||||
] = self.metrics.peak_gpu_memory.get(i, 0)
|
|
||||||
|
|
||||||
return memory_metrics
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
organizations:
|
|
||||||
- "axolotl-ai-co"
|
|
||||||
- "meta-llama"
|
|
||||||
- "huggingface"
|
|
||||||
- "nvidia"
|
|
||||||
- "facebook"
|
|
||||||
- "google"
|
|
||||||
- "microsoft"
|
|
||||||
- "deepseek-ai"
|
|
||||||
- "HuggingFaceTB"
|
|
||||||
- "mistralai"
|
|
||||||
- "Qwen"
|
|
||||||
- "briaai"
|
|
||||||
- "unsloth"
|
|
||||||
- "NousResearch"
|
|
||||||
- "allenai"
|
|
||||||
- "amd"
|
|
||||||
- "tiiuae"
|
|
||||||
@@ -7,23 +7,24 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import weakref
|
import weakref
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, Union
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import save_fsdp_model
|
from accelerate.utils import save_fsdp_model
|
||||||
from peft import PeftModel
|
from datasets import Dataset
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
from peft import PeftConfig, PeftModel
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.telemetry.errors import send_errors
|
|
||||||
from axolotl.telemetry.manager import TelemetryManager
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
@@ -34,20 +35,25 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
BetterTransformer = None
|
BetterTransformer = None
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
||||||
src_dir = os.path.join(project_root, "src")
|
|
||||||
sys.path.insert(0, src_dir)
|
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
|
||||||
|
|
||||||
|
def setup_model_and_tokenizer(
|
||||||
|
cfg: DictDefault,
|
||||||
|
) -> tuple[
|
||||||
|
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Load the tokenizer, processor (for multimodal models), and model based on configuration.
|
||||||
|
|
||||||
@send_errors
|
Args:
|
||||||
def train(
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
|
||||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
Returns:
|
||||||
|
Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else
|
||||||
|
`None`), and processor (if multimodal, else `None`).
|
||||||
|
"""
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||||
@@ -60,11 +66,58 @@ def train(
|
|||||||
if cfg.is_multimodal:
|
if cfg.is_multimodal:
|
||||||
processor = load_processor(cfg, tokenizer)
|
processor = load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
# Get datasets
|
# Load the model and peft_config
|
||||||
train_dataset = dataset_meta.train_dataset
|
msg = "loading model"
|
||||||
eval_dataset = dataset_meta.eval_dataset
|
if cfg.adapter:
|
||||||
total_num_steps = dataset_meta.total_num_steps
|
msg += " and peft_config..."
|
||||||
|
LOG.debug(msg)
|
||||||
|
|
||||||
|
model, peft_config = load_model(cfg, tokenizer, processor=processor)
|
||||||
|
if model.generation_config is not None:
|
||||||
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
|
# Apply freezing if specified
|
||||||
|
if cfg.unfrozen_parameters:
|
||||||
|
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||||
|
|
||||||
|
return model, tokenizer, peft_config, processor
|
||||||
|
|
||||||
|
|
||||||
|
def setup_reference_model(
|
||||||
|
cfg: DictDefault, tokenizer: PreTrainedTokenizer
|
||||||
|
) -> PreTrainedModel | None:
|
||||||
|
"""
|
||||||
|
Set up the reference model for RL training if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
tokenizer: The tokenizer to use for the reference model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reference model if needed for RL training, `None` otherwise.
|
||||||
|
"""
|
||||||
|
model_ref = None
|
||||||
|
if cfg.rl and cfg.rl != "orpo":
|
||||||
|
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
|
# use built-in trl autounwrap
|
||||||
|
LOG.debug("Passing model_ref: None to RL trainer")
|
||||||
|
model_ref = None # explicit setting to None
|
||||||
|
else:
|
||||||
|
# load the model again for model_ref/baseline
|
||||||
|
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
||||||
|
return model_ref
|
||||||
|
|
||||||
|
|
||||||
|
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
||||||
|
"""
|
||||||
|
Determine the checkpoint to resume from based on configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the checkpoint to resume from, or `None` if not resuming.
|
||||||
|
"""
|
||||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||||
possible_checkpoints = [
|
possible_checkpoints = [
|
||||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||||
@@ -78,85 +131,22 @@ def train(
|
|||||||
LOG.info(
|
LOG.info(
|
||||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||||
)
|
)
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
return cfg.resume_from_checkpoint
|
||||||
|
|
||||||
# Load model
|
|
||||||
msg = "loading model"
|
|
||||||
if cfg.adapter:
|
|
||||||
msg += " and peft_config..."
|
|
||||||
LOG.debug(msg)
|
|
||||||
model, peft_config = load_model(cfg, tokenizer, processor=processor)
|
|
||||||
if model.generation_config is not None:
|
|
||||||
model.generation_config.do_sample = True
|
|
||||||
|
|
||||||
TELEMETRY_MANAGER.send_event(
|
def setup_signal_handler(
|
||||||
event_type="model-loaded", properties=model.config.to_dict()
|
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||||
)
|
):
|
||||||
if peft_config:
|
"""
|
||||||
TELEMETRY_MANAGER.send_event(
|
Set up signal handler for graceful termination.
|
||||||
event_type="peft-config-loaded", properties=peft_config.to_dict()
|
|
||||||
)
|
|
||||||
|
|
||||||
model_ref = None
|
Args:
|
||||||
if cfg.rl and cfg.rl != "orpo":
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
model: The model to save on termination
|
||||||
# use built-in trl autounwrap
|
safe_serialization: Whether to use safe serialization when saving
|
||||||
LOG.debug("Passing model_ref: None to RL trainer")
|
"""
|
||||||
model_ref = None # explicit setting to None
|
# ray workers don't have access to this signal
|
||||||
else:
|
if cfg.local_rank == 0 and not cfg.use_ray:
|
||||||
# load the model again for model_ref / baseline
|
|
||||||
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
|
||||||
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
|
||||||
|
|
||||||
if cfg.unfrozen_parameters:
|
|
||||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
|
||||||
|
|
||||||
trainer = setup_trainer(
|
|
||||||
cfg,
|
|
||||||
train_dataset,
|
|
||||||
eval_dataset,
|
|
||||||
(model, model_ref, peft_config),
|
|
||||||
tokenizer,
|
|
||||||
processor,
|
|
||||||
total_num_steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.fix_untrained_tokens:
|
|
||||||
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
|
||||||
sig = inspect.signature(fix_untrained_tokens)
|
|
||||||
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
|
||||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
|
||||||
cfg.fix_untrained_tokens, list
|
|
||||||
):
|
|
||||||
fix_untrained_tokens(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
train_dataset,
|
|
||||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
|
||||||
if cfg.local_rank == 0:
|
|
||||||
model.save_pretrained(
|
|
||||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
|
||||||
)
|
|
||||||
|
|
||||||
# go ahead and presave, so we have the adapter config available to inspect
|
|
||||||
if peft_config:
|
|
||||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
|
||||||
peft_config.save_pretrained(cfg.output_dir)
|
|
||||||
# additionally presave the tokenizer and model configs
|
|
||||||
if not Path(cfg.output_dir).is_dir():
|
|
||||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
|
||||||
if hasattr(model, "config"):
|
|
||||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
|
||||||
|
|
||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
|
||||||
if (
|
|
||||||
cfg.local_rank == 0 and not cfg.use_ray
|
|
||||||
): # ray workers don't have access to this signal
|
|
||||||
|
|
||||||
def terminate_handler(_, __, model_weakref):
|
def terminate_handler(_, __, model_weakref):
|
||||||
if model_weakref() is not None:
|
if model_weakref() is not None:
|
||||||
@@ -174,15 +164,18 @@ def train(
|
|||||||
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
||||||
)
|
)
|
||||||
|
|
||||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
|
||||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
|
||||||
|
|
||||||
if getattr(cfg, "axolotl_config_path"):
|
def execute_training(
|
||||||
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||||
version = importlib.metadata.version("axolotl")
|
):
|
||||||
if raw_axolotl_cfg.is_file():
|
"""
|
||||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
Execute the training process with appropriate backend configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
trainer: The configured trainer object.
|
||||||
|
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||||
|
"""
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
@@ -197,13 +190,31 @@ def train(
|
|||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
else:
|
else:
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
|
||||||
|
|
||||||
# post training
|
|
||||||
|
def save_trained_model(
|
||||||
|
cfg: DictDefault,
|
||||||
|
trainer: Any,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
safe_serialization: bool,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Save the trained model according to configuration and training setup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
trainer: The trainer object.
|
||||||
|
model: The trained model to save.
|
||||||
|
safe_serialization: Whether to use safe serialization.
|
||||||
|
"""
|
||||||
|
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
|
||||||
|
|
||||||
|
# Post training module hooks
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if hasattr(module, "_post_training"):
|
if hasattr(module, "_post_training"):
|
||||||
module._post_training(model, name) # pylint: disable=protected-access
|
module._post_training(model, name) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
# Handle FSDP state dict type
|
||||||
state_dict_type = "FULL_STATE_DICT"
|
state_dict_type = "FULL_STATE_DICT"
|
||||||
if trainer.is_fsdp_enabled:
|
if trainer.is_fsdp_enabled:
|
||||||
if cfg.fsdp_final_state_dict_type:
|
if cfg.fsdp_final_state_dict_type:
|
||||||
@@ -211,16 +222,18 @@ def train(
|
|||||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
||||||
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
|
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
|
||||||
|
|
||||||
|
# Handle ReLoRA early return case
|
||||||
if cfg.relora_steps:
|
if cfg.relora_steps:
|
||||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
else:
|
else:
|
||||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||||
return model, tokenizer
|
return
|
||||||
|
|
||||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
|
||||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
|
# TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||||
|
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple
|
||||||
|
# processes attempt to write the same file
|
||||||
if (
|
if (
|
||||||
state_dict_type == "SHARDED_STATE_DICT"
|
state_dict_type == "SHARDED_STATE_DICT"
|
||||||
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
|
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
|
||||||
@@ -252,7 +265,6 @@ def train(
|
|||||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum and BetterTransformer:
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
@@ -263,40 +275,239 @@ def train(
|
|||||||
)
|
)
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_card(cfg: DictDefault, trainer: Trainer):
|
||||||
|
"""
|
||||||
|
Create a model card for the trained model if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
trainer: The trainer object with model card creation capabilities.
|
||||||
|
"""
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
|
# Guard since create_model_card may fail if dataset_tags is empty list
|
||||||
try:
|
try:
|
||||||
model_card_kwarg = {
|
model_card_kwarg = {
|
||||||
"model_name": cfg.output_dir.lstrip("./")
|
"model_name": cfg.output_dir.lstrip("./")
|
||||||
.encode("utf-8")
|
.encode("utf-8")
|
||||||
.decode("utf-8")
|
.decode("utf-8")
|
||||||
}
|
}
|
||||||
if cfg.datasets is not None:
|
|
||||||
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model:
|
# We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.
|
||||||
dataset_tags = [
|
rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model
|
||||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
if cfg.datasets is not None and not rl:
|
||||||
]
|
dataset_tags = [
|
||||||
dataset_tags = [
|
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
d for d in dataset_tags if not d.startswith("https://")
|
]
|
||||||
]
|
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
|
||||||
if dataset_tags:
|
|
||||||
# guard as create_model_card may fail if dataset_tags is empty list
|
if dataset_tags:
|
||||||
model_card_kwarg["dataset_name"] = dataset_tags
|
model_card_kwarg["dataset_tags"] = dataset_tags
|
||||||
else:
|
|
||||||
dataset_tags = [
|
|
||||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
|
||||||
]
|
|
||||||
dataset_tags = [
|
|
||||||
d for d in dataset_tags if not d.startswith("https://")
|
|
||||||
]
|
|
||||||
if dataset_tags:
|
|
||||||
# guard as create_model_card may fail if dataset_tags is empty list
|
|
||||||
model_card_kwarg["dataset_tags"] = dataset_tags
|
|
||||||
|
|
||||||
trainer.create_model_card(**model_card_kwarg)
|
trainer.create_model_card(**model_card_kwarg)
|
||||||
except (AttributeError, UnicodeDecodeError):
|
except (AttributeError, UnicodeDecodeError):
|
||||||
pass
|
pass
|
||||||
elif cfg.hub_model_id:
|
elif cfg.hub_model_id:
|
||||||
# defensively push to the hub to ensure the model card is updated
|
# Defensively push to the hub to ensure the model card is updated
|
||||||
trainer.push_to_hub()
|
trainer.push_to_hub()
|
||||||
|
|
||||||
return model, tokenizer
|
|
||||||
|
def save_initial_configs(
|
||||||
|
cfg: DictDefault,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
peft_config: PeftConfig | None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Save initial configurations before training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
tokenizer: The tokenizer to save.
|
||||||
|
model: The model to save configuration for.
|
||||||
|
peft_config: The PEFT configuration to save if applicable.
|
||||||
|
"""
|
||||||
|
# Create output_dir if it doesn't already exist
|
||||||
|
output_dir = Path(cfg.output_dir)
|
||||||
|
if not output_dir.is_dir():
|
||||||
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Pre-save adapter config so it's available to inspect
|
||||||
|
if peft_config:
|
||||||
|
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...")
|
||||||
|
peft_config.save_pretrained(cfg.output_dir)
|
||||||
|
|
||||||
|
# Pre-save the tokenizer and model configs
|
||||||
|
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
|
||||||
|
tokenizer.save_pretrained(str(output_dir))
|
||||||
|
if hasattr(model, "config"):
|
||||||
|
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||||
|
model.config.save_pretrained(str(output_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model_card(cfg: DictDefault):
|
||||||
|
"""
|
||||||
|
Set up the Axolotl badge and add the Axolotl config to the model card if available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
"""
|
||||||
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
||||||
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||||
|
|
||||||
|
if getattr(cfg, "axolotl_config_path"):
|
||||||
|
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
||||||
|
version = importlib.metadata.version("axolotl")
|
||||||
|
if raw_axolotl_cfg.is_file():
|
||||||
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
||||||
|
|
||||||
|
|
||||||
|
def handle_untrained_tokens_fix(
|
||||||
|
cfg: DictDefault,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
train_dataset: Dataset,
|
||||||
|
safe_serialization: bool,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Apply fixes for untrained tokens if configured.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
model: The model to apply fixes to.
|
||||||
|
tokenizer: The tokenizer for token identification.
|
||||||
|
train_dataset: The training dataset to use.
|
||||||
|
safe_serialization: Whether to use safe serialization when saving.
|
||||||
|
"""
|
||||||
|
if not cfg.fix_untrained_tokens:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||||
|
sig = inspect.signature(fix_untrained_tokens)
|
||||||
|
|
||||||
|
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||||
|
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||||
|
cfg.fix_untrained_tokens, list
|
||||||
|
):
|
||||||
|
fix_untrained_tokens(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
train_dataset,
|
||||||
|
token_ids_to_fix=cfg.fix_untrained_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||||
|
|
||||||
|
if cfg.local_rank == 0:
|
||||||
|
model.save_pretrained(
|
||||||
|
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model_and_trainer(
|
||||||
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
|
) -> tuple[
|
||||||
|
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
||||||
|
PeftModel | PreTrainedModel,
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
PeftConfig | None,
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
|
||||||
|
trainer setup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: The configuration dictionary with training parameters.
|
||||||
|
dataset_meta: Object with training, validation datasets and metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of:
|
||||||
|
- Trainer (Causal or RLHF)
|
||||||
|
- Model
|
||||||
|
- Tokenizer
|
||||||
|
- PEFT config
|
||||||
|
"""
|
||||||
|
# Load tokenizer, processor and model
|
||||||
|
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
|
||||||
|
|
||||||
|
# Set up reference model for RL if needed
|
||||||
|
model_ref = setup_reference_model(cfg, tokenizer)
|
||||||
|
|
||||||
|
# Get datasets from metadata
|
||||||
|
train_dataset = dataset_meta.train_dataset
|
||||||
|
eval_dataset = dataset_meta.eval_dataset
|
||||||
|
total_num_steps = dataset_meta.total_num_steps
|
||||||
|
|
||||||
|
# Set up trainer
|
||||||
|
trainer = setup_trainer(
|
||||||
|
cfg=cfg,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
processor=processor,
|
||||||
|
total_num_steps=total_num_steps,
|
||||||
|
model_ref=model_ref,
|
||||||
|
peft_config=peft_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
trainer,
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
peft_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
|
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
|
||||||
|
"""
|
||||||
|
Train a model on the given dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: The configuration dictionary with training parameters
|
||||||
|
dataset_meta: Object with training, validation datasets and metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (model, tokenizer) after training
|
||||||
|
"""
|
||||||
|
# Setup model, tokenizer, (causal or RLHF) trainer etc.
|
||||||
|
(
|
||||||
|
trainer,
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
peft_config,
|
||||||
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||||
|
|
||||||
|
# Determine if we need to resume from a checkpoint
|
||||||
|
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||||
|
|
||||||
|
# Configuration for saving
|
||||||
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
|
# Handle untrained tokens if configured
|
||||||
|
train_dataset = dataset_meta.train_dataset
|
||||||
|
handle_untrained_tokens_fix(
|
||||||
|
cfg, model, tokenizer, train_dataset, safe_serialization
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save initial configs
|
||||||
|
save_initial_configs(cfg, tokenizer, model, peft_config)
|
||||||
|
|
||||||
|
# Set up signal handler for graceful termination
|
||||||
|
setup_signal_handler(cfg, model, safe_serialization)
|
||||||
|
|
||||||
|
# Set up badges and config info for model card
|
||||||
|
setup_model_card(cfg)
|
||||||
|
|
||||||
|
# Execute the training
|
||||||
|
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||||
|
|
||||||
|
# Save the trained model
|
||||||
|
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||||
|
|
||||||
|
# Create model card
|
||||||
|
create_model_card(cfg, trainer)
|
||||||
|
|
||||||
|
return model, tokenizer, trainer
|
||||||
|
|||||||
@@ -64,6 +64,18 @@ class ChatTemplate(str, Enum):
|
|||||||
metharme = "metharme" # pylint: disable=invalid-name
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class CustomSupportedOptimizers(str, Enum):
|
||||||
|
"""Custom supported optimizers"""
|
||||||
|
|
||||||
|
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||||
|
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||||
|
lion_pytorch = "lion_pytorch" # pylint: disable=invalid-name
|
||||||
|
muon = "muon" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedParameters(BaseModel):
|
class DeprecatedParameters(BaseModel):
|
||||||
"""configurations that are deprecated"""
|
"""configurations that are deprecated"""
|
||||||
|
|
||||||
@@ -494,17 +506,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
embedding_lr_scale: Optional[float] = None
|
embedding_lr_scale: Optional[float] = None
|
||||||
weight_decay: Optional[float] = 0.0
|
weight_decay: Optional[float] = 0.0
|
||||||
optimizer: Optional[
|
optimizer: Optional[
|
||||||
Union[
|
Union[OptimizerNames, CustomSupportedOptimizers]
|
||||||
OptimizerNames,
|
|
||||||
Literal[
|
|
||||||
"lion_pytorch",
|
|
||||||
"optimi_adamw",
|
|
||||||
"ao_adamw_4bit",
|
|
||||||
"ao_adamw_8bit",
|
|
||||||
"ao_adamw_fp8",
|
|
||||||
"adopt_adamw",
|
|
||||||
],
|
|
||||||
]
|
|
||||||
] = OptimizerNames.ADAMW_HF
|
] = OptimizerNames.ADAMW_HF
|
||||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -518,7 +520,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
torchdistx_path: Optional[str] = None
|
torchdistx_path: Optional[str] = None
|
||||||
lr_scheduler: Optional[
|
lr_scheduler: Optional[
|
||||||
Union[SchedulerType, Literal["one_cycle"]]
|
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
|
||||||
] = SchedulerType.COSINE
|
] = SchedulerType.COSINE
|
||||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||||
lr_quadratic_warmup: Optional[bool] = None
|
lr_quadratic_warmup: Optional[bool] = None
|
||||||
@@ -1177,6 +1179,13 @@ class AxolotlInputConfig(
|
|||||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_lr_groups(cls, data):
|
||||||
|
if data.get("lr_groups") and data.get("loraplus_lr_ratio"):
|
||||||
|
raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.")
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_saves(cls, data):
|
def check_saves(cls, data):
|
||||||
@@ -1683,7 +1692,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""Wrapper to validate GPU capabilities with the config options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
|
|
||||||
capabilities: GPUCapabilities
|
capabilities: GPUCapabilities
|
||||||
env_capabilities: EnvCapabilities
|
env_capabilities: EnvCapabilities
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class TRLConfig(BaseModel):
|
|||||||
vllm_dtype: Optional[str] = "auto"
|
vllm_dtype: Optional[str] = "auto"
|
||||||
|
|
||||||
reward_funcs: Optional[List[str]] = None
|
reward_funcs: Optional[List[str]] = None
|
||||||
|
reward_weights: Optional[List[float]] = None
|
||||||
num_generations: Optional[int] = None
|
num_generations: Optional[int] = None
|
||||||
log_completions: Optional[bool] = False
|
log_completions: Optional[bool] = False
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,6 @@ from axolotl.monkeypatch.multipack import (
|
|||||||
patch_for_multipack,
|
patch_for_multipack,
|
||||||
)
|
)
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.telemetry.errors import send_errors
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -166,7 +165,6 @@ def load_model_config(cfg):
|
|||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(cfg):
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
@@ -320,7 +318,6 @@ def load_tokenizer(cfg):
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||||
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
|
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
|
||||||
|
|
||||||
@@ -1195,17 +1192,18 @@ class ModelLoader:
|
|||||||
return self.model, lora_config
|
return self.model, lora_config
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def load_model(
|
def load_model(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
*,
|
*,
|
||||||
processor: ProcessorMixin = None,
|
processor: ProcessorMixin = None, # pylint: disable=unused-argument
|
||||||
inference: bool = False,
|
inference: bool = False,
|
||||||
reference_model: bool = False,
|
reference_model: bool = False,
|
||||||
**kwargs,
|
**kwargs, # pylint: disable=unused-argument
|
||||||
) -> Tuple[PreTrainedModel, PeftConfig | None]:
|
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||||
"""Load a model for a given configuration and tokenizer"""
|
"""
|
||||||
|
Load a model for a given configuration and tokenizer.
|
||||||
|
"""
|
||||||
loader = ModelLoader(
|
loader = ModelLoader(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -1217,7 +1215,6 @@ def load_model(
|
|||||||
return loader.load_model()
|
return loader.load_model()
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(model, cfg, adapter, inference=False):
|
||||||
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,80 @@ from torch.optim import Optimizer
|
|||||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
|
|
||||||
|
|
||||||
|
class RexLR(LRScheduler):
|
||||||
|
"""
|
||||||
|
Reflected Exponential (REX) learning rate scheduler.
|
||||||
|
|
||||||
|
- Original implementation: https://github.com/IvanVassi/REX_LR
|
||||||
|
- Original license: Apache 2.0
|
||||||
|
- Based on: https://arxiv.org/abs/2107.04197
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (torch.optim.Optimizer): The optimizer to schedule the learning rate for.
|
||||||
|
max_lr (float): The maximum learning rate.
|
||||||
|
min_lr (float): The minimum learning rate.
|
||||||
|
total_steps (int): The total number of training steps.
|
||||||
|
num_warmup_steps (int): The number of warmup steps.
|
||||||
|
last_step (int): The index of last step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, optimizer, max_lr, min_lr, total_steps=0, num_warmup_steps=0, last_step=0
|
||||||
|
):
|
||||||
|
if min_lr > max_lr:
|
||||||
|
raise ValueError(
|
||||||
|
f'Value of "min_lr" should be less than value of "max_lr". Got min_lr={min_lr} and max_lr={max_lr}'
|
||||||
|
)
|
||||||
|
if num_warmup_steps > total_steps:
|
||||||
|
raise ValueError(
|
||||||
|
f"num_warmup_steps ({num_warmup_steps}) must be less than or equal to total_steps ({total_steps})."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.min_lr = min_lr
|
||||||
|
self.max_lr = max_lr
|
||||||
|
self.total_steps = total_steps
|
||||||
|
self.num_warmup_steps = num_warmup_steps
|
||||||
|
self.last_step = last_step - 1
|
||||||
|
|
||||||
|
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
group.setdefault("initial_lr", group["lr"])
|
||||||
|
|
||||||
|
# Pass self.last_step as last_epoch to the parent.
|
||||||
|
super().__init__(optimizer, last_epoch=self.last_step)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_step(self):
|
||||||
|
return self.last_epoch
|
||||||
|
|
||||||
|
@last_step.setter
|
||||||
|
def last_step(self, value):
|
||||||
|
self.last_epoch = value
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
# Warmup phase: if defined, increase lr linearly from 0 to max_lr.
|
||||||
|
if 1 <= self.last_step <= self.num_warmup_steps:
|
||||||
|
return [
|
||||||
|
base_lr * self.last_step / self.num_warmup_steps
|
||||||
|
for base_lr in self.base_lrs
|
||||||
|
]
|
||||||
|
|
||||||
|
# Post-warmup phase: adjust step relative to the end of warmup.
|
||||||
|
step_after = self.last_step - self.num_warmup_steps
|
||||||
|
remaining_steps = self.total_steps - self.num_warmup_steps
|
||||||
|
|
||||||
|
# Avoid LR spiking
|
||||||
|
if step_after >= remaining_steps or step_after == -1 or remaining_steps <= 0:
|
||||||
|
return [self.min_lr for _ in self.base_lrs]
|
||||||
|
|
||||||
|
mod_iter = step_after % remaining_steps
|
||||||
|
z = (remaining_steps - mod_iter) / remaining_steps
|
||||||
|
rex_factor = self.min_lr / self.max_lr + (1.0 - self.min_lr / self.max_lr) * (
|
||||||
|
z / (0.1 + 0.9 * z)
|
||||||
|
)
|
||||||
|
return [base_lr * rex_factor for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
class InterpolatingLogScheduler(LRScheduler):
|
class InterpolatingLogScheduler(LRScheduler):
|
||||||
"""
|
"""
|
||||||
A scheduler that interpolates learning rates in a logarithmic fashion
|
A scheduler that interpolates learning rates in a logarithmic fashion
|
||||||
|
|||||||
@@ -574,14 +574,40 @@ def prepare_opinionated_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
cfg,
|
||||||
|
train_dataset,
|
||||||
|
eval_dataset,
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
processor,
|
||||||
|
total_num_steps,
|
||||||
|
model_ref=None,
|
||||||
|
peft_config=None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Helper method for instantiating and building a (causal or RLHF) trainer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Axolotl config object containing training parameters.
|
||||||
|
train_dataset: Dataset to use for training.
|
||||||
|
eval_dataset: Dataset to use for evaluation.
|
||||||
|
model: The model to train.
|
||||||
|
tokenizer: Tokenizer for processing text input.
|
||||||
|
processor: Processor for data preparation.
|
||||||
|
total_num_steps: The total number of training steps.
|
||||||
|
model_ref: Optional reference model for RLHF training. Default is None.
|
||||||
|
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
|
||||||
|
on the provided parameters.
|
||||||
|
"""
|
||||||
if cfg.rl:
|
if cfg.rl:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model_ref
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = peft_config
|
||||||
else:
|
else:
|
||||||
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
|
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)
|
||||||
|
|
||||||
trainer_builder.train_dataset = train_dataset
|
trainer_builder.train_dataset = train_dataset
|
||||||
trainer_builder.eval_dataset = eval_dataset
|
trainer_builder.eval_dataset = eval_dataset
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("axolotl.cli.train.train") as mock_train:
|
with patch("axolotl.cli.train.train") as mock_train:
|
||||||
mock_train.return_value = (MagicMock(), MagicMock())
|
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||||
|
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
@@ -48,7 +48,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
||||||
|
|
||||||
with patch("axolotl.cli.train.train") as mock_train:
|
with patch("axolotl.cli.train.train") as mock_train:
|
||||||
mock_train.return_value = (MagicMock(), MagicMock())
|
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||||
|
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Shared pytest fixtures"""
|
"""
|
||||||
|
shared pytest fixtures
|
||||||
|
"""
|
||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
import shutil
|
import shutil
|
||||||
@@ -170,9 +171,3 @@ def cleanup_monkeypatches():
|
|||||||
module_globals = module_name_tuple[1]
|
module_globals = module_name_tuple[1]
|
||||||
for module_global in module_globals:
|
for module_global in module_globals:
|
||||||
globals().pop(module_global, None)
|
globals().pop(module_global, None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def disable_telemetry(monkeypatch):
|
|
||||||
monkeypatch.setenv("AXOLOTL_DO_NOT_TRACK", "1")
|
|
||||||
yield
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -190,7 +190,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -249,7 +249,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
|
|||||||
@@ -65,8 +65,9 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
assert trainer.optimizer.optimizer.__class__.__name__ == "AdamW"
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@require_torch_2_5_1
|
@require_torch_2_5_1
|
||||||
@@ -111,8 +112,57 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
assert "ADOPT" in trainer.optimizer.optimizer.__class__.__name__
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
@require_torch_2_5_1
|
||||||
|
def test_muon(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "muon",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"weight_decay": 0.01,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||||
|
|||||||
71
tests/e2e/test_schedulers.py
Normal file
71
tests/e2e/test_schedulers.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for custom schedulers using Llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomSchedulers(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_rex_scheduler(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_hf",
|
||||||
|
"max_steps": 20,
|
||||||
|
"lr_scheduler": "rex",
|
||||||
|
"warmup_steps": 5,
|
||||||
|
"cosine_min_lr_ratio": 0.05,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
"""Shared pytest fixtures for telemetry tests."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def disable_telemetry(monkeypatch):
|
|
||||||
monkeypatch.delenv("AXOLOTL_DO_NOT_TRACK", raising=False)
|
|
||||||
yield
|
|
||||||
@@ -1,372 +0,0 @@
|
|||||||
"""Tests for telemetry callback module."""
|
|
||||||
# pylint: disable=redefined-outer-name
|
|
||||||
|
|
||||||
import time
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
|
||||||
|
|
||||||
from axolotl.telemetry.callbacks import TIME_SINCE_LAST, TelemetryCallback
|
|
||||||
|
|
||||||
|
|
||||||
def calc_expected_metrics(step, last_step, current_time, last_time, start_time=900.0):
|
|
||||||
"""Calculate expected metrics values for tests"""
|
|
||||||
time_diff = current_time - last_time
|
|
||||||
step_diff = step - last_step
|
|
||||||
return {
|
|
||||||
"steps_per_second": step_diff / time_diff
|
|
||||||
if time_diff > 0 and step_diff > 0
|
|
||||||
else 0,
|
|
||||||
"time_since_last_report": time_diff,
|
|
||||||
"elapsed_time": current_time - start_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_time():
|
|
||||||
"""Mock time.time() to have predictable values in tests"""
|
|
||||||
with patch("axolotl.telemetry.callbacks.time") as mock_time:
|
|
||||||
mock_time.time.return_value = 1000.0
|
|
||||||
yield mock_time
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_telemetry_manager():
|
|
||||||
"""Create a mock TelemetryManager"""
|
|
||||||
with patch("axolotl.telemetry.callbacks.TelemetryManager") as mock_manager_class:
|
|
||||||
mock_manager = MagicMock()
|
|
||||||
mock_manager_class.get_instance.return_value = mock_manager
|
|
||||||
yield mock_manager
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_runtime_metrics_tracker():
|
|
||||||
"""Create a mock RuntimeMetricsTracker"""
|
|
||||||
with patch(
|
|
||||||
"axolotl.telemetry.callbacks.RuntimeMetricsTracker"
|
|
||||||
) as mock_tracker_class:
|
|
||||||
mock_tracker = MagicMock()
|
|
||||||
# Set up metrics property on the tracker
|
|
||||||
mock_metrics = MagicMock()
|
|
||||||
mock_metrics.to_dict.return_value = {
|
|
||||||
"total_steps": 100,
|
|
||||||
"peak_cpu_memory_bytes": 1024,
|
|
||||||
}
|
|
||||||
mock_tracker.metrics = mock_metrics
|
|
||||||
|
|
||||||
# Make the constructor return our mock
|
|
||||||
mock_tracker_class.return_value = mock_tracker
|
|
||||||
yield mock_tracker
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def training_args():
|
|
||||||
"""Create a minimal TrainingArguments instance"""
|
|
||||||
return TrainingArguments(output_dir="./output")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def trainer_state():
|
|
||||||
"""Create a mock TrainerState"""
|
|
||||||
state = MagicMock(spec=TrainerState)
|
|
||||||
state.global_step = 10
|
|
||||||
state.epoch = 0.5 # halfway through first epoch
|
|
||||||
state.log_history = [{"loss": 2.5, "learning_rate": 5e-5}]
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def trainer_control():
|
|
||||||
"""Create a mock TrainerControl"""
|
|
||||||
return MagicMock(spec=TrainerControl)
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
@pytest.fixture
|
|
||||||
def callback(mock_telemetry_manager, mock_runtime_metrics_tracker):
|
|
||||||
"""Create a TelemetryCallback instance with mocked dependencies"""
|
|
||||||
return TelemetryCallback()
|
|
||||||
|
|
||||||
|
|
||||||
class TestTelemetryCallback:
|
|
||||||
"""Tests for the TelemetryCallback class."""
|
|
||||||
|
|
||||||
def test_initialization(self, callback, mock_runtime_metrics_tracker):
|
|
||||||
"""Test callback initialization."""
|
|
||||||
assert callback.current_epoch == -1
|
|
||||||
assert callback.tracker == mock_runtime_metrics_tracker
|
|
||||||
assert callback.last_report_step == 0
|
|
||||||
assert hasattr(callback, "start_time")
|
|
||||||
assert hasattr(callback, "last_report_time")
|
|
||||||
assert callback.report_interval_steps == 100
|
|
||||||
|
|
||||||
def test_on_train_begin(
|
|
||||||
self,
|
|
||||||
callback,
|
|
||||||
mock_telemetry_manager,
|
|
||||||
training_args,
|
|
||||||
trainer_state,
|
|
||||||
trainer_control,
|
|
||||||
):
|
|
||||||
"""Test on_train_begin sends expected event."""
|
|
||||||
callback.on_train_begin(training_args, trainer_state, trainer_control)
|
|
||||||
|
|
||||||
mock_telemetry_manager.send_event.assert_called_once_with(
|
|
||||||
event_type="train-started"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_on_train_end(
|
|
||||||
self,
|
|
||||||
callback,
|
|
||||||
mock_telemetry_manager,
|
|
||||||
training_args,
|
|
||||||
trainer_state,
|
|
||||||
trainer_control,
|
|
||||||
):
|
|
||||||
"""Test on_train_end sends expected event with metrics."""
|
|
||||||
callback.on_train_end(training_args, trainer_state, trainer_control)
|
|
||||||
|
|
||||||
mock_telemetry_manager.send_event.assert_called_once()
|
|
||||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
|
||||||
|
|
||||||
assert call_args["event_type"] == "train-ended"
|
|
||||||
assert "loss" in call_args["properties"]
|
|
||||||
assert call_args["properties"]["loss"] == 2.5
|
|
||||||
assert "learning_rate" in call_args["properties"]
|
|
||||||
assert call_args["properties"]["learning_rate"] == 5e-5
|
|
||||||
|
|
||||||
# Check that metrics from RuntimeMetricsTracker are included
|
|
||||||
assert "total_steps" in call_args["properties"]
|
|
||||||
assert call_args["properties"]["total_steps"] == 100
|
|
||||||
assert "peak_cpu_memory_bytes" in call_args["properties"]
|
|
||||||
assert call_args["properties"]["peak_cpu_memory_bytes"] == 1024
|
|
||||||
|
|
||||||
def test_on_epoch_begin(
|
|
||||||
self,
|
|
||||||
callback,
|
|
||||||
mock_runtime_metrics_tracker,
|
|
||||||
training_args,
|
|
||||||
trainer_state,
|
|
||||||
trainer_control,
|
|
||||||
):
|
|
||||||
"""Test on_epoch_begin updates epoch counter and calls tracker."""
|
|
||||||
initial_epoch = callback.current_epoch
|
|
||||||
|
|
||||||
callback.on_epoch_begin(training_args, trainer_state, trainer_control)
|
|
||||||
|
|
||||||
assert callback.current_epoch == initial_epoch + 1
|
|
||||||
mock_runtime_metrics_tracker.start_epoch.assert_called_once_with(
|
|
||||||
initial_epoch + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_on_epoch_end(
|
|
||||||
self,
|
|
||||||
callback,
|
|
||||||
mock_runtime_metrics_tracker,
|
|
||||||
training_args,
|
|
||||||
trainer_state,
|
|
||||||
trainer_control,
|
|
||||||
):
|
|
||||||
"""Test on_epoch_end calls tracker."""
|
|
||||||
# Set current epoch
|
|
||||||
callback.current_epoch = 2
|
|
||||||
|
|
||||||
callback.on_epoch_end(training_args, trainer_state, trainer_control)
|
|
||||||
|
|
||||||
mock_runtime_metrics_tracker.end_epoch.assert_called_once_with(2)
|
|
||||||
|
|
||||||
def test_on_step_end_no_report(
|
|
||||||
self,
|
|
||||||
callback,
|
|
||||||
mock_telemetry_manager,
|
|
||||||
mock_runtime_metrics_tracker,
|
|
||||||
training_args,
|
|
||||||
trainer_state,
|
|
||||||
trainer_control,
|
|
||||||
):
|
|
||||||
"""Test on_step_end updates tracker but doesn't report if criteria not met."""
|
|
||||||
# Set up state to avoid reporting
|
|
||||||
trainer_state.global_step = 42 # Not divisible by report_interval_steps
|
|
||||||
callback.last_report_step = 41 # Just 1 step since last report
|
|
||||||
callback.last_report_time = time.time() # Just now
|
|
||||||
|
|
||||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
|
||||||
|
|
||||||
# Should update tracker
|
|
||||||
mock_runtime_metrics_tracker.update_step.assert_called_once_with(42)
|
|
||||||
|
|
||||||
# Should not send telemetry
|
|
||||||
mock_telemetry_manager.send_event.assert_not_called()
|
|
||||||
|
|
||||||
# Should not update last report time/step
|
|
||||||
assert callback.last_report_step == 41
|
|
||||||
|
|
||||||
def test_on_step_end_report_interval_steps(
|
|
||||||
self,
|
|
||||||
callback,
|
|
||||||
mock_telemetry_manager,
|
|
||||||
mock_runtime_metrics_tracker,
|
|
||||||
mock_time,
|
|
||||||
training_args,
|
|
||||||
trainer_state,
|
|
||||||
trainer_control,
|
|
||||||
):
|
|
||||||
"""Test on_step_end reports when step interval is reached."""
|
|
||||||
# Set up state with clear values
|
|
||||||
current_step = 100 # Exactly matches report_interval_steps
|
|
||||||
last_step = 0
|
|
||||||
start_time = 900.0
|
|
||||||
current_time = 1000.0
|
|
||||||
time_diff = current_time - start_time # 100 seconds
|
|
||||||
|
|
||||||
# Configure state and callback
|
|
||||||
trainer_state.global_step = current_step
|
|
||||||
callback.report_interval_steps = 100
|
|
||||||
callback.last_report_step = last_step
|
|
||||||
callback.start_time = start_time
|
|
||||||
callback.last_report_time = start_time
|
|
||||||
|
|
||||||
# Mock time.time() to return consistent values
|
|
||||||
mock_time.time.return_value = current_time
|
|
||||||
|
|
||||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
|
||||||
|
|
||||||
# Should update tracker
|
|
||||||
mock_runtime_metrics_tracker.update_step.assert_called_once_with(current_step)
|
|
||||||
mock_runtime_metrics_tracker.update_memory_metrics.assert_called_once()
|
|
||||||
|
|
||||||
# Should send telemetry
|
|
||||||
mock_telemetry_manager.send_event.assert_called_once()
|
|
||||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
|
||||||
assert call_args["event_type"] == "train-progress"
|
|
||||||
|
|
||||||
# Properties should include expected values
|
|
||||||
props = call_args["properties"]
|
|
||||||
assert props["step"] == current_step
|
|
||||||
assert props["elapsed_time"] == time_diff # 1000 - 900 = 100
|
|
||||||
assert props["time_since_last_report"] == time_diff # 1000 - 900 = 100
|
|
||||||
assert props["steps_per_second"] == 1.0 # 100 steps / 100 seconds
|
|
||||||
|
|
||||||
# Should update last report time/step
|
|
||||||
assert callback.last_report_step == current_step
|
|
||||||
assert callback.last_report_time == current_time
|
|
||||||
|
|
||||||
def test_on_step_end_report_time_elapsed(
|
|
||||||
self,
|
|
||||||
callback,
|
|
||||||
mock_telemetry_manager,
|
|
||||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
|
||||||
mock_time,
|
|
||||||
training_args,
|
|
||||||
trainer_state,
|
|
||||||
trainer_control,
|
|
||||||
):
|
|
||||||
"""Test on_step_end reports when enough time has elapsed."""
|
|
||||||
# Set up state with clear values
|
|
||||||
current_step = 120
|
|
||||||
last_step = 10
|
|
||||||
start_time = 900.0
|
|
||||||
current_time = 1000.0
|
|
||||||
time_diff = TIME_SINCE_LAST + 1 # Just over the threshold
|
|
||||||
|
|
||||||
# Configure state and callback
|
|
||||||
trainer_state.global_step = current_step
|
|
||||||
callback.report_interval_steps = 100
|
|
||||||
callback.last_report_step = last_step
|
|
||||||
callback.start_time = start_time
|
|
||||||
callback.last_report_time = current_time - time_diff
|
|
||||||
|
|
||||||
# Mock time.time() to return consistent values
|
|
||||||
mock_time.time.return_value = current_time
|
|
||||||
|
|
||||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
|
||||||
|
|
||||||
# Should send telemetry
|
|
||||||
mock_telemetry_manager.send_event.assert_called_once()
|
|
||||||
|
|
||||||
# Properties should include expected values
|
|
||||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
|
||||||
expected_metrics = calc_expected_metrics(
|
|
||||||
current_step, last_step, current_time, current_time - time_diff, start_time
|
|
||||||
)
|
|
||||||
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
|
|
||||||
assert (
|
|
||||||
props["time_since_last_report"]
|
|
||||||
== expected_metrics["time_since_last_report"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_on_step_end_first_step(
|
|
||||||
self,
|
|
||||||
callback,
|
|
||||||
mock_telemetry_manager,
|
|
||||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
|
||||||
mock_time,
|
|
||||||
training_args,
|
|
||||||
trainer_state,
|
|
||||||
trainer_control,
|
|
||||||
):
|
|
||||||
"""Test on_step_end always reports on first step."""
|
|
||||||
# Set up state with clear values
|
|
||||||
current_step = 1 # First step
|
|
||||||
last_step = 0
|
|
||||||
start_time = 900.0
|
|
||||||
current_time = 1000.0
|
|
||||||
last_report_time = 999.0 # Just 1 second ago
|
|
||||||
|
|
||||||
# Configure state and callback
|
|
||||||
trainer_state.global_step = current_step
|
|
||||||
callback.report_interval_steps = 100
|
|
||||||
callback.last_report_step = last_step
|
|
||||||
callback.start_time = start_time
|
|
||||||
callback.last_report_time = last_report_time
|
|
||||||
|
|
||||||
# Mock time.time() to return consistent values
|
|
||||||
mock_time.time.return_value = current_time
|
|
||||||
|
|
||||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
|
||||||
|
|
||||||
# Should send telemetry even though not much time has passed
|
|
||||||
mock_telemetry_manager.send_event.assert_called_once()
|
|
||||||
|
|
||||||
# Properties should include expected values for first step
|
|
||||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
|
||||||
assert props["step"] == current_step
|
|
||||||
expected_metrics = calc_expected_metrics(
|
|
||||||
current_step, last_step, current_time, last_report_time, start_time
|
|
||||||
)
|
|
||||||
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
|
|
||||||
|
|
||||||
def test_log_history_empty(
|
|
||||||
self,
|
|
||||||
callback,
|
|
||||||
mock_telemetry_manager,
|
|
||||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
|
||||||
mock_time,
|
|
||||||
training_args,
|
|
||||||
trainer_state,
|
|
||||||
trainer_control,
|
|
||||||
):
|
|
||||||
"""Test handling of empty log history."""
|
|
||||||
# Set up state with clear values
|
|
||||||
current_step = 1
|
|
||||||
start_time = 900.0
|
|
||||||
current_time = 1000.0
|
|
||||||
|
|
||||||
# Configure state and callback
|
|
||||||
trainer_state.global_step = current_step
|
|
||||||
trainer_state.log_history = []
|
|
||||||
callback.start_time = start_time
|
|
||||||
|
|
||||||
# Mock time.time() to return consistent values
|
|
||||||
mock_time.time.return_value = current_time
|
|
||||||
|
|
||||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
|
||||||
|
|
||||||
# Should still send telemetry
|
|
||||||
mock_telemetry_manager.send_event.assert_called_once()
|
|
||||||
|
|
||||||
# Properties should have default values for missing log data
|
|
||||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
|
||||||
assert props["loss"] == 0
|
|
||||||
assert props["learning_rate"] == 0
|
|
||||||
@@ -1,340 +0,0 @@
|
|||||||
"""Tests for telemetry error utilities"""
|
|
||||||
# pylint: disable=redefined-outer-name
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.telemetry.errors import sanitize_stack_trace, send_errors
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def reset_error_flag(monkeypatch):
|
|
||||||
"""Reset ERROR_HANDLED flag using monkeypatch"""
|
|
||||||
import axolotl.telemetry.errors
|
|
||||||
|
|
||||||
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
|
|
||||||
yield
|
|
||||||
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def example_stack_trace():
|
|
||||||
"""Provide a sample stack trace with mixed paths"""
|
|
||||||
return """Traceback (most recent call last):
|
|
||||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
|
|
||||||
trainer = get_trainer(cfg)
|
|
||||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer
|
|
||||||
model = get_model(cfg, tokenizer)
|
|
||||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model
|
|
||||||
raise ValueError("Model path not found")
|
|
||||||
ValueError: Model path not found
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def windows_stack_trace():
|
|
||||||
"""Provide a sample stack trace with Windows paths"""
|
|
||||||
return """Traceback (most recent call last):
|
|
||||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main
|
|
||||||
trainer = get_trainer(cfg)
|
|
||||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer
|
|
||||||
model = get_model(cfg, tokenizer)
|
|
||||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained
|
|
||||||
raise ValueError(f"Unrecognized configuration class {config.__class__}")
|
|
||||||
ValueError: Unrecognized configuration class <class 'transformers.models.llama.configuration_llama.LlamaConfig'>
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mixed_stack_trace():
|
|
||||||
"""Provide a sample stack trace with both axolotl and non-axolotl paths"""
|
|
||||||
return """Traceback (most recent call last):
|
|
||||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
|
|
||||||
trainer = get_trainer(cfg)
|
|
||||||
File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train
|
|
||||||
self._inner_training_loop()
|
|
||||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop
|
|
||||||
super()._inner_training_loop()
|
|
||||||
File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
|
|
||||||
data = self._next_data()
|
|
||||||
RuntimeError: CUDA out of memory
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def venv_stack_trace():
|
|
||||||
"""Provide a sample stack trace with virtual environment paths"""
|
|
||||||
return """Traceback (most recent call last):
|
|
||||||
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train
|
|
||||||
self._inner_training_loop()
|
|
||||||
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop
|
|
||||||
self.accelerator.backward(loss)
|
|
||||||
File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward
|
|
||||||
self.scaler.scale(loss).backward(**kwargs)
|
|
||||||
File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
|
|
||||||
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
|
|
||||||
RuntimeError: CUDA out of memory
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def dist_packages_stack_trace():
|
|
||||||
"""Provide a sample stack trace with dist-packages paths"""
|
|
||||||
return """Traceback (most recent call last):
|
|
||||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
|
|
||||||
data = self._next_data()
|
|
||||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data
|
|
||||||
data = self._dataset_fetcher.fetch(index)
|
|
||||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
|
|
||||||
data = [self.dataset[idx] for idx in possibly_batched_index]
|
|
||||||
File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__
|
|
||||||
raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.")
|
|
||||||
IndexError: Index 10000 out of range for dataset of length 9832.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def project_stack_trace():
|
|
||||||
"""Provide a sample stack trace from a project directory (not a virtual env)"""
|
|
||||||
return """Traceback (most recent call last):
|
|
||||||
File "/home/user/projects/myproject/run.py", line 25, in <module>
|
|
||||||
main()
|
|
||||||
File "/home/user/projects/myproject/src/cli.py", line 45, in main
|
|
||||||
app.run()
|
|
||||||
File "/home/user/projects/myproject/src/app.py", line 102, in run
|
|
||||||
raise ValueError("Configuration missing")
|
|
||||||
ValueError: Configuration missing
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_stack_trace(example_stack_trace):
|
|
||||||
"""Test that sanitize_stack_trace properly preserves axolotl paths"""
|
|
||||||
sanitized = sanitize_stack_trace(example_stack_trace)
|
|
||||||
|
|
||||||
# Check that personal paths are removed
|
|
||||||
assert "/home/user" not in sanitized
|
|
||||||
assert ".local/lib/python3.9" not in sanitized
|
|
||||||
|
|
||||||
# Check that site-packages is preserved
|
|
||||||
assert "site-packages/axolotl/cli/train.py" in sanitized
|
|
||||||
assert "site-packages/axolotl/train.py" in sanitized
|
|
||||||
assert "site-packages/axolotl/utils/models.py" in sanitized
|
|
||||||
|
|
||||||
# Check that error message is preserved
|
|
||||||
assert "ValueError: Model path not found" in sanitized
|
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_windows_paths(windows_stack_trace):
|
|
||||||
"""Test that sanitize_stack_trace handles Windows paths"""
|
|
||||||
sanitized = sanitize_stack_trace(windows_stack_trace)
|
|
||||||
|
|
||||||
# Check that personal paths are removed
|
|
||||||
assert "C:\\Users\\name" not in sanitized
|
|
||||||
assert "AppData\\Local\\Programs\\Python" not in sanitized
|
|
||||||
|
|
||||||
# Check that both axolotl and transformers packages are preserved
|
|
||||||
assert (
|
|
||||||
"site-packages\\axolotl\\cli\\train.py" in sanitized
|
|
||||||
or "site-packages/axolotl/cli/train.py" in sanitized
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
"site-packages\\axolotl\\train.py" in sanitized
|
|
||||||
or "site-packages/axolotl/train.py" in sanitized
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
"site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized
|
|
||||||
or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that error message is preserved
|
|
||||||
assert "ValueError: Unrecognized configuration class" in sanitized
|
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_mixed_paths(mixed_stack_trace):
|
|
||||||
"""Test that sanitize_stack_trace preserves all package paths"""
|
|
||||||
sanitized = sanitize_stack_trace(mixed_stack_trace)
|
|
||||||
|
|
||||||
# Check that all package paths are preserved
|
|
||||||
assert "site-packages/axolotl/cli/train.py" in sanitized
|
|
||||||
assert "site-packages/transformers/trainer.py" in sanitized
|
|
||||||
assert "site-packages/axolotl/utils/trainer.py" in sanitized
|
|
||||||
assert "site-packages/torch/utils/data/dataloader.py" in sanitized
|
|
||||||
|
|
||||||
# Check that error message is preserved
|
|
||||||
assert "RuntimeError: CUDA out of memory" in sanitized
|
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_venv_paths(venv_stack_trace):
|
|
||||||
"""Test that sanitize_stack_trace preserves virtual environment package paths"""
|
|
||||||
sanitized = sanitize_stack_trace(venv_stack_trace)
|
|
||||||
|
|
||||||
# Check that personal paths are removed
|
|
||||||
assert "/home/user/venv" not in sanitized
|
|
||||||
|
|
||||||
# Check that all package paths are preserved
|
|
||||||
assert "site-packages/transformers/trainer.py" in sanitized
|
|
||||||
assert "site-packages/accelerate/accelerator.py" in sanitized
|
|
||||||
assert "site-packages/torch/_tensor.py" in sanitized
|
|
||||||
|
|
||||||
# Check that error message is preserved
|
|
||||||
assert "RuntimeError: CUDA out of memory" in sanitized
|
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_dist_packages(dist_packages_stack_trace):
|
|
||||||
"""Test that sanitize_stack_trace preserves dist-packages paths"""
|
|
||||||
sanitized = sanitize_stack_trace(dist_packages_stack_trace)
|
|
||||||
|
|
||||||
# Check that system paths are removed
|
|
||||||
assert "/usr/local/lib/python3.8" not in sanitized
|
|
||||||
|
|
||||||
# Check that all package paths are preserved
|
|
||||||
assert "dist-packages/torch/utils/data/dataloader.py" in sanitized
|
|
||||||
assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized
|
|
||||||
assert "dist-packages/datasets/arrow_dataset.py" in sanitized
|
|
||||||
|
|
||||||
# Check that error message is preserved
|
|
||||||
assert (
|
|
||||||
"IndexError: Index 10000 out of range for dataset of length 9832." in sanitized
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_project_paths(project_stack_trace):
|
|
||||||
"""Test handling of project paths (non-virtual env)"""
|
|
||||||
sanitized = sanitize_stack_trace(project_stack_trace)
|
|
||||||
|
|
||||||
# Check that personal paths are removed
|
|
||||||
assert "/home/user/projects" not in sanitized
|
|
||||||
|
|
||||||
# For non-package paths, we should at least preserve the filename
|
|
||||||
assert "run.py" in sanitized
|
|
||||||
assert "cli.py" in sanitized
|
|
||||||
assert "app.py" in sanitized
|
|
||||||
|
|
||||||
# Check that error message is preserved
|
|
||||||
assert "ValueError: Configuration missing" in sanitized
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_telemetry_manager():
|
|
||||||
"""Create a mock TelemetryManager"""
|
|
||||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
|
||||||
mock_manager = MagicMock()
|
|
||||||
mock_manager.enabled = True
|
|
||||||
mock_manager_class.get_instance.return_value = mock_manager
|
|
||||||
yield mock_manager
|
|
||||||
|
|
||||||
|
|
||||||
def test_send_errors_successful_execution(mock_telemetry_manager):
|
|
||||||
"""Test that send_errors doesn't send telemetry for successful function execution"""
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def test_func():
|
|
||||||
return "success"
|
|
||||||
|
|
||||||
result = test_func()
|
|
||||||
assert result == "success"
|
|
||||||
mock_telemetry_manager.send_event.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_send_errors_with_exception(mock_telemetry_manager):
|
|
||||||
"""Test that send_errors sends telemetry when an exception occurs"""
|
|
||||||
test_error = ValueError("Test error")
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def test_func():
|
|
||||||
raise test_error
|
|
||||||
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
|
||||||
test_func()
|
|
||||||
|
|
||||||
assert excinfo.value == test_error
|
|
||||||
mock_telemetry_manager.send_event.assert_called_once()
|
|
||||||
|
|
||||||
# Check that the error info was passed correctly
|
|
||||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
|
||||||
assert "test_func-errored" in call_args["event_type"]
|
|
||||||
assert "Test error" in call_args["properties"]["exception"]
|
|
||||||
assert "stack_trace" in call_args["properties"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_send_errors_nested_calls(mock_telemetry_manager):
|
|
||||||
"""Test that send_errors only sends telemetry once for nested decorated functions"""
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def inner_func():
|
|
||||||
raise ValueError("Inner error")
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def outer_func():
|
|
||||||
return inner_func()
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
outer_func()
|
|
||||||
|
|
||||||
# Telemetry should be sent only once for the inner function
|
|
||||||
assert mock_telemetry_manager.send_event.call_count == 1
|
|
||||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
|
||||||
assert "inner_func-error" in call_args["event_type"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_send_errors_telemetry_disable():
|
|
||||||
"""Test that send_errors doesn't attempt to send telemetry when disabled"""
|
|
||||||
|
|
||||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
|
||||||
mock_manager = MagicMock()
|
|
||||||
mock_manager.enabled = False
|
|
||||||
mock_manager_class.get_instance.return_value = mock_manager
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def test_func():
|
|
||||||
raise ValueError("Test error")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
test_func()
|
|
||||||
|
|
||||||
mock_manager.send_event.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_error_handled_reset():
|
|
||||||
"""Test that ERROR_HANDLED flag is properly reset"""
|
|
||||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
|
||||||
# Create and configure the mock manager
|
|
||||||
mock_manager = MagicMock()
|
|
||||||
mock_manager.enabled = True
|
|
||||||
mock_manager_class.get_instance.return_value = mock_manager
|
|
||||||
|
|
||||||
from axolotl.telemetry.errors import ERROR_HANDLED
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def test_func():
|
|
||||||
raise ValueError("Test error")
|
|
||||||
|
|
||||||
assert not ERROR_HANDLED
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
test_func()
|
|
||||||
|
|
||||||
from axolotl.telemetry.errors import ERROR_HANDLED
|
|
||||||
|
|
||||||
assert ERROR_HANDLED
|
|
||||||
|
|
||||||
|
|
||||||
def test_module_path_resolution(mock_telemetry_manager):
|
|
||||||
"""Test that the module path is correctly resolved for the event type"""
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
current_module = inspect.getmodule(test_module_path_resolution).__name__
|
|
||||||
|
|
||||||
@send_errors
|
|
||||||
def test_func():
|
|
||||||
raise ValueError("Test error")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
test_func()
|
|
||||||
|
|
||||||
assert mock_telemetry_manager.send_event.called
|
|
||||||
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
|
|
||||||
|
|
||||||
expected_event_type = f"{current_module}.test_func-errored"
|
|
||||||
assert expected_event_type == event_type
|
|
||||||
@@ -1,245 +0,0 @@
|
|||||||
"""Tests for TelemetryManager class and utilities"""
|
|
||||||
# pylint: disable=redefined-outer-name,protected-access
|
|
||||||
|
|
||||||
import os
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from axolotl.telemetry.manager import TelemetryManager
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_whitelist(tmp_path):
|
|
||||||
"""Create a temporary whitelist file for testing"""
|
|
||||||
whitelist_content = {
|
|
||||||
"organizations": ["meta-llama", "mistralai"],
|
|
||||||
}
|
|
||||||
whitelist_file = tmp_path / "whitelist.yaml"
|
|
||||||
with open(whitelist_file, "w", encoding="utf-8") as f:
|
|
||||||
yaml.dump(whitelist_content, f)
|
|
||||||
|
|
||||||
return str(whitelist_file)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def telemetry_manager_class():
|
|
||||||
"""Reset the TelemetryManager singleton between tests"""
|
|
||||||
original_instance = TelemetryManager._instance
|
|
||||||
original_initialized = TelemetryManager._initialized
|
|
||||||
TelemetryManager._instance = None
|
|
||||||
TelemetryManager._initialized = False
|
|
||||||
yield TelemetryManager
|
|
||||||
TelemetryManager._instance = original_instance
|
|
||||||
TelemetryManager._initialized = original_initialized
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def manager(telemetry_manager_class, mock_whitelist):
|
|
||||||
"""Create a TelemetryManager instance with mocked dependencies"""
|
|
||||||
with patch("posthog.capture"), patch("posthog.flush"), patch("time.sleep"), patch(
|
|
||||||
"axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist
|
|
||||||
), patch.dict(os.environ, {"RANK": "0"}):
|
|
||||||
manager = telemetry_manager_class()
|
|
||||||
# Manually enable for most tests
|
|
||||||
manager.enabled = True
|
|
||||||
return manager
|
|
||||||
|
|
||||||
|
|
||||||
def test_singleton_instance(telemetry_manager_class):
|
|
||||||
"""Test that TelemetryManager is a singleton"""
|
|
||||||
with patch("posthog.capture"), patch("time.sleep"), patch.dict(
|
|
||||||
os.environ, {"RANK": "0"}
|
|
||||||
):
|
|
||||||
first = telemetry_manager_class()
|
|
||||||
second = telemetry_manager_class()
|
|
||||||
assert first is second
|
|
||||||
assert telemetry_manager_class.get_instance() is first
|
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
|
|
||||||
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
|
|
||||||
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}):
|
|
||||||
manager = telemetry_manager_class()
|
|
||||||
assert not manager.enabled
|
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
|
|
||||||
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
|
|
||||||
with patch.dict(os.environ, {"DO_NOT_TRACK": "1", "RANK": "0"}):
|
|
||||||
manager = telemetry_manager_class()
|
|
||||||
assert not manager.enabled
|
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
|
||||||
"""Test that telemetry is disabled for non-main processes"""
|
|
||||||
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}):
|
|
||||||
manager = telemetry_manager_class()
|
|
||||||
assert not manager.enabled
|
|
||||||
|
|
||||||
|
|
||||||
def test_telemetry_enabled_by_default(telemetry_manager_class):
|
|
||||||
"""Test that telemetry is enabled by default"""
|
|
||||||
with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch("time.sleep"), patch(
|
|
||||||
"logging.Logger.warning"
|
|
||||||
):
|
|
||||||
manager = telemetry_manager_class()
|
|
||||||
assert manager.enabled
|
|
||||||
assert not manager.explicit_enable
|
|
||||||
|
|
||||||
|
|
||||||
def test_explicit_enable_disables_warning(telemetry_manager_class):
|
|
||||||
"""Test that explicit enabling prevents warning"""
|
|
||||||
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}), patch(
|
|
||||||
"logging.Logger.warning"
|
|
||||||
) as mock_warning, patch("time.sleep"):
|
|
||||||
manager = telemetry_manager_class()
|
|
||||||
assert manager.enabled
|
|
||||||
assert manager.explicit_enable
|
|
||||||
for call in mock_warning.call_args_list:
|
|
||||||
assert "Telemetry is enabled" not in str(call)
|
|
||||||
|
|
||||||
|
|
||||||
def test_warning_displayed_for_implicit_enable(telemetry_manager_class):
|
|
||||||
"""Test that warning is displayed when telemetry is implicitly enabled"""
|
|
||||||
with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch(
|
|
||||||
"logging.Logger.warning"
|
|
||||||
) as mock_warning, patch("time.sleep"):
|
|
||||||
manager = telemetry_manager_class()
|
|
||||||
assert manager.enabled
|
|
||||||
assert not manager.explicit_enable
|
|
||||||
warning_displayed = False
|
|
||||||
for call in mock_warning.call_args_list:
|
|
||||||
if "Telemetry is enabled" in str(call):
|
|
||||||
warning_displayed = True
|
|
||||||
break
|
|
||||||
assert warning_displayed
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_whitelisted(manager, mock_whitelist):
|
|
||||||
"""Test org whitelist functionality"""
|
|
||||||
with patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist):
|
|
||||||
# Should match organizations from the mock whitelist
|
|
||||||
assert manager._is_whitelisted("meta-llama/llama-7b")
|
|
||||||
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
|
|
||||||
# Should not match
|
|
||||||
assert not manager._is_whitelisted("unknown/model")
|
|
||||||
# Should handle case insensitively
|
|
||||||
assert manager._is_whitelisted("META-LLAMA/Llama-7B")
|
|
||||||
# Should handle empty input
|
|
||||||
assert not manager._is_whitelisted("")
|
|
||||||
assert not manager._is_whitelisted(None)
|
|
||||||
|
|
||||||
|
|
||||||
def test_system_info_collection(manager):
|
|
||||||
"""Test system information collection"""
|
|
||||||
system_info = manager._get_system_info()
|
|
||||||
|
|
||||||
# Check essential keys
|
|
||||||
assert "os" in system_info
|
|
||||||
assert "python_version" in system_info
|
|
||||||
assert "torch_version" in system_info
|
|
||||||
assert "transformers_version" in system_info
|
|
||||||
assert "axolotl_version" in system_info
|
|
||||||
assert "cpu_count" in system_info
|
|
||||||
assert "memory_total" in system_info
|
|
||||||
assert "accelerator_count" in system_info
|
|
||||||
|
|
||||||
|
|
||||||
def test_send_event(manager):
|
|
||||||
"""Test basic event sending"""
|
|
||||||
with patch("posthog.capture") as mock_capture:
|
|
||||||
# Test with clean properties (no PII)
|
|
||||||
manager.send_event("test_event", {"key": "value"})
|
|
||||||
assert mock_capture.called
|
|
||||||
assert mock_capture.call_args[1]["event"] == "test_event"
|
|
||||||
assert mock_capture.call_args[1]["properties"] == {"key": "value"}
|
|
||||||
assert mock_capture.call_args[1]["distinct_id"] == manager.run_id
|
|
||||||
|
|
||||||
# Test with default properties (None)
|
|
||||||
mock_capture.reset_mock()
|
|
||||||
manager.send_event("simple_event")
|
|
||||||
assert mock_capture.called
|
|
||||||
assert mock_capture.call_args[1]["properties"] == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_send_system_info(manager):
|
|
||||||
"""Test sending system info"""
|
|
||||||
with patch("posthog.capture") as mock_capture:
|
|
||||||
manager.send_system_info()
|
|
||||||
assert mock_capture.called
|
|
||||||
assert mock_capture.call_args[1]["event"] == "system-info"
|
|
||||||
assert mock_capture.call_args[1]["properties"] == manager.system_info
|
|
||||||
|
|
||||||
|
|
||||||
def test_redacted_properties(manager):
|
|
||||||
"""Test path redaction in send_event method"""
|
|
||||||
with patch("posthog.capture") as mock_capture:
|
|
||||||
# Test with properties containing various paths and non-paths
|
|
||||||
test_properties = {
|
|
||||||
"filepath": "/home/user/sensitive/data.txt",
|
|
||||||
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
|
|
||||||
"output_dir": "/var/lib/data",
|
|
||||||
"path_to_model": "models/llama/7b",
|
|
||||||
"message": "Training started", # Should not be redacted
|
|
||||||
"metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted
|
|
||||||
"base_model": "models/local_model",
|
|
||||||
"nested": {
|
|
||||||
"model_path": "/models/my_model",
|
|
||||||
"root_dir": "/home/user/projects",
|
|
||||||
"stats": {"steps": 1000, "epochs": 3}, # Should not be redacted
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager.send_event("test_event", test_properties)
|
|
||||||
|
|
||||||
# Verify the call was made
|
|
||||||
assert mock_capture.called
|
|
||||||
|
|
||||||
# Get the sanitized properties that were sent
|
|
||||||
sanitized = mock_capture.call_args[1]["properties"]
|
|
||||||
|
|
||||||
# Check that path-like and base_model keys were redacted
|
|
||||||
assert sanitized["filepath"] == "[REDACTED]"
|
|
||||||
assert sanitized["windows_path"] == "[REDACTED]"
|
|
||||||
assert sanitized["path_to_model"] == "[REDACTED]"
|
|
||||||
assert sanitized["base_model"] == "[REDACTED]"
|
|
||||||
|
|
||||||
# Check that non-path values were preserved
|
|
||||||
assert sanitized["message"] == "Training started"
|
|
||||||
assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95}
|
|
||||||
|
|
||||||
# Check nested structure handling
|
|
||||||
assert sanitized["nested"]["model_path"] == "[REDACTED]"
|
|
||||||
assert sanitized["nested"]["root_dir"] == "[REDACTED]"
|
|
||||||
assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3}
|
|
||||||
|
|
||||||
|
|
||||||
def test_disable_telemetry(manager):
|
|
||||||
"""Test that disabled telemetry doesn't send events"""
|
|
||||||
with patch("posthog.capture") as mock_capture:
|
|
||||||
manager.enabled = False
|
|
||||||
manager.send_event("test_event")
|
|
||||||
assert not mock_capture.called
|
|
||||||
|
|
||||||
|
|
||||||
def test_exception_handling_during_send(manager):
|
|
||||||
"""Test that exceptions in PostHog are handled gracefully"""
|
|
||||||
with patch("posthog.capture", side_effect=Exception("Test error")), patch(
|
|
||||||
"logging.Logger.warning"
|
|
||||||
) as mock_warning:
|
|
||||||
manager.send_event("test_event")
|
|
||||||
warning_logged = False
|
|
||||||
for call in mock_warning.call_args_list:
|
|
||||||
if "Failed to send telemetry event" in str(call):
|
|
||||||
warning_logged = True
|
|
||||||
break
|
|
||||||
assert warning_logged
|
|
||||||
|
|
||||||
|
|
||||||
def test_shutdown(manager):
|
|
||||||
"""Test shutdown behavior"""
|
|
||||||
with patch("posthog.flush") as mock_flush:
|
|
||||||
manager.shutdown()
|
|
||||||
assert mock_flush.called
|
|
||||||
@@ -1,356 +0,0 @@
|
|||||||
"""Tests for runtime metrics telemetry module"""
|
|
||||||
# pylint: disable=redefined-outer-name
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.telemetry.runtime_metrics import RuntimeMetrics, RuntimeMetricsTracker
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_time():
|
|
||||||
"""Mock time.time() to have predictable values in tests"""
|
|
||||||
with patch("time.time") as mock_time:
|
|
||||||
# Start with time 1000.0 and increment by 10 seconds on each call
|
|
||||||
times = [1000.0 + i * 10 for i in range(10)]
|
|
||||||
mock_time.side_effect = times
|
|
||||||
yield mock_time
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_telemetry_manager():
|
|
||||||
"""Create a mock TelemetryManager"""
|
|
||||||
with patch(
|
|
||||||
"axolotl.telemetry.runtime_metrics.TelemetryManager"
|
|
||||||
) as mock_manager_class:
|
|
||||||
mock_manager = MagicMock()
|
|
||||||
mock_manager.enabled = True
|
|
||||||
mock_manager_class.get_instance.return_value = mock_manager
|
|
||||||
yield mock_manager
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_psutil():
|
|
||||||
"""Mock psutil for memory information"""
|
|
||||||
with patch("axolotl.telemetry.runtime_metrics.psutil") as mock_psutil:
|
|
||||||
mock_process = MagicMock()
|
|
||||||
mock_memory_info = MagicMock()
|
|
||||||
# Set initial memory to 1GB
|
|
||||||
mock_memory_info.rss = 1024 * 1024 * 1024
|
|
||||||
mock_process.memory_info.return_value = mock_memory_info
|
|
||||||
mock_psutil.Process.return_value = mock_process
|
|
||||||
yield mock_psutil
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_torch():
|
|
||||||
"""Mock torch.cuda functions"""
|
|
||||||
with patch("axolotl.telemetry.runtime_metrics.torch") as mock_torch:
|
|
||||||
mock_torch.cuda.is_available.return_value = True
|
|
||||||
mock_torch.cuda.device_count.return_value = 2
|
|
||||||
|
|
||||||
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
|
|
||||||
mock_torch.cuda.memory_allocated.side_effect = (
|
|
||||||
lambda device: (device + 1) * 1024 * 1024 * 1024
|
|
||||||
)
|
|
||||||
|
|
||||||
yield mock_torch
|
|
||||||
|
|
||||||
|
|
||||||
class TestRuntimeMetrics:
|
|
||||||
"""Tests for RuntimeMetrics class."""
|
|
||||||
|
|
||||||
def test_initialization(self):
|
|
||||||
"""Test RuntimeMetrics initialization."""
|
|
||||||
metrics = RuntimeMetrics(start_time=1000.0)
|
|
||||||
|
|
||||||
assert metrics.start_time == 1000.0
|
|
||||||
assert metrics.epoch_start_times == {}
|
|
||||||
assert metrics.epoch_end_times == {}
|
|
||||||
assert metrics.peak_gpu_memory == {}
|
|
||||||
assert metrics.total_steps == 0
|
|
||||||
assert metrics.current_epoch == 0
|
|
||||||
assert metrics.current_step == 0
|
|
||||||
assert metrics.peak_cpu_memory == 0
|
|
||||||
|
|
||||||
def test_elapsed_time(self, mock_time):
|
|
||||||
"""Test elapsed_time property."""
|
|
||||||
metrics = RuntimeMetrics(start_time=1000.0)
|
|
||||||
|
|
||||||
# Mock time.time() to return 1050.0
|
|
||||||
mock_time.side_effect = [1050.0]
|
|
||||||
|
|
||||||
assert metrics.elapsed_time == 50.0
|
|
||||||
|
|
||||||
def test_epoch_time(self):
|
|
||||||
"""Test epoch_time method."""
|
|
||||||
metrics = RuntimeMetrics(start_time=1000.0)
|
|
||||||
|
|
||||||
# No epoch data
|
|
||||||
assert metrics.epoch_time(0) is None
|
|
||||||
|
|
||||||
# Add epoch start but no end
|
|
||||||
metrics.epoch_start_times[0] = 1000.0
|
|
||||||
assert metrics.epoch_time(0) is None
|
|
||||||
|
|
||||||
# Add epoch end
|
|
||||||
metrics.epoch_end_times[0] = 1060.0
|
|
||||||
assert metrics.epoch_time(0) == 60.0
|
|
||||||
|
|
||||||
def test_average_epoch_time(self):
|
|
||||||
"""Test average_epoch_time method."""
|
|
||||||
metrics = RuntimeMetrics(start_time=1000.0)
|
|
||||||
|
|
||||||
# No completed epochs
|
|
||||||
assert metrics.average_epoch_time() is None
|
|
||||||
|
|
||||||
# Add one completed epoch
|
|
||||||
metrics.epoch_start_times[0] = 1000.0
|
|
||||||
metrics.epoch_end_times[0] = 1060.0
|
|
||||||
assert metrics.average_epoch_time() == 60.0
|
|
||||||
|
|
||||||
# Add second completed epoch
|
|
||||||
metrics.epoch_start_times[1] = 1060.0
|
|
||||||
metrics.epoch_end_times[1] = 1140.0 # 80 seconds
|
|
||||||
assert metrics.average_epoch_time() == 70.0 # Average of 60 and 80
|
|
||||||
|
|
||||||
# Add incomplete epoch (should not affect average)
|
|
||||||
metrics.epoch_start_times[2] = 1140.0
|
|
||||||
assert metrics.average_epoch_time() == 70.0
|
|
||||||
|
|
||||||
def test_steps_per_second(self, mock_time):
|
|
||||||
"""Test steps_per_second method."""
|
|
||||||
metrics = RuntimeMetrics(start_time=1000.0)
|
|
||||||
|
|
||||||
# No steps - first call to time.time()
|
|
||||||
mock_time.side_effect = None
|
|
||||||
mock_time.return_value = 1050.0
|
|
||||||
assert metrics.steps_per_second() is None
|
|
||||||
|
|
||||||
# Add steps - second call to time.time()
|
|
||||||
metrics.total_steps = 100
|
|
||||||
mock_time.return_value = 1050.0 # Keep same time for consistent result
|
|
||||||
assert metrics.steps_per_second() == 2.0 # 100 steps / 50 seconds
|
|
||||||
|
|
||||||
def test_to_dict_basic(self, mock_time):
|
|
||||||
"""Test to_dict method with basic metrics."""
|
|
||||||
metrics = RuntimeMetrics(start_time=1000.0)
|
|
||||||
metrics.total_steps = 100
|
|
||||||
metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 # 2GB
|
|
||||||
|
|
||||||
# Mock elapsed_time
|
|
||||||
mock_time.side_effect = None
|
|
||||||
mock_time.return_value = 1050.0
|
|
||||||
|
|
||||||
result = metrics.to_dict()
|
|
||||||
|
|
||||||
assert result["total_time_seconds"] == 50.0
|
|
||||||
assert result["total_steps"] == 100
|
|
||||||
assert result["steps_per_second"] == 2.0
|
|
||||||
assert result["epochs_completed"] == 0
|
|
||||||
assert result["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
|
||||||
assert "epoch_times" not in result
|
|
||||||
assert "gpu_memory" not in result
|
|
||||||
|
|
||||||
def test_to_dict_with_epochs(self, mock_time):
|
|
||||||
"""Test to_dict method with epoch data."""
|
|
||||||
metrics = RuntimeMetrics(start_time=1000.0)
|
|
||||||
metrics.total_steps = 100
|
|
||||||
|
|
||||||
# Add epoch data
|
|
||||||
metrics.epoch_start_times[0] = 1000.0
|
|
||||||
metrics.epoch_end_times[0] = 1060.0
|
|
||||||
metrics.epoch_start_times[1] = 1060.0
|
|
||||||
metrics.epoch_end_times[1] = 1140.0
|
|
||||||
|
|
||||||
# Mock elapsed_time
|
|
||||||
mock_time.side_effect = None
|
|
||||||
mock_time.return_value = 1150.0
|
|
||||||
|
|
||||||
result = metrics.to_dict()
|
|
||||||
|
|
||||||
assert "epoch_times" in result
|
|
||||||
assert result["epoch_times"]["epoch_0_seconds"] == 60.0
|
|
||||||
assert result["epoch_times"]["epoch_1_seconds"] == 80.0
|
|
||||||
assert result["average_epoch_time_seconds"] == 70.0
|
|
||||||
|
|
||||||
def test_to_dict_with_gpu_memory(self, mock_time):
|
|
||||||
"""Test to_dict method with GPU memory data."""
|
|
||||||
metrics = RuntimeMetrics(start_time=1000.0)
|
|
||||||
metrics.peak_gpu_memory = {
|
|
||||||
0: 1 * 1024 * 1024 * 1024, # 1GB
|
|
||||||
1: 2 * 1024 * 1024 * 1024, # 2GB
|
|
||||||
}
|
|
||||||
|
|
||||||
# Mock elapsed_time
|
|
||||||
mock_time.side_effect = [1050.0]
|
|
||||||
|
|
||||||
result = metrics.to_dict()
|
|
||||||
|
|
||||||
assert "gpu_memory" in result
|
|
||||||
assert result["gpu_memory"]["gpu_0_peak_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
|
||||||
assert result["gpu_memory"]["gpu_1_peak_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
class TestRuntimeMetricsTracker:
|
|
||||||
"""Tests for RuntimeMetricsTracker class."""
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def test_initialization(self, mock_time, mock_telemetry_manager):
|
|
||||||
"""Test RuntimeMetricsTracker initialization."""
|
|
||||||
tracker = RuntimeMetricsTracker()
|
|
||||||
|
|
||||||
assert isinstance(tracker.metrics, RuntimeMetrics)
|
|
||||||
assert tracker.metrics.start_time == 1000.0 # First value from mock_time
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def test_start_epoch(
|
|
||||||
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
|
|
||||||
):
|
|
||||||
"""Test start_epoch method."""
|
|
||||||
tracker = RuntimeMetricsTracker()
|
|
||||||
|
|
||||||
# Reset mock_time to control next value
|
|
||||||
mock_time.side_effect = [1010.0]
|
|
||||||
|
|
||||||
tracker.start_epoch(0)
|
|
||||||
|
|
||||||
assert tracker.metrics.current_epoch == 0
|
|
||||||
assert tracker.metrics.epoch_start_times[0] == 1010.0
|
|
||||||
|
|
||||||
# Verify memory metrics were updated
|
|
||||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
|
||||||
assert 0 in tracker.metrics.peak_gpu_memory
|
|
||||||
assert 1 in tracker.metrics.peak_gpu_memory
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def test_end_epoch(self, mock_time, mock_telemetry_manager):
|
|
||||||
"""Test end_epoch method."""
|
|
||||||
tracker = RuntimeMetricsTracker()
|
|
||||||
|
|
||||||
# Start epoch 0
|
|
||||||
mock_time.side_effect = [1010.0]
|
|
||||||
tracker.start_epoch(0)
|
|
||||||
|
|
||||||
# End epoch 0
|
|
||||||
mock_time.side_effect = [1060.0]
|
|
||||||
tracker.end_epoch(0)
|
|
||||||
|
|
||||||
assert 0 in tracker.metrics.epoch_end_times
|
|
||||||
assert tracker.metrics.epoch_end_times[0] == 1060.0
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def test_update_step(
|
|
||||||
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
|
|
||||||
):
|
|
||||||
"""Test update_step method."""
|
|
||||||
tracker = RuntimeMetricsTracker()
|
|
||||||
|
|
||||||
# Update step to a non-multiple of 100
|
|
||||||
tracker.update_step(42)
|
|
||||||
|
|
||||||
assert tracker.metrics.current_step == 42
|
|
||||||
assert tracker.metrics.total_steps == 1
|
|
||||||
|
|
||||||
# Memory metrics should not be updated for non-multiple of 100
|
|
||||||
assert tracker.metrics.peak_cpu_memory == 0
|
|
||||||
|
|
||||||
# Update step to a multiple of 100
|
|
||||||
tracker.update_step(100)
|
|
||||||
|
|
||||||
assert tracker.metrics.current_step == 100
|
|
||||||
assert tracker.metrics.total_steps == 2
|
|
||||||
|
|
||||||
# Memory metrics should be updated for multiple of 100
|
|
||||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def test_update_memory_metrics(
|
|
||||||
self, mock_psutil, mock_torch, mock_telemetry_manager
|
|
||||||
):
|
|
||||||
"""Test update_memory_metrics method."""
|
|
||||||
tracker = RuntimeMetricsTracker()
|
|
||||||
|
|
||||||
# Initial memory state
|
|
||||||
assert tracker.metrics.peak_cpu_memory == 0
|
|
||||||
assert tracker.metrics.peak_gpu_memory == {}
|
|
||||||
|
|
||||||
# Update memory metrics
|
|
||||||
tracker.update_memory_metrics()
|
|
||||||
|
|
||||||
# Verify CPU memory
|
|
||||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
|
||||||
|
|
||||||
# Verify GPU memory
|
|
||||||
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
|
|
||||||
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
|
|
||||||
|
|
||||||
# Change mocked memory values to be lower
|
|
||||||
mock_process = mock_psutil.Process.return_value
|
|
||||||
mock_memory_info = mock_process.memory_info.return_value
|
|
||||||
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
|
|
||||||
|
|
||||||
mock_torch.cuda.memory_allocated.side_effect = (
|
|
||||||
lambda device: (device + 0.5) * 1024 * 1024 * 1024
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update memory metrics again
|
|
||||||
tracker.update_memory_metrics()
|
|
||||||
|
|
||||||
# Peak values should not decrease
|
|
||||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
|
||||||
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
|
|
||||||
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
|
|
||||||
|
|
||||||
# Change mocked memory values to be higher
|
|
||||||
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
|
|
||||||
|
|
||||||
mock_torch.cuda.memory_allocated.side_effect = (
|
|
||||||
lambda device: (device + 2) * 1024 * 1024 * 1024
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update memory metrics again
|
|
||||||
tracker.update_memory_metrics()
|
|
||||||
|
|
||||||
# Peak values should increase
|
|
||||||
assert tracker.metrics.peak_cpu_memory == 2 * 1024 * 1024 * 1024
|
|
||||||
assert tracker.metrics.peak_gpu_memory[0] == 2 * 1024 * 1024 * 1024
|
|
||||||
assert tracker.metrics.peak_gpu_memory[1] == 3 * 1024 * 1024 * 1024
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_telemetry_manager):
|
|
||||||
"""Test get_memory_metrics method."""
|
|
||||||
tracker = RuntimeMetricsTracker()
|
|
||||||
|
|
||||||
# Set peak memory values
|
|
||||||
tracker.metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024
|
|
||||||
tracker.metrics.peak_gpu_memory = {
|
|
||||||
0: 3 * 1024 * 1024 * 1024,
|
|
||||||
1: 4 * 1024 * 1024 * 1024,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get memory metrics
|
|
||||||
memory_metrics = tracker.get_memory_metrics()
|
|
||||||
|
|
||||||
# Verify CPU memory
|
|
||||||
assert (
|
|
||||||
memory_metrics["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
|
||||||
) # Current value from mock
|
|
||||||
assert (
|
|
||||||
memory_metrics["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
|
||||||
) # Peak value we set
|
|
||||||
|
|
||||||
# Verify GPU memory
|
|
||||||
assert (
|
|
||||||
memory_metrics["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
|
||||||
) # Current value from mock
|
|
||||||
assert (
|
|
||||||
memory_metrics["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024
|
|
||||||
) # Peak value we set
|
|
||||||
assert (
|
|
||||||
memory_metrics["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
|
||||||
) # Current value from mock
|
|
||||||
assert (
|
|
||||||
memory_metrics["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024
|
|
||||||
) # Peak value we set
|
|
||||||
Reference in New Issue
Block a user