Compare commits

..

22 Commits

Author SHA1 Message Date
Dan Saunders
f144319697 distributed fix 2025-02-26 03:04:58 +00:00
Dan Saunders
07bb41812b fix issue with tests in ci 2025-02-26 03:04:58 +00:00
Dan Saunders
cae8c7636b fixes 2025-02-26 03:04:58 +00:00
Dan Saunders
09611cea10 remove duplicate info 2025-02-26 03:04:58 +00:00
Dan Saunders
61266ab843 adding runtime metrics / system info additional accelerator support, etc. 2025-02-26 03:04:58 +00:00
Dan Saunders
aea0e760e4 adding runtime metrics / system info additional accelerator support, etc. 2025-02-26 03:04:58 +00:00
Dan Saunders
5afad670da improved redaction, send system info during model config load telemetry, etc. 2025-02-26 03:04:58 +00:00
Dan Saunders
49ac79ed1e doc update 2025-02-26 03:04:58 +00:00
Dan Saunders
c9af72cd7a fix 2025-02-26 03:04:58 +00:00
Dan Saunders
d3d63c1432 adding back in base_model redaction w/ whitelist 2025-02-26 03:04:58 +00:00
Dan Saunders
675b65d711 sleep on all ranks in distributed setting 2025-02-26 03:04:58 +00:00
Dan Saunders
b23187daea simplifying path redaction 2025-02-26 03:04:58 +00:00
Dan Saunders
e373a6b8d0 small update / fix 2025-02-26 03:04:58 +00:00
Dan Saunders
fd5d5aecdc tests for runtime metrics telemetry and assoc. callback 2025-02-26 03:04:58 +00:00
Dan Saunders
3760175440 adding runtime metrics (cpu + gpu memory, steps/s, etc.) 2025-02-26 03:04:58 +00:00
Dan Saunders
7927abff90 updated sanitization logic, tests 2025-02-26 03:04:58 +00:00
Dan Saunders
ec36839316 update error file path sanitization function; adding more error tracking 2025-02-26 03:04:58 +00:00
Dan Saunders
3076b8df00 progress on telemetry: config load, process, model load, train start / end, error tracking 2025-02-26 03:04:58 +00:00
Dan Saunders
c50610375f updates 2025-02-26 03:04:58 +00:00
Dan Saunders
07ffd47f2b updates 2025-02-26 03:04:58 +00:00
Dan Saunders
76d951afd2 adding todo 2025-02-26 03:04:58 +00:00
Dan Saunders
5220e8ccf4 initial telemetry manager impl 2025-02-26 03:04:58 +00:00
66 changed files with 2770 additions and 2050 deletions

View File

@@ -88,11 +88,6 @@ jobs:
pytorch: 2.5.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -80,11 +80,6 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -19,6 +19,9 @@
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
</p>
Axolotl is a tool designed to streamline post-training for various AI models.

View File

@@ -40,7 +40,6 @@ website:
- section: "Deployments"
contents:
- docs/docker.qmd
- docs/multi-gpu.qmd
- docs/multi-node.qmd
- docs/ray-integration.qmd

View File

@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \

View File

@@ -154,6 +154,8 @@ datasets:
content: value
# ...
message_property_mappings:
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
user: ["human", "user"]
@@ -161,12 +163,6 @@ datasets:
system: ["system"]
tool: ["tool"]
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
# This does not drop the default system message from chat_template if it exists. If you wish to,
# we recommend using a custom jinja template with the default system message removed or
# adding a system turn with empty content.
drop_system_message:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`
@@ -226,8 +222,8 @@ process_reward_model:
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message. Currently only supports chatml.
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
@@ -449,7 +445,7 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -532,8 +528,6 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
sdp_attention:
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention:
# Optional[bool]. Whether to use low_cpu_mem_usage
low_cpu_mem_usage:
# Resume from a specific checkpoint dir
resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
@@ -554,13 +548,6 @@ special_tokens:
# Add extra tokens.
tokens:
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
# Can be checked if they exist in tokenizer.json added_tokens.
added_tokens_overrides: # Dict[int, str]
# 128041: "<|im_start|>"
# 128042: "<|im_end|>"
# FSDP
fsdp:
fsdp_config:

View File

@@ -74,10 +74,6 @@ datasets:
train_on_eos:
```
::: {.callout-tip}
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml

View File

@@ -129,7 +129,6 @@ You can mix and match within each approach or across approaches to train a model
We suggest this approach when you want to bring your own tokenized dataset.
Axolotl expects the dataset to have three keys:
- `input_ids`: from tokenizing formatted prompt
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.

View File

@@ -1,140 +0,0 @@
---
title: "Docker"
format:
html:
toc: true
toc-depth: 4
---
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
#### Image
```
axolotlai/axolotl-base
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
#### Tags format
```bash
main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
```
Tags examples:
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
## Main
The main image is the image that is used to run Axolotl. It is based on the `axolotlai/axolotl-base` image and includes the Axolotl codebase, dependencies, and more.
#### Image
```
axolotlai/axolotl
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
#### Tags format {#sec-main-tags}
```bash
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
main-latest
# nightly build
{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}
# tagged release
{version}
```
:::{.callout-tip}
There may be some extra tags appended to the image, like `-vllm` which installs those packages.
:::
Tags examples:
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu124-2.4.1`
- `0.7.1`
## Cloud
The cloud image is the image that is used to run Axolotl in the cloud. It is based on the `axolotlai/axolotl` image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.
:::{.callout-tip}
Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variables to disable it.
:::
#### Image
```
axolotlai/axolotl-cloud
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
#### Tags format
This uses the same tags as the [`main` image](#sec-main-tags).
#### Environment variables
- `JUPYTER_DISABLE`: Disable Jupyter lab.
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
- `PUBLIC_KEY`: Add a public key for the SSH service.
- `SSH_KEY`: Add a private key for the SSH service.
#### Volume mounts
:::{.callout-tip}
We recommend mounting volumes to `/workspace/data` for data persistence. `/workspace/axolotl` contains the source code and is ephemeral.
:::
- `/workspace/data/axolotl-artifacts`: Directory to store Axolotl artifacts.
- `/workspace/data/huggingface-cache`: Directory to store HuggingFace cache.
## Cloud-no-tmux
This is the same as the [`cloud` image](#sec-cloud) but without tmux.
#### Image
```
axolotlai/axolotl-cloud-term
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud-term)
:::{.callout-note}
The naming may be a bit confusing as it has `-term` appended to the end.
:::
#### Tags format
This uses the same tags as the [`cloud` image](#sec-cloud-tags).

View File

@@ -19,9 +19,7 @@ description: Frequently asked questions
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
**Q: ModuleNotFoundError: No module named 'mpi4py' using single GPU with deepspeed**
> A: You may be using deepspeed with single gpu. Please remove the `deepspeed:` section in the yaml file or `--deepspeed` CLI flag.
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
**Q: The codes is stuck on saving preprocessed datasets.**
@@ -52,7 +50,3 @@ description: Frequently asked questions
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.

View File

@@ -65,8 +65,6 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
```
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
## Cloud Environments {#sec-cloud}
### Cloud GPU Providers {#sec-cloud-gpu}

View File

@@ -28,17 +28,6 @@ val_set_size: 0.1
eval_steps: 100
```
Bradley-Terry chat templates expect single-turn conversations in the following format:
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### Process Reward Models (PRM)
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
@@ -56,5 +45,3 @@ datasets:
val_set_size: 0.1
eval_steps: 100
```
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.

View File

@@ -3,7 +3,6 @@ title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
@@ -529,7 +528,6 @@ trl:
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
@@ -538,8 +536,6 @@ datasets:
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### Using local dataset files
```yaml

59
docs/telemetry.qmd Normal file
View File

@@ -0,0 +1,59 @@
---
title: Telemetry
description: A description of the opt-out telemetry implementation in Axolotl.
---
# Telemetry in Axolotl
Axolotl implements anonymous telemetry to help maintainers understand how the library
is used and where users encounter issues. This data helps prioritize features, optimize
performance, and fix bugs.
## Data Collection
We collect:
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
version, etc.
- Hardware info: CPU count, memory, GPU count and models
- Runtime metrics: Training progress, memory usage, timing information
- Usage patterns: Models (from a whitelist) and configurations used
- Error tracking: Stack traces and error messages (sanitized to remove personal
information)
No personally identifiable information (PII) is collected.
## Implementation
Telemetry is implemented using PostHog and consists of:
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
telemetry system and provides methods for tracking events.
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
sends sanitized stack traces.
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
runtime metrics during training.
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
runtime metrics telemetry.
The telemetry system will block training startup for 15 seconds to ensure users are
aware of data collection, unless telemetry is explicitly enabled or disabled.
## Opt-Out Mechanism
Telemetry is **enabled by default** on an opt-out basis. To disable it, set either:
- `AXOLOTL_DO_NOT_TRACK=1` (Axolotl-specific)
- `DO_NOT_TRACK=1` (Global standard; see https://consoledonottrack.com/)
To acknowledge and explicitly enable telemetry (and remove the warning message), set:
`AXOLOTL_DO_NOT_TRACK=0`.
## Privacy
- All path-like config information is automatically redacted from telemetry data
- Model information is only collected for whitelisted organizations
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
- Each run generates a unique anonymous ID
- This allows us to link different telemetry events in a single same training run
- Telemetry is only sent from the main process to avoid duplicate events

View File

@@ -62,5 +62,7 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
axolotl-contribs-lgpl==0.0.3
# telemetry
posthog>=3.15.1

View File

@@ -24,5 +24,5 @@ if cce_spec:
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
)

View File

@@ -113,7 +113,7 @@ class ModalCloud(Cloud):
[
# Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
]
)
@@ -270,7 +270,6 @@ def _preprocess(config_yaml: str, volumes=None):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
@@ -289,7 +288,6 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _lm_eval(config_yaml: str, volumes=None):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"

View File

@@ -14,6 +14,8 @@ import yaml
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -27,6 +29,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
"""
@@ -152,6 +156,7 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager.register(plugin_name)
@send_errors
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
"""
Loads the `axolotl` configuration stored at `config`, validates it, and performs
@@ -171,6 +176,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
@@ -214,4 +220,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
return cfg

View File

@@ -17,6 +17,7 @@ from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
@@ -42,6 +43,7 @@ def get_multi_line_input() -> str:
return instruction
@send_errors
def do_inference(
*,
cfg: DictDefault,
@@ -135,6 +137,7 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
@send_errors
def do_inference_gradio(
*,
cfg: DictDefault,

View File

@@ -12,11 +12,13 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
@send_errors
def do_merge_lora(*, cfg: DictDefault) -> None:
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config

View File

@@ -27,6 +27,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.telemetry.errors import send_errors
LOG = logging.getLogger(__name__)
@@ -120,6 +121,7 @@ def _distributed_checkpoint_to_merged_weights(
return save_path_
@send_errors
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,

View File

@@ -18,12 +18,14 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger(__name__)
@send_errors
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
"""
Preprocesses dataset specified in axolotl config.

View File

@@ -1,7 +1,6 @@
"""CLI to run training on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -35,20 +34,18 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
print_axolotl_text_art()
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
del trainer
plugin_manager.post_train_unload(cfg)

View File

@@ -10,6 +10,7 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.telemetry.errors import send_errors
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
@@ -24,8 +25,8 @@ class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset
eval_dataset: Dataset | None = None
total_num_steps: int | None = None
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
@@ -44,6 +45,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
)
@send_errors
def load_datasets(
*,
cfg: DictDefault,
@@ -103,6 +105,7 @@ def load_datasets(
)
@send_errors
def load_preference_datasets(
*,
cfg: DictDefault,

View File

@@ -43,7 +43,7 @@ class TokenizedChatDataset(Dataset):
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(32, process_or_cpu_count)
num_proc = min(64, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,

View File

@@ -35,11 +35,11 @@ from transformers import (
EarlyStoppingCallback,
TrainerCallback,
)
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlMambaTrainer,
AxolotlORPOTrainer,
AxolotlPRMTrainer,
@@ -50,7 +50,6 @@ from axolotl.core.trainers.base import (
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.trainers.kto import AxolotlKTOTrainer
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
@@ -62,6 +61,8 @@ from axolotl.core.training_args import (
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.telemetry.callbacks import TelemetryCallback
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
@@ -85,7 +86,6 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
from axolotl.utils.models import ensure_dtype
try:
@@ -93,11 +93,13 @@ try:
except ImportError:
pass
LOG = logging.getLogger(__name__)
LOG = logging.getLogger("axolotl.core.trainer_builder")
class TrainerBuilderBase(abc.ABC):
"""Base class for trainer builder."""
"""
Base class for trainer builder
"""
_train_dataset = None
_eval_dataset = None
@@ -110,9 +112,9 @@ class TrainerBuilderBase(abc.ABC):
self.tokenizer = tokenizer
self.processor = processor
# If the model supports tagging, add the axolotl tag.
# in case the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instead of trainer.push_to_hub.
# model.push_to_hub instad of trainer.push_to_hub.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
@@ -176,10 +178,8 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
]
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
@@ -188,6 +188,10 @@ class TrainerBuilderBase(abc.ABC):
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
telemetry_manager = TelemetryManager.get_instance()
if telemetry_manager.enabled:
callbacks.append(TelemetryCallback())
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -227,8 +231,8 @@ class TrainerBuilderBase(abc.ABC):
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for causal models and reward modeling
using TRL.
Build the HuggingFace training args/trainer for causal models
and reward modelling using TRL.
"""
def get_callbacks(self):
@@ -551,8 +555,30 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
"alternate_lr_scheduler_type"
@@ -636,114 +662,46 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# Handle custom optimizer
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
# Common optimizer kwargs
optimizer_kwargs = {
"lr": training_arguments_kwargs.get("learning_rate"),
"weight_decay": training_arguments_kwargs.get("weight_decay"),
}
# pylint: disable=duplicate-code
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
# Adam-specific kwargs
adam_kwargs = {}
if training_arguments_kwargs.get(
"adam_beta1"
) and training_arguments_kwargs.get("adam_beta2"):
adam_kwargs["betas"] = (
training_arguments_kwargs.get("adam_beta1"),
training_arguments_kwargs.get("adam_beta2"),
)
if training_arguments_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
MuonOptimizerFactory,
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
if "weight_decay" in training_arguments_kwargs:
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
)
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
trainer_kwargs["optimizers"] = (
Lion(params=self.model.parameters(), **lion_kwargs),
None,
)
else:
# Use transformers' optimizer
training_arguments_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config:
training_arguments_kwargs[
"accelerator_config"
@@ -918,7 +876,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase):
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
"""
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
"""
def get_callbacks(self):
callbacks = super().get_callbacks()

View File

@@ -14,21 +14,17 @@ from typing import Dict, Literal, Optional
import torch
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from axolotl.core.trainers.kto import AxolotlKTOTrainer
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
@@ -119,17 +115,6 @@ class SchedulerMixin(Trainer):
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
@@ -169,18 +154,47 @@ class SchedulerMixin(Trainer):
return self.lr_scheduler
class OptimizerMixin(Trainer):
class AxolotlTrainer(SchedulerMixin, Trainer):
"""
Mixin class for shared handling of building custom optimizers
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def create_optimizer_grouped_parameters(
self, opt_model, optimizer_kwargs
) -> list[dict]:
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
params: dict = {
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
@@ -267,30 +281,23 @@ class OptimizerMixin(Trainer):
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if (
not self.optimizer
and self.optimizer_cls_and_kwargs is not None
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
@@ -307,47 +314,50 @@ class OptimizerMixin(Trainer):
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
else:
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
)
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
or self.args.lr_groups is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -356,45 +366,6 @@ class OptimizerMixin(Trainer):
return self.optimizer
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
@@ -875,6 +846,14 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers

View File

@@ -9,7 +9,6 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl")
@@ -32,44 +31,30 @@ class GRPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if not hasattr(cfg, "trl") or not cfg.trl:
return grpo_args_kwargs
trl: TRLConfig = cfg.trl # type: ignore
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_device"] = (
trl.vllm_device if trl.vllm_device else "auto"
)
if trl.vllm_gpu_memory_utilization:
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = trl.vllm_gpu_memory_utilization
if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
if trl.num_generations:
grpo_args_kwargs["num_generations"] = trl.num_generations
if trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
if trl.ref_model_mixup_alpha:
grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha
if trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
return grpo_args_kwargs
@classmethod

View File

@@ -1,7 +0,0 @@
"""
KTO package initialization.
"""
from axolotl.core.trainers.kto.trainer import AxolotlKTOTrainer
__all__ = ["AxolotlKTOTrainer"]

View File

@@ -1,512 +0,0 @@
"""
KTO trainer implementation for Axolotl.
"""
from __future__ import annotations
import inspect
import logging
import warnings
from contextlib import nullcontext
from typing import Any, Callable, Literal, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import Dataset
from torch.utils.data import DataLoader, SequentialSampler
from transformers import (
AutoModelForCausalLM,
BaseImageProcessor,
DataCollator,
FeatureExtractionMixin,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainerCallback,
TrainingArguments,
)
from transformers.trainer_utils import EvalLoopOutput
from trl import KTOTrainer
from trl.trainer.kto_config import KTOConfig
from trl.trainer.utils import KTODataCollatorWithPadding, pad_to_length
from axolotl.core.trainers.base import SchedulerMixin
# Check if PEFT is available
try:
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training, peft_module_casting_to_bf16
is_peft_available = True
except ImportError:
is_peft_available = False
LOG = logging.getLogger("axolotl.core.trainers.kto")
class AxolotlKTOTrainer(SchedulerMixin, Trainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
args: KTOConfig = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
data_collator: Optional[DataCollator] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
dataset_tags=None,
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
):
self.dataset_tags = dataset_tags
self._tag_names = ["trl", "kto"]
if hasattr(self, "tag_names"):
self._tag_names.extend(self.tag_names)
if type(args) is TrainingArguments:
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
torch_dtype = model_init_kwargs.get("torch_dtype")
if torch_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(torch, torch_dtype)
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
)
model_init_kwargs["torch_dtype"] = torch_dtype
if args.ref_model_init_kwargs is None:
ref_model_init_kwargs = {}
elif not isinstance(ref_model, str):
raise ValueError(
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
)
else:
ref_model_init_kwargs = args.ref_model_init_kwargs
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
if torch_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(torch, torch_dtype)
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
)
ref_model_init_kwargs["torch_dtype"] = torch_dtype
if isinstance(model, str):
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
if isinstance(ref_model, str):
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False
if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
# if model is a peft model and we have a peft_config, we merge and unload it first
if isinstance(model, PeftModel):
model = model.merge_and_unload()
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if _support_gc_kwargs:
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# get peft model with the given config
model = get_peft_model(model, peft_config)
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(model)
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
self._peft_has_been_casted_to_bf16 = True
# For models that use gradient_checkpointing, we need to attach a hook that enables input
# to explicitly have `requires_grad=True`, otherwise training will either silently
# fail or completely fail.
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
" Please install `wandb` or `comet-ml` to resolve."
)
if model is not None:
self.is_encoder_decoder = model.config.is_encoder_decoder
elif args.is_encoder_decoder is None:
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
else:
self.is_encoder_decoder = args.is_encoder_decoder
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
self.model_adapter_name = model_adapter_name
self.ref_adapter_name = ref_adapter_name
if ref_model:
self.ref_model = ref_model
elif self.is_peft_model or args.precompute_ref_log_probs:
# The `model` with adapters turned off will be used as the reference model
self.ref_model = None
else:
self.ref_model = create_reference_model(model)
if processing_class is None:
raise ValueError(
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
)
if args.max_length is None:
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
" it will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
if args.max_length is not None:
max_length = args.max_length
if args.max_prompt_length is None:
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
" it will be set to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_prompt_length = 128
if args.max_prompt_length is not None:
max_prompt_length = args.max_prompt_length
max_completion_length = None
if args.max_completion_length is None and self.is_encoder_decoder:
warnings.warn(
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
" it will be set to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_completion_length = 128
if args.max_completion_length is not None and self.is_encoder_decoder:
max_completion_length = args.max_completion_length
if data_collator is None:
data_collator = DPODataCollatorWithPadding(
pad_token_id=processing_class.pad_token_id,
label_pad_token_id=args.label_pad_token_id,
is_encoder_decoder=self.is_encoder_decoder,
)
if args.remove_unused_columns:
args.remove_unused_columns = False
# warn users
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
" we have set it for you, but you should do it yourself in the future.",
UserWarning,
)
self.use_dpo_data_collator = True
else:
self.use_dpo_data_collator = False
# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
self.loss_type = args.loss_type
self.max_length = max_length
self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
self.max_prompt_length = max_prompt_length
self.truncation_mode = args.truncation_mode
self.max_completion_length = max_completion_length
self.processing_class = processing_class
self.precompute_ref_log_probs = args.precompute_ref_log_probs
# Not all losses require a KL calculation
self.calculate_KL = True
if self.loss_type in ["apo_zero_unpaired"]:
self.calculate_KL = False
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
# keep track of first called to avoid computation of future calls
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
# metric
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# KTO parameter
self.beta = args.beta
self.desirable_weight = args.desirable_weight
self.undesirable_weight = args.undesirable_weight
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
warnings.warn(
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
"loss.",
UserWarning,
)
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
# issued.
model.warnings_issued["estimate_tokens"] = True
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# Extract the prompt if needed
train_dataset = train_dataset.map(
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
)
# Unpair the dataset if needed
train_dataset = maybe_unpair_preference_dataset(
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
)
# Apply the chat template if needed
train_dataset = train_dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class},
num_proc=args.dataset_num_proc,
desc="Applying chat template to train dataset",
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
)
eval_dataset = maybe_unpair_preference_dataset(
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
)
eval_dataset = eval_dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class},
num_proc=args.dataset_num_proc,
desc="Applying chat template to eval dataset",
)
# Tokenize and prepare the training datasets
train_dataset = train_dataset.map(
_tokenize,
batched=True,
fn_kwargs={"tokenizer": self.processing_class},
num_proc=args.dataset_num_proc,
desc="Tokenizing train dataset",
)
fn_kwargs = {
"prefix": "",
"is_encoder_decoder": self.is_encoder_decoder,
"tokenizer": self.processing_class,
"max_length": self.max_length,
"truncation_mode": self.truncation_mode,
"label_pad_token_id": self.label_pad_token_id,
"max_prompt_length": self.max_prompt_length,
"max_completion_length": self.max_completion_length,
}
train_dataset = train_dataset.map(
_process_tokens,
fn_kwargs=fn_kwargs,
num_proc=args.dataset_num_proc,
desc="Processing tokenized train dataset",
)
# Tokenize and prepare the eval datasets
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
_tokenize,
fn_kwargs={"tokenizer": self.processing_class},
batched=True,
num_proc=args.dataset_num_proc,
desc="Tokenizing eval dataset",
)
eval_dataset = eval_dataset.map(
_process_tokens,
fn_kwargs=fn_kwargs,
num_proc=args.dataset_num_proc,
desc="Processing tokenized eval dataset",
)
# Get KL datasets if needed
if self.calculate_KL:
if args.per_device_train_batch_size <= 1:
raise ValueError(
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
)
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
train_kl_dataset = train_dataset.map(
_get_kl_dataset,
batched=True,
batch_size=args.per_device_train_batch_size,
num_proc=args.dataset_num_proc,
desc="Extracting KL train dataset",
)
fn_kwargs["prefix"] = "KL_"
train_kl_dataset = train_kl_dataset.map(
_process_tokens,
fn_kwargs=fn_kwargs,
num_proc=args.dataset_num_proc,
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
desc="Processing tokenized train KL dataset",
)
# merge the datasets
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
if eval_dataset is not None:
# Get KL dataset
eval_kl_dataset = eval_dataset.map(
_get_kl_dataset,
batched=True,
batch_size=args.per_device_train_batch_size,
num_proc=args.dataset_num_proc,
desc="Extracting eval KL dataset",
)
eval_kl_dataset = eval_kl_dataset.map(
_process_tokens,
fn_kwargs=fn_kwargs,
num_proc=args.dataset_num_proc,
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
desc="Processing tokenized eval KL dataset",
)
# merge the datasets
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
# calculate dataset desirability balance
num_desirable = max(sum(train_dataset["label"]), 1)
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
if num_desirable != num_undesirable:
# The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
if not (des_weight_in_range or und_weight_in_range):
warnings.warn(
"You have different amounts of desirable/positive and undesirable/negative examples but the "
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
f"on your data, we recommend EITHER "
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
"See the documentation on how to optimally set these weights.",
UserWarning,
)
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
if not hasattr(self, "accelerator"):
raise AttributeError(
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)
# Deepspeed Zero-3 does not support precompute_ref_log_probs
if self.is_deepspeed_enabled:
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
raise ValueError(
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
)
if self.ref_model is None:
if not (self.is_peft_model or self.precompute_ref_log_probs):
raise ValueError(
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

View File

@@ -10,6 +10,7 @@ import torch
from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.telemetry.errors import send_errors
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
@@ -61,6 +62,7 @@ def evaluate_dataset(
return metrics
@send_errors
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets

View File

@@ -23,8 +23,6 @@ import importlib
import logging
from typing import OrderedDict
import torch
class BasePlugin:
"""
@@ -471,14 +469,3 @@ class PluginManager:
"""
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)
class BaseOptimizerFactory:
"""
Base class for factories to create custom optimizers
"""
def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> "torch.optim.Optimizer":
pass

View File

@@ -4,22 +4,6 @@ Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy o
See https://github.com/apple/ml-cross-entropy
## Requirements
- PyTorch 2.4.0 or higher
## Installation
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
```bash
# if you are in dev environment
python scripts/cutcrossentropy_install.py | sh
# if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
```
## Usage
```yaml

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
)

View File

@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
"""
from typing import Optional
from pydantic import BaseModel, model_validator
from pydantic import BaseModel
class SpectrumArgs(BaseModel):
@@ -27,20 +27,3 @@ class SpectrumArgs(BaseModel):
spectrum_top_fraction: Optional[float] = 0.5
spectrum_model_name: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_fsdp_use_orig_params(cls, data):
if (
data.get("fsdp")
and data.get("fsdp_config")
and not data["fsdp_config"].get("use_orig_params")
and data.get("plugins")
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
):
# would otherwise raise
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
raise ValueError(
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
)
return data

View File

View File

@@ -0,0 +1,164 @@
"""Trainer callbacks for reporting runtime metrics at regular intervals."""
import logging
import time
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from axolotl.telemetry.manager import TelemetryManager
from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker
LOG = logging.getLogger(__name__)
TIME_SINCE_LAST = 30
class TelemetryCallback(TrainerCallback):
"""
Trainer callback for tracking and reporting runtime metrics.
This callback tracks training progress, runtime, and memory usage,
sending telemetry at configurable intervals.
"""
report_interval_steps: int = 100
def __init__(self):
"""Initialize the metrics callback."""
self.tracker = RuntimeMetricsTracker()
self.telemetry_manager = TelemetryManager.get_instance()
self.current_epoch = -1
self.start_time = time.time()
self.last_report_time = None
self.last_report_step = 0
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Handle training start."""
self.telemetry_manager.send_event(event_type="train-started")
def on_train_end(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Handle training end."""
# Send training completion event
self.telemetry_manager.send_event(
event_type="train-ended",
properties={
"loss": state.log_history[-1].get("loss", 0)
if state.log_history
else None,
"learning_rate": state.log_history[-1].get("learning_rate", 0)
if state.log_history
else None,
}
| self.tracker.metrics.to_dict(),
)
def on_epoch_begin(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Handle epoch start."""
self.current_epoch += 1
self.tracker.start_epoch(self.current_epoch)
def on_epoch_end(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Handle epoch end."""
self.tracker.end_epoch(self.current_epoch)
def on_step_end(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Handle step end."""
step = state.global_step
self.tracker.update_step(step)
# Check if we should report metrics
should_report = (
step % self.report_interval_steps == 0
or step == 1 # Always report first step
or step - self.last_report_step >= self.report_interval_steps
)
if should_report:
current_time = time.time()
if self.last_report_time is not None:
time_since_last_report = current_time - self.last_report_time
else:
time_since_last_report = current_time - self.start_time
steps_since_last_report = step - self.last_report_step
# Only report if enough time has passed to avoid flooding
if (
step == 1
or time_since_last_report >= TIME_SINCE_LAST
or steps_since_last_report >= self.report_interval_steps
):
# Calculate steps per second for this interval
if time_since_last_report > 0 and steps_since_last_report > 0:
steps_per_second = steps_since_last_report / time_since_last_report
else:
steps_per_second = 0
# Update memory metrics
self.tracker.update_memory_metrics()
loss = state.log_history[-1].get("loss", 0) if state.log_history else 0
learning_rate = (
state.log_history[-1].get("learning_rate", 0)
if state.log_history
else 0
)
# Prepare metrics to report
metrics = {
"step": step,
"epoch": self.current_epoch,
"progress": state.epoch, # Fractional epoch progress
"loss": loss,
"learning_rate": learning_rate,
"steps_per_second": steps_per_second,
"elapsed_time": current_time - self.start_time,
"time_since_last_report": time_since_last_report,
}
# Add memory metrics
memory_metrics = self.tracker.get_memory_metrics()
metrics.update({"memory": memory_metrics})
# Send telemetry
self.telemetry_manager.send_event(
event_type="train-progress", properties=metrics
)
# Update last report time and step
self.last_report_time = current_time
self.last_report_step = step

View File

@@ -0,0 +1,160 @@
"""Telemetry utilities for exception and traceback information."""
import logging
import os
import re
import traceback
from functools import wraps
from inspect import getmodule
from typing import Any, Callable
from axolotl.telemetry.manager import TelemetryManager
LOG = logging.getLogger(__name__)
ERROR_HANDLED = False
def sanitize_stack_trace(stack_trace: str) -> str:
"""
Remove personal information from stack trace messages while keeping Python package codepaths.
This function identifies Python packages by looking for common patterns in virtual environment
and site-packages directories, preserving the package path while removing user-specific paths.
Args:
stack_trace: The original stack trace string.
Returns:
A sanitized version of the stack trace with Python package paths preserved.
"""
# Split the stack trace into lines to process each file path separately
lines = stack_trace.split("\n")
sanitized_lines = []
# Regular expression to find file paths in the stack trace
path_pattern = re.compile(r'(?:File ")(.*?)(?:")')
# Regular expression to identify paths in site-packages or dist-packages
# This matches path segments like "site-packages/package_name" or "dist-packages/package_name"
site_packages_pattern = re.compile(
r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
)
# Additional common virtual environment patterns
venv_lib_pattern = re.compile(
r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
)
for line in lines:
# Check if this line contains a file path
path_match = path_pattern.search(line)
if path_match:
full_path = path_match.group(1)
sanitized_path = ""
# Try to match site-packages pattern
site_packages_match = site_packages_pattern.search(full_path)
venv_lib_match = venv_lib_pattern.search(full_path)
if site_packages_match:
# Find the index where the matched pattern starts
idx = full_path.find("site-packages")
if idx == -1:
idx = full_path.find("dist-packages")
# Keep from 'site-packages' onward
if idx >= 0:
sanitized_path = full_path[idx:]
elif venv_lib_match:
# For other virtual environment patterns, find the package directory
match_idx = venv_lib_match.start(1)
if match_idx > 0:
# Keep from the package name onward
package_name = venv_lib_match.group(1)
idx = full_path.rfind(
package_name, 0, match_idx + len(package_name)
)
if idx >= 0:
sanitized_path = full_path[idx:]
# If we couldn't identify a package pattern but path contains 'axolotl'
elif "axolotl" in full_path:
idx = full_path.rfind("axolotl")
if idx >= 0:
sanitized_path = full_path[idx:]
# Apply the sanitization to the line
if sanitized_path:
line = line.replace(full_path, sanitized_path)
else:
# If we couldn't identify a package pattern, just keep the filename
filename = os.path.basename(full_path)
if filename:
line = line.replace(full_path, filename)
else:
line = line.replace(full_path, "")
sanitized_lines.append(line)
return "\n".join(sanitized_lines)
def send_errors(func: Callable) -> Callable:
"""
Decorator to send exception info in a function. If an exception is raised, we send
telemetry containing the stack trace and error message.
If an error occurs in a decorated function that is called by another decorated
function, we'll only send telemetry corresponding to the lower-level function.
Args:
func: Function to decorate.
Returns:
Decorated function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
telemetry_manager = TelemetryManager.get_instance()
if not telemetry_manager.enabled:
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except Exception as exception:
# Only track if we're not already handling an error. This prevents us from
# capturing an error more than once in nested decorated function calls.=
global ERROR_HANDLED # pylint: disable=global-statement
if not ERROR_HANDLED:
ERROR_HANDLED = True
# Get function module path
module = getmodule(func)
module_path = (
f"{module.__name__}.{func.__name__}" if module else func.__name__
)
# Get stack trace
stack_trace = "".join(
traceback.format_exception(
type(exception), exception, exception.__traceback__
)
)
stack_trace = sanitize_stack_trace(stack_trace)
# Send error telemetry
telemetry_manager.send_event(
event_type=f"{module_path}-errored",
properties={
"exception": str(exception),
"stack_trace": stack_trace,
},
)
raise
return wrapper

View File

@@ -0,0 +1,399 @@
"""Telemetry manager and associated utilities."""
import atexit
import importlib
import logging
import os
import platform
import time
import uuid
from pathlib import Path
from typing import Any
import posthog
import psutil
import torch
import yaml
LOG = logging.getLogger(__name__)
POSTHOG_HOST = "https://app.posthog.com"
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
ENABLED_WARNING_SLEEP_SECONDS = 15
ENABLED_WARNING = (
"\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
"- Which models and configurations are most commonly used\n"
"- What hardware setups need to be supported\n"
"- Where users encounter errors\n\n"
"This data helps us prioritize features, optimize performance, and fix bugs.\n\n"
"To disable telemetry, set either:\n"
"- AXOLOTL_DO_NOT_TRACK=1 (Axolotl-specific)\n"
"- DO_NOT_TRACK=1 (Global standard; see https://consoledonottrack.com/)\n\n"
"To remove this warning and continue with telemetry enabled,"
"explicitly set AXOLOTL_DO_NOT_TRACK=0 (and leave DO_NOT_TRACK unset / set to 0)\n\n"
"No personally identifiable information is collected."
"For details, see: https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html\n\n"
f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..."
)
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
# NOTE: Keep these up to date with any config schema changes
FIELDS_WITH_ORGS = {
"base_model",
"tokenizer_config",
"base_model_config",
"pretraining_dataset", # NOTE: this field may be a string or a dictionary
}
FIELDS_TO_REDACT = {"resume_from_checkpoint", "hub_model_id"}
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
PATH_INDICATORS = {"path", "dir"}
# pylint: disable=duplicate-code
RELEVANT_PACKAGES = {
"torch",
"transformers",
"trl",
"datasets",
"peft",
"bitsandbytes",
"accelerate",
"optimum",
"deepspeed",
"ray",
"axolotl",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"tokenizers",
"sentencepiece",
"torchao",
"lm_eval",
}
def is_main_process() -> bool:
"""
Check whether we're running in the main process.
Note:
We're using this function instead of `torch.utils.distributed.is_main_process`
causes issues with DeepSpeed world_size since. This function avoids that issue
by checking env vars that are set by various launchers.
Returns:
Whether we're running in the main process.
"""
# If PyTorch distributed is already initialized, use it
if torch.distributed.is_initialized():
return torch.distributed.get_rank() == 0
# Otherwise check environment variables for global rank
# NOTE: need to verify this in SLURM / OpenMPI environments
global_rank = int(
os.environ.get(
"RANK",
os.environ.get(
"GLOBAL_RANK",
os.environ.get(
"SLURM_PROCID",
os.environ.get(
"OMPI_COMM_WORLD_RANK",
"0",
),
),
),
)
)
return global_rank == 0
class TelemetryManager:
"""Manages telemetry collection and transmission"""
_instance = None
_initialized = False
def __new__(cls):
"""
Telemetry manager constructor. Creates the singleton instance of this class if
it doesn't already exist.
"""
if cls._instance is None:
cls._instance = super(TelemetryManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""Telemetry manager initializer"""
if self._initialized:
return
self.enabled, self.explicit_enable = self._check_telemetry_enabled()
if self.enabled:
self.run_id = str(uuid.uuid4())
self.whitelist = self._load_whitelist()
try:
self.system_info = self._get_system_info()
except Exception as e: # pylint: disable=broad-exception-caught
LOG.warning(f"Error during system info collection: {e}")
self.system_info = None
self._init_posthog()
# Register shutdown method to flush posthog telemetry
atexit.register(self.shutdown)
self._initialized = True
@classmethod
def get_instance(cls) -> "TelemetryManager":
if cls._instance is None:
cls._instance = TelemetryManager()
return cls._instance
def _check_telemetry_enabled(self) -> tuple[bool, bool]:
"""
Check if telemetry is enabled based on environment variables. We also check
whether this is the main process (for the distributed setting and to avoid
sending duplicate PostHog events per GPU).
Note: This is enabled by default on an opt-out basis. Set either
`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1` to disable telemetry. For more
details, see https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
Returns:
Tuple containing:
- Boolean denoting whether telemetry is enabled or disabled.
- Boolean denoting whether telemetry is explicitly enabled or not.
"""
# Parse relevant env vars and fill opt-out default values
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
do_not_track = os.getenv("DO_NOT_TRACK")
# If explicitly enabled, we'll disable the telemetry warning message
explicit_enabled = axolotl_do_not_track in ["0", "false"]
if axolotl_do_not_track is None:
axolotl_do_not_track = "0"
if do_not_track is None:
do_not_track = "0"
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
enabled = axolotl_do_not_track.lower() not in (
"1",
"true",
) and do_not_track.lower() not in ("1", "true")
# Show warning (and sleep on all ranks) unless explicitly enabled
if enabled and not explicit_enabled:
if is_main_process():
LOG.warning(ENABLED_WARNING)
time.sleep(ENABLED_WARNING_SLEEP_SECONDS)
# Only rank 0 will send telemetry
if not is_main_process():
return False, False
return enabled, explicit_enabled
def _load_whitelist(self) -> dict:
"""Load HuggingFace Hub organization whitelist"""
with open(WHITELIST_PATH, encoding="utf-8") as f:
return yaml.safe_load(f)
def _is_whitelisted(self, base_model: str) -> bool:
"""Check if model org is in whitelist"""
if not base_model:
return False
base_model = base_model.lower()
return any(
org.lower() in base_model for org in self.whitelist.get("organizations", [])
)
def _init_posthog(self):
"""Initialize PostHog client"""
posthog.host = POSTHOG_HOST
posthog.project_api_key = POSTHOG_WRITE_KEY
def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]:
"""
Redact properties to remove any paths, so as to avoid inadvertently collecting
private or personally identifiable information (PII). We also remove
information related to Wandb, MLflow, etc. configuration.
Args:
properties: Dictionary of properties to redact.
Returns:
Properties dictionary with redaction applied.
"""
if not properties:
return {}
def redact_value(value: Any, key: str = "") -> Any:
"""Recursively sanitize values, redacting those with path-like keys"""
if isinstance(key, str) and isinstance(value, str):
# Fields that should be redacted if org is not whitelisted
if key in FIELDS_WITH_ORGS:
org = value.split("/")[0]
if org not in self.whitelist["organizations"]:
return "[REDACTED]"
# Other redaction special cases
if (
key in FIELDS_TO_REDACT
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
or any(indicator in key.lower() for indicator in PATH_INDICATORS)
):
return "[REDACTED]"
# Handle nested structures
if isinstance(value, dict):
return {k: redact_value(v, k) for k, v in value.items()}
if isinstance(value, list):
return [redact_value(item) for item in value]
return value
# Create new dict with redacted values
redacted = {k: redact_value(v, k) for k, v in properties.items()}
return redacted
def _get_system_info(self) -> dict[str, Any]:
"""Collect system information for various hardware accelerators"""
gpu_info = []
accelerator_type = "none"
# NVIDIA GPUs
if torch.cuda.is_available():
accelerator_type = "cuda"
for i in range(torch.cuda.device_count()):
gpu_info.append(
{
"name": torch.cuda.get_device_name(i),
"memory": torch.cuda.get_device_properties(i).total_memory,
}
)
# AMD GPUs
elif hasattr(torch, "hip") and torch.hip.is_available():
accelerator_type = "hip"
for i in range(torch.hip.device_count()):
gpu_info.append(
{
"name": torch.hip.get_device_name(i),
"memory": torch.hip.get_device_properties(i).total_memory
if hasattr(torch.hip, "get_device_properties")
else None,
}
)
# Apple Silicon
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
accelerator_type = "mps"
gpu_info.append(
{
"name": "Apple Silicon",
# NOTE: this is memory allocated to this process, not total memory
"memory": torch.mps.driver_allocated_memory(),
}
)
# Intel GPUs
elif hasattr(torch, "xpu") and torch.xpu.is_available():
accelerator_type = "xpu"
for i in range(torch.xpu.device_count()):
memory = None
if hasattr(torch.xpu, "get_device_properties"):
memory = torch.xpu.get_device_properties(i).total_memory
gpu_info.append(
{
"name": torch.xpu.get_device_name(i),
"memory": memory,
}
)
# NPUs
elif hasattr(torch, "npu") and torch.npu.is_available():
accelerator_type = "npu"
for i in range(torch.npu.device_count()):
memory = None
if hasattr(torch.npu, "get_device_properties"):
memory = torch.npu.get_device_properties(i).total_memory
gpu_info.append(
{
"name": torch.npu.get_device_name(i),
"memory": memory,
}
)
# Get relevant package versions
installed_packages = {}
for package in RELEVANT_PACKAGES:
try:
version = importlib.metadata.version(package)
installed_packages[f"{package}_version"] = version
except importlib.metadata.PackageNotFoundError:
pass
return {
"os": platform.system(),
"python_version": platform.python_version(),
"cpu_count": psutil.cpu_count(),
"memory_total": psutil.virtual_memory().total,
"accelerator_type": accelerator_type,
"accelerator_count": len(gpu_info),
"accelerator_info": gpu_info,
**installed_packages,
}
def send_event(self, event_type: str, properties: dict[str, Any] | None = None):
"""Send a telemetry event"""
if not self.enabled:
return
if properties is None:
properties = {}
# Sanitize properties to remove PII
properties = self._redact_paths(properties)
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
try:
# Send event via PostHog
posthog.capture(
distinct_id=self.run_id,
event=event_type,
properties=properties,
disable_geoip=True,
)
except Exception as e: # pylint: disable=broad-exception-caught
LOG.warning(f"Failed to send telemetry event: {e}")
# Additionally, send system info telemetry when loading config.
# NOTE: Is this the best place for this?
if event_type == "config-loaded":
self.send_system_info()
def send_system_info(self):
"""Helper method for sending system info"""
self.send_event(event_type="system-info", properties=self.system_info)
def shutdown(self):
"""Ensure all queued events are processed before shutdown"""
if self.enabled:
posthog.flush()

View File

@@ -0,0 +1,209 @@
"""Telemetry utilities for runtime and memory metrics."""
import logging
import time
from dataclasses import dataclass, field
from typing import Any
import psutil
import torch
from axolotl.telemetry.manager import TelemetryManager
LOG = logging.getLogger(__name__)
@dataclass
class RuntimeMetrics:
"""Container for runtime metrics to be tracked throughout training."""
# Timing metrics
start_time: float
epoch_start_times: dict[int, float] = field(init=False)
epoch_end_times: dict[int, float] = field(init=False)
# Memory metrics
peak_cpu_memory: int = 0
peak_gpu_memory: dict[int, int] = field(init=False)
# Progress metrics
total_steps: int = 0
current_epoch: int = 0
current_step: int = 0
def __post_init__(self):
"""Initialize empty metric mappings."""
self.epoch_start_times = {}
self.epoch_end_times = {}
self.peak_gpu_memory = {}
@property
def elapsed_time(self) -> float:
"""Calculate total elapsed time in seconds."""
return time.time() - self.start_time
def epoch_time(self, epoch: int) -> float | None:
"""Calculate time taken for a specific epoch in seconds."""
if epoch in self.epoch_start_times and epoch in self.epoch_end_times:
return self.epoch_end_times[epoch] - self.epoch_start_times[epoch]
return None
def average_epoch_time(self) -> float | None:
"""Calculate average time per epoch in seconds."""
completed_epochs = [
epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times
]
if not completed_epochs:
return None
total_time = 0.0
for epoch in completed_epochs:
epoch_time = self.epoch_time(epoch)
if epoch_time is not None: # Check to avoid mypy warning
total_time += epoch_time
return total_time / len(completed_epochs)
def steps_per_second(self) -> float | None:
"""Calculate average steps per second across all training."""
if self.total_steps == 0 or self.elapsed_time == 0:
return None
return self.total_steps / self.elapsed_time
def to_dict(self) -> dict[str, Any]:
"""Convert metrics to a dictionary for telemetry reporting."""
metrics = {
"total_time_seconds": self.elapsed_time,
"total_steps": self.total_steps,
"steps_per_second": self.steps_per_second(),
"epochs_completed": len(
[
epoch
for epoch in self.epoch_start_times
if epoch in self.epoch_end_times
]
),
"peak_cpu_memory_bytes": self.peak_cpu_memory,
}
# Add per-epoch timing if available
epoch_times: dict[str, float] = {}
for epoch in sorted(self.epoch_end_times.keys()):
time_taken = self.epoch_time(epoch)
if time_taken is not None:
epoch_times[f"epoch_{epoch}_seconds"] = time_taken
if epoch_times:
metrics["epoch_times"] = epoch_times # type: ignore
metrics["average_epoch_time_seconds"] = self.average_epoch_time()
# Add GPU memory metrics if available
if self.peak_gpu_memory:
gpu_metrics: dict[str, int] = {}
for gpu_id, memory in self.peak_gpu_memory.items():
gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory
metrics["gpu_memory"] = gpu_metrics # type: ignore
return metrics
class RuntimeMetricsTracker:
"""Tracker for runtime metrics during training."""
update_interval = 100
def __init__(self):
"""Initialize the runtime metrics tracker."""
self.metrics = RuntimeMetrics(start_time=time.time())
self.telemetry_manager = TelemetryManager.get_instance()
def start_epoch(self, epoch: int):
"""Record the start of a new epoch."""
self.metrics.current_epoch = epoch
self.metrics.epoch_start_times[epoch] = time.time()
self.update_memory_metrics()
def end_epoch(self, epoch: int):
"""Record the end of an epoch."""
self.metrics.epoch_end_times[epoch] = time.time()
def update_step(self, step: int):
"""Update the current step count."""
self.metrics.current_step = step
self.metrics.total_steps += 1
# Periodically update memory metrics
if step % self.update_interval == 0:
self.update_memory_metrics()
def _get_allocated_memory(self) -> dict[int, int]:
"""
Helper function for getting accelerator-agnostic allocated memory.
Returns:
A dictionary mapping device IDs to allocated memory in bytes
"""
memory_used: dict[int, int] = {}
# NVIDIA GPUs
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
memory_used[i] = torch.cuda.memory_allocated(i)
# AMD GPUs
elif hasattr(torch, "hip") and torch.hip.is_available():
for i in range(torch.hip.device_count()):
if hasattr(torch.hip, "memory_allocated"):
memory_used[i] = torch.hip.memory_allocated(i)
# Apple Silicon
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
# MPS doesn't have per-device memory stats since there's only one device
if hasattr(torch.mps, "current_allocated_memory"):
memory_used[0] = torch.mps.current_allocated_memory()
# Intel GPUs
elif hasattr(torch, "xpu") and torch.xpu.is_available():
for i in range(torch.xpu.device_count()):
if hasattr(torch.xpu, "memory_allocated"):
memory_used[i] = torch.xpu.memory_allocated(i)
# NPUs
elif hasattr(torch, "npu") and torch.npu.is_available():
for i in range(torch.npu.device_count()):
if hasattr(torch.npu, "memory_allocated"):
memory_used[i] = torch.npu.memory_allocated(i)
return memory_used
def update_memory_metrics(self):
"""Update peak memory usage metrics."""
# CPU memory
cpu_memory = psutil.Process().memory_info().rss
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
# GPU memory (if available)
memory_used = self._get_allocated_memory()
for i, memory in memory_used.items():
self.metrics.peak_gpu_memory[i] = max(
self.metrics.peak_gpu_memory.get(i, 0), memory
)
def get_memory_metrics(self) -> dict[str, Any]:
"""Get the current memory metrics as a dictionary."""
memory_metrics = {
"cpu_memory_bytes": psutil.Process().memory_info().rss,
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
}
# GPU memory (if available)
memory_used = self._get_allocated_memory()
for i, memory in memory_used.items():
memory_metrics[f"gpu_{i}_memory_bytes"] = memory
memory_metrics[
f"gpu_{i}_peak_memory_bytes"
] = self.metrics.peak_gpu_memory.get(i, 0)
return memory_metrics

View File

@@ -0,0 +1,18 @@
organizations:
- "axolotl-ai-co"
- "meta-llama"
- "huggingface"
- "nvidia"
- "facebook"
- "google"
- "microsoft"
- "deepseek-ai"
- "HuggingFaceTB"
- "mistralai"
- "Qwen"
- "briaai"
- "unsloth"
- "NousResearch"
- "allenai"
- "amd"
- "tiiuae"

View File

@@ -7,24 +7,23 @@ import signal
import sys
import weakref
from pathlib import Path
from typing import Any, Dict
from typing import Tuple, Union
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftConfig, PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging
from axolotl.telemetry.errors import send_errors
from axolotl.telemetry.manager import TelemetryManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
@@ -35,25 +34,20 @@ try:
except ImportError:
BetterTransformer = None
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
def setup_model_and_tokenizer(
cfg: DictDefault,
) -> tuple[
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
]:
"""
Load the tokenizer, processor (for multimodal models), and model based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else
`None`), and processor (if multimodal, else `None`).
"""
@send_errors
def train(
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
@@ -66,58 +60,11 @@ def setup_model_and_tokenizer(
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
# Load the model and peft_config
msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
# Get datasets
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
# Apply freezing if specified
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
return model, tokenizer, peft_config, processor
def setup_reference_model(
cfg: DictDefault, tokenizer: PreTrainedTokenizer
) -> PreTrainedModel | None:
"""
Set up the reference model for RL training if needed.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: The tokenizer to use for the reference model.
Returns:
Reference model if needed for RL training, `None` otherwise.
"""
model_ref = None
if cfg.rl and cfg.rl != "orpo":
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
return model_ref
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
"""
Determine the checkpoint to resume from based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Path to the checkpoint to resume from, or `None` if not resuming.
"""
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
@@ -131,22 +78,85 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
return cfg.resume_from_checkpoint
resume_from_checkpoint = cfg.resume_from_checkpoint
# Load model
msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
def setup_signal_handler(
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
):
"""
Set up signal handler for graceful termination.
TELEMETRY_MANAGER.send_event(
event_type="model-loaded", properties=model.config.to_dict()
)
if peft_config:
TELEMETRY_MANAGER.send_event(
event_type="peft-config-loaded", properties=peft_config.to_dict()
)
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to save on termination
safe_serialization: Whether to use safe serialization when saving
"""
# ray workers don't have access to this signal
if cfg.local_rank == 0 and not cfg.use_ray:
model_ref = None
if cfg.rl and cfg.rl != "orpo":
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None # explicit setting to None
else:
# load the model again for model_ref / baseline
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
safe_serialization = cfg.save_safetensors is True
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
trainer = setup_trainer(
cfg,
train_dataset,
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
processor,
total_num_steps,
)
if cfg.fix_untrained_tokens:
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# additionally presave the tokenizer and model configs
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"):
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if (
cfg.local_rank == 0 and not cfg.use_ray
): # ray workers don't have access to this signal
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
@@ -164,18 +174,15 @@ def setup_signal_handler(
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
)
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
def execute_training(
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
):
"""
Execute the training process with appropriate backend configurations.
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = importlib.metadata.version("axolotl")
if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
@@ -190,31 +197,13 @@ def execute_training(
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
def save_trained_model(
cfg: DictDefault,
trainer: Any,
model: PreTrainedModel,
safe_serialization: bool,
):
"""
Save the trained model according to configuration and training setup.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object.
model: The trained model to save.
safe_serialization: Whether to use safe serialization.
"""
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
# Post training module hooks
# post training
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
# Handle FSDP state dict type
state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled:
if cfg.fsdp_final_state_dict_type:
@@ -222,18 +211,16 @@ def save_trained_model(
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
# Handle ReLoRA early return case
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
return model, tokenizer
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
# TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple
# processes attempt to write the same file
if (
state_dict_type == "SHARDED_STATE_DICT"
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
@@ -265,6 +252,7 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
@@ -275,241 +263,40 @@ def save_trained_model(
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def create_model_card(cfg: DictDefault, trainer: Trainer):
"""
Create a model card for the trained model if needed.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object with model card creation capabilities.
"""
if not cfg.hub_model_id:
# Guard since create_model_card may fail if dataset_tags is empty list
try:
model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./")
.encode("utf-8")
.decode("utf-8")
}
# We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.
rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model
if cfg.datasets is not None and not rl:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
if dataset_tags:
model_card_kwarg["dataset_tags"] = dataset_tags
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [
d for d in dataset_tags if not d.startswith("https://")
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_name"] = dataset_tags
else:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [
d for d in dataset_tags if not d.startswith("https://")
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError):
pass
elif cfg.hub_model_id:
# Defensively push to the hub to ensure the model card is updated
# defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()
def save_initial_configs(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
peft_config: PeftConfig | None,
):
"""
Save initial configurations before training.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: The tokenizer to save.
model: The model to save configuration for.
peft_config: The PEFT configuration to save if applicable.
"""
# Create output_dir if it doesn't already exist
output_dir = Path(cfg.output_dir)
if not output_dir.is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
# Pre-save adapter config so it's available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...")
peft_config.save_pretrained(cfg.output_dir)
# Pre-save the tokenizer and model configs
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
tokenizer.save_pretrained(str(output_dir))
if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
def setup_model_card(cfg: DictDefault):
"""
Set up the Axolotl badge and add the Axolotl config to the model card if available.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = importlib.metadata.version("axolotl")
if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
def handle_untrained_tokens_fix(
cfg: DictDefault,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
train_dataset: Dataset,
safe_serialization: bool,
):
"""
Apply fixes for untrained tokens if configured.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to apply fixes to.
tokenizer: The tokenizer for token identification.
train_dataset: The training dataset to use.
safe_serialization: Whether to use safe serialization when saving.
"""
if not cfg.fix_untrained_tokens:
return
is_ds_zero3: bool = False
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
is_ds_zero3 = True
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
fix_kwargs: Dict[str, Any] = {}
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
if "is_ds_zero3" in sig.parameters:
fix_kwargs["is_ds_zero3"] = is_ds_zero3
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
def setup_model_and_trainer(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
]:
"""
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
trainer setup.
Args:
cfg: The configuration dictionary with training parameters.
dataset_meta: Object with training, validation datasets and metadata.
Returns:
Tuple of:
- Trainer (Causal or RLHF)
- Model
- Tokenizer
- PEFT config
"""
# Load tokenizer, processor and model
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
# Set up reference model for RL if needed
model_ref = setup_reference_model(cfg, tokenizer)
# Get datasets from metadata
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Set up trainer
trainer = setup_trainer(
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=model,
tokenizer=tokenizer,
processor=processor,
total_num_steps=total_num_steps,
model_ref=model_ref,
peft_config=peft_config,
)
return (
trainer,
model,
tokenizer,
peft_config,
)
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
"""
Train a model on the given dataset.
Args:
cfg: The configuration dictionary with training parameters
dataset_meta: Object with training, validation datasets and metadata
Returns:
Tuple of (model, tokenizer) after training
"""
# Setup model, tokenizer, (causal or RLHF) trainer etc.
(
trainer,
model,
tokenizer,
peft_config,
) = setup_model_and_trainer(cfg, dataset_meta)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Handle untrained tokens if configured
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
# Save initial configs
save_initial_configs(cfg, tokenizer, model, peft_config)
# Set up signal handler for graceful termination
setup_signal_handler(cfg, model, safe_serialization)
# Set up badges and config info for model card
setup_model_card(cfg)
# Execute the training
execute_training(cfg, trainer, resume_from_checkpoint)
# Save the trained model
save_trained_model(cfg, trainer, model, safe_serialization)
# Create model card
create_model_card(cfg, trainer)
return model, tokenizer, trainer
return model, tokenizer

View File

@@ -64,17 +64,6 @@ class ChatTemplate(str, Enum):
metharme = "metharme" # pylint: disable=invalid-name
class CustomSupportedOptimizers(str, Enum):
"""Custom supported optimizers"""
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
@@ -505,7 +494,17 @@ class HyperparametersConfig(BaseModel):
embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[OptimizerNames, CustomSupportedOptimizers]
Union[
OptimizerNames,
Literal[
"lion_pytorch",
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
],
]
] = OptimizerNames.ADAMW_HF
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None,
@@ -519,7 +518,7 @@ class HyperparametersConfig(BaseModel):
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
Union[SchedulerType, Literal["one_cycle"]]
] = SchedulerType.COSINE
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
@@ -728,7 +727,7 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None
@@ -779,9 +778,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[
Union[Literal["unsloth", "offload"], bool]
] = Field(default=False)
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
default=False
)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None
@@ -856,7 +855,6 @@ class AxolotlInputConfig(
special_tokens: Optional[SpecialTokensConfig] = None
tokens: Optional[List[str]] = None
added_tokens_overrides: Optional[Dict[int, str]] = None
torch_compile: Optional[Union[Literal["auto"], bool]] = None
torch_compile_backend: Optional[str] = None
@@ -1155,15 +1153,6 @@ class AxolotlInputConfig(
raise ValueError("gradient_checkpointing is not supported for MPT models")
return self
@model_validator(mode="after")
def check_offload_grad_checkpointing(self):
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
LOG.warning(
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
)
self.gradient_checkpointing = "offload"
return self
@model_validator(mode="after")
def check_better_transformers(self):
if self.flash_optimum is True:
@@ -1188,13 +1177,6 @@ class AxolotlInputConfig(
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
@model_validator(mode="before")
@classmethod
def check_lr_groups(cls, data):
if data.get("lr_groups") and data.get("loraplus_lr_ratio"):
raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.")
return data
@model_validator(mode="before")
@classmethod
def check_saves(cls, data):
@@ -1701,7 +1683,7 @@ class AxolotlInputConfig(
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
"""Wrapper to validate GPU capabilities with the config options"""
capabilities: GPUCapabilities
env_capabilities: EnvCapabilities

View File

@@ -1,8 +1,7 @@
"""
GRPO specific configuration args
"""
from typing import Optional
from typing import List, Optional
from pydantic import BaseModel, Field
@@ -12,10 +11,7 @@ class TRLConfig(BaseModel):
Input args for TRL.
"""
beta: Optional[float] = Field(
default=None,
json_schema_extra={"description": "Beta for RL training"},
)
beta: Optional[float] = None
max_completion_length: Optional[int] = Field(
default=None,
json_schema_extra={
@@ -24,68 +20,16 @@ class TRLConfig(BaseModel):
)
# GRPO specific args
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
use_vllm: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"},
)
vllm_device: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
vllm_gpu_memory_utilization: Optional[float] = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
vllm_dtype: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
use_vllm: Optional[bool] = False
vllm_device: Optional[str] = "auto"
vllm_gpu_memory_utilization: Optional[float] = 0.9
vllm_max_model_len: Optional[int] = None
vllm_dtype: Optional[str] = "auto"
reward_funcs: Optional[list[str]] = Field(
default=None,
json_schema_extra={"description": "List of reward functions to load"},
)
reward_weights: Optional[list[float]] = Field(
default=None,
json_schema_extra={
"description": "Weights for each reward function. Must match the number of reward functions."
},
)
num_generations: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
},
)
log_completions: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to log completions"},
)
sync_ref_model: Optional[bool] = Field(
default=False,
json_schema_extra={
"description": (
"Whether to sync the reference model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
)
},
)
ref_model_mixup_alpha: Optional[float] = Field(
default=0.9,
json_schema_extra={
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
},
)
ref_model_sync_steps: Optional[int] = Field(
default=64,
json_schema_extra={
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
)
reward_funcs: Optional[List[str]] = None
num_generations: Optional[int] = None
log_completions: Optional[bool] = False
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64

View File

@@ -121,7 +121,6 @@ def drop_long_rl_seq(
def load_prepare_preference_datasets(cfg):
import pdb; pdb.set_trace()
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token

View File

@@ -79,7 +79,7 @@ def is_main_process():
def is_local_main_process():
return PartialState().is_local_main_process
return PartialState().is_main_process
def get_world_size():

View File

@@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import (
)
def hf_grad_checkpoint_offload_wrapper(
def hf_grad_checkpoint_unsloth_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(

View File

@@ -54,17 +54,12 @@ from axolotl.monkeypatch.multipack import (
patch_for_multipack,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.telemetry.errors import send_errors
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
barrier,
get_device_count,
get_device_type,
is_local_main_process,
zero_only,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -171,95 +166,8 @@ def load_model_config(cfg):
return model_config
def modify_tokenizer_files(
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
) -> str:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
Args:
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
Returns:
Path to the modified tokenizer directory
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
"""
import json
# Create the tokenizer directory in output_dir if it doesn't exist
tokenizer_dir = os.path.join(output_dir, "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
# Get the token IDs and map them to their new values
token_id_mappings = {
int(token_id): new_value for token_id, new_value in token_mappings.items()
}
# 1. Update tokenizer_config.json - added_tokens_decoder
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
# Update added_tokens_decoder
if "added_tokens_decoder" in config_data:
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str][
"content"
] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
# Write the updated config back
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
# 2. Update tokenizer.json - added_tokens
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
# Update added_tokens
if "added_tokens" in tokenizer_data:
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
raise ValueError(
f"Token ID {token_id} not found in added_tokens"
)
# Write the updated tokenizer data back
with open(tokenizer_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
barrier()
return tokenizer_dir
@send_errors
def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
@@ -274,18 +182,8 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
# Set base tokenizer path
tokenizer_path = cfg.tokenizer_config
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
)
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_path,
cfg.tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
@@ -422,6 +320,7 @@ def load_tokenizer(cfg):
return tokenizer
@send_errors
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
@@ -493,8 +392,8 @@ class ModelLoader:
patch_fa_peft_integration()
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention:
self.patch_attention()
@@ -1296,18 +1195,17 @@ class ModelLoader:
return self.model, lora_config
@send_errors
def load_model(
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
*,
processor: ProcessorMixin = None, # pylint: disable=unused-argument
processor: ProcessorMixin = None,
inference: bool = False,
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
"""
**kwargs,
) -> Tuple[PreTrainedModel, PeftConfig | None]:
"""Load a model for a given configuration and tokenizer"""
loader = ModelLoader(
cfg,
tokenizer,
@@ -1319,6 +1217,7 @@ def load_model(
return loader.load_model()
@send_errors
def load_adapter(model, cfg, adapter, inference=False):
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]

View File

@@ -6,80 +6,6 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
class RexLR(LRScheduler):
"""
Reflected Exponential (REX) learning rate scheduler.
- Original implementation: https://github.com/IvanVassi/REX_LR
- Original license: Apache 2.0
- Based on: https://arxiv.org/abs/2107.04197
Args:
optimizer (torch.optim.Optimizer): The optimizer to schedule the learning rate for.
max_lr (float): The maximum learning rate.
min_lr (float): The minimum learning rate.
total_steps (int): The total number of training steps.
num_warmup_steps (int): The number of warmup steps.
last_step (int): The index of last step.
"""
def __init__(
self, optimizer, max_lr, min_lr, total_steps=0, num_warmup_steps=0, last_step=0
):
if min_lr > max_lr:
raise ValueError(
f'Value of "min_lr" should be less than value of "max_lr". Got min_lr={min_lr} and max_lr={max_lr}'
)
if num_warmup_steps > total_steps:
raise ValueError(
f"num_warmup_steps ({num_warmup_steps}) must be less than or equal to total_steps ({total_steps})."
)
self.min_lr = min_lr
self.max_lr = max_lr
self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps
self.last_step = last_step - 1
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
# Pass self.last_step as last_epoch to the parent.
super().__init__(optimizer, last_epoch=self.last_step)
@property
def last_step(self):
return self.last_epoch
@last_step.setter
def last_step(self, value):
self.last_epoch = value
def get_lr(self):
# Warmup phase: if defined, increase lr linearly from 0 to max_lr.
if 1 <= self.last_step <= self.num_warmup_steps:
return [
base_lr * self.last_step / self.num_warmup_steps
for base_lr in self.base_lrs
]
# Post-warmup phase: adjust step relative to the end of warmup.
step_after = self.last_step - self.num_warmup_steps
remaining_steps = self.total_steps - self.num_warmup_steps
# Avoid LR spiking
if step_after >= remaining_steps or step_after == -1 or remaining_steps <= 0:
return [self.min_lr for _ in self.base_lrs]
mod_iter = step_after % remaining_steps
z = (remaining_steps - mod_iter) / remaining_steps
rex_factor = self.min_lr / self.max_lr + (1.0 - self.min_lr / self.max_lr) * (
z / (0.1 + 0.9 * z)
)
return [base_lr * rex_factor for base_lr in self.base_lrs]
class InterpolatingLogScheduler(LRScheduler):
"""
A scheduler that interpolates learning rates in a logarithmic fashion

View File

@@ -574,40 +574,14 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg,
train_dataset,
eval_dataset,
model,
tokenizer,
processor,
total_num_steps,
model_ref=None,
peft_config=None,
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
"""
Helper method for instantiating and building a (causal or RLHF) trainer.
Args:
cfg: Axolotl config object containing training parameters.
train_dataset: Dataset to use for training.
eval_dataset: Dataset to use for evaluation.
model: The model to train.
tokenizer: Tokenizer for processing text input.
processor: Processor for data preparation.
total_num_steps: The total number of training steps.
model_ref: Optional reference model for RLHF training. Default is None.
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
Returns:
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model_ref
trainer_builder.peft_config = peft_config
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]
else:
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset

View File

@@ -14,7 +14,7 @@
h1 {
font-family: var(--font-title);
font-weight: 400;
font-size: 5rem;
font-size: 6rem;
line-height: 1.1;
letter-spacing: -0.05em;
font-feature-settings: "ss01" on;

View File

@@ -28,7 +28,7 @@ class TestTrainCommand(BaseCliTest):
config_path.write_text(valid_test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,
@@ -48,7 +48,7 @@ class TestTrainCommand(BaseCliTest):
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,

View File

@@ -1,6 +1,5 @@
"""
shared pytest fixtures
"""
"""Shared pytest fixtures"""
import functools
import importlib
import shutil
@@ -171,3 +170,9 @@ def cleanup_monkeypatches():
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)
@pytest.fixture(autouse=True)
def disable_telemetry(monkeypatch):
monkeypatch.setenv("AXOLOTL_DO_NOT_TRACK", "1")
yield

View File

@@ -69,51 +69,6 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# pylint: disable=redefined-outer-name
def test_qwen2_w_cce(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"plugins": [
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
],
"cut_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
}
)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize(
"attention_type",
[

View File

@@ -750,66 +750,3 @@ class TestMultiGPULlama:
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
)

View File

@@ -66,54 +66,6 @@ class TestLlama:
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens_already_trained(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{

View File

@@ -75,7 +75,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -190,7 +190,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -249,7 +249,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32

View File

@@ -65,9 +65,8 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert trainer.optimizer.optimizer.__class__.__name__ == "AdamW"
@with_temp_dir
@require_torch_2_5_1
@@ -112,57 +111,8 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "ADOPT" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
@require_torch_2_5_1
def test_muon(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "muon",
"lr_scheduler": "cosine",
"weight_decay": 0.01,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):

View File

@@ -1,71 +0,0 @@
"""
E2E tests for custom schedulers using Llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomSchedulers(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_rex_scheduler(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_hf",
"max_steps": 20,
"lr_scheduler": "rex",
"warmup_steps": 5,
"cosine_min_lr_ratio": 0.05,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

View File

@@ -0,0 +1,9 @@
"""Shared pytest fixtures for telemetry tests."""
import pytest
@pytest.fixture(autouse=True)
def disable_telemetry(monkeypatch):
monkeypatch.delenv("AXOLOTL_DO_NOT_TRACK", raising=False)
yield

View File

@@ -0,0 +1,372 @@
"""Tests for telemetry callback module."""
# pylint: disable=redefined-outer-name
import time
from unittest.mock import MagicMock, patch
import pytest
from transformers import TrainerControl, TrainerState, TrainingArguments
from axolotl.telemetry.callbacks import TIME_SINCE_LAST, TelemetryCallback
def calc_expected_metrics(step, last_step, current_time, last_time, start_time=900.0):
"""Calculate expected metrics values for tests"""
time_diff = current_time - last_time
step_diff = step - last_step
return {
"steps_per_second": step_diff / time_diff
if time_diff > 0 and step_diff > 0
else 0,
"time_since_last_report": time_diff,
"elapsed_time": current_time - start_time,
}
@pytest.fixture
def mock_time():
"""Mock time.time() to have predictable values in tests"""
with patch("axolotl.telemetry.callbacks.time") as mock_time:
mock_time.time.return_value = 1000.0
yield mock_time
@pytest.fixture
def mock_telemetry_manager():
"""Create a mock TelemetryManager"""
with patch("axolotl.telemetry.callbacks.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager_class.get_instance.return_value = mock_manager
yield mock_manager
@pytest.fixture
def mock_runtime_metrics_tracker():
"""Create a mock RuntimeMetricsTracker"""
with patch(
"axolotl.telemetry.callbacks.RuntimeMetricsTracker"
) as mock_tracker_class:
mock_tracker = MagicMock()
# Set up metrics property on the tracker
mock_metrics = MagicMock()
mock_metrics.to_dict.return_value = {
"total_steps": 100,
"peak_cpu_memory_bytes": 1024,
}
mock_tracker.metrics = mock_metrics
# Make the constructor return our mock
mock_tracker_class.return_value = mock_tracker
yield mock_tracker
@pytest.fixture
def training_args():
"""Create a minimal TrainingArguments instance"""
return TrainingArguments(output_dir="./output")
@pytest.fixture
def trainer_state():
"""Create a mock TrainerState"""
state = MagicMock(spec=TrainerState)
state.global_step = 10
state.epoch = 0.5 # halfway through first epoch
state.log_history = [{"loss": 2.5, "learning_rate": 5e-5}]
return state
@pytest.fixture
def trainer_control():
"""Create a mock TrainerControl"""
return MagicMock(spec=TrainerControl)
# pylint: disable=unused-argument
@pytest.fixture
def callback(mock_telemetry_manager, mock_runtime_metrics_tracker):
"""Create a TelemetryCallback instance with mocked dependencies"""
return TelemetryCallback()
class TestTelemetryCallback:
"""Tests for the TelemetryCallback class."""
def test_initialization(self, callback, mock_runtime_metrics_tracker):
"""Test callback initialization."""
assert callback.current_epoch == -1
assert callback.tracker == mock_runtime_metrics_tracker
assert callback.last_report_step == 0
assert hasattr(callback, "start_time")
assert hasattr(callback, "last_report_time")
assert callback.report_interval_steps == 100
def test_on_train_begin(
self,
callback,
mock_telemetry_manager,
training_args,
trainer_state,
trainer_control,
):
"""Test on_train_begin sends expected event."""
callback.on_train_begin(training_args, trainer_state, trainer_control)
mock_telemetry_manager.send_event.assert_called_once_with(
event_type="train-started"
)
def test_on_train_end(
self,
callback,
mock_telemetry_manager,
training_args,
trainer_state,
trainer_control,
):
"""Test on_train_end sends expected event with metrics."""
callback.on_train_end(training_args, trainer_state, trainer_control)
mock_telemetry_manager.send_event.assert_called_once()
call_args = mock_telemetry_manager.send_event.call_args[1]
assert call_args["event_type"] == "train-ended"
assert "loss" in call_args["properties"]
assert call_args["properties"]["loss"] == 2.5
assert "learning_rate" in call_args["properties"]
assert call_args["properties"]["learning_rate"] == 5e-5
# Check that metrics from RuntimeMetricsTracker are included
assert "total_steps" in call_args["properties"]
assert call_args["properties"]["total_steps"] == 100
assert "peak_cpu_memory_bytes" in call_args["properties"]
assert call_args["properties"]["peak_cpu_memory_bytes"] == 1024
def test_on_epoch_begin(
self,
callback,
mock_runtime_metrics_tracker,
training_args,
trainer_state,
trainer_control,
):
"""Test on_epoch_begin updates epoch counter and calls tracker."""
initial_epoch = callback.current_epoch
callback.on_epoch_begin(training_args, trainer_state, trainer_control)
assert callback.current_epoch == initial_epoch + 1
mock_runtime_metrics_tracker.start_epoch.assert_called_once_with(
initial_epoch + 1
)
def test_on_epoch_end(
self,
callback,
mock_runtime_metrics_tracker,
training_args,
trainer_state,
trainer_control,
):
"""Test on_epoch_end calls tracker."""
# Set current epoch
callback.current_epoch = 2
callback.on_epoch_end(training_args, trainer_state, trainer_control)
mock_runtime_metrics_tracker.end_epoch.assert_called_once_with(2)
def test_on_step_end_no_report(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker,
training_args,
trainer_state,
trainer_control,
):
"""Test on_step_end updates tracker but doesn't report if criteria not met."""
# Set up state to avoid reporting
trainer_state.global_step = 42 # Not divisible by report_interval_steps
callback.last_report_step = 41 # Just 1 step since last report
callback.last_report_time = time.time() # Just now
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should update tracker
mock_runtime_metrics_tracker.update_step.assert_called_once_with(42)
# Should not send telemetry
mock_telemetry_manager.send_event.assert_not_called()
# Should not update last report time/step
assert callback.last_report_step == 41
def test_on_step_end_report_interval_steps(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker,
mock_time,
training_args,
trainer_state,
trainer_control,
):
"""Test on_step_end reports when step interval is reached."""
# Set up state with clear values
current_step = 100 # Exactly matches report_interval_steps
last_step = 0
start_time = 900.0
current_time = 1000.0
time_diff = current_time - start_time # 100 seconds
# Configure state and callback
trainer_state.global_step = current_step
callback.report_interval_steps = 100
callback.last_report_step = last_step
callback.start_time = start_time
callback.last_report_time = start_time
# Mock time.time() to return consistent values
mock_time.time.return_value = current_time
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should update tracker
mock_runtime_metrics_tracker.update_step.assert_called_once_with(current_step)
mock_runtime_metrics_tracker.update_memory_metrics.assert_called_once()
# Should send telemetry
mock_telemetry_manager.send_event.assert_called_once()
call_args = mock_telemetry_manager.send_event.call_args[1]
assert call_args["event_type"] == "train-progress"
# Properties should include expected values
props = call_args["properties"]
assert props["step"] == current_step
assert props["elapsed_time"] == time_diff # 1000 - 900 = 100
assert props["time_since_last_report"] == time_diff # 1000 - 900 = 100
assert props["steps_per_second"] == 1.0 # 100 steps / 100 seconds
# Should update last report time/step
assert callback.last_report_step == current_step
assert callback.last_report_time == current_time
def test_on_step_end_report_time_elapsed(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
mock_time,
training_args,
trainer_state,
trainer_control,
):
"""Test on_step_end reports when enough time has elapsed."""
# Set up state with clear values
current_step = 120
last_step = 10
start_time = 900.0
current_time = 1000.0
time_diff = TIME_SINCE_LAST + 1 # Just over the threshold
# Configure state and callback
trainer_state.global_step = current_step
callback.report_interval_steps = 100
callback.last_report_step = last_step
callback.start_time = start_time
callback.last_report_time = current_time - time_diff
# Mock time.time() to return consistent values
mock_time.time.return_value = current_time
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should send telemetry
mock_telemetry_manager.send_event.assert_called_once()
# Properties should include expected values
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
expected_metrics = calc_expected_metrics(
current_step, last_step, current_time, current_time - time_diff, start_time
)
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
assert (
props["time_since_last_report"]
== expected_metrics["time_since_last_report"]
)
def test_on_step_end_first_step(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
mock_time,
training_args,
trainer_state,
trainer_control,
):
"""Test on_step_end always reports on first step."""
# Set up state with clear values
current_step = 1 # First step
last_step = 0
start_time = 900.0
current_time = 1000.0
last_report_time = 999.0 # Just 1 second ago
# Configure state and callback
trainer_state.global_step = current_step
callback.report_interval_steps = 100
callback.last_report_step = last_step
callback.start_time = start_time
callback.last_report_time = last_report_time
# Mock time.time() to return consistent values
mock_time.time.return_value = current_time
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should send telemetry even though not much time has passed
mock_telemetry_manager.send_event.assert_called_once()
# Properties should include expected values for first step
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
assert props["step"] == current_step
expected_metrics = calc_expected_metrics(
current_step, last_step, current_time, last_report_time, start_time
)
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
def test_log_history_empty(
self,
callback,
mock_telemetry_manager,
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
mock_time,
training_args,
trainer_state,
trainer_control,
):
"""Test handling of empty log history."""
# Set up state with clear values
current_step = 1
start_time = 900.0
current_time = 1000.0
# Configure state and callback
trainer_state.global_step = current_step
trainer_state.log_history = []
callback.start_time = start_time
# Mock time.time() to return consistent values
mock_time.time.return_value = current_time
callback.on_step_end(training_args, trainer_state, trainer_control)
# Should still send telemetry
mock_telemetry_manager.send_event.assert_called_once()
# Properties should have default values for missing log data
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
assert props["loss"] == 0
assert props["learning_rate"] == 0

View File

@@ -0,0 +1,340 @@
"""Tests for telemetry error utilities"""
# pylint: disable=redefined-outer-name
from unittest.mock import MagicMock, patch
import pytest
from axolotl.telemetry.errors import sanitize_stack_trace, send_errors
@pytest.fixture(autouse=True)
def reset_error_flag(monkeypatch):
"""Reset ERROR_HANDLED flag using monkeypatch"""
import axolotl.telemetry.errors
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
yield
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
@pytest.fixture
def example_stack_trace():
"""Provide a sample stack trace with mixed paths"""
return """Traceback (most recent call last):
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
trainer = get_trainer(cfg)
File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer
model = get_model(cfg, tokenizer)
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model
raise ValueError("Model path not found")
ValueError: Model path not found
"""
@pytest.fixture
def windows_stack_trace():
"""Provide a sample stack trace with Windows paths"""
return """Traceback (most recent call last):
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main
trainer = get_trainer(cfg)
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer
model = get_model(cfg, tokenizer)
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained
raise ValueError(f"Unrecognized configuration class {config.__class__}")
ValueError: Unrecognized configuration class <class 'transformers.models.llama.configuration_llama.LlamaConfig'>
"""
@pytest.fixture
def mixed_stack_trace():
"""Provide a sample stack trace with both axolotl and non-axolotl paths"""
return """Traceback (most recent call last):
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
trainer = get_trainer(cfg)
File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train
self._inner_training_loop()
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop
super()._inner_training_loop()
File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
RuntimeError: CUDA out of memory
"""
@pytest.fixture
def venv_stack_trace():
"""Provide a sample stack trace with virtual environment paths"""
return """Traceback (most recent call last):
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train
self._inner_training_loop()
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop
self.accelerator.backward(loss)
File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward
self.scaler.scale(loss).backward(**kwargs)
File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
RuntimeError: CUDA out of memory
"""
@pytest.fixture
def dist_packages_stack_trace():
"""Provide a sample stack trace with dist-packages paths"""
return """Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data
data = self._dataset_fetcher.fetch(index)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__
raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.")
IndexError: Index 10000 out of range for dataset of length 9832.
"""
@pytest.fixture
def project_stack_trace():
"""Provide a sample stack trace from a project directory (not a virtual env)"""
return """Traceback (most recent call last):
File "/home/user/projects/myproject/run.py", line 25, in <module>
main()
File "/home/user/projects/myproject/src/cli.py", line 45, in main
app.run()
File "/home/user/projects/myproject/src/app.py", line 102, in run
raise ValueError("Configuration missing")
ValueError: Configuration missing
"""
def test_sanitize_stack_trace(example_stack_trace):
"""Test that sanitize_stack_trace properly preserves axolotl paths"""
sanitized = sanitize_stack_trace(example_stack_trace)
# Check that personal paths are removed
assert "/home/user" not in sanitized
assert ".local/lib/python3.9" not in sanitized
# Check that site-packages is preserved
assert "site-packages/axolotl/cli/train.py" in sanitized
assert "site-packages/axolotl/train.py" in sanitized
assert "site-packages/axolotl/utils/models.py" in sanitized
# Check that error message is preserved
assert "ValueError: Model path not found" in sanitized
def test_sanitize_windows_paths(windows_stack_trace):
"""Test that sanitize_stack_trace handles Windows paths"""
sanitized = sanitize_stack_trace(windows_stack_trace)
# Check that personal paths are removed
assert "C:\\Users\\name" not in sanitized
assert "AppData\\Local\\Programs\\Python" not in sanitized
# Check that both axolotl and transformers packages are preserved
assert (
"site-packages\\axolotl\\cli\\train.py" in sanitized
or "site-packages/axolotl/cli/train.py" in sanitized
)
assert (
"site-packages\\axolotl\\train.py" in sanitized
or "site-packages/axolotl/train.py" in sanitized
)
assert (
"site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized
or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized
)
# Check that error message is preserved
assert "ValueError: Unrecognized configuration class" in sanitized
def test_sanitize_mixed_paths(mixed_stack_trace):
"""Test that sanitize_stack_trace preserves all package paths"""
sanitized = sanitize_stack_trace(mixed_stack_trace)
# Check that all package paths are preserved
assert "site-packages/axolotl/cli/train.py" in sanitized
assert "site-packages/transformers/trainer.py" in sanitized
assert "site-packages/axolotl/utils/trainer.py" in sanitized
assert "site-packages/torch/utils/data/dataloader.py" in sanitized
# Check that error message is preserved
assert "RuntimeError: CUDA out of memory" in sanitized
def test_sanitize_venv_paths(venv_stack_trace):
"""Test that sanitize_stack_trace preserves virtual environment package paths"""
sanitized = sanitize_stack_trace(venv_stack_trace)
# Check that personal paths are removed
assert "/home/user/venv" not in sanitized
# Check that all package paths are preserved
assert "site-packages/transformers/trainer.py" in sanitized
assert "site-packages/accelerate/accelerator.py" in sanitized
assert "site-packages/torch/_tensor.py" in sanitized
# Check that error message is preserved
assert "RuntimeError: CUDA out of memory" in sanitized
def test_sanitize_dist_packages(dist_packages_stack_trace):
"""Test that sanitize_stack_trace preserves dist-packages paths"""
sanitized = sanitize_stack_trace(dist_packages_stack_trace)
# Check that system paths are removed
assert "/usr/local/lib/python3.8" not in sanitized
# Check that all package paths are preserved
assert "dist-packages/torch/utils/data/dataloader.py" in sanitized
assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized
assert "dist-packages/datasets/arrow_dataset.py" in sanitized
# Check that error message is preserved
assert (
"IndexError: Index 10000 out of range for dataset of length 9832." in sanitized
)
def test_sanitize_project_paths(project_stack_trace):
"""Test handling of project paths (non-virtual env)"""
sanitized = sanitize_stack_trace(project_stack_trace)
# Check that personal paths are removed
assert "/home/user/projects" not in sanitized
# For non-package paths, we should at least preserve the filename
assert "run.py" in sanitized
assert "cli.py" in sanitized
assert "app.py" in sanitized
# Check that error message is preserved
assert "ValueError: Configuration missing" in sanitized
@pytest.fixture
def mock_telemetry_manager():
"""Create a mock TelemetryManager"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
yield mock_manager
def test_send_errors_successful_execution(mock_telemetry_manager):
"""Test that send_errors doesn't send telemetry for successful function execution"""
@send_errors
def test_func():
return "success"
result = test_func()
assert result == "success"
mock_telemetry_manager.send_event.assert_not_called()
def test_send_errors_with_exception(mock_telemetry_manager):
"""Test that send_errors sends telemetry when an exception occurs"""
test_error = ValueError("Test error")
@send_errors
def test_func():
raise test_error
with pytest.raises(ValueError) as excinfo:
test_func()
assert excinfo.value == test_error
mock_telemetry_manager.send_event.assert_called_once()
# Check that the error info was passed correctly
call_args = mock_telemetry_manager.send_event.call_args[1]
assert "test_func-errored" in call_args["event_type"]
assert "Test error" in call_args["properties"]["exception"]
assert "stack_trace" in call_args["properties"]
def test_send_errors_nested_calls(mock_telemetry_manager):
"""Test that send_errors only sends telemetry once for nested decorated functions"""
@send_errors
def inner_func():
raise ValueError("Inner error")
@send_errors
def outer_func():
return inner_func()
with pytest.raises(ValueError):
outer_func()
# Telemetry should be sent only once for the inner function
assert mock_telemetry_manager.send_event.call_count == 1
call_args = mock_telemetry_manager.send_event.call_args[1]
assert "inner_func-error" in call_args["event_type"]
def test_send_errors_telemetry_disable():
"""Test that send_errors doesn't attempt to send telemetry when disabled"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = False
mock_manager_class.get_instance.return_value = mock_manager
@send_errors
def test_func():
raise ValueError("Test error")
with pytest.raises(ValueError):
test_func()
mock_manager.send_event.assert_not_called()
def test_error_handled_reset():
"""Test that ERROR_HANDLED flag is properly reset"""
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
# Create and configure the mock manager
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
from axolotl.telemetry.errors import ERROR_HANDLED
@send_errors
def test_func():
raise ValueError("Test error")
assert not ERROR_HANDLED
with pytest.raises(ValueError):
test_func()
from axolotl.telemetry.errors import ERROR_HANDLED
assert ERROR_HANDLED
def test_module_path_resolution(mock_telemetry_manager):
"""Test that the module path is correctly resolved for the event type"""
import inspect
current_module = inspect.getmodule(test_module_path_resolution).__name__
@send_errors
def test_func():
raise ValueError("Test error")
with pytest.raises(ValueError):
test_func()
assert mock_telemetry_manager.send_event.called
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
expected_event_type = f"{current_module}.test_func-errored"
assert expected_event_type == event_type

View File

@@ -0,0 +1,245 @@
"""Tests for TelemetryManager class and utilities"""
# pylint: disable=redefined-outer-name,protected-access
import os
from unittest.mock import patch
import pytest
import yaml
from axolotl.telemetry.manager import TelemetryManager
@pytest.fixture
def mock_whitelist(tmp_path):
"""Create a temporary whitelist file for testing"""
whitelist_content = {
"organizations": ["meta-llama", "mistralai"],
}
whitelist_file = tmp_path / "whitelist.yaml"
with open(whitelist_file, "w", encoding="utf-8") as f:
yaml.dump(whitelist_content, f)
return str(whitelist_file)
@pytest.fixture
def telemetry_manager_class():
"""Reset the TelemetryManager singleton between tests"""
original_instance = TelemetryManager._instance
original_initialized = TelemetryManager._initialized
TelemetryManager._instance = None
TelemetryManager._initialized = False
yield TelemetryManager
TelemetryManager._instance = original_instance
TelemetryManager._initialized = original_initialized
@pytest.fixture
def manager(telemetry_manager_class, mock_whitelist):
"""Create a TelemetryManager instance with mocked dependencies"""
with patch("posthog.capture"), patch("posthog.flush"), patch("time.sleep"), patch(
"axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist
), patch.dict(os.environ, {"RANK": "0"}):
manager = telemetry_manager_class()
# Manually enable for most tests
manager.enabled = True
return manager
def test_singleton_instance(telemetry_manager_class):
"""Test that TelemetryManager is a singleton"""
with patch("posthog.capture"), patch("time.sleep"), patch.dict(
os.environ, {"RANK": "0"}
):
first = telemetry_manager_class()
second = telemetry_manager_class()
assert first is second
assert telemetry_manager_class.get_instance() is first
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}):
manager = telemetry_manager_class()
assert not manager.enabled
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
with patch.dict(os.environ, {"DO_NOT_TRACK": "1", "RANK": "0"}):
manager = telemetry_manager_class()
assert not manager.enabled
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
"""Test that telemetry is disabled for non-main processes"""
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}):
manager = telemetry_manager_class()
assert not manager.enabled
def test_telemetry_enabled_by_default(telemetry_manager_class):
"""Test that telemetry is enabled by default"""
with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch("time.sleep"), patch(
"logging.Logger.warning"
):
manager = telemetry_manager_class()
assert manager.enabled
assert not manager.explicit_enable
def test_explicit_enable_disables_warning(telemetry_manager_class):
"""Test that explicit enabling prevents warning"""
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}), patch(
"logging.Logger.warning"
) as mock_warning, patch("time.sleep"):
manager = telemetry_manager_class()
assert manager.enabled
assert manager.explicit_enable
for call in mock_warning.call_args_list:
assert "Telemetry is enabled" not in str(call)
def test_warning_displayed_for_implicit_enable(telemetry_manager_class):
"""Test that warning is displayed when telemetry is implicitly enabled"""
with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch(
"logging.Logger.warning"
) as mock_warning, patch("time.sleep"):
manager = telemetry_manager_class()
assert manager.enabled
assert not manager.explicit_enable
warning_displayed = False
for call in mock_warning.call_args_list:
if "Telemetry is enabled" in str(call):
warning_displayed = True
break
assert warning_displayed
def test_is_whitelisted(manager, mock_whitelist):
"""Test org whitelist functionality"""
with patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist):
# Should match organizations from the mock whitelist
assert manager._is_whitelisted("meta-llama/llama-7b")
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
# Should not match
assert not manager._is_whitelisted("unknown/model")
# Should handle case insensitively
assert manager._is_whitelisted("META-LLAMA/Llama-7B")
# Should handle empty input
assert not manager._is_whitelisted("")
assert not manager._is_whitelisted(None)
def test_system_info_collection(manager):
"""Test system information collection"""
system_info = manager._get_system_info()
# Check essential keys
assert "os" in system_info
assert "python_version" in system_info
assert "torch_version" in system_info
assert "transformers_version" in system_info
assert "axolotl_version" in system_info
assert "cpu_count" in system_info
assert "memory_total" in system_info
assert "accelerator_count" in system_info
def test_send_event(manager):
"""Test basic event sending"""
with patch("posthog.capture") as mock_capture:
# Test with clean properties (no PII)
manager.send_event("test_event", {"key": "value"})
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "test_event"
assert mock_capture.call_args[1]["properties"] == {"key": "value"}
assert mock_capture.call_args[1]["distinct_id"] == manager.run_id
# Test with default properties (None)
mock_capture.reset_mock()
manager.send_event("simple_event")
assert mock_capture.called
assert mock_capture.call_args[1]["properties"] == {}
def test_send_system_info(manager):
"""Test sending system info"""
with patch("posthog.capture") as mock_capture:
manager.send_system_info()
assert mock_capture.called
assert mock_capture.call_args[1]["event"] == "system-info"
assert mock_capture.call_args[1]["properties"] == manager.system_info
def test_redacted_properties(manager):
"""Test path redaction in send_event method"""
with patch("posthog.capture") as mock_capture:
# Test with properties containing various paths and non-paths
test_properties = {
"filepath": "/home/user/sensitive/data.txt",
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
"output_dir": "/var/lib/data",
"path_to_model": "models/llama/7b",
"message": "Training started", # Should not be redacted
"metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted
"base_model": "models/local_model",
"nested": {
"model_path": "/models/my_model",
"root_dir": "/home/user/projects",
"stats": {"steps": 1000, "epochs": 3}, # Should not be redacted
},
}
manager.send_event("test_event", test_properties)
# Verify the call was made
assert mock_capture.called
# Get the sanitized properties that were sent
sanitized = mock_capture.call_args[1]["properties"]
# Check that path-like and base_model keys were redacted
assert sanitized["filepath"] == "[REDACTED]"
assert sanitized["windows_path"] == "[REDACTED]"
assert sanitized["path_to_model"] == "[REDACTED]"
assert sanitized["base_model"] == "[REDACTED]"
# Check that non-path values were preserved
assert sanitized["message"] == "Training started"
assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95}
# Check nested structure handling
assert sanitized["nested"]["model_path"] == "[REDACTED]"
assert sanitized["nested"]["root_dir"] == "[REDACTED]"
assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3}
def test_disable_telemetry(manager):
"""Test that disabled telemetry doesn't send events"""
with patch("posthog.capture") as mock_capture:
manager.enabled = False
manager.send_event("test_event")
assert not mock_capture.called
def test_exception_handling_during_send(manager):
"""Test that exceptions in PostHog are handled gracefully"""
with patch("posthog.capture", side_effect=Exception("Test error")), patch(
"logging.Logger.warning"
) as mock_warning:
manager.send_event("test_event")
warning_logged = False
for call in mock_warning.call_args_list:
if "Failed to send telemetry event" in str(call):
warning_logged = True
break
assert warning_logged
def test_shutdown(manager):
"""Test shutdown behavior"""
with patch("posthog.flush") as mock_flush:
manager.shutdown()
assert mock_flush.called

View File

@@ -0,0 +1,356 @@
"""Tests for runtime metrics telemetry module"""
# pylint: disable=redefined-outer-name
from unittest.mock import MagicMock, patch
import pytest
from axolotl.telemetry.runtime_metrics import RuntimeMetrics, RuntimeMetricsTracker
@pytest.fixture
def mock_time():
"""Mock time.time() to have predictable values in tests"""
with patch("time.time") as mock_time:
# Start with time 1000.0 and increment by 10 seconds on each call
times = [1000.0 + i * 10 for i in range(10)]
mock_time.side_effect = times
yield mock_time
@pytest.fixture
def mock_telemetry_manager():
"""Create a mock TelemetryManager"""
with patch(
"axolotl.telemetry.runtime_metrics.TelemetryManager"
) as mock_manager_class:
mock_manager = MagicMock()
mock_manager.enabled = True
mock_manager_class.get_instance.return_value = mock_manager
yield mock_manager
@pytest.fixture
def mock_psutil():
"""Mock psutil for memory information"""
with patch("axolotl.telemetry.runtime_metrics.psutil") as mock_psutil:
mock_process = MagicMock()
mock_memory_info = MagicMock()
# Set initial memory to 1GB
mock_memory_info.rss = 1024 * 1024 * 1024
mock_process.memory_info.return_value = mock_memory_info
mock_psutil.Process.return_value = mock_process
yield mock_psutil
@pytest.fixture
def mock_torch():
"""Mock torch.cuda functions"""
with patch("axolotl.telemetry.runtime_metrics.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = True
mock_torch.cuda.device_count.return_value = 2
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 1) * 1024 * 1024 * 1024
)
yield mock_torch
class TestRuntimeMetrics:
"""Tests for RuntimeMetrics class."""
def test_initialization(self):
"""Test RuntimeMetrics initialization."""
metrics = RuntimeMetrics(start_time=1000.0)
assert metrics.start_time == 1000.0
assert metrics.epoch_start_times == {}
assert metrics.epoch_end_times == {}
assert metrics.peak_gpu_memory == {}
assert metrics.total_steps == 0
assert metrics.current_epoch == 0
assert metrics.current_step == 0
assert metrics.peak_cpu_memory == 0
def test_elapsed_time(self, mock_time):
"""Test elapsed_time property."""
metrics = RuntimeMetrics(start_time=1000.0)
# Mock time.time() to return 1050.0
mock_time.side_effect = [1050.0]
assert metrics.elapsed_time == 50.0
def test_epoch_time(self):
"""Test epoch_time method."""
metrics = RuntimeMetrics(start_time=1000.0)
# No epoch data
assert metrics.epoch_time(0) is None
# Add epoch start but no end
metrics.epoch_start_times[0] = 1000.0
assert metrics.epoch_time(0) is None
# Add epoch end
metrics.epoch_end_times[0] = 1060.0
assert metrics.epoch_time(0) == 60.0
def test_average_epoch_time(self):
"""Test average_epoch_time method."""
metrics = RuntimeMetrics(start_time=1000.0)
# No completed epochs
assert metrics.average_epoch_time() is None
# Add one completed epoch
metrics.epoch_start_times[0] = 1000.0
metrics.epoch_end_times[0] = 1060.0
assert metrics.average_epoch_time() == 60.0
# Add second completed epoch
metrics.epoch_start_times[1] = 1060.0
metrics.epoch_end_times[1] = 1140.0 # 80 seconds
assert metrics.average_epoch_time() == 70.0 # Average of 60 and 80
# Add incomplete epoch (should not affect average)
metrics.epoch_start_times[2] = 1140.0
assert metrics.average_epoch_time() == 70.0
def test_steps_per_second(self, mock_time):
"""Test steps_per_second method."""
metrics = RuntimeMetrics(start_time=1000.0)
# No steps - first call to time.time()
mock_time.side_effect = None
mock_time.return_value = 1050.0
assert metrics.steps_per_second() is None
# Add steps - second call to time.time()
metrics.total_steps = 100
mock_time.return_value = 1050.0 # Keep same time for consistent result
assert metrics.steps_per_second() == 2.0 # 100 steps / 50 seconds
def test_to_dict_basic(self, mock_time):
"""Test to_dict method with basic metrics."""
metrics = RuntimeMetrics(start_time=1000.0)
metrics.total_steps = 100
metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 # 2GB
# Mock elapsed_time
mock_time.side_effect = None
mock_time.return_value = 1050.0
result = metrics.to_dict()
assert result["total_time_seconds"] == 50.0
assert result["total_steps"] == 100
assert result["steps_per_second"] == 2.0
assert result["epochs_completed"] == 0
assert result["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
assert "epoch_times" not in result
assert "gpu_memory" not in result
def test_to_dict_with_epochs(self, mock_time):
"""Test to_dict method with epoch data."""
metrics = RuntimeMetrics(start_time=1000.0)
metrics.total_steps = 100
# Add epoch data
metrics.epoch_start_times[0] = 1000.0
metrics.epoch_end_times[0] = 1060.0
metrics.epoch_start_times[1] = 1060.0
metrics.epoch_end_times[1] = 1140.0
# Mock elapsed_time
mock_time.side_effect = None
mock_time.return_value = 1150.0
result = metrics.to_dict()
assert "epoch_times" in result
assert result["epoch_times"]["epoch_0_seconds"] == 60.0
assert result["epoch_times"]["epoch_1_seconds"] == 80.0
assert result["average_epoch_time_seconds"] == 70.0
def test_to_dict_with_gpu_memory(self, mock_time):
"""Test to_dict method with GPU memory data."""
metrics = RuntimeMetrics(start_time=1000.0)
metrics.peak_gpu_memory = {
0: 1 * 1024 * 1024 * 1024, # 1GB
1: 2 * 1024 * 1024 * 1024, # 2GB
}
# Mock elapsed_time
mock_time.side_effect = [1050.0]
result = metrics.to_dict()
assert "gpu_memory" in result
assert result["gpu_memory"]["gpu_0_peak_memory_bytes"] == 1 * 1024 * 1024 * 1024
assert result["gpu_memory"]["gpu_1_peak_memory_bytes"] == 2 * 1024 * 1024 * 1024
class TestRuntimeMetricsTracker:
"""Tests for RuntimeMetricsTracker class."""
# pylint: disable=unused-argument
def test_initialization(self, mock_time, mock_telemetry_manager):
"""Test RuntimeMetricsTracker initialization."""
tracker = RuntimeMetricsTracker()
assert isinstance(tracker.metrics, RuntimeMetrics)
assert tracker.metrics.start_time == 1000.0 # First value from mock_time
# pylint: disable=unused-argument
def test_start_epoch(
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
):
"""Test start_epoch method."""
tracker = RuntimeMetricsTracker()
# Reset mock_time to control next value
mock_time.side_effect = [1010.0]
tracker.start_epoch(0)
assert tracker.metrics.current_epoch == 0
assert tracker.metrics.epoch_start_times[0] == 1010.0
# Verify memory metrics were updated
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
assert 0 in tracker.metrics.peak_gpu_memory
assert 1 in tracker.metrics.peak_gpu_memory
# pylint: disable=unused-argument
def test_end_epoch(self, mock_time, mock_telemetry_manager):
"""Test end_epoch method."""
tracker = RuntimeMetricsTracker()
# Start epoch 0
mock_time.side_effect = [1010.0]
tracker.start_epoch(0)
# End epoch 0
mock_time.side_effect = [1060.0]
tracker.end_epoch(0)
assert 0 in tracker.metrics.epoch_end_times
assert tracker.metrics.epoch_end_times[0] == 1060.0
# pylint: disable=unused-argument
def test_update_step(
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
):
"""Test update_step method."""
tracker = RuntimeMetricsTracker()
# Update step to a non-multiple of 100
tracker.update_step(42)
assert tracker.metrics.current_step == 42
assert tracker.metrics.total_steps == 1
# Memory metrics should not be updated for non-multiple of 100
assert tracker.metrics.peak_cpu_memory == 0
# Update step to a multiple of 100
tracker.update_step(100)
assert tracker.metrics.current_step == 100
assert tracker.metrics.total_steps == 2
# Memory metrics should be updated for multiple of 100
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
# pylint: disable=unused-argument
def test_update_memory_metrics(
self, mock_psutil, mock_torch, mock_telemetry_manager
):
"""Test update_memory_metrics method."""
tracker = RuntimeMetricsTracker()
# Initial memory state
assert tracker.metrics.peak_cpu_memory == 0
assert tracker.metrics.peak_gpu_memory == {}
# Update memory metrics
tracker.update_memory_metrics()
# Verify CPU memory
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
# Verify GPU memory
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
# Change mocked memory values to be lower
mock_process = mock_psutil.Process.return_value
mock_memory_info = mock_process.memory_info.return_value
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 0.5) * 1024 * 1024 * 1024
)
# Update memory metrics again
tracker.update_memory_metrics()
# Peak values should not decrease
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
# Change mocked memory values to be higher
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
mock_torch.cuda.memory_allocated.side_effect = (
lambda device: (device + 2) * 1024 * 1024 * 1024
)
# Update memory metrics again
tracker.update_memory_metrics()
# Peak values should increase
assert tracker.metrics.peak_cpu_memory == 2 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[0] == 2 * 1024 * 1024 * 1024
assert tracker.metrics.peak_gpu_memory[1] == 3 * 1024 * 1024 * 1024
# pylint: disable=unused-argument
def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_telemetry_manager):
"""Test get_memory_metrics method."""
tracker = RuntimeMetricsTracker()
# Set peak memory values
tracker.metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024
tracker.metrics.peak_gpu_memory = {
0: 3 * 1024 * 1024 * 1024,
1: 4 * 1024 * 1024 * 1024,
}
# Get memory metrics
memory_metrics = tracker.get_memory_metrics()
# Verify CPU memory
assert (
memory_metrics["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024
) # Current value from mock
assert (
memory_metrics["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
) # Peak value we set
# Verify GPU memory
assert (
memory_metrics["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024
) # Current value from mock
assert (
memory_metrics["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024
) # Peak value we set
assert (
memory_metrics["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024
) # Current value from mock
assert (
memory_metrics["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024
) # Peak value we set

View File

@@ -1,7 +1,6 @@
"""
Test cases for the tokenizer loading
"""
import unittest
import pytest
@@ -10,7 +9,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
class TestTokenizers:
class TestTokenizers(unittest.TestCase):
"""
test class for the load_tokenizer fn
"""
@@ -76,48 +75,12 @@ class TestTokenizers:
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
assert len(tokenizer) == 32001
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
self.assertEqual(len(tokenizer), 32001)
# ensure reloading the tokenizer again from cfg results in same vocab length
tokenizer = load_tokenizer(cfg)
assert len(tokenizer) == 32001
def test_added_tokens_overrides(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {
128041: "RANDOM_OVERRIDE_1",
128042: "RANDOM_OVERRIDE_2",
},
"output_dir": temp_dir,
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
128041
]
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
128042
]
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
"output_dir": temp_dir,
}
)
with pytest.raises(
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
):
load_tokenizer(cfg)
self.assertEqual(len(tokenizer), 32001)
if __name__ == "__main__":