Compare commits
10 Commits
telemetry
...
topk-logpr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68e97d032a | ||
|
|
23f029a89c | ||
|
|
afbb44f08b | ||
|
|
d753ead033 | ||
|
|
c011405117 | ||
|
|
a2e52a29e9 | ||
|
|
e82268e580 | ||
|
|
75e1480c10 | ||
|
|
45e1548d59 | ||
|
|
165088e7c1 |
@@ -50,14 +50,13 @@ Features:
|
||||
## 🚀 Quick Start
|
||||
|
||||
**Requirements**:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.4.1
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
```shell
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
@@ -69,7 +68,7 @@ Other installation approaches are described [here](https://axolotl-ai-cloud.gith
|
||||
|
||||
### Your First Fine-tune
|
||||
|
||||
```bash
|
||||
```shell
|
||||
# Fetch axolotl examples
|
||||
axolotl fetch examples
|
||||
|
||||
|
||||
61
_quarto.yml
61
_quarto.yml
@@ -3,12 +3,10 @@ project:
|
||||
|
||||
website:
|
||||
title: "Axolotl"
|
||||
description: "We make fine-tuning accessible, scalable, and fun"
|
||||
description: "Fine-tuning"
|
||||
favicon: favicon.jpg
|
||||
|
||||
navbar:
|
||||
logo: image/axolotl_logo_digital_white.svg
|
||||
title: false
|
||||
title: Axolotl
|
||||
background: dark
|
||||
pinned: false
|
||||
collapse: false
|
||||
@@ -27,58 +25,33 @@ website:
|
||||
contents:
|
||||
- text: Home
|
||||
href: index.qmd
|
||||
|
||||
- section: "Getting Started"
|
||||
- section: "How-To Guides"
|
||||
contents:
|
||||
# TODO Edit folder structure after we have more docs.
|
||||
- docs/getting-started.qmd
|
||||
- docs/installation.qmd
|
||||
- docs/cli.qmd
|
||||
- docs/debugging.qmd
|
||||
- docs/inference.qmd
|
||||
|
||||
- section: "Dataset Formats"
|
||||
contents: docs/dataset-formats/*
|
||||
|
||||
- section: "Deployments"
|
||||
contents:
|
||||
- docs/multipack.qmd
|
||||
- docs/fsdp_qlora.qmd
|
||||
- docs/input_output.qmd
|
||||
- docs/rlhf.qmd
|
||||
- docs/nccl.qmd
|
||||
- docs/mac.qmd
|
||||
- docs/multi-gpu.qmd
|
||||
- docs/multi-node.qmd
|
||||
- docs/ray-integration.qmd
|
||||
- docs/amd_hpc.qmd
|
||||
- docs/mac.qmd
|
||||
|
||||
- section: "How To Guides"
|
||||
contents:
|
||||
- docs/multimodal.qmd
|
||||
- docs/rlhf.qmd
|
||||
- docs/reward_modelling.qmd
|
||||
- docs/lr_groups.qmd
|
||||
- docs/lora_optims.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
contents:
|
||||
- docs/batch_vs_grad.qmd
|
||||
- docs/dataset_preprocessing.qmd
|
||||
- docs/multipack.qmd
|
||||
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
- docs/fsdp_qlora.qmd
|
||||
- docs/unsloth.qmd
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
- docs/faq.qmd
|
||||
- docs/debugging.qmd
|
||||
- docs/nccl.qmd
|
||||
|
||||
- docs/amd_hpc.qmd
|
||||
- docs/ray-integration.qmd
|
||||
- section: "Dataset Formats"
|
||||
contents: docs/dataset-formats/*
|
||||
- section: "Reference"
|
||||
contents:
|
||||
- docs/config.qmd
|
||||
- docs/faq.qmd
|
||||
|
||||
format:
|
||||
html:
|
||||
theme: darkly
|
||||
theme: materia
|
||||
css: styles.css
|
||||
toc: true
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: AMD GPUs on HPC Systems
|
||||
title: Training with AMD GPUs on HPC Systems
|
||||
description: A comprehensive guide for using Axolotl on distributed systems with AMD GPUs
|
||||
---
|
||||
|
||||
|
||||
134
docs/cli.qmd
134
docs/cli.qmd
@@ -1,19 +1,28 @@
|
||||
---
|
||||
title: "CLI Reference"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-expand: 2
|
||||
toc-depth: 3
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
# Axolotl CLI Documentation
|
||||
|
||||
The Axolotl CLI provides a streamlined interface for training and fine-tuning large language models. This guide covers
|
||||
the CLI commands, their usage, and common examples.
|
||||
|
||||
### Table of Contents
|
||||
|
||||
## Basic Commands
|
||||
- Basic Commands
|
||||
- Command Reference
|
||||
- fetch
|
||||
- preprocess
|
||||
- train
|
||||
- inference
|
||||
- merge-lora
|
||||
- merge-sharded-fsdp-weights
|
||||
- evaluate
|
||||
- lm-eval
|
||||
- Legacy CLI Usage
|
||||
- Remote Compute with Modal Cloud
|
||||
- Cloud Configuration
|
||||
- Running on Modal Cloud
|
||||
- Cloud Configuration Options
|
||||
|
||||
|
||||
### Basic Commands
|
||||
|
||||
All Axolotl commands follow this general structure:
|
||||
|
||||
@@ -23,9 +32,9 @@ axolotl <command> [config.yml] [options]
|
||||
|
||||
The config file can be local or a URL to a raw YAML file.
|
||||
|
||||
## Command Reference
|
||||
### Command Reference
|
||||
|
||||
### fetch
|
||||
#### fetch
|
||||
|
||||
Downloads example configurations and deepspeed configs to your local machine.
|
||||
|
||||
@@ -40,7 +49,7 @@ axolotl fetch deepspeed_configs
|
||||
axolotl fetch examples --dest path/to/folder
|
||||
```
|
||||
|
||||
### preprocess
|
||||
#### preprocess
|
||||
|
||||
Preprocesses and tokenizes your dataset before training. This is recommended for large datasets.
|
||||
|
||||
@@ -65,7 +74,7 @@ dataset_prepared_path: Local folder for saving preprocessed data
|
||||
push_dataset_to_hub: HuggingFace repo to push preprocessed data (optional)
|
||||
```
|
||||
|
||||
### train
|
||||
#### train
|
||||
|
||||
Trains or fine-tunes a model using the configuration specified in your YAML file.
|
||||
|
||||
@@ -86,38 +95,7 @@ axolotl train config.yml --no-accelerate
|
||||
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
|
||||
```
|
||||
|
||||
It is possible to run sweeps over multiple hyperparameters by passing in a sweeps config.
|
||||
|
||||
```bash
|
||||
# Basic training with sweeps
|
||||
axolotl train config.yml --sweep path/to/sweep.yaml
|
||||
```
|
||||
|
||||
Example sweep config:
|
||||
```yaml
|
||||
_:
|
||||
# This section is for dependent variables we need to fix
|
||||
- load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
adapter: lora
|
||||
- load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
adapter: lora
|
||||
|
||||
# These are independent variables
|
||||
learning_rate: [0.0003, 0.0006]
|
||||
lora_r:
|
||||
- 16
|
||||
- 32
|
||||
lora_alpha:
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
||||
```
|
||||
|
||||
|
||||
|
||||
### inference
|
||||
#### inference
|
||||
|
||||
Runs inference using your trained model in either CLI or Gradio interface mode.
|
||||
|
||||
@@ -137,7 +115,7 @@ cat prompt.txt | axolotl inference config.yml \
|
||||
--base-model="./completed-model"
|
||||
```
|
||||
|
||||
### merge-lora
|
||||
#### merge-lora
|
||||
|
||||
Merges trained LoRA adapters into the base model.
|
||||
|
||||
@@ -159,7 +137,7 @@ gpu_memory_limit: Limit GPU memory usage
|
||||
lora_on_cpu: Load LoRA weights on CPU
|
||||
```
|
||||
|
||||
### merge-sharded-fsdp-weights
|
||||
#### merge-sharded-fsdp-weights
|
||||
|
||||
Merges sharded FSDP model checkpoints into a single combined checkpoint.
|
||||
|
||||
@@ -168,7 +146,7 @@ Merges sharded FSDP model checkpoints into a single combined checkpoint.
|
||||
axolotl merge-sharded-fsdp-weights config.yml
|
||||
```
|
||||
|
||||
### evaluate
|
||||
#### evaluate
|
||||
|
||||
Evaluates a model's performance using metrics specified in the config.
|
||||
|
||||
@@ -177,27 +155,27 @@ Evaluates a model's performance using metrics specified in the config.
|
||||
axolotl evaluate config.yml
|
||||
```
|
||||
|
||||
### lm-eval
|
||||
#### lm-eval
|
||||
|
||||
Runs LM Evaluation Harness on your model.
|
||||
|
||||
```bash
|
||||
# Basic evaluation
|
||||
axolotl lm-eval config.yml
|
||||
|
||||
# Evaluate specific tasks
|
||||
axolotl lm-eval config.yml --tasks arc_challenge,hellaswag
|
||||
```
|
||||
|
||||
Configuration options:
|
||||
|
||||
```yaml
|
||||
# List of tasks to evaluate
|
||||
lm_eval_tasks:
|
||||
- arc_challenge
|
||||
- hellaswag
|
||||
lm_eval_batch_size: # Batch size for evaluation
|
||||
output_dir: # Directory to save evaluation results
|
||||
lm_eval_tasks: List of tasks to evaluate
|
||||
lm_eval_batch_size: Batch size for evaluation
|
||||
output_dir: Directory to save evaluation results
|
||||
```
|
||||
|
||||
## Legacy CLI Usage
|
||||
### Legacy CLI Usage
|
||||
|
||||
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
|
||||
|
||||
@@ -217,18 +195,12 @@ accelerate launch -m axolotl.cli.inference config.yml \
|
||||
--lora_model_dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
When overriding CLI parameters in the legacy CLI, use same notation as in yaml file (e.g., `--lora_model_dir`).
|
||||
|
||||
**Note:** This differs from the new Click-based CLI, which uses dash notation (e.g., `--lora-model-dir`). Keep this in mind if you're referencing newer documentation or switching between CLI versions.
|
||||
:::
|
||||
|
||||
## Remote Compute with Modal Cloud
|
||||
### Remote Compute with Modal Cloud
|
||||
|
||||
Axolotl supports running training and inference workloads on Modal cloud infrastructure. This is configured using a
|
||||
cloud YAML file alongside your regular Axolotl config.
|
||||
|
||||
### Cloud Configuration
|
||||
#### Cloud Configuration
|
||||
|
||||
Create a cloud config YAML with your Modal settings:
|
||||
|
||||
@@ -243,17 +215,13 @@ branch: main # Git branch to use (optional)
|
||||
volumes: # Persistent storage volumes
|
||||
- name: axolotl-cache
|
||||
mount: /workspace/cache
|
||||
- name: axolotl-data
|
||||
mount: /workspace/data
|
||||
- name: axolotl-artifacts
|
||||
mount: /workspace/artifacts
|
||||
|
||||
env: # Environment variables
|
||||
- WANDB_API_KEY
|
||||
- HF_TOKEN
|
||||
```
|
||||
|
||||
### Running on Modal Cloud
|
||||
#### Running on Modal Cloud
|
||||
|
||||
Commands that support the --cloud flag:
|
||||
|
||||
@@ -271,18 +239,18 @@ axolotl train config.yml --cloud cloud_config.yml --no-accelerate
|
||||
axolotl lm-eval config.yml --cloud cloud_config.yml
|
||||
```
|
||||
|
||||
### Cloud Configuration Options
|
||||
#### Cloud Configuration Options
|
||||
|
||||
```yaml
|
||||
provider: # compute provider, currently only `modal` is supported
|
||||
gpu: # GPU type to use
|
||||
gpu_count: # Number of GPUs (default: 1)
|
||||
memory: # RAM in GB (default: 128)
|
||||
timeout: # Maximum runtime in seconds
|
||||
timeout_preprocess: # Preprocessing timeout
|
||||
branch: # Git branch to use
|
||||
docker_tag: # Custom Docker image tag
|
||||
volumes: # List of persistent storage volumes
|
||||
env: # Environment variables to pass
|
||||
secrets: # Secrets to inject
|
||||
provider: compute provider, currently only `modal` is supported
|
||||
gpu: GPU type to use
|
||||
gpu_count: Number of GPUs (default: 1)
|
||||
memory: RAM in GB (default: 128)
|
||||
timeout: Maximum runtime in seconds
|
||||
timeout_preprocess: Preprocessing timeout
|
||||
branch: Git branch to use
|
||||
docker_tag: Custom Docker image tag
|
||||
volumes: List of persistent storage volumes
|
||||
env: Environment variables to pass
|
||||
secrets: Secrets to inject
|
||||
```
|
||||
|
||||
@@ -166,7 +166,7 @@ datasets:
|
||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
||||
# Note: If the below 4 fields are set to empty, defaults to training only on the last message.
|
||||
# Note: If the below 4 fields are empty, defaults to training only on the last message.
|
||||
|
||||
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
||||
roles_to_train: ["assistant"] # default
|
||||
@@ -174,7 +174,6 @@ datasets:
|
||||
# - all: train on all EOS tokens
|
||||
# - turn (default): train on the EOS token at the end of each trainable turn
|
||||
# - last: train on the last EOS token in the conversation
|
||||
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
||||
train_on_eos: last
|
||||
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
||||
message_field_training: training
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
---
|
||||
title: Custom Integrations
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
---
|
||||
|
||||
```{python}
|
||||
#| echo: false
|
||||
|
||||
import re
|
||||
|
||||
def process_readme(integration_name):
|
||||
try:
|
||||
path = f'../src/axolotl/integrations/{integration_name}/README.md'
|
||||
with open(path, 'r') as f:
|
||||
txt = f.read()
|
||||
# Remove h1 headings
|
||||
txt = re.sub(r'^# .*\n?', '', txt, flags=re.MULTILINE)
|
||||
# Convert h2 to h3
|
||||
txt = re.sub(r'^## ', '### ', txt, flags=re.MULTILINE)
|
||||
return txt
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def print_section(name, folder_name):
|
||||
output = f"\n## {name}\n"
|
||||
content = process_readme(folder_name)
|
||||
if content:
|
||||
output += content
|
||||
output += f"\nPlease see reference [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/{folder_name})\n"
|
||||
return output
|
||||
```
|
||||
|
||||
```{python}
|
||||
#| output: asis
|
||||
#| echo: false
|
||||
|
||||
# Introduction text
|
||||
print("""
|
||||
Axolotl adds custom features through `integrations`. They are located within the `src/axolotl/integrations` directory.
|
||||
|
||||
To enable them, please check the respective documentations.
|
||||
""")
|
||||
|
||||
# Sections
|
||||
sections = [
|
||||
("Cut Cross Entropy", "cut_cross_entropy"),
|
||||
("Grokfast", "grokfast"),
|
||||
("Knowledge Distillation (KD)", "kd"),
|
||||
("Liger Kernels", "liger"),
|
||||
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
|
||||
("Spectrum", "spectrum")
|
||||
]
|
||||
|
||||
for section_name, folder_name in sections:
|
||||
print(print_section(section_name, folder_name))
|
||||
```
|
||||
@@ -6,9 +6,7 @@ order: 3
|
||||
|
||||
## sharegpt
|
||||
|
||||
::: {.callout-important}
|
||||
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
|
||||
:::
|
||||
IMPORTANT: ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
|
||||
|
||||
## pygmalion
|
||||
|
||||
@@ -104,10 +102,6 @@ datasets:
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
||||
:::
|
||||
|
||||
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
|
||||
For a data sample that looks like:
|
||||
@@ -155,6 +149,4 @@ datasets:
|
||||
message_field_training_detail: train_detail
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||
:::
|
||||
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.
|
||||
|
||||
@@ -13,7 +13,7 @@ As there are a lot of available options in Axolotl, this guide aims to provide a
|
||||
|
||||
Axolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.
|
||||
|
||||
## Pre-training
|
||||
## [Pre-training](pretraining.qmd)
|
||||
|
||||
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
|
||||
|
||||
@@ -96,10 +96,6 @@ One step is equal to `sequence_len * micro_batch_size * gradient_accumulation_st
|
||||
|
||||
It is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.
|
||||
|
||||
### Reference
|
||||
|
||||
Please see docs [here](pretraining.qmd).
|
||||
|
||||
## Supervised fine-tuning (SFT)
|
||||
|
||||
Supervised fine-tuning is the process of training models to respond to an instruction or chat input.
|
||||
@@ -124,7 +120,7 @@ If you went through the flow chart and did not find one that matches, it is reco
|
||||
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
|
||||
:::
|
||||
|
||||
### Pre-Tokenized Dataset
|
||||
### [Pre-Tokenized Dataset](tokenized.qmd)
|
||||
|
||||
We suggest this approach when you want to bring your own tokenized dataset.
|
||||
|
||||
@@ -149,9 +145,7 @@ datasets:
|
||||
`type: ` is empty!
|
||||
:::
|
||||
|
||||
Reference: [Pre-Tokenized Dataset Documentation](tokenized.qmd).
|
||||
|
||||
### Template Free Dataset
|
||||
### [Template Free Dataset](template_free.qmd)
|
||||
|
||||
We reccomend this approach when you want granular control over the prompt formatting, special tokens, and masking, whilst letting Axolotl handle the tokenization. This is very useful if your dataset has unique prompts that differ across samples and where one single general template wouldn't suffice.
|
||||
|
||||
@@ -188,9 +182,7 @@ datasets:
|
||||
type: input_output
|
||||
```
|
||||
|
||||
Reference: [Template Free Documentation](template_free.qmd).
|
||||
|
||||
### Conversation Dataset
|
||||
### [Conversation Dataset](conversation.qmd)
|
||||
|
||||
`conversation` messages are a list of messages which usually contain a `role` and `content` key.
|
||||
|
||||
@@ -266,7 +258,7 @@ Newer conversation datasets usually follow the OpenAI format.
|
||||
|
||||
Axolotl supports both as well as allowing customization of any kind of key.
|
||||
|
||||
#### Chat Template Usage
|
||||
#### [Chat Template Usage](conversation.qmd#chat_template)
|
||||
|
||||
To properly use this method, it is important to identify three things:
|
||||
|
||||
@@ -348,19 +340,9 @@ datasets:
|
||||
narrator: ["narrator"]
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
As chat_templates may use hardcoded EOS/EOT tokens that are different from the tokenizer's EOS, it is highly recommended to set them. For example, `ChatML` uses `<|im_end|>` to end turns.
|
||||
#### Applying `chat_template`
|
||||
|
||||
```yaml
|
||||
special_tokens:
|
||||
eos_token: <|im_end|>
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
##### Applying `chat_template`
|
||||
|
||||
Once all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset.
|
||||
Once all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset. The final step would be to correctly set the EOS token in your config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
@@ -409,17 +391,7 @@ If this config were to be applied to the sample dataset above, the output would
|
||||
|
||||
The first number refers to the label, the second refers to the `token_id`. For example, `-100` labels appear on non-assistant portions, meaning that they are masked during. For assistant portions, the label is the same as the `token_id`.
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
If during `preprocess`, there are a lot of warnings of `Could not find content __ boundary`, please check the FAQ section for [chat_templates](../faq.qmd#chat-templates).
|
||||
|
||||
:::
|
||||
|
||||
#### Reference
|
||||
|
||||
Please see docs [here](conversation.qmd).
|
||||
|
||||
### Instruction Dataset
|
||||
### [Instruction Dataset](inst_tune.qmd)
|
||||
|
||||
Instruction datasets are used to train instruction-following models and comprise a prompt, containing an instruction, and a single response. In contrast to chat datasets which may be multi-turn, instruct datasets are typically single-turn.
|
||||
|
||||
@@ -451,9 +423,6 @@ datasets:
|
||||
|
||||
Axolotl supports many kinds of instruction dataset. All of them can be found here (https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/inst_tune.html) with their respective type and sample row format.
|
||||
|
||||
|
||||
Reference: [Instruction Dataset Documentation](inst_tune.qmd).
|
||||
|
||||
#### Custom Instruct Prompt Format
|
||||
|
||||
Due to the myriad possibilities of instruction formats, Axolotl allows customizing your own instruction format without having to dive into the code directly.
|
||||
@@ -484,8 +453,6 @@ datasets:
|
||||
|
||||
The config sets that the `field_instruction` is actually named `input`, and the `field_input` is empty as we don't have an `input` in this sample. Generally, `instruction` can be thought as the question to the model, and `input` as the additional information with `output` being the response. It is not necessary to have an `input` nor `system`. In the end, the most important part is to understand what format you want it to look like and how you can customize this to your use case.
|
||||
|
||||
Reference: [Custom Instruct Prompt Format Documentation](inst_tune.qmd#how-to-add-custom-prompt-format).
|
||||
|
||||
## Reinforcement Learning from Human Feedback (RLHF)
|
||||
|
||||
As there are multiple RLHF methods with their own dataset requirements. Please see [RLHF documentation](../rlhf.qmd) for more detail.
|
||||
As there are multiple RLHF methods with their own dataset requirements. Please see [RLHF datasets](../rlhf.qmd) documentation for more detail.
|
||||
|
||||
@@ -27,6 +27,7 @@ pretraining_dataset:
|
||||
type: pretrain
|
||||
trust_remote_code:
|
||||
skip: # number of rows of data to skip over from the beginning
|
||||
...
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
@@ -1,239 +1,7 @@
|
||||
---
|
||||
title: Template-Free
|
||||
description: Construct prompts without a template.
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
order: 4
|
||||
---
|
||||
|
||||
## Background {#sec-background}
|
||||
|
||||
### Masking Inputs {#masking-inputs}
|
||||
|
||||
One of the most popular features of
|
||||
[axolotl](https://github.com/axolotl-ai-cloud/axolotl) is
|
||||
setting the following configuration value:
|
||||
|
||||
|
||||
```yaml
|
||||
train_on_inputs: false
|
||||
```
|
||||
|
||||
If you declare a [dataset formats](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#dataset)
|
||||
such as `alpaca` or `chatml`, axolotl knows what is an input
|
||||
(i.e. human) vs. an output (i.e. the assistant) and masks the input
|
||||
labels so that your model can focus on predicting the outputs only.
|
||||
|
||||
### You may not want prompt templates {#sec-you-may-not-want-prompt-templates}
|
||||
|
||||
However, there are many situations where you don't want to use one of
|
||||
these formats or templates. This is because they can:
|
||||
|
||||
- Add unnecessary boilerplate to your prompts.
|
||||
- Create artifacts like special delimiters `<|im_start|>` that can
|
||||
quickly become footguns if you don't include them correctly at
|
||||
inference time.
|
||||
- Enforce a *chat* interface when you do not want one. Sometimes you
|
||||
just want to fine-tune a model to a very specific task and do NOT
|
||||
want multi-turn conversations, roles, etc.
|
||||
- Limit you to only certain roles that the template allows.
|
||||
|
||||
### The `input_output` format {#sec-the-inputoutput-format}
|
||||
|
||||
You can construct your prompts without a template by using the
|
||||
`input_output` format, by setting `type: input_output` in your
|
||||
configuration file like this:
|
||||
|
||||
**config.yml**
|
||||
|
||||
```yaml
|
||||
train_on_inputs: false # Mask segments of your data
|
||||
datasets:
|
||||
- path: output.jsonl
|
||||
type: input_output # use template free prompt construction
|
||||
```
|
||||
|
||||
Unlike `type: completion`, which is also template-free,
|
||||
`type: input_output` allows you to mask segments of your text. More
|
||||
details on how this works are described below.
|
||||
|
||||
## Usage {#sec-usage}
|
||||
|
||||
This is how you can use the `input_output` format:
|
||||
|
||||
### 1. Prepare Data {#sec-1-prepare-data}
|
||||
|
||||
To use the `input_output` format, collect your data in the following
|
||||
format into a jsonl file (below is the first row from the file
|
||||
`output`.jsonl` pretty printed):
|
||||
|
||||
```bash
|
||||
$ head -n1 output.jsonl | python -m json.tool
|
||||
```
|
||||
|
||||
:::{.cell-output .cell-output-stdout}
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
:::
|
||||
|
||||
Set `label:false` when you want to mask a segment of text so that the
|
||||
model isn't trained on it. Some things to keep in mind:
|
||||
|
||||
> [!IMPORTANT]
|
||||
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
|
||||
concatenates all the segments as-is.** The tokenizer doesn't add
|
||||
anything additional. Notice how I added spaces, newlines, `<s>`
|
||||
(BOS), and `</s>` (EOS) myself.
|
||||
> 2. Make sure you check the materialized output to validate that the
|
||||
prompt is getting assembled how you like.
|
||||
|
||||
### 2. Use `type: input_output` {#sec-2-use-type-inputoutput}
|
||||
|
||||
Let's materialize data with our `output.jsonl` file by setting
|
||||
`type: input_output` in our axolotl config:
|
||||
|
||||
```yaml
|
||||
# training_config.yaml
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
data_seed: 49
|
||||
seed: 49
|
||||
|
||||
datasets:
|
||||
- path: output.jsonl
|
||||
type: input_output
|
||||
val_set_size: 0.1
|
||||
|
||||
sequence_len: 896
|
||||
sample_packing: false
|
||||
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 3
|
||||
eval_batch_size: 2
|
||||
num_epochs: 1
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
```
|
||||
|
||||
You can use the following command to materialize your data. The
|
||||
`--debug` flag will print the tokens, along with the labels so you can
|
||||
verify that the correct items are being ignored:
|
||||
|
||||
```bash
|
||||
axolotl preprocess training_config.yaml --debug
|
||||
|
||||
...
|
||||
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
|
||||
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
|
||||
|
||||
```
|
||||
|
||||
The format is `decoded_token`(`label`, `token_id`), for example,
|
||||
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
|
||||
token_id is `1`. When the label is `-100` then that token is ignored for
|
||||
training.
|
||||
|
||||
### 3. Check the prompts {#sec-3-check-the-prompts}
|
||||
|
||||
Here is another way to check the materialized output:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from datasets import load_from_disk
|
||||
import yaml
|
||||
|
||||
directory = !ls last_run_prepared/
|
||||
with open('training_config.yaml', 'r') as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
model_id = cfg['base_model']
|
||||
tok = AutoTokenizer.from_pretrained(model_id)
|
||||
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
|
||||
```
|
||||
|
||||
```python
|
||||
>>> row = ds[0]
|
||||
>>> print(tok.decode(row['input_ids']))
|
||||
<s> Hello
|
||||
hi there!. goodbye farewell</s>
|
||||
```
|
||||
|
||||
We can check that the right tokens are ignored by comparing the labels
|
||||
to each token:
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
|
||||
zip(row['input_ids'], row['labels'])])
|
||||
```
|
||||
|
||||
| token | label | id |
|
||||
|-------|-------|-------|
|
||||
| 0 | \<s\> | 1 |
|
||||
| 1 | Hello | 22557 |
|
||||
| 2 | \\n | 13 |
|
||||
| 3 | hi | 12014 |
|
||||
| 4 | there | 736 |
|
||||
| 5 | ! | 28808 |
|
||||
| 6 | . | 28723 |
|
||||
| 7 | | 28705 |
|
||||
| 8 | good | -100 |
|
||||
| 9 | bye | -100 |
|
||||
| 10 | | -100 |
|
||||
| 11 | fare | 19111 |
|
||||
| 12 | well | 5458 |
|
||||
| 13 | \</s\>| 2 |
|
||||
|
||||
|
||||
|
||||
If we look at the input data, the above table seems correct! (The jsonl
|
||||
version is repeated below for reference):
|
||||
|
||||
|
||||
```bash
|
||||
$ head -n1 output.jsonl | python -m json.tool
|
||||
```
|
||||
|
||||
:::{.cell-output .cell-output-stdout}
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
:::
|
||||
See [these docs](../input_output.qmd).
|
||||
|
||||
@@ -3,11 +3,8 @@ title: Dataset Preprocessing
|
||||
description: How datasets are processed
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||
the [dataset format](docs/dataset-formats) and prompt strategies to:
|
||||
|
||||
the (dataset format)[../dataset-formats/] and prompt strategies to:
|
||||
- parse the dataset based on the *dataset format*
|
||||
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||
- tokenize the dataset based on the configured model & tokenizer
|
||||
@@ -15,12 +12,10 @@ the [dataset format](docs/dataset-formats) and prompt strategies to:
|
||||
|
||||
The processing of the datasets can happen one of two ways:
|
||||
|
||||
1. Before kicking off training by calling `axolotl preprocess config.yaml --debug`
|
||||
1. Before kicking off training by calling `python -m axolotl.cli.preprocess /path/to/your.yaml --debug`
|
||||
2. When training is started
|
||||
|
||||
### What are the benefits of pre-processing?
|
||||
|
||||
When training interactively or for sweeps
|
||||
What are the benefits of pre-processing? When training interactively or for sweeps
|
||||
(e.g. you are restarting the trainer often), processing the datasets can oftentimes be frustratingly
|
||||
slow. Pre-processing will cache the tokenized/formatted datasets according to a hash of dependent
|
||||
training parameters so that it will intelligently pull from its cache when possible.
|
||||
@@ -33,12 +28,8 @@ default path of `./last_run_prepared/`, but will ignore anything already cached
|
||||
setting `dataset_prepared_path: ./last_run_prepared`, the trainer will use whatever pre-processed
|
||||
data is in the cache.
|
||||
|
||||
### What are the edge cases?
|
||||
|
||||
Let's say you are writing a custom prompt strategy or using a user-defined
|
||||
What are the edge cases? Let's say you are writing a custom prompt strategy or using a user-defined
|
||||
prompt template. Because the trainer cannot readily detect these changes, we cannot change the
|
||||
calculated hash value for the pre-processed dataset.
|
||||
|
||||
If you have `dataset_prepared_path: ...` set
|
||||
calculated hash value for the pre-processed dataset. If you have `dataset_prepared_path: ...` set
|
||||
and change your prompt templating logic, it may not pick up the changes you made and you will be
|
||||
training over the old prompt.
|
||||
|
||||
@@ -31,13 +31,11 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
||||
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
|
||||
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
|
||||
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
dataset:
|
||||
...
|
||||
shards: 20
|
||||
```
|
||||
|
||||
3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
|
||||
4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings.
|
||||
- `micro_batch_size: 1`
|
||||
@@ -87,7 +85,7 @@ The easiest way to get started is to modify the [.vscode/launch.json](../.vscode
|
||||
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
|
||||
```json
|
||||
```jsonc
|
||||
// .vscode/launch.json
|
||||
{
|
||||
"version": "0.2.0",
|
||||
@@ -134,7 +132,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
||||
|
||||
Below is the [./vscode/tasks.json](../.vscode/tasks.json) file that defines the `cleanup-for-dataprep` task. This task is run before each debugging session when you use the above configuration. Note how there are two tasks that delete the two folders mentioned above. The third task `cleanup-for-dataprep` is a composite task that combines the two tasks. A composite task is necessary because VSCode does not allow you to specify multiple tasks in the `preLaunchTask` argument of the `launch.json` file.
|
||||
|
||||
```json
|
||||
```jsonc
|
||||
// .vscode/tasks.json
|
||||
// this file is used by launch.json
|
||||
{
|
||||
|
||||
23
docs/faq.qmd
23
docs/faq.qmd
@@ -3,7 +3,6 @@ title: FAQ
|
||||
description: Frequently asked questions
|
||||
---
|
||||
|
||||
### General
|
||||
|
||||
**Q: The trainer stopped and hasn't progressed in several minutes.**
|
||||
|
||||
@@ -25,28 +24,6 @@ description: Frequently asked questions
|
||||
|
||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.
|
||||
|
||||
**Q: `Empty template generated for turn ___`**
|
||||
|
||||
> A: The `content` is empty for that turn.
|
||||
|
||||
**Q: `Could not find content start/end boundary for turn __`**
|
||||
|
||||
> A: The specific turn's start/end could not be detected. Please ensure you have set the `eos_token` following your `chat_template`. Otherwise, this could be a `chat_template` which doesn't use proper boundaries for each turn (like system). On the rare occurrence, make sure your content is not `[[dummy_message]]`. Please let us know about this.
|
||||
|
||||
**Q: `Content end boundary is before start boundary for turn ___`**
|
||||
|
||||
> A: This is an edge case which should not occur. Please create an Issue if this happens.
|
||||
|
||||
**Q: `Content end boundary is the same as start boundary for turn ___. This is likely an empty turn.`**
|
||||
|
||||
> A: This is likely an empty turn.
|
||||
|
||||
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
||||
|
||||
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Quickstart"
|
||||
title: "Getting Started with Axolotl"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
@@ -17,12 +17,12 @@ Let's start by fine-tuning a small language model using LoRA. This example uses
|
||||
Assuming `axolotl` is installed (if not, see our [Installation Guide](installation.qmd))
|
||||
|
||||
1. Download example configs:
|
||||
```bash
|
||||
```shell
|
||||
axolotl fetch examples
|
||||
```
|
||||
|
||||
2. Run the training:
|
||||
```bash
|
||||
```shell
|
||||
axolotl train examples/llama-3/lora-1b.yml
|
||||
```
|
||||
|
||||
@@ -108,7 +108,7 @@ Please consult the supported [Dataset Formats](dataset-formats/) for more detail
|
||||
|
||||
3. Run the training:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
axolotl train my_training.yml
|
||||
```
|
||||
|
||||
@@ -118,7 +118,7 @@ axolotl train my_training.yml
|
||||
|
||||
After training, test your model:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
```
|
||||
|
||||
@@ -126,7 +126,7 @@ axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
|
||||
For large datasets, preprocess first:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
axolotl preprocess my_training.yml
|
||||
```
|
||||
|
||||
@@ -134,7 +134,7 @@ axolotl preprocess my_training.yml
|
||||
|
||||
Launch a Gradio interface:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
---
|
||||
title: "Inference"
|
||||
title: "Inference Guide"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
number-sections: true
|
||||
code-tools: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
@@ -3,4 +3,263 @@ title: Template-free prompt construction
|
||||
description: "Template-free prompt construction with the `input_output` format"
|
||||
---
|
||||
|
||||
The documentation moved to [here](dataset-formats/template_free.qmd).
|
||||
<!-- TOC -->
|
||||
|
||||
- [Background](#background)
|
||||
- [Masking Inputs](#masking-inputs)
|
||||
- [You may not want prompt templates](#you-may-not-want-prompt-templates)
|
||||
- [The `input_output` format](#the-input_output-format)
|
||||
- [Usage](#usage)
|
||||
- [1. Prepare Data](#1-prepare-data)
|
||||
- [2. Use `type: input_output`](#2-use-type-input_output)
|
||||
- [3. Check the prompts](#3-check-the-prompts)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
<a id="markdown-background" name="background"></a>
|
||||
|
||||
## Background
|
||||
|
||||
<a id="markdown-masking-inputs" name="masking-inputs"></a>
|
||||
|
||||
### Masking Inputs
|
||||
|
||||
One of the most popular features of
|
||||
[axolotl](https://github.com/axolotl-ai-cloud/axolotl) is
|
||||
setting the following configuration value:
|
||||
|
||||
|
||||
```yaml
|
||||
train_on_inputs: false
|
||||
```
|
||||
|
||||
If you declare a [dataset formats](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#dataset)
|
||||
such as `alpaca` or `chatml`, axolotl knows what is an input
|
||||
(i.e. human) vs. an output (i.e. the assistant) and masks the input
|
||||
labels so that your model can focus on predicting the outputs only.
|
||||
|
||||
<a id="markdown-you-may-not-want-prompt-templates" name="you-may-not-want-prompt-templates"></a>
|
||||
|
||||
### You may not want prompt templates
|
||||
|
||||
However, there are many situations where you don't want to use one of
|
||||
these formats or templates. This is because they can:
|
||||
|
||||
- Add unnecessary boilerplate to your prompts.
|
||||
- Create artifacts like special delimiters `<|im_start|>` that can
|
||||
quickly become footguns if you don't include them correctly at
|
||||
inference time.
|
||||
- Enforce a *chat* interface when you do not want one. Sometimes you
|
||||
just want to fine-tune a model to a very specific task and do NOT
|
||||
want multi-turn conversations, roles, etc.
|
||||
- Limit you to only certain roles that the template allows.
|
||||
|
||||
<a id="markdown-the-inputoutput-format" name="the-inputoutput-format"></a>
|
||||
|
||||
### The `input_output` format
|
||||
|
||||
You can construct your prompts without a template by using the
|
||||
`input_output` format, by setting `type: input_output` in your
|
||||
configuration file like this:
|
||||
|
||||
**config.yml**
|
||||
|
||||
```yaml
|
||||
train_on_inputs: false # Mask segments of your data
|
||||
datasets:
|
||||
- path: output.jsonl
|
||||
type: input_output # use template free prompt construction
|
||||
```
|
||||
|
||||
Unlike `type: completion`, which is also template-free,
|
||||
`type: input_output` allows you to mask segments of your text. More
|
||||
details on how this works are described below.
|
||||
|
||||
<a id="markdown-usage" name="usage"></a>
|
||||
|
||||
## Usage
|
||||
|
||||
This is how you can use the `input_output` format:
|
||||
|
||||
<a id="markdown-1-prepare-data" name="1-prepare-data"></a>
|
||||
|
||||
### 1. Prepare Data
|
||||
|
||||
To use the `input_output` format, collect your data in the following
|
||||
format into a jsonl file (below is the first row from the file
|
||||
`output`.jsonl` pretty printed):
|
||||
|
||||
```bash
|
||||
$ head -n1 output.jsonl | python -m json.tool
|
||||
```
|
||||
|
||||
:::{.cell-output .cell-output-stdout}
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
:::
|
||||
|
||||
Set `label:false` when you want to mask a segment of text so that the
|
||||
model isn't trained on it. Some things to keep in mind:
|
||||
|
||||
> [!IMPORTANT]
|
||||
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
|
||||
concatenates all the segments as-is.** The tokenizer doesn't add
|
||||
anything additional. Notice how I added spaces, newlines, `<s>`
|
||||
(BOS), and `</s>` (EOS) myself.
|
||||
> 2. Make sure you check the materialized output to validate that the
|
||||
prompt is getting assembled how you like.
|
||||
|
||||
<a id="markdown-2-use-type-inputoutput" name="2-use-type-inputoutput"></a>
|
||||
|
||||
### 2. Use `type: input_output`
|
||||
|
||||
Let's materialize data with our `output.jsonl` file by setting
|
||||
`type: input_output` in our axolotl config:
|
||||
|
||||
```yaml
|
||||
# training_config.yaml
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
data_seed: 49
|
||||
seed: 49
|
||||
|
||||
datasets:
|
||||
- path: output.jsonl
|
||||
type: input_output
|
||||
val_set_size: 0.1
|
||||
|
||||
sequence_len: 896
|
||||
sample_packing: false
|
||||
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 3
|
||||
eval_batch_size: 2
|
||||
num_epochs: 1
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
```
|
||||
|
||||
You can use the following command to materialize your data. The
|
||||
`--debug` flag will print the tokens, along with the labels so you can
|
||||
verify that the correct items are being ignored:
|
||||
|
||||
```bash
|
||||
$ python -m axolotl.cli.preprocess training_config.yaml --debug
|
||||
|
||||
...
|
||||
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
|
||||
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
|
||||
|
||||
```
|
||||
|
||||
The format is `decoded_token`(`label`, `token_id`), for example,
|
||||
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
|
||||
token_id is `1`. When the label is `-100` then that token is ignored for
|
||||
training.
|
||||
|
||||
<a id="markdown-3-check-the-prompts" name="3-check-the-prompts"></a>
|
||||
|
||||
### 3. Check the prompts
|
||||
|
||||
Here is another way to check the materialized output:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from datasets import load_from_disk
|
||||
import yaml
|
||||
|
||||
directory = !ls last_run_prepared/
|
||||
with open('training_config.yaml', 'r') as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
model_id = cfg['base_model']
|
||||
tok = AutoTokenizer.from_pretrained(model_id)
|
||||
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
|
||||
```
|
||||
|
||||
```python
|
||||
>>> row = ds[0]
|
||||
>>> print(tok.decode(row['input_ids']))
|
||||
<s> Hello
|
||||
hi there!. goodbye farewell</s>
|
||||
```
|
||||
|
||||
We can check that the right tokens are ignored by comparing the labels
|
||||
to each token:
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
|
||||
zip(row['input_ids'], row['labels'])])
|
||||
```
|
||||
|
||||
| token | label | id |
|
||||
|-------|-------|-------|
|
||||
| 0 | \<s\> | 1 |
|
||||
| 1 | Hello | 22557 |
|
||||
| 2 | \\n | 13 |
|
||||
| 3 | hi | 12014 |
|
||||
| 4 | there | 736 |
|
||||
| 5 | ! | 28808 |
|
||||
| 6 | . | 28723 |
|
||||
| 7 | | 28705 |
|
||||
| 8 | good | -100 |
|
||||
| 9 | bye | -100 |
|
||||
| 10 | | -100 |
|
||||
| 11 | fare | 19111 |
|
||||
| 12 | well | 5458 |
|
||||
| 13 | \</s\>| 2 |
|
||||
|
||||
|
||||
|
||||
If we look at the input data, the above table seems correct! (The jsonl
|
||||
version is repeated below for reference):
|
||||
|
||||
|
||||
```bash
|
||||
$ head -n1 output.jsonl | python -m json.tool
|
||||
```
|
||||
|
||||
:::{.cell-output .cell-output-stdout}
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
:::
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
---
|
||||
title: "Installation"
|
||||
title: "Installation Guide"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
number-sections: true
|
||||
code-tools: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
---
|
||||
title: "LoRA Optimizations"
|
||||
description: "Custom autograd functions and Triton kernels in Axolotl for optimized LoRA fine-tuning"
|
||||
description: "Custom autograd functions and Triton kernels in Axolotl for optimized
|
||||
LoRA fine-tuning"
|
||||
---
|
||||
|
||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||
|
||||
@@ -19,5 +19,4 @@ Current support:
|
||||
- [ ] DeepSpeed
|
||||
|
||||
Untested:
|
||||
|
||||
- FSDP
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Multi-GPU"
|
||||
title: "Multi-GPU Training Guide"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
@@ -35,11 +35,7 @@ deepspeed: deepspeed_configs/zero1.json
|
||||
### Usage {#sec-deepspeed-usage}
|
||||
|
||||
```{.bash}
|
||||
# Passing arg via config
|
||||
axolotl train config.yml
|
||||
|
||||
# Passing arg via cli
|
||||
axolotl train config.yml --deepspeed deepspeed_configs/zero1.json
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
|
||||
```
|
||||
|
||||
### ZeRO Stages {#sec-zero-stages}
|
||||
@@ -74,7 +70,25 @@ For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
||||
|
||||
### Liger Kernel Integration {#sec-liger}
|
||||
|
||||
Please see [docs](custom_integrations.qmd#liger) for more info.
|
||||
::: {.callout-note}
|
||||
Liger Kernel provides efficient Triton kernels for LLM training, offering:
|
||||
|
||||
- 20% increase in multi-GPU training throughput
|
||||
- 60% reduction in memory usage
|
||||
- Compatibility with both FSDP and DeepSpeed
|
||||
:::
|
||||
|
||||
Configuration:
|
||||
|
||||
```{.yaml}
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ You will also need to have the same configuration file for your model on each ma
|
||||
Make sure the main machine is reachable by other machines.
|
||||
:::
|
||||
|
||||
## Accelerate
|
||||
# Accelerate
|
||||
|
||||
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
|
||||
|
||||
@@ -51,17 +51,17 @@ fsdp_config:
|
||||
|
||||
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
|
||||
|
||||
## Raytrain
|
||||
# Raytrain
|
||||
|
||||
Please see ray train doc [here](ray-integration.qmd).
|
||||
|
||||
## Torchrun
|
||||
# Torchrun
|
||||
|
||||
If you are using Infiniband, we recommend torchrun to utilize the full bandwidth.
|
||||
|
||||
Set the following env (change buffersize/socketname depending on your system):
|
||||
|
||||
```bash
|
||||
```yaml
|
||||
export NCCL_IB_DISABLE=0
|
||||
export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
|
||||
export NCCL_BUFFSIZE=2097152
|
||||
|
||||
@@ -13,13 +13,13 @@ Often, this timeout will happen after 30 minutes (the default setting) and is ac
|
||||
|
||||
Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
nvidia-smi nvlink --status
|
||||
```
|
||||
|
||||
To force NCCL to use NVLink, simply set this in the environment:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
export NCCL_P2P_LEVEL=NVL
|
||||
```
|
||||
|
||||
@@ -33,13 +33,13 @@ If NVLink is not available in your environment there are other options for ``NCC
|
||||
|
||||
To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
|
||||
```
|
||||
|
||||
It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:
|
||||
|
||||
```bash
|
||||
```shell
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_DEBUG_SUBSYS=ALL
|
||||
export TORCH_DISTRIBUTED_DEBUG=INFO
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Ray Train
|
||||
title: Ray Train integration
|
||||
description: How to use Axolotl with Ray Train
|
||||
---
|
||||
|
||||
@@ -9,7 +9,7 @@ With the `--use-ray` CLI flag, Axolotl will use Ray Train's [`TorchTrainer`](htt
|
||||
|
||||
## Ray cluster setup
|
||||
|
||||
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs [here](https://docs.ray.io/en/latest/cluster/getting-started.html).
|
||||
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs here: https://docs.ray.io/en/latest/cluster/getting-started.html
|
||||
|
||||
Every Ray cluster has one _head_ node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this [doc](https://docs.ray.io/en/latest/cluster/key-concepts.html#cluster-key-concepts).
|
||||
|
||||
@@ -58,11 +58,13 @@ You can find an example configuration at `configs/llama-3/lora-1b-ray.yaml`.
|
||||
The key parameters to note here are:
|
||||
|
||||
```yaml
|
||||
...
|
||||
use_ray: true
|
||||
ray_num_workers: 4
|
||||
# optional
|
||||
resources_per_worker:
|
||||
GPU: 1
|
||||
...
|
||||
```
|
||||
|
||||
- `use_ray`: This is the flag that enables the Ray Train integration. You can either use the corresponding `--use-ray` flag in the CLI or set `use_ray` in the config file.
|
||||
|
||||
116
docs/rlhf.qmd
116
docs/rlhf.qmd
@@ -3,22 +3,22 @@ title: "RLHF (Beta)"
|
||||
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-depth: 4
|
||||
toc-depth: 3
|
||||
---
|
||||
|
||||
## Overview
|
||||
# Overview
|
||||
|
||||
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
||||
feedback. Various methods include, but not limited to:
|
||||
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
||||
- [Direct Preference Optimization (DPO)](#dpo)
|
||||
- [Identity Preference Optimization (IPO)](#ipo)
|
||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
||||
|
||||
|
||||
## RLHF using Axolotl
|
||||
# RLHF using Axolotl
|
||||
|
||||
::: {.callout-important}
|
||||
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
@@ -30,7 +30,7 @@ We rely on the [TRL](https://github.com/huggingface/trl) library for implementat
|
||||
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
|
||||
:::
|
||||
|
||||
### DPO
|
||||
## DPO
|
||||
|
||||
Example config:
|
||||
|
||||
@@ -47,7 +47,7 @@ datasets:
|
||||
|
||||
DPO supports the following types with the following dataset format:
|
||||
|
||||
#### chatml.argilla
|
||||
### chatml.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -58,7 +58,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chatml.argilla_chat
|
||||
### chatml.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -73,7 +73,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chatml.icr
|
||||
### chatml.icr
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -84,7 +84,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chatml.intel
|
||||
### chatml.intel
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -95,7 +95,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chatml.prompt_pairs
|
||||
### chatml.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -106,7 +106,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chatml.ultra
|
||||
### chatml.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -123,7 +123,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### llama3.argilla
|
||||
### llama3.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -134,7 +134,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### llama3.argilla_chat
|
||||
### llama3.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -149,7 +149,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### llama3.icr
|
||||
### llama3.icr
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -160,7 +160,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### llama3.intel
|
||||
### llama3.intel
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -171,7 +171,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### llama3.prompt_pairs
|
||||
### llama3.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -182,7 +182,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### llama3.ultra
|
||||
### llama3.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -199,7 +199,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### zephyr.nectar
|
||||
### zephyr.nectar
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -218,7 +218,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chat_template.default
|
||||
### chat_template.default
|
||||
|
||||
```yaml
|
||||
rl: dpo
|
||||
@@ -264,7 +264,7 @@ Sample input format:
|
||||
}
|
||||
```
|
||||
|
||||
#### user_defined.default
|
||||
### user_defined.default
|
||||
|
||||
For custom behaviors,
|
||||
|
||||
@@ -295,7 +295,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
}
|
||||
```
|
||||
|
||||
### IPO
|
||||
## IPO
|
||||
|
||||
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
||||
|
||||
@@ -303,7 +303,7 @@ As IPO is just DPO with a different loss function, all supported options for DPO
|
||||
rl: ipo
|
||||
```
|
||||
|
||||
### ORPO
|
||||
## ORPO
|
||||
|
||||
Paper: https://arxiv.org/abs/2403.07691
|
||||
|
||||
@@ -320,7 +320,7 @@ datasets:
|
||||
|
||||
ORPO supports the following types with the following dataset format:
|
||||
|
||||
#### chat_template.argilla
|
||||
### chat_template.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -339,7 +339,7 @@ ORPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### KTO
|
||||
## KTO
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
@@ -360,7 +360,7 @@ gradient_checkpointing_kwargs:
|
||||
|
||||
KTO supports the following types with the following dataset format:
|
||||
|
||||
#### chatml.argilla
|
||||
### chatml.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -370,7 +370,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chatml.argilla_chat
|
||||
### chatml.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -383,7 +383,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chatml.intel
|
||||
### chatml.intel
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -393,7 +393,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chatml.prompt_pairs
|
||||
### chatml.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -403,7 +403,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### chatml.ultra
|
||||
### chatml.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -413,7 +413,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### llama3.argilla
|
||||
### llama3.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -423,7 +423,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### llama3.argilla_chat
|
||||
### llama3.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -434,7 +434,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### llama3.intel
|
||||
### llama3.intel
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -444,7 +444,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### llama3.prompt_pairs
|
||||
### llama3.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -454,7 +454,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### llama3.ultra
|
||||
### llama3.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -464,7 +464,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
#### user_defined.default
|
||||
### user_defined.default
|
||||
|
||||
For custom behaviors,
|
||||
|
||||
@@ -494,49 +494,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
}
|
||||
```
|
||||
|
||||
### GRPO
|
||||
|
||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||
|
||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||
|
||||
```python
|
||||
# rewards.py
|
||||
import random
|
||||
|
||||
def rand_reward_func(completions, **kwargs) -> list[float]:
|
||||
return [random.uniform(0, 1) for _ in completions]
|
||||
|
||||
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]},],
|
||||
"answer": label,
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
```
|
||||
|
||||
```yaml
|
||||
rl: grpo
|
||||
|
||||
trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 256
|
||||
use_vllm: True
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.15
|
||||
num_generations: 4
|
||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}'
|
||||
```
|
||||
|
||||
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
||||
|
||||
### Using local dataset files
|
||||
## Using local dataset files
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
@@ -547,7 +505,7 @@ datasets:
|
||||
type: chatml.intel
|
||||
```
|
||||
|
||||
### TRL auto-unwrapping for PEFT
|
||||
## TRL auto-unwrapping for PEFT
|
||||
|
||||
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
|
||||
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
---
|
||||
title: Telemetry
|
||||
description: A description of the opt-out telemetry implementation in Axolotl.
|
||||
---
|
||||
|
||||
# Telemetry in Axolotl
|
||||
|
||||
Axolotl implements anonymous telemetry to help maintainers understand how the library
|
||||
is used and where users encounter issues. This data helps prioritize features, optimize
|
||||
performance, and fix bugs.
|
||||
|
||||
## Data Collection
|
||||
|
||||
We collect:
|
||||
|
||||
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
|
||||
version, etc.
|
||||
- Hardware info: CPU count, memory, GPU count and models
|
||||
- Runtime metrics: Training progress, memory usage, timing information
|
||||
- Usage patterns: Models (from a whitelist) and configurations used
|
||||
- Error tracking: Stack traces and error messages (sanitized to remove personal
|
||||
information)
|
||||
|
||||
No personally identifiable information (PII) is collected.
|
||||
|
||||
## Implementation
|
||||
|
||||
Telemetry is implemented using PostHog and consists of:
|
||||
|
||||
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
|
||||
telemetry system and provides methods for tracking events.
|
||||
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
|
||||
sends sanitized stack traces.
|
||||
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
|
||||
runtime metrics during training.
|
||||
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
|
||||
runtime metrics telemetry.
|
||||
|
||||
The telemetry system will block training startup for 15 seconds to ensure users are
|
||||
aware of data collection, unless telemetry is explicitly enabled or disabled.
|
||||
|
||||
## Opt-Out Mechanism
|
||||
|
||||
Telemetry is **enabled by default** on an opt-out basis. To disable it, set either:
|
||||
|
||||
- `AXOLOTL_DO_NOT_TRACK=1` (Axolotl-specific)
|
||||
- `DO_NOT_TRACK=1` (Global standard; see https://consoledonottrack.com/)
|
||||
|
||||
To acknowledge and explicitly enable telemetry (and remove the warning message), set:
|
||||
`AXOLOTL_DO_NOT_TRACK=0`.
|
||||
|
||||
## Privacy
|
||||
|
||||
- All path-like config information is automatically redacted from telemetry data
|
||||
- Model information is only collected for whitelisted organizations
|
||||
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
|
||||
- Each run generates a unique anonymous ID
|
||||
- This allows us to link different telemetry events in a single same training run
|
||||
- Telemetry is only sent from the main process to avoid duplicate events
|
||||
@@ -3,12 +3,6 @@ title: "PyTorch ao"
|
||||
description: "Custom data types and layouts for training and inference"
|
||||
---
|
||||
|
||||
To use experimental optimizers (`AdamWFp8`, `AdamW4bit`, `AdamW8bit`) from Pytorch Ao, please install the package as shown below.
|
||||
|
||||
::: {.callout-tip}
|
||||
Some experimental optimizers are already present in regular Pytorch, so please re-check if you actually need this package!
|
||||
:::
|
||||
|
||||
### Installation
|
||||
|
||||
Stable Release from the PyTorch index
|
||||
|
||||
@@ -8,12 +8,6 @@ description: "Hyper-optimized QLoRA finetuning for single GPUs"
|
||||
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
|
||||
standard industry baselines.
|
||||
|
||||
::: {.callout-important}
|
||||
Due to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.
|
||||
|
||||
This will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).
|
||||
:::
|
||||
|
||||
|
||||
### Installation
|
||||
|
||||
@@ -23,7 +17,7 @@ The following will install the correct unsloth and extras from source.
|
||||
python scripts/unsloth_install.py | sh
|
||||
```
|
||||
|
||||
### Usage
|
||||
### Using unsloth w Axolotl
|
||||
|
||||
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
---
|
||||
# toc-location: right-body
|
||||
# toc-title: Table Of Contents
|
||||
# toc-expand: 2
|
||||
toc-location: right-body
|
||||
toc-title: Table Of Contents
|
||||
toc-expand: 2
|
||||
---
|
||||
|
||||
```{python}
|
||||
|
||||
@@ -7,7 +7,7 @@ mamba-ssm==1.2.0.post1
|
||||
flash-attn==2.7.4.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.3
|
||||
liger-kernel==0.5.2
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
@@ -63,6 +63,3 @@ torchao==0.7.0
|
||||
schedulefree==1.3.0
|
||||
|
||||
axolotl-contribs-lgpl==0.0.3
|
||||
|
||||
# telemetry
|
||||
posthog>=3.15.1
|
||||
|
||||
@@ -258,21 +258,25 @@ class ModalCloud(Cloud):
|
||||
|
||||
|
||||
def _preprocess(config_yaml: str, volumes=None):
|
||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/mounts"
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
run_cmd(
|
||||
"axolotl preprocess /workspace/mounts/config.yaml --dataset-processes=8",
|
||||
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/mounts"
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
if accelerate:
|
||||
accelerate_args = "--accelerate"
|
||||
else:
|
||||
@@ -281,18 +285,20 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
if num_processes := kwargs.pop("num_processes", None):
|
||||
num_processes_args = f"--num-processes {num_processes}"
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
|
||||
f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
|
||||
def _lm_eval(config_yaml: str, volumes=None):
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/mounts"
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
run_cmd(
|
||||
"axolotl lm-eval /workspace/mounts/config.yaml",
|
||||
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
@@ -14,8 +14,6 @@ import yaml
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
@@ -29,8 +27,6 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||
"""
|
||||
@@ -156,7 +152,6 @@ def prepare_plugins(cfg: DictDefault):
|
||||
plugin_manager.register(plugin_name)
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
|
||||
"""
|
||||
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
||||
@@ -176,7 +171,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
||||
# Load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
TELEMETRY_MANAGER.send_event(event_type="config-loaded", properties=cfg)
|
||||
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
@@ -220,6 +214,4 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
|
||||
setup_mlflow_env_vars(cfg)
|
||||
setup_comet_env_vars(cfg)
|
||||
|
||||
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
@@ -17,7 +17,6 @@ from axolotl.cli.args import InferenceCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.chat_templates import (
|
||||
get_chat_template,
|
||||
get_chat_template_from_config,
|
||||
@@ -43,7 +42,6 @@ def get_multi_line_input() -> str:
|
||||
return instruction
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_inference(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
@@ -137,7 +135,6 @@ def do_inference(
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_inference_gradio(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -12,13 +12,11 @@ from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
|
||||
|
||||
@@ -27,7 +27,6 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -121,7 +120,6 @@ def _distributed_checkpoint_to_merged_weights(
|
||||
return save_path_
|
||||
|
||||
|
||||
@send_errors
|
||||
def merge_fsdp_weights(
|
||||
checkpoint_dir: str,
|
||||
output_path: str,
|
||||
|
||||
@@ -18,14 +18,12 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@send_errors
|
||||
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
"""
|
||||
Preprocesses dataset specified in axolotl config.
|
||||
|
||||
@@ -10,7 +10,6 @@ from datasets import Dataset
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -45,7 +44,6 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||
)
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
@@ -105,7 +103,6 @@ def load_datasets(
|
||||
)
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_preference_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -61,8 +61,6 @@ from axolotl.core.training_args import (
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.telemetry.callbacks import TelemetryCallback
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
@@ -178,8 +176,10 @@ class TrainerBuilderBase(abc.ABC):
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
)
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||
callbacks.extend(
|
||||
[
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
|
||||
]
|
||||
)
|
||||
if self.cfg.use_comet and is_comet_available():
|
||||
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
||||
@@ -188,10 +188,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
if telemetry_manager.enabled:
|
||||
callbacks.append(TelemetryCallback())
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
|
||||
@@ -10,7 +10,6 @@ import torch
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -62,7 +61,6 @@ def evaluate_dataset(
|
||||
return metrics
|
||||
|
||||
|
||||
@send_errors
|
||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate a model on training and validation datasets
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
# Cut Cross Entropy
|
||||
|
||||
Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
|
||||
|
||||
See https://github.com/apple/ml-cross-entropy
|
||||
|
||||
## Usage
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
@@ -12,19 +8,3 @@ plugins:
|
||||
|
||||
cut_cross_entropy: true
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
@article{wijmans2024cut,
|
||||
author = {Erik Wijmans and
|
||||
Brody Huval and
|
||||
Alexander Hertzberg and
|
||||
Vladlen Koltun and
|
||||
Philipp Kr\"ahenb\"uhl},
|
||||
title = {Cut Your Losses in Large-Vocabulary Language Models},
|
||||
journal = {arXiv},
|
||||
year = {2024},
|
||||
url = {https://arxiv.org/abs/2411.09009},
|
||||
}
|
||||
```
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
See https://github.com/ironjr/grokfast
|
||||
|
||||
## Usage
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
@@ -11,14 +11,3 @@ plugins:
|
||||
grokfast_alpha: 2.0
|
||||
grokfast_lamb: 0.98
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
@article{lee2024grokfast,
|
||||
title={{Grokfast}: Accelerated Grokking by Amplifying Slow Gradients},
|
||||
author={Lee, Jaerin and Kang, Bong Gyun and Kim, Kihoon and Lee, Kyoung Mu},
|
||||
journal={arXiv preprint arXiv:2405.20233},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
# Knowledge Distillation
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- "axolotl.integrations.kd.KDPlugin"
|
||||
|
||||
kd_trainer: True
|
||||
kd_ce_alpha: 0.1
|
||||
kd_alpha: 0.9
|
||||
kd_temperature: 1.0
|
||||
|
||||
torch_compile: True # torch>=2.5.1, recommended to reduce vram
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
type: "axolotl.integrations.kd.chat_template"
|
||||
field_messages: "messages_combined"
|
||||
logprobs_field: "llm_text_generation_vllm_logprobs" # for kd only, field of logprobs
|
||||
```
|
||||
|
||||
An example dataset can be found at [`axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample`](https://huggingface.co/datasets/axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample)
|
||||
391
src/axolotl/integrations/kd/topk_logprob/bench_kl.py
Normal file
391
src/axolotl/integrations/kd/topk_logprob/bench_kl.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
benchmark utility helper for benchmarking the KL divergence triton kernel
|
||||
"""
|
||||
import gc
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.utils.benchmark import Timer
|
||||
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
|
||||
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
def benchmark_kl_div_loss_with_backward():
|
||||
# Test configurations
|
||||
batch_sizes = [1, 4]
|
||||
seq_lens = [64, 512, 2048, 4096, 8192]
|
||||
vocab_size = 32000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
for seq_len in seq_lens:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create tensors with gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the two implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Define functions for timing that include both forward and backward passes
|
||||
def run_reference():
|
||||
# Forward pass
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
# Backward pass
|
||||
loss_ref.backward()
|
||||
|
||||
def run_triton():
|
||||
# Forward pass
|
||||
# pylint: disable=duplicate-code
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
# Backward pass
|
||||
loss_triton.backward()
|
||||
|
||||
# Benchmark reference implementation (forward + backward)
|
||||
t0 = Timer(
|
||||
stmt="run_reference()",
|
||||
globals={
|
||||
"run_reference": run_reference,
|
||||
},
|
||||
)
|
||||
# Reset gradients before timing
|
||||
student_logits_ref.grad = None
|
||||
ref_time = t0.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Benchmark Triton implementation (forward + backward)
|
||||
t1 = Timer(
|
||||
stmt="run_triton()",
|
||||
globals={
|
||||
"run_triton": run_triton,
|
||||
},
|
||||
)
|
||||
# Reset gradients before timing
|
||||
student_logits_triton.grad = None
|
||||
triton_time = t1.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Compute speedup
|
||||
speedup = ref_time / triton_time if triton_time > 0 else float("inf")
|
||||
|
||||
# Store results
|
||||
results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"seq_len": seq_len,
|
||||
"reference_time_ms": ref_time,
|
||||
"triton_time_ms": triton_time,
|
||||
"speedup": speedup,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
|
||||
print(f" Reference time (fwd+bwd): {ref_time:.2f} ms")
|
||||
print(f" Triton time (fwd+bwd): {triton_time:.2f} ms")
|
||||
print(f" Speedup: {speedup:.2f}x")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def benchmark_forward_backward_separately():
|
||||
"""
|
||||
Benchmark forward and backward passes separately to identify where the speedup comes from.
|
||||
"""
|
||||
# Test configurations
|
||||
batch_sizes = [1, 4, 8]
|
||||
seq_lens = [64, 512, 2048]
|
||||
vocab_size = 32000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
detailed_results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
for seq_len in seq_lens:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create tensors with gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the two implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Forward-only reference
|
||||
def run_reference_forward():
|
||||
with torch.no_grad():
|
||||
return eager_loss(
|
||||
student_logits_ref,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
|
||||
# Forward-only triton
|
||||
def run_triton_forward():
|
||||
with torch.no_grad():
|
||||
return triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
|
||||
# Benchmark forward pass only
|
||||
|
||||
t0_fwd = Timer(
|
||||
stmt="run_reference_forward()",
|
||||
globals={
|
||||
"run_reference_forward": run_reference_forward,
|
||||
},
|
||||
)
|
||||
ref_fwd_time = t0_fwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
t1_fwd = Timer(
|
||||
stmt="run_triton_forward()",
|
||||
globals={
|
||||
"run_triton_forward": run_triton_forward,
|
||||
},
|
||||
)
|
||||
triton_fwd_time = t1_fwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Pre-compute losses for backward pass benchmarking
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
|
||||
# Backward-only reference
|
||||
def run_reference_backward():
|
||||
student_logits_ref.grad = None
|
||||
loss_ref.backward()
|
||||
|
||||
# Backward-only triton
|
||||
def run_triton_backward():
|
||||
student_logits_triton.grad = None
|
||||
loss_triton.backward()
|
||||
|
||||
# Benchmark backward pass only
|
||||
t0_bwd = Timer(
|
||||
stmt="run_reference_backward()",
|
||||
globals={
|
||||
"run_reference_backward": run_reference_backward,
|
||||
},
|
||||
)
|
||||
ref_bwd_time = t0_bwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
t1_bwd = Timer(
|
||||
stmt="run_triton_backward()",
|
||||
globals={
|
||||
"run_triton_backward": run_triton_backward,
|
||||
},
|
||||
)
|
||||
triton_bwd_time = t1_bwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Compute speedups
|
||||
fwd_speedup = (
|
||||
ref_fwd_time / triton_fwd_time if triton_fwd_time > 0 else float("inf")
|
||||
)
|
||||
bwd_speedup = (
|
||||
ref_bwd_time / triton_bwd_time if triton_bwd_time > 0 else float("inf")
|
||||
)
|
||||
total_ref_time = ref_fwd_time + ref_bwd_time
|
||||
total_triton_time = triton_fwd_time + triton_bwd_time
|
||||
total_speedup = (
|
||||
total_ref_time / total_triton_time
|
||||
if total_triton_time > 0
|
||||
else float("inf")
|
||||
)
|
||||
|
||||
# Store results
|
||||
detailed_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"seq_len": seq_len,
|
||||
"ref_forward_ms": ref_fwd_time,
|
||||
"triton_forward_ms": triton_fwd_time,
|
||||
"forward_speedup": fwd_speedup,
|
||||
"ref_backward_ms": ref_bwd_time,
|
||||
"triton_backward_ms": triton_bwd_time,
|
||||
"backward_speedup": bwd_speedup,
|
||||
"total_ref_ms": total_ref_time,
|
||||
"total_triton_ms": total_triton_time,
|
||||
"total_speedup": total_speedup,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
|
||||
print(
|
||||
f" Forward: Reference={ref_fwd_time:.2f}ms, Triton={triton_fwd_time:.2f}ms, Speedup={fwd_speedup:.2f}x"
|
||||
)
|
||||
print(
|
||||
f" Backward: Reference={ref_bwd_time:.2f}ms, Triton={triton_bwd_time:.2f}ms, Speedup={bwd_speedup:.2f}x"
|
||||
)
|
||||
print(
|
||||
f" Total: Reference={total_ref_time:.2f}ms, Triton={total_triton_time:.2f}ms, Speedup={total_speedup:.2f}x"
|
||||
)
|
||||
|
||||
return detailed_results
|
||||
|
||||
|
||||
def benchmark_memory_usage_with_backward():
|
||||
# Test configurations
|
||||
batch_sizes = [1, 2]
|
||||
seq_len = 8192
|
||||
vocab_size = 128000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
mem_results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Measure PyTorch memory usage (forward + backward)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_ref.backward()
|
||||
torch.cuda.synchronize()
|
||||
pytorch_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
# Measure Triton memory usage (forward + backward)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_triton.backward()
|
||||
torch.cuda.synchronize()
|
||||
triton_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
# Measure Triton memory usage with different chunk sizes (forward + backward)
|
||||
for n_chunks in [1, 2, 4, 8]:
|
||||
student_logits_chunk = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_chunk = triton_loss(
|
||||
student_logits_chunk,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
loss_chunk.backward()
|
||||
torch.cuda.synchronize()
|
||||
chunk_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": f"Triton (chunks={n_chunks})",
|
||||
"memory_mb": chunk_mem,
|
||||
}
|
||||
)
|
||||
|
||||
# Store results
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": "PyTorch",
|
||||
"memory_mb": pytorch_mem,
|
||||
}
|
||||
)
|
||||
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": "Triton (default)",
|
||||
"memory_mb": triton_mem,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size} (with backward pass)")
|
||||
print(f" PyTorch memory: {pytorch_mem:.2f} MB")
|
||||
print(f" Triton memory: {triton_mem:.2f} MB")
|
||||
print(f" Memory reduction: {(1 - triton_mem/pytorch_mem)*100:.2f}%")
|
||||
|
||||
return mem_results
|
||||
|
||||
|
||||
def main():
|
||||
print("Running benchmarks with forward and backward passes...")
|
||||
benchmark_kl_div_loss_with_backward()
|
||||
clean()
|
||||
|
||||
print("\nRunning detailed forward/backward benchmarks...")
|
||||
# benchmark_forward_backward_separately()
|
||||
# clean()
|
||||
|
||||
print("\nRunning memory usage benchmarks with backward passes...")
|
||||
benchmark_memory_usage_with_backward()
|
||||
clean()
|
||||
|
||||
|
||||
def clean():
|
||||
for _ in range(5):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
750
src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py
Normal file
750
src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py
Normal file
@@ -0,0 +1,750 @@
|
||||
"""
|
||||
Optimized Triton kernel for KL divergence loss between teacher and student models.
|
||||
"""
|
||||
# pylint: disable=invalid-name,unused-argument
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_logsumexp_logprobs_kernel(
|
||||
student_logits_ptr, # Input logits in original dtype
|
||||
student_logprobs_ptr, # Output logprobs (float32)
|
||||
token_ids_ptr, # Token IDs for top-k
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
temperature,
|
||||
stride_l_b,
|
||||
stride_l_s,
|
||||
stride_l_v,
|
||||
stride_lp_b,
|
||||
stride_lp_s,
|
||||
stride_lp_k,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Fused kernel that computes logsumexp and logprobs for topk tokens.
|
||||
All computations are done in float32 for numerical stability.
|
||||
"""
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Compute logsumexp over the vocabulary
|
||||
max_val = -float("inf")
|
||||
|
||||
# Phase 1: Find max value across vocabulary
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
# Create block indices and mask
|
||||
block_size = min(BLOCK_SIZE, V - v_offset)
|
||||
block_idx = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_idx < block_size
|
||||
|
||||
# Load logits block and convert to float32 in-place
|
||||
ptrs = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ (v_offset + block_idx) * stride_l_v
|
||||
)
|
||||
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
block_logits = block_logits / temperature
|
||||
|
||||
# Update max value
|
||||
block_max = tl.max(block_logits, axis=0)
|
||||
max_val = tl.maximum(max_val, block_max)
|
||||
|
||||
# Phase 2: Compute sum of exp(logits - max_val)
|
||||
sum_exp = 0.0
|
||||
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
# Create block indices and mask
|
||||
block_size = min(BLOCK_SIZE, V - v_offset)
|
||||
block_idx = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_idx < block_size
|
||||
|
||||
# Load logits block and convert to float32 in-place
|
||||
ptrs = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ (v_offset + block_idx) * stride_l_v
|
||||
)
|
||||
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
block_logits = block_logits / temperature
|
||||
|
||||
# Compute exp(logits - max_val) and add to sum
|
||||
block_exp = tl.exp(block_logits - max_val)
|
||||
sum_exp += tl.sum(block_exp * mask, axis=0)
|
||||
|
||||
# Compute final logsumexp
|
||||
logsumexp = max_val + tl.log(sum_exp)
|
||||
|
||||
# Phase 3: Compute and store logprobs for the top-k tokens
|
||||
token_ids_base = token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
logprobs_base = (
|
||||
student_logprobs_ptr + batch_idx * stride_lp_b + seq_idx * stride_lp_s
|
||||
)
|
||||
|
||||
for k in range(K):
|
||||
# Load token ID for position k
|
||||
token_id = tl.load(token_ids_base + k * stride_t_k)
|
||||
|
||||
# Load the corresponding logit and convert to float32
|
||||
token_logit_ptr = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ token_id * stride_l_v
|
||||
)
|
||||
token_logit = tl.load(token_logit_ptr).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
token_logit = token_logit / temperature
|
||||
|
||||
# Compute logprob directly: logit - logsumexp
|
||||
token_logprob = token_logit - logsumexp
|
||||
|
||||
# Store the result
|
||||
tl.store(logprobs_base + k * stride_lp_k, token_logprob)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def grad_softmax_kernel(
|
||||
grad_student_logits_ptr,
|
||||
target_token_ids_ptr,
|
||||
teacher_probs_ptr,
|
||||
student_probs_ptr,
|
||||
mask_ptr,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
scale,
|
||||
stride_gl_b,
|
||||
stride_gl_s,
|
||||
stride_gl_v,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
stride_p_b,
|
||||
stride_p_s,
|
||||
stride_p_k,
|
||||
stride_sp_b,
|
||||
stride_sp_s,
|
||||
stride_sp_k,
|
||||
stride_m_b,
|
||||
stride_m_s,
|
||||
stride_m_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Base pointers for this (batch, seq) pair
|
||||
grad_logits_base = (
|
||||
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
|
||||
)
|
||||
token_ids_base = (
|
||||
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
)
|
||||
teacher_probs_base = (
|
||||
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
|
||||
)
|
||||
student_probs_base = (
|
||||
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
|
||||
)
|
||||
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
|
||||
|
||||
# Process each teacher probability one at a time, computing all gradients for it
|
||||
for k in range(0, K):
|
||||
# Load data for current position k
|
||||
teacher_prob = tl.load(teacher_probs_base + k * stride_p_k)
|
||||
student_prob_k = tl.load(student_probs_base + k * stride_sp_k)
|
||||
mask_val = tl.load(mask_base + k * stride_m_k)
|
||||
|
||||
# Precompute the self-influence term (multiplied by scale)
|
||||
self_term = teacher_prob * (1.0 - student_prob_k) * scale
|
||||
|
||||
# Calculate gradient contributions for all positions j
|
||||
for j in range(0, K):
|
||||
token_id_j = tl.load(token_ids_base + j * stride_t_k)
|
||||
student_prob_j = tl.load(student_probs_base + j * stride_sp_k)
|
||||
mask_j = tl.load(mask_base + j * stride_m_k)
|
||||
|
||||
# Calculate the masking factor
|
||||
combined_mask = mask_val * mask_j
|
||||
|
||||
# Determine if this is a diagonal or off-diagonal term
|
||||
is_k_equals_j = tl.where(k == j, 1.0, 0.0)
|
||||
|
||||
# Compute the gradient contribution
|
||||
# For diagonal (k==j): -teacher_prob * (1-student_prob_k) * scale * mask
|
||||
# For off-diagonal: -(-teacher_prob * student_prob_j) * scale * mask
|
||||
grad_contribution = (
|
||||
-(
|
||||
self_term * is_k_equals_j
|
||||
- teacher_prob * student_prob_j * scale * (1.0 - is_k_equals_j)
|
||||
)
|
||||
* combined_mask
|
||||
)
|
||||
|
||||
# Atomically update the gradient for this token
|
||||
tl.atomic_add(
|
||||
grad_logits_base + token_id_j * stride_gl_v, grad_contribution
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def grad_topk_softmax_kernel(
|
||||
grad_student_logits_ptr,
|
||||
student_logits_ptr,
|
||||
target_token_ids_ptr,
|
||||
teacher_probs_ptr,
|
||||
student_probs_ptr,
|
||||
mask_ptr,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
scale,
|
||||
stride_gl_b,
|
||||
stride_gl_s,
|
||||
stride_gl_v,
|
||||
stride_l_b,
|
||||
stride_l_s,
|
||||
stride_l_v,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
stride_p_b,
|
||||
stride_p_s,
|
||||
stride_p_k,
|
||||
stride_sp_b,
|
||||
stride_sp_s,
|
||||
stride_sp_k,
|
||||
stride_m_b,
|
||||
stride_m_s,
|
||||
stride_m_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Base pointers for this (batch, seq) pair
|
||||
grad_logits_base = (
|
||||
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
|
||||
)
|
||||
# logits_base = student_logits_ptr + batch_idx * stride_l_b + seq_idx * stride_l_s
|
||||
token_ids_base = (
|
||||
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
)
|
||||
teacher_probs_base = (
|
||||
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
|
||||
)
|
||||
student_probs_base = (
|
||||
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
|
||||
)
|
||||
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
|
||||
|
||||
# Load all token IDs, probs and masks for this position
|
||||
token_ids = tl.zeros([K], dtype=tl.int32)
|
||||
teacher_probs = tl.zeros([K], dtype=tl.float32)
|
||||
student_probs = tl.zeros([K], dtype=tl.float32)
|
||||
masks = tl.zeros([K], dtype=tl.float32)
|
||||
|
||||
for k in range(K):
|
||||
token_ids[k] = tl.load(token_ids_base + k * stride_t_k)
|
||||
teacher_probs[k] = tl.load(teacher_probs_base + k * stride_p_k)
|
||||
student_probs[k] = tl.load(student_probs_base + k * stride_sp_k)
|
||||
masks[k] = tl.load(mask_base + k * stride_m_k)
|
||||
|
||||
# Process gradients for all tokens in this position
|
||||
for k in range(K):
|
||||
# token_id = token_ids[k]
|
||||
mask_k = masks[k]
|
||||
|
||||
# Skip computation if mask is zero by multiplying gradient by mask
|
||||
for j in range(K):
|
||||
other_token_id = token_ids[j]
|
||||
mask_j = masks[j]
|
||||
combined_mask = mask_k * mask_j
|
||||
|
||||
# Compute gradient differently for diagonal vs off-diagonal entries
|
||||
# Using * 1.0 to convert boolean to float
|
||||
is_diagonal = tl.where(j == k, 1.0, 0.0)
|
||||
|
||||
# Self influence: gradient = teacher_prob * (1 - student_prob)
|
||||
self_grad = teacher_probs[k] * (1.0 - student_probs[k]) * is_diagonal
|
||||
|
||||
# Cross influence: gradient = -teacher_prob[k] * student_prob[j]
|
||||
cross_grad = -teacher_probs[k] * student_probs[j] * (1.0 - is_diagonal)
|
||||
|
||||
# Combined gradient scaled by mask
|
||||
grad_val = (self_grad + cross_grad) * scale * combined_mask
|
||||
|
||||
tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val)
|
||||
|
||||
|
||||
# Triton-accelerated implementation of KL divergence loss for top-k tokens
|
||||
# Chunking helper functions for handling long sequences
|
||||
def chunk_tensor(
|
||||
tensor: torch.Tensor, max_seq_len: int
|
||||
) -> Tuple[torch.Tensor, Optional[int]]:
|
||||
"""Split a tensor along sequence dimension if needed."""
|
||||
_, seq_len, *__ = tensor.shape
|
||||
|
||||
if seq_len <= max_seq_len:
|
||||
return tensor, None
|
||||
|
||||
num_chunks = (seq_len + max_seq_len - 1) // max_seq_len
|
||||
chunks = []
|
||||
|
||||
for i in range(num_chunks):
|
||||
start_idx = i * max_seq_len
|
||||
end_idx = min((i + 1) * max_seq_len, seq_len)
|
||||
chunks.append(tensor[:, start_idx:end_idx, ...])
|
||||
|
||||
return chunks, num_chunks
|
||||
|
||||
|
||||
def merge_chunks(chunks: list, original_shape: torch.Size):
|
||||
"""Merge chunks back into a single tensor with original shape."""
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
|
||||
# Triton-accelerated implementation of KL divergence loss for top-k tokens
|
||||
class TopKKLDivergence(torch.autograd.Function):
|
||||
"""
|
||||
Autograd function for KL divergence loss between top-k logprobs
|
||||
with support for chunking to handle very long sequences.
|
||||
"""
|
||||
|
||||
# Max sequence length to process in a single kernel launch
|
||||
# This is a tunable parameter that might need adjustment based on GPU memory
|
||||
MAX_SEQ_LEN = 8192
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch=-1,
|
||||
kd_temperature=1.0,
|
||||
top_k_before_softmax=0,
|
||||
):
|
||||
"""
|
||||
Forward pass for KL divergence loss between top-k logprobs with chunking.
|
||||
"""
|
||||
# Only convert target_logprobs to float, leave student_logits as is
|
||||
target_logprobs = target_logprobs.float()
|
||||
|
||||
# Get dimensions
|
||||
batch_size, _, vocab_size = student_logits.shape
|
||||
_, teacher_seq_len, top_k = target_token_ids.shape
|
||||
|
||||
# Slice student logits to match teacher sequence length
|
||||
student_logits_for_kd = student_logits[:, :teacher_seq_len, :]
|
||||
|
||||
# Store original values for backward pass
|
||||
ctx.original_seq_len = teacher_seq_len
|
||||
ctx.original_dtype = student_logits.dtype
|
||||
|
||||
# Apply chunking for long sequences
|
||||
if teacher_seq_len > TopKKLDivergence.MAX_SEQ_LEN:
|
||||
# Chunk the inputs
|
||||
student_logits_chunks, num_chunks = chunk_tensor(
|
||||
student_logits_for_kd, TopKKLDivergence.MAX_SEQ_LEN
|
||||
)
|
||||
target_token_ids_chunks, _ = chunk_tensor(
|
||||
target_token_ids, TopKKLDivergence.MAX_SEQ_LEN
|
||||
)
|
||||
# target_logprobs_chunks, _ = chunk_tensor(
|
||||
# target_logprobs, TopKKLDivergence.MAX_SEQ_LEN
|
||||
# )
|
||||
# target_mask_chunks, _ = chunk_tensor(
|
||||
# target_mask, TopKKLDivergence.MAX_SEQ_LEN
|
||||
# )
|
||||
|
||||
# Process each chunk
|
||||
student_logprobs_chunks = []
|
||||
student_probs_chunks = []
|
||||
|
||||
for i in range(num_chunks):
|
||||
chunk_logits = student_logits_chunks[i]
|
||||
chunk_token_ids = target_token_ids_chunks[i]
|
||||
chunk_seq_len = chunk_logits.shape[1]
|
||||
|
||||
if top_k_before_softmax:
|
||||
# Apply temperature to student logits
|
||||
if kd_temperature != 1.0:
|
||||
chunk_logits = chunk_logits / kd_temperature
|
||||
|
||||
# Gather student logits for top-k tokens
|
||||
chunk_logits_topk = torch.gather(
|
||||
chunk_logits, dim=-1, index=chunk_token_ids
|
||||
)
|
||||
|
||||
# Compute softmax over gathered logits
|
||||
chunk_logprobs_topk = torch.log_softmax(chunk_logits_topk, dim=-1)
|
||||
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
|
||||
else:
|
||||
# Allocate output tensor for logprobs directly (always in float32)
|
||||
chunk_logprobs_topk = torch.empty(
|
||||
(batch_size, chunk_seq_len, top_k),
|
||||
dtype=torch.float32,
|
||||
device=chunk_logits.device,
|
||||
)
|
||||
|
||||
# Launch fused kernel directly
|
||||
grid = (batch_size * chunk_seq_len,)
|
||||
fused_logsumexp_logprobs_kernel[grid](
|
||||
chunk_logits.contiguous(),
|
||||
chunk_logprobs_topk,
|
||||
chunk_token_ids.contiguous(),
|
||||
batch_size,
|
||||
chunk_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
kd_temperature,
|
||||
chunk_logits.stride(0),
|
||||
chunk_logits.stride(1),
|
||||
chunk_logits.stride(2),
|
||||
chunk_logprobs_topk.stride(0),
|
||||
chunk_logprobs_topk.stride(1),
|
||||
chunk_logprobs_topk.stride(2),
|
||||
chunk_token_ids.stride(0),
|
||||
chunk_token_ids.stride(1),
|
||||
chunk_token_ids.stride(2),
|
||||
min(1024, triton.next_power_of_2(vocab_size)),
|
||||
)
|
||||
|
||||
# Calculate probs from logprobs
|
||||
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
|
||||
|
||||
# Store results
|
||||
student_logprobs_chunks.append(chunk_logprobs_topk)
|
||||
student_probs_chunks.append(chunk_probs_topk)
|
||||
|
||||
# Merge results
|
||||
student_logprobs_topk = torch.cat(student_logprobs_chunks, dim=1)
|
||||
student_probs_topk = torch.cat(student_probs_chunks, dim=1)
|
||||
|
||||
# Save chunking info for backward pass
|
||||
ctx.used_chunking = True
|
||||
ctx.num_chunks = num_chunks
|
||||
|
||||
else:
|
||||
# Original code path for shorter sequences
|
||||
if top_k_before_softmax:
|
||||
# Apply temperature to student logits
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_for_kd = student_logits_for_kd / kd_temperature
|
||||
|
||||
# Gather student logits for top-k tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
)
|
||||
|
||||
# Compute softmax over gathered logits
|
||||
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
|
||||
student_probs_topk = torch.exp(student_logprobs_topk)
|
||||
else:
|
||||
# Allocate output tensor for logprobs directly (always in float32)
|
||||
student_logprobs_topk = torch.empty(
|
||||
(batch_size, teacher_seq_len, top_k),
|
||||
dtype=torch.float32,
|
||||
device=student_logits.device,
|
||||
)
|
||||
|
||||
# Launch fused kernel directly
|
||||
grid = (batch_size * teacher_seq_len,)
|
||||
fused_logsumexp_logprobs_kernel[grid](
|
||||
student_logits_for_kd.contiguous(),
|
||||
student_logprobs_topk,
|
||||
target_token_ids.contiguous(),
|
||||
batch_size,
|
||||
teacher_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
kd_temperature,
|
||||
student_logits_for_kd.stride(0),
|
||||
student_logits_for_kd.stride(1),
|
||||
student_logits_for_kd.stride(2),
|
||||
student_logprobs_topk.stride(0),
|
||||
student_logprobs_topk.stride(1),
|
||||
student_logprobs_topk.stride(2),
|
||||
target_token_ids.stride(0),
|
||||
target_token_ids.stride(1),
|
||||
target_token_ids.stride(2),
|
||||
min(1024, triton.next_power_of_2(vocab_size)),
|
||||
)
|
||||
|
||||
# Calculate probs from logprobs
|
||||
student_probs_topk = torch.exp(student_logprobs_topk)
|
||||
|
||||
# No chunking used
|
||||
ctx.used_chunking = False
|
||||
|
||||
# Save tensors for backward pass
|
||||
ctx.save_for_backward(
|
||||
student_logits_for_kd,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
student_probs_topk,
|
||||
)
|
||||
ctx.kd_temperature = kd_temperature
|
||||
ctx.top_k_before_softmax = top_k_before_softmax
|
||||
ctx.num_items_in_batch = num_items_in_batch
|
||||
|
||||
# Convert mask to boolean
|
||||
valid_mask = target_mask.bool()
|
||||
|
||||
# Extract valid tokens only - this is where the error was happening
|
||||
# Use cloned contiguous tensors and explicit indexing for safety
|
||||
student_logprobs_flat = student_logprobs_topk.view(-1, top_k)
|
||||
target_logprobs_flat = target_logprobs.view(-1, top_k)
|
||||
valid_mask_flat = valid_mask.view(-1, top_k)
|
||||
|
||||
# Gather valid indices explicitly to avoid illegal memory access
|
||||
valid_indices = torch.nonzero(valid_mask_flat.view(-1)).squeeze(-1)
|
||||
student_logprobs_valid = torch.index_select(
|
||||
student_logprobs_flat.view(-1), 0, valid_indices
|
||||
)
|
||||
target_logprobs_valid = torch.index_select(
|
||||
target_logprobs_flat.view(-1), 0, valid_indices
|
||||
)
|
||||
|
||||
# Convert teacher logprobs to probabilities
|
||||
teacher_probs_valid = torch.exp(target_logprobs_valid)
|
||||
|
||||
# Compute KL divergence loss
|
||||
token_losses = teacher_probs_valid * (
|
||||
target_logprobs_valid - student_logprobs_valid
|
||||
)
|
||||
kd_loss = token_losses.sum()
|
||||
|
||||
# Apply temperature scaling
|
||||
# pylint: disable=duplicate-code
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items or valid tokens
|
||||
if num_items_in_batch > 0:
|
||||
kd_loss = kd_loss / float(num_items_in_batch)
|
||||
else:
|
||||
num_valid_tokens = valid_indices.numel()
|
||||
kd_loss = kd_loss / float(num_valid_tokens if num_valid_tokens > 0 else 1)
|
||||
|
||||
return kd_loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
Optimized backward pass for KL divergence loss with proper dtype handling and chunking.
|
||||
"""
|
||||
(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
student_probs,
|
||||
) = ctx.saved_tensors
|
||||
kd_temperature = ctx.kd_temperature
|
||||
num_items_in_batch = ctx.num_items_in_batch
|
||||
original_dtype = ctx.original_dtype
|
||||
|
||||
# Get dimensions
|
||||
batch_size, _, vocab_size = student_logits.shape
|
||||
_, teacher_seq_len, top_k = target_token_ids.shape
|
||||
|
||||
# Initialize gradient tensor in float32 to support atomic operations
|
||||
grad_student_logits = torch.zeros_like(student_logits, dtype=torch.float32)
|
||||
|
||||
# Compute scaling factor
|
||||
scale = grad_output.item()
|
||||
|
||||
# Apply temperature scaling from forward pass
|
||||
if kd_temperature != 1.0:
|
||||
scale = scale * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items or valid tokens
|
||||
if num_items_in_batch > 0:
|
||||
scale = scale / float(num_items_in_batch)
|
||||
else:
|
||||
scale = scale / float(target_mask.sum().item())
|
||||
|
||||
# Apply chain rule for temperature scaling (1/temperature)
|
||||
if kd_temperature != 1.0:
|
||||
scale = scale / kd_temperature
|
||||
|
||||
# Convert teacher logprobs to probabilities
|
||||
teacher_probs = torch.exp(target_logprobs)
|
||||
|
||||
# Use chunking for the backward pass if used in forward
|
||||
if getattr(ctx, "used_chunking", False):
|
||||
num_chunks = ctx.num_chunks
|
||||
max_seq = TopKKLDivergence.MAX_SEQ_LEN
|
||||
|
||||
# Process each chunk
|
||||
for i in range(num_chunks):
|
||||
start_idx = i * max_seq
|
||||
end_idx = min((i + 1) * max_seq, teacher_seq_len)
|
||||
chunk_len = end_idx - start_idx
|
||||
|
||||
# Get chunk slices
|
||||
# student_logits_chunk = student_logits[:, start_idx:end_idx, :]
|
||||
target_token_ids_chunk = target_token_ids[:, start_idx:end_idx, :]
|
||||
teacher_probs_chunk = teacher_probs[:, start_idx:end_idx, :]
|
||||
student_probs_chunk = student_probs[:, start_idx:end_idx, :]
|
||||
target_mask_chunk = target_mask[:, start_idx:end_idx, :]
|
||||
grad_student_logits_chunk = grad_student_logits[:, start_idx:end_idx, :]
|
||||
|
||||
# Launch gradient computation kernel for this chunk
|
||||
grid = (batch_size * chunk_len,)
|
||||
grad_softmax_kernel[grid](
|
||||
grad_student_logits_chunk.contiguous(),
|
||||
target_token_ids_chunk.contiguous(),
|
||||
teacher_probs_chunk.contiguous(),
|
||||
student_probs_chunk.contiguous(),
|
||||
target_mask_chunk.contiguous(),
|
||||
batch_size,
|
||||
chunk_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
scale,
|
||||
grad_student_logits_chunk.stride(0),
|
||||
grad_student_logits_chunk.stride(1),
|
||||
grad_student_logits_chunk.stride(2),
|
||||
target_token_ids_chunk.stride(0),
|
||||
target_token_ids_chunk.stride(1),
|
||||
target_token_ids_chunk.stride(2),
|
||||
teacher_probs_chunk.stride(0),
|
||||
teacher_probs_chunk.stride(1),
|
||||
teacher_probs_chunk.stride(2),
|
||||
student_probs_chunk.stride(0),
|
||||
student_probs_chunk.stride(1),
|
||||
student_probs_chunk.stride(2),
|
||||
target_mask_chunk.stride(0),
|
||||
target_mask_chunk.stride(1),
|
||||
target_mask_chunk.stride(2),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
|
||||
# Update the gradient tensor (already in-place)
|
||||
else:
|
||||
# Original code path for shorter sequences
|
||||
# Launch gradient computation kernel
|
||||
grid = (batch_size * teacher_seq_len,)
|
||||
grad_softmax_kernel[grid](
|
||||
grad_student_logits.contiguous(),
|
||||
target_token_ids.contiguous(),
|
||||
teacher_probs.contiguous(),
|
||||
student_probs.contiguous(),
|
||||
target_mask.contiguous(),
|
||||
batch_size,
|
||||
teacher_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
scale,
|
||||
grad_student_logits.stride(0),
|
||||
grad_student_logits.stride(1),
|
||||
grad_student_logits.stride(2),
|
||||
target_token_ids.stride(0),
|
||||
target_token_ids.stride(1),
|
||||
target_token_ids.stride(2),
|
||||
teacher_probs.stride(0),
|
||||
teacher_probs.stride(1),
|
||||
teacher_probs.stride(2),
|
||||
student_probs.stride(0),
|
||||
student_probs.stride(1),
|
||||
student_probs.stride(2),
|
||||
target_mask.stride(0),
|
||||
target_mask.stride(1),
|
||||
target_mask.stride(2),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
|
||||
# Convert gradient back to original dtype if needed
|
||||
if original_dtype != torch.float32:
|
||||
grad_student_logits = grad_student_logits.to(original_dtype)
|
||||
|
||||
# Return gradients for student_logits and None for other inputs
|
||||
return grad_student_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
# Wrapper function for chunked computation
|
||||
def loss(
|
||||
student_logits: torch.Tensor,
|
||||
target_token_ids: torch.Tensor,
|
||||
target_logprobs: torch.Tensor,
|
||||
target_mask: torch.Tensor,
|
||||
num_items_in_batch: int = -1,
|
||||
kd_temperature: float = 1.0,
|
||||
top_k_before_softmax: int = 0,
|
||||
max_seq_len: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Triton-accelerated Memory-efficient KL divergence loss computation for knowledge distillation
|
||||
with support for very long sequences.
|
||||
|
||||
Args:
|
||||
student_logits: Student logits [B, seq_len, vocab_size]
|
||||
target_token_ids: Teacher token IDs [B, seq_len, top_k]
|
||||
target_logprobs: Teacher logprobs [B, seq_len, top_k]
|
||||
target_mask: Token mask [B, seq_len, top_k]
|
||||
num_items_in_batch: Number of items for normalization (-1 for auto)
|
||||
kd_temperature: Temperature for KD
|
||||
top_k_before_softmax: Flag for softmax application order
|
||||
max_seq_len: Override default MAX_SEQ_LEN value for chunking
|
||||
"""
|
||||
# Allow overriding the max sequence length
|
||||
if max_seq_len is not None and max_seq_len > 0:
|
||||
TopKKLDivergence.MAX_SEQ_LEN = max_seq_len
|
||||
|
||||
total_loss = TopKKLDivergence.apply(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
-1 if num_items_in_batch <= 0 else num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
|
||||
return total_loss
|
||||
67
src/axolotl/integrations/kd/topk_logprob/logsumexp.py
Normal file
67
src/axolotl/integrations/kd/topk_logprob/logsumexp.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Optimized Triton kernels for logsumexp
|
||||
"""
|
||||
# pylint: disable=invalid-name,unused-argument
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# Helper function for computing logsumexp
|
||||
@triton.jit
|
||||
def logsumexp_kernel(
|
||||
logits_ptr,
|
||||
output_ptr,
|
||||
B,
|
||||
S,
|
||||
V, # batch size, seq len, vocab size
|
||||
stride_b,
|
||||
stride_s,
|
||||
stride_v,
|
||||
out_stride_b,
|
||||
out_stride_s,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Program ID
|
||||
# pylint: disable=duplicate-code
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Pointers
|
||||
logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s
|
||||
|
||||
# Find maximum for numerical stability
|
||||
max_val = -float("inf")
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
v_size = min(BLOCK_SIZE, V - v_offset)
|
||||
mask = tl.arange(0, BLOCK_SIZE) < v_size
|
||||
|
||||
logits_block = tl.load(
|
||||
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
|
||||
mask=mask,
|
||||
other=-float("inf"),
|
||||
)
|
||||
max_val = tl.maximum(max_val, tl.max(logits_block, axis=0))
|
||||
|
||||
# Compute sum of exp(logit - max_val)
|
||||
sum_exp = 0.0
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
v_size = min(BLOCK_SIZE, V - v_offset)
|
||||
mask = tl.arange(0, BLOCK_SIZE) < v_size
|
||||
|
||||
logits_block = tl.load(
|
||||
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
|
||||
mask=mask,
|
||||
other=-float("inf"),
|
||||
)
|
||||
sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0)
|
||||
|
||||
# Compute logsumexp
|
||||
result = max_val + tl.log(sum_exp)
|
||||
|
||||
# Store result
|
||||
tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result)
|
||||
@@ -20,6 +20,7 @@ from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||
from .topk_logprob.forward_kl_triton import loss as topk_kd_loss_triton
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
@@ -85,7 +86,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
else:
|
||||
loss_kd = topk_kd_loss(
|
||||
loss_fn = (
|
||||
topk_kd_loss
|
||||
if self.args.kd_top_k_before_softmax
|
||||
else topk_kd_loss_triton
|
||||
)
|
||||
loss_kd = loss_fn(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
# Liger Kernel Integration
|
||||
|
||||
Liger Kernel provides efficient Triton kernels for LLM training, offering:
|
||||
|
||||
- 20% increase in multi-GPU training throughput
|
||||
- 60% reduction in memory usage
|
||||
- Compatibility with both FSDP and DeepSpeed
|
||||
|
||||
See https://github.com/linkedin/Liger-Kernel
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
@article{hsu2024ligerkernelefficienttriton,
|
||||
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
|
||||
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
|
||||
year={2024},
|
||||
eprint={2410.10989},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.LG},
|
||||
url={https://arxiv.org/abs/2410.10989},
|
||||
journal={arXiv preprint arXiv:2410.10989},
|
||||
}
|
||||
```
|
||||
@@ -1,10 +1,6 @@
|
||||
# LM Eval Harness
|
||||
|
||||
Run evaluation on model using the popular lm-evaluation-harness library.
|
||||
|
||||
See https://github.com/EleutherAI/lm-evaluation-harness
|
||||
|
||||
## Usage
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
@@ -14,22 +10,4 @@ lm_eval_tasks:
|
||||
- gsm8k
|
||||
- hellaswag
|
||||
- arc_easy
|
||||
|
||||
lm_eval_batch_size: # Batch size for evaluation
|
||||
output_dir: # Directory to save evaluation results
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
@misc{eval-harness,
|
||||
author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},
|
||||
title = {A framework for few-shot language model evaluation},
|
||||
month = 07,
|
||||
year = 2024,
|
||||
publisher = {Zenodo},
|
||||
version = {v0.4.3},
|
||||
doi = {10.5281/zenodo.12608602},
|
||||
url = {https://zenodo.org/records/12608602}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
# Spectrum: Targeted Training on Signal to Noise Ratio
|
||||
## Spectrum: Targeted Training on Signal to Noise Ratio
|
||||
|
||||
by Eric Hartford, Lucas Atkins, Fernando Fernandes, David Golchinfar
|
||||
|
||||
This plugin contains code to freeze the bottom fraction of modules in a model, based on the Signal-to-Noise Ratio (SNR).
|
||||
|
||||
See https://github.com/cognitivecomputations/spectrum
|
||||
|
||||
## Overview
|
||||
### Overview
|
||||
|
||||
Spectrum is a tool for scanning and evaluating the Signal-to-Noise Ratio (SNR) of layers in large language models.
|
||||
By identifying the top n% of layers with the highest SNR, you can optimize training efficiency.
|
||||
|
||||
## Usage
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
@@ -21,17 +19,3 @@ spectrum_top_fraction: 0.5
|
||||
# Optional if using a pre-scanned model as your base_model. Useful if using a model mirror
|
||||
spectrum_model_name: meta-llama/Meta-Llama-3.1-8B
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
@misc{hartford2024spectrumtargetedtrainingsignal,
|
||||
title={Spectrum: Targeted Training on Signal to Noise Ratio},
|
||||
author={Eric Hartford and Lucas Atkins and Fernando Fernandes Neto and David Golchinfar},
|
||||
year={2024},
|
||||
eprint={2406.06623},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.LG},
|
||||
url={https://arxiv.org/abs/2406.06623},
|
||||
}
|
||||
```
|
||||
|
||||
@@ -25,7 +25,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"gemmoe",
|
||||
"starcoder2",
|
||||
"deepseek_v2",
|
||||
"deepseek_v3",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
"""Trainer callbacks for reporting runtime metrics at regular intervals."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.telemetry.runtime_metrics import RuntimeMetricsTracker
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
TIME_SINCE_LAST = 30
|
||||
|
||||
|
||||
class TelemetryCallback(TrainerCallback):
|
||||
"""
|
||||
Trainer callback for tracking and reporting runtime metrics.
|
||||
|
||||
This callback tracks training progress, runtime, and memory usage,
|
||||
sending telemetry at configurable intervals.
|
||||
"""
|
||||
|
||||
report_interval_steps: int = 100
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the metrics callback."""
|
||||
self.tracker = RuntimeMetricsTracker()
|
||||
self.telemetry_manager = TelemetryManager.get_instance()
|
||||
self.current_epoch = -1
|
||||
self.start_time = time.time()
|
||||
self.last_report_time = None
|
||||
self.last_report_step = 0
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""Handle training start."""
|
||||
self.telemetry_manager.send_event(event_type="train-started")
|
||||
|
||||
def on_train_end(
|
||||
self,
|
||||
args: TrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState,
|
||||
control: TrainerControl, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""Handle training end."""
|
||||
# Send training completion event
|
||||
self.telemetry_manager.send_event(
|
||||
event_type="train-ended",
|
||||
properties={
|
||||
"loss": state.log_history[-1].get("loss", 0)
|
||||
if state.log_history
|
||||
else None,
|
||||
"learning_rate": state.log_history[-1].get("learning_rate", 0)
|
||||
if state.log_history
|
||||
else None,
|
||||
}
|
||||
| self.tracker.metrics.to_dict(),
|
||||
)
|
||||
|
||||
def on_epoch_begin(
|
||||
self,
|
||||
args: TrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""Handle epoch start."""
|
||||
self.current_epoch += 1
|
||||
self.tracker.start_epoch(self.current_epoch)
|
||||
|
||||
def on_epoch_end(
|
||||
self,
|
||||
args: TrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""Handle epoch end."""
|
||||
self.tracker.end_epoch(self.current_epoch)
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState,
|
||||
control: TrainerControl, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""Handle step end."""
|
||||
step = state.global_step
|
||||
self.tracker.update_step(step)
|
||||
|
||||
# Check if we should report metrics
|
||||
should_report = (
|
||||
step % self.report_interval_steps == 0
|
||||
or step == 1 # Always report first step
|
||||
or step - self.last_report_step >= self.report_interval_steps
|
||||
)
|
||||
|
||||
if should_report:
|
||||
current_time = time.time()
|
||||
if self.last_report_time is not None:
|
||||
time_since_last_report = current_time - self.last_report_time
|
||||
else:
|
||||
time_since_last_report = current_time - self.start_time
|
||||
steps_since_last_report = step - self.last_report_step
|
||||
|
||||
# Only report if enough time has passed to avoid flooding
|
||||
if (
|
||||
step == 1
|
||||
or time_since_last_report >= TIME_SINCE_LAST
|
||||
or steps_since_last_report >= self.report_interval_steps
|
||||
):
|
||||
# Calculate steps per second for this interval
|
||||
if time_since_last_report > 0 and steps_since_last_report > 0:
|
||||
steps_per_second = steps_since_last_report / time_since_last_report
|
||||
else:
|
||||
steps_per_second = 0
|
||||
|
||||
# Update memory metrics
|
||||
self.tracker.update_memory_metrics()
|
||||
|
||||
loss = state.log_history[-1].get("loss", 0) if state.log_history else 0
|
||||
learning_rate = (
|
||||
state.log_history[-1].get("learning_rate", 0)
|
||||
if state.log_history
|
||||
else 0
|
||||
)
|
||||
|
||||
# Prepare metrics to report
|
||||
metrics = {
|
||||
"step": step,
|
||||
"epoch": self.current_epoch,
|
||||
"progress": state.epoch, # Fractional epoch progress
|
||||
"loss": loss,
|
||||
"learning_rate": learning_rate,
|
||||
"steps_per_second": steps_per_second,
|
||||
"elapsed_time": current_time - self.start_time,
|
||||
"time_since_last_report": time_since_last_report,
|
||||
}
|
||||
|
||||
# Add memory metrics
|
||||
memory_metrics = self.tracker.get_memory_metrics()
|
||||
metrics.update({"memory": memory_metrics})
|
||||
|
||||
# Send telemetry
|
||||
self.telemetry_manager.send_event(
|
||||
event_type="train-progress", properties=metrics
|
||||
)
|
||||
|
||||
# Update last report time and step
|
||||
self.last_report_time = current_time
|
||||
self.last_report_step = step
|
||||
@@ -1,160 +0,0 @@
|
||||
"""Telemetry utilities for exception and traceback information."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from functools import wraps
|
||||
from inspect import getmodule
|
||||
from typing import Any, Callable
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
ERROR_HANDLED = False
|
||||
|
||||
|
||||
def sanitize_stack_trace(stack_trace: str) -> str:
|
||||
"""
|
||||
Remove personal information from stack trace messages while keeping Python package codepaths.
|
||||
|
||||
This function identifies Python packages by looking for common patterns in virtual environment
|
||||
and site-packages directories, preserving the package path while removing user-specific paths.
|
||||
|
||||
Args:
|
||||
stack_trace: The original stack trace string.
|
||||
|
||||
Returns:
|
||||
A sanitized version of the stack trace with Python package paths preserved.
|
||||
"""
|
||||
# Split the stack trace into lines to process each file path separately
|
||||
lines = stack_trace.split("\n")
|
||||
sanitized_lines = []
|
||||
|
||||
# Regular expression to find file paths in the stack trace
|
||||
path_pattern = re.compile(r'(?:File ")(.*?)(?:")')
|
||||
|
||||
# Regular expression to identify paths in site-packages or dist-packages
|
||||
# This matches path segments like "site-packages/package_name" or "dist-packages/package_name"
|
||||
site_packages_pattern = re.compile(
|
||||
r"(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
|
||||
)
|
||||
|
||||
# Additional common virtual environment patterns
|
||||
venv_lib_pattern = re.compile(
|
||||
r"(?:lib|Lib)[/\\](?:python\d+(?:\.\d+)?[/\\])?(?:site-packages|dist-packages)[/\\]([\w\-\.]+)"
|
||||
)
|
||||
|
||||
for line in lines:
|
||||
# Check if this line contains a file path
|
||||
path_match = path_pattern.search(line)
|
||||
|
||||
if path_match:
|
||||
full_path = path_match.group(1)
|
||||
sanitized_path = ""
|
||||
|
||||
# Try to match site-packages pattern
|
||||
site_packages_match = site_packages_pattern.search(full_path)
|
||||
venv_lib_match = venv_lib_pattern.search(full_path)
|
||||
|
||||
if site_packages_match:
|
||||
# Find the index where the matched pattern starts
|
||||
idx = full_path.find("site-packages")
|
||||
if idx == -1:
|
||||
idx = full_path.find("dist-packages")
|
||||
|
||||
# Keep from 'site-packages' onward
|
||||
if idx >= 0:
|
||||
sanitized_path = full_path[idx:]
|
||||
elif venv_lib_match:
|
||||
# For other virtual environment patterns, find the package directory
|
||||
match_idx = venv_lib_match.start(1)
|
||||
if match_idx > 0:
|
||||
# Keep from the package name onward
|
||||
package_name = venv_lib_match.group(1)
|
||||
idx = full_path.rfind(
|
||||
package_name, 0, match_idx + len(package_name)
|
||||
)
|
||||
if idx >= 0:
|
||||
sanitized_path = full_path[idx:]
|
||||
|
||||
# If we couldn't identify a package pattern but path contains 'axolotl'
|
||||
elif "axolotl" in full_path:
|
||||
idx = full_path.rfind("axolotl")
|
||||
if idx >= 0:
|
||||
sanitized_path = full_path[idx:]
|
||||
|
||||
# Apply the sanitization to the line
|
||||
if sanitized_path:
|
||||
line = line.replace(full_path, sanitized_path)
|
||||
else:
|
||||
# If we couldn't identify a package pattern, just keep the filename
|
||||
filename = os.path.basename(full_path)
|
||||
if filename:
|
||||
line = line.replace(full_path, filename)
|
||||
else:
|
||||
line = line.replace(full_path, "")
|
||||
|
||||
sanitized_lines.append(line)
|
||||
|
||||
return "\n".join(sanitized_lines)
|
||||
|
||||
|
||||
def send_errors(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to send exception info in a function. If an exception is raised, we send
|
||||
telemetry containing the stack trace and error message.
|
||||
|
||||
If an error occurs in a decorated function that is called by another decorated
|
||||
function, we'll only send telemetry corresponding to the lower-level function.
|
||||
|
||||
Args:
|
||||
func: Function to decorate.
|
||||
|
||||
Returns:
|
||||
Decorated function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
telemetry_manager = TelemetryManager.get_instance()
|
||||
|
||||
if not telemetry_manager.enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as exception:
|
||||
# Only track if we're not already handling an error. This prevents us from
|
||||
# capturing an error more than once in nested decorated function calls.=
|
||||
global ERROR_HANDLED # pylint: disable=global-statement
|
||||
if not ERROR_HANDLED:
|
||||
ERROR_HANDLED = True
|
||||
|
||||
# Get function module path
|
||||
module = getmodule(func)
|
||||
module_path = (
|
||||
f"{module.__name__}.{func.__name__}" if module else func.__name__
|
||||
)
|
||||
|
||||
# Get stack trace
|
||||
stack_trace = "".join(
|
||||
traceback.format_exception(
|
||||
type(exception), exception, exception.__traceback__
|
||||
)
|
||||
)
|
||||
stack_trace = sanitize_stack_trace(stack_trace)
|
||||
|
||||
# Send error telemetry
|
||||
telemetry_manager.send_event(
|
||||
event_type=f"{module_path}-errored",
|
||||
properties={
|
||||
"exception": str(exception),
|
||||
"stack_trace": stack_trace,
|
||||
},
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
@@ -1,399 +0,0 @@
|
||||
"""Telemetry manager and associated utilities."""
|
||||
|
||||
import atexit
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import posthog
|
||||
import psutil
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
POSTHOG_HOST = "https://app.posthog.com"
|
||||
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
|
||||
|
||||
ENABLED_WARNING_SLEEP_SECONDS = 15
|
||||
ENABLED_WARNING = (
|
||||
"\nTelemetry is enabled. This helps Axolotl's maintainers by providing insights into:\n"
|
||||
"- Which models and configurations are most commonly used\n"
|
||||
"- What hardware setups need to be supported\n"
|
||||
"- Where users encounter errors\n\n"
|
||||
"This data helps us prioritize features, optimize performance, and fix bugs.\n\n"
|
||||
"To disable telemetry, set either:\n"
|
||||
"- AXOLOTL_DO_NOT_TRACK=1 (Axolotl-specific)\n"
|
||||
"- DO_NOT_TRACK=1 (Global standard; see https://consoledonottrack.com/)\n\n"
|
||||
"To remove this warning and continue with telemetry enabled,"
|
||||
"explicitly set AXOLOTL_DO_NOT_TRACK=0 (and leave DO_NOT_TRACK unset / set to 0)\n\n"
|
||||
"No personally identifiable information is collected."
|
||||
"For details, see: https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html\n\n"
|
||||
f"Sleeping for {ENABLED_WARNING_SLEEP_SECONDS}s..."
|
||||
)
|
||||
|
||||
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
|
||||
|
||||
# NOTE: Keep these up to date with any config schema changes
|
||||
FIELDS_WITH_ORGS = {
|
||||
"base_model",
|
||||
"tokenizer_config",
|
||||
"base_model_config",
|
||||
"pretraining_dataset", # NOTE: this field may be a string or a dictionary
|
||||
}
|
||||
FIELDS_TO_REDACT = {"resume_from_checkpoint", "hub_model_id"}
|
||||
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
|
||||
PATH_INDICATORS = {"path", "dir"}
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
RELEVANT_PACKAGES = {
|
||||
"torch",
|
||||
"transformers",
|
||||
"trl",
|
||||
"datasets",
|
||||
"peft",
|
||||
"bitsandbytes",
|
||||
"accelerate",
|
||||
"optimum",
|
||||
"deepspeed",
|
||||
"ray",
|
||||
"axolotl",
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"flash-attn",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"tokenizers",
|
||||
"sentencepiece",
|
||||
"torchao",
|
||||
"lm_eval",
|
||||
}
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
"""
|
||||
Check whether we're running in the main process.
|
||||
|
||||
Note:
|
||||
We're using this function instead of `torch.utils.distributed.is_main_process`
|
||||
causes issues with DeepSpeed world_size since. This function avoids that issue
|
||||
by checking env vars that are set by various launchers.
|
||||
|
||||
Returns:
|
||||
Whether we're running in the main process.
|
||||
"""
|
||||
# If PyTorch distributed is already initialized, use it
|
||||
if torch.distributed.is_initialized():
|
||||
return torch.distributed.get_rank() == 0
|
||||
|
||||
# Otherwise check environment variables for global rank
|
||||
# NOTE: need to verify this in SLURM / OpenMPI environments
|
||||
global_rank = int(
|
||||
os.environ.get(
|
||||
"RANK",
|
||||
os.environ.get(
|
||||
"GLOBAL_RANK",
|
||||
os.environ.get(
|
||||
"SLURM_PROCID",
|
||||
os.environ.get(
|
||||
"OMPI_COMM_WORLD_RANK",
|
||||
"0",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return global_rank == 0
|
||||
|
||||
|
||||
class TelemetryManager:
|
||||
"""Manages telemetry collection and transmission"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
"""
|
||||
Telemetry manager constructor. Creates the singleton instance of this class if
|
||||
it doesn't already exist.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(TelemetryManager, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Telemetry manager initializer"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.enabled, self.explicit_enable = self._check_telemetry_enabled()
|
||||
|
||||
if self.enabled:
|
||||
self.run_id = str(uuid.uuid4())
|
||||
self.whitelist = self._load_whitelist()
|
||||
|
||||
try:
|
||||
self.system_info = self._get_system_info()
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"Error during system info collection: {e}")
|
||||
self.system_info = None
|
||||
|
||||
self._init_posthog()
|
||||
|
||||
# Register shutdown method to flush posthog telemetry
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "TelemetryManager":
|
||||
if cls._instance is None:
|
||||
cls._instance = TelemetryManager()
|
||||
|
||||
return cls._instance
|
||||
|
||||
def _check_telemetry_enabled(self) -> tuple[bool, bool]:
|
||||
"""
|
||||
Check if telemetry is enabled based on environment variables. We also check
|
||||
whether this is the main process (for the distributed setting and to avoid
|
||||
sending duplicate PostHog events per GPU).
|
||||
|
||||
Note: This is enabled by default on an opt-out basis. Set either
|
||||
`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1` to disable telemetry. For more
|
||||
details, see https://axolotl-ai-cloud.github.io/axolotl/docs/telemetry.html.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Boolean denoting whether telemetry is enabled or disabled.
|
||||
- Boolean denoting whether telemetry is explicitly enabled or not.
|
||||
"""
|
||||
# Parse relevant env vars and fill opt-out default values
|
||||
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
||||
do_not_track = os.getenv("DO_NOT_TRACK")
|
||||
|
||||
# If explicitly enabled, we'll disable the telemetry warning message
|
||||
explicit_enabled = axolotl_do_not_track in ["0", "false"]
|
||||
|
||||
if axolotl_do_not_track is None:
|
||||
axolotl_do_not_track = "0"
|
||||
|
||||
if do_not_track is None:
|
||||
do_not_track = "0"
|
||||
|
||||
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
|
||||
enabled = axolotl_do_not_track.lower() not in (
|
||||
"1",
|
||||
"true",
|
||||
) and do_not_track.lower() not in ("1", "true")
|
||||
|
||||
# Show warning (and sleep on all ranks) unless explicitly enabled
|
||||
if enabled and not explicit_enabled:
|
||||
if is_main_process():
|
||||
LOG.warning(ENABLED_WARNING)
|
||||
time.sleep(ENABLED_WARNING_SLEEP_SECONDS)
|
||||
|
||||
# Only rank 0 will send telemetry
|
||||
if not is_main_process():
|
||||
return False, False
|
||||
|
||||
return enabled, explicit_enabled
|
||||
|
||||
def _load_whitelist(self) -> dict:
|
||||
"""Load HuggingFace Hub organization whitelist"""
|
||||
with open(WHITELIST_PATH, encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def _is_whitelisted(self, base_model: str) -> bool:
|
||||
"""Check if model org is in whitelist"""
|
||||
if not base_model:
|
||||
return False
|
||||
|
||||
base_model = base_model.lower()
|
||||
return any(
|
||||
org.lower() in base_model for org in self.whitelist.get("organizations", [])
|
||||
)
|
||||
|
||||
def _init_posthog(self):
|
||||
"""Initialize PostHog client"""
|
||||
posthog.host = POSTHOG_HOST
|
||||
posthog.project_api_key = POSTHOG_WRITE_KEY
|
||||
|
||||
def _redact_paths(self, properties: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Redact properties to remove any paths, so as to avoid inadvertently collecting
|
||||
private or personally identifiable information (PII). We also remove
|
||||
information related to Wandb, MLflow, etc. configuration.
|
||||
|
||||
Args:
|
||||
properties: Dictionary of properties to redact.
|
||||
|
||||
Returns:
|
||||
Properties dictionary with redaction applied.
|
||||
"""
|
||||
if not properties:
|
||||
return {}
|
||||
|
||||
def redact_value(value: Any, key: str = "") -> Any:
|
||||
"""Recursively sanitize values, redacting those with path-like keys"""
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
# Fields that should be redacted if org is not whitelisted
|
||||
if key in FIELDS_WITH_ORGS:
|
||||
org = value.split("/")[0]
|
||||
if org not in self.whitelist["organizations"]:
|
||||
return "[REDACTED]"
|
||||
|
||||
# Other redaction special cases
|
||||
if (
|
||||
key in FIELDS_TO_REDACT
|
||||
or any(prefix in key for prefix in PREFIXES_TO_REDACT)
|
||||
or any(indicator in key.lower() for indicator in PATH_INDICATORS)
|
||||
):
|
||||
return "[REDACTED]"
|
||||
|
||||
# Handle nested structures
|
||||
if isinstance(value, dict):
|
||||
return {k: redact_value(v, k) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [redact_value(item) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
# Create new dict with redacted values
|
||||
redacted = {k: redact_value(v, k) for k, v in properties.items()}
|
||||
|
||||
return redacted
|
||||
|
||||
def _get_system_info(self) -> dict[str, Any]:
|
||||
"""Collect system information for various hardware accelerators"""
|
||||
gpu_info = []
|
||||
accelerator_type = "none"
|
||||
|
||||
# NVIDIA GPUs
|
||||
if torch.cuda.is_available():
|
||||
accelerator_type = "cuda"
|
||||
for i in range(torch.cuda.device_count()):
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.cuda.get_device_name(i),
|
||||
"memory": torch.cuda.get_device_properties(i).total_memory,
|
||||
}
|
||||
)
|
||||
|
||||
# AMD GPUs
|
||||
elif hasattr(torch, "hip") and torch.hip.is_available():
|
||||
accelerator_type = "hip"
|
||||
for i in range(torch.hip.device_count()):
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.hip.get_device_name(i),
|
||||
"memory": torch.hip.get_device_properties(i).total_memory
|
||||
if hasattr(torch.hip, "get_device_properties")
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
# Apple Silicon
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
accelerator_type = "mps"
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": "Apple Silicon",
|
||||
# NOTE: this is memory allocated to this process, not total memory
|
||||
"memory": torch.mps.driver_allocated_memory(),
|
||||
}
|
||||
)
|
||||
|
||||
# Intel GPUs
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
accelerator_type = "xpu"
|
||||
for i in range(torch.xpu.device_count()):
|
||||
memory = None
|
||||
if hasattr(torch.xpu, "get_device_properties"):
|
||||
memory = torch.xpu.get_device_properties(i).total_memory
|
||||
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.xpu.get_device_name(i),
|
||||
"memory": memory,
|
||||
}
|
||||
)
|
||||
|
||||
# NPUs
|
||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
||||
accelerator_type = "npu"
|
||||
for i in range(torch.npu.device_count()):
|
||||
memory = None
|
||||
if hasattr(torch.npu, "get_device_properties"):
|
||||
memory = torch.npu.get_device_properties(i).total_memory
|
||||
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": torch.npu.get_device_name(i),
|
||||
"memory": memory,
|
||||
}
|
||||
)
|
||||
|
||||
# Get relevant package versions
|
||||
installed_packages = {}
|
||||
for package in RELEVANT_PACKAGES:
|
||||
try:
|
||||
version = importlib.metadata.version(package)
|
||||
installed_packages[f"{package}_version"] = version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
pass
|
||||
|
||||
return {
|
||||
"os": platform.system(),
|
||||
"python_version": platform.python_version(),
|
||||
"cpu_count": psutil.cpu_count(),
|
||||
"memory_total": psutil.virtual_memory().total,
|
||||
"accelerator_type": accelerator_type,
|
||||
"accelerator_count": len(gpu_info),
|
||||
"accelerator_info": gpu_info,
|
||||
**installed_packages,
|
||||
}
|
||||
|
||||
def send_event(self, event_type: str, properties: dict[str, Any] | None = None):
|
||||
"""Send a telemetry event"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if properties is None:
|
||||
properties = {}
|
||||
|
||||
# Sanitize properties to remove PII
|
||||
properties = self._redact_paths(properties)
|
||||
|
||||
# Wrap PostHog errors in try / except to not raise errors during Axolotl usage
|
||||
try:
|
||||
# Send event via PostHog
|
||||
posthog.capture(
|
||||
distinct_id=self.run_id,
|
||||
event=event_type,
|
||||
properties=properties,
|
||||
disable_geoip=True,
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"Failed to send telemetry event: {e}")
|
||||
|
||||
# Additionally, send system info telemetry when loading config.
|
||||
# NOTE: Is this the best place for this?
|
||||
if event_type == "config-loaded":
|
||||
self.send_system_info()
|
||||
|
||||
def send_system_info(self):
|
||||
"""Helper method for sending system info"""
|
||||
self.send_event(event_type="system-info", properties=self.system_info)
|
||||
|
||||
def shutdown(self):
|
||||
"""Ensure all queued events are processed before shutdown"""
|
||||
if self.enabled:
|
||||
posthog.flush()
|
||||
@@ -1,209 +0,0 @@
|
||||
"""Telemetry utilities for runtime and memory metrics."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeMetrics:
|
||||
"""Container for runtime metrics to be tracked throughout training."""
|
||||
|
||||
# Timing metrics
|
||||
start_time: float
|
||||
epoch_start_times: dict[int, float] = field(init=False)
|
||||
epoch_end_times: dict[int, float] = field(init=False)
|
||||
|
||||
# Memory metrics
|
||||
peak_cpu_memory: int = 0
|
||||
peak_gpu_memory: dict[int, int] = field(init=False)
|
||||
|
||||
# Progress metrics
|
||||
total_steps: int = 0
|
||||
current_epoch: int = 0
|
||||
current_step: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize empty metric mappings."""
|
||||
self.epoch_start_times = {}
|
||||
self.epoch_end_times = {}
|
||||
self.peak_gpu_memory = {}
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float:
|
||||
"""Calculate total elapsed time in seconds."""
|
||||
return time.time() - self.start_time
|
||||
|
||||
def epoch_time(self, epoch: int) -> float | None:
|
||||
"""Calculate time taken for a specific epoch in seconds."""
|
||||
if epoch in self.epoch_start_times and epoch in self.epoch_end_times:
|
||||
return self.epoch_end_times[epoch] - self.epoch_start_times[epoch]
|
||||
|
||||
return None
|
||||
|
||||
def average_epoch_time(self) -> float | None:
|
||||
"""Calculate average time per epoch in seconds."""
|
||||
completed_epochs = [
|
||||
epoch for epoch in self.epoch_start_times if epoch in self.epoch_end_times
|
||||
]
|
||||
if not completed_epochs:
|
||||
return None
|
||||
|
||||
total_time = 0.0
|
||||
for epoch in completed_epochs:
|
||||
epoch_time = self.epoch_time(epoch)
|
||||
if epoch_time is not None: # Check to avoid mypy warning
|
||||
total_time += epoch_time
|
||||
|
||||
return total_time / len(completed_epochs)
|
||||
|
||||
def steps_per_second(self) -> float | None:
|
||||
"""Calculate average steps per second across all training."""
|
||||
if self.total_steps == 0 or self.elapsed_time == 0:
|
||||
return None
|
||||
|
||||
return self.total_steps / self.elapsed_time
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert metrics to a dictionary for telemetry reporting."""
|
||||
metrics = {
|
||||
"total_time_seconds": self.elapsed_time,
|
||||
"total_steps": self.total_steps,
|
||||
"steps_per_second": self.steps_per_second(),
|
||||
"epochs_completed": len(
|
||||
[
|
||||
epoch
|
||||
for epoch in self.epoch_start_times
|
||||
if epoch in self.epoch_end_times
|
||||
]
|
||||
),
|
||||
"peak_cpu_memory_bytes": self.peak_cpu_memory,
|
||||
}
|
||||
|
||||
# Add per-epoch timing if available
|
||||
epoch_times: dict[str, float] = {}
|
||||
for epoch in sorted(self.epoch_end_times.keys()):
|
||||
time_taken = self.epoch_time(epoch)
|
||||
if time_taken is not None:
|
||||
epoch_times[f"epoch_{epoch}_seconds"] = time_taken
|
||||
|
||||
if epoch_times:
|
||||
metrics["epoch_times"] = epoch_times # type: ignore
|
||||
metrics["average_epoch_time_seconds"] = self.average_epoch_time()
|
||||
|
||||
# Add GPU memory metrics if available
|
||||
if self.peak_gpu_memory:
|
||||
gpu_metrics: dict[str, int] = {}
|
||||
for gpu_id, memory in self.peak_gpu_memory.items():
|
||||
gpu_metrics[f"gpu_{gpu_id}_peak_memory_bytes"] = memory
|
||||
metrics["gpu_memory"] = gpu_metrics # type: ignore
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
class RuntimeMetricsTracker:
|
||||
"""Tracker for runtime metrics during training."""
|
||||
|
||||
update_interval = 100
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the runtime metrics tracker."""
|
||||
self.metrics = RuntimeMetrics(start_time=time.time())
|
||||
self.telemetry_manager = TelemetryManager.get_instance()
|
||||
|
||||
def start_epoch(self, epoch: int):
|
||||
"""Record the start of a new epoch."""
|
||||
self.metrics.current_epoch = epoch
|
||||
self.metrics.epoch_start_times[epoch] = time.time()
|
||||
self.update_memory_metrics()
|
||||
|
||||
def end_epoch(self, epoch: int):
|
||||
"""Record the end of an epoch."""
|
||||
self.metrics.epoch_end_times[epoch] = time.time()
|
||||
|
||||
def update_step(self, step: int):
|
||||
"""Update the current step count."""
|
||||
self.metrics.current_step = step
|
||||
self.metrics.total_steps += 1
|
||||
|
||||
# Periodically update memory metrics
|
||||
if step % self.update_interval == 0:
|
||||
self.update_memory_metrics()
|
||||
|
||||
def _get_allocated_memory(self) -> dict[int, int]:
|
||||
"""
|
||||
Helper function for getting accelerator-agnostic allocated memory.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping device IDs to allocated memory in bytes
|
||||
"""
|
||||
memory_used: dict[int, int] = {}
|
||||
|
||||
# NVIDIA GPUs
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
memory_used[i] = torch.cuda.memory_allocated(i)
|
||||
|
||||
# AMD GPUs
|
||||
elif hasattr(torch, "hip") and torch.hip.is_available():
|
||||
for i in range(torch.hip.device_count()):
|
||||
if hasattr(torch.hip, "memory_allocated"):
|
||||
memory_used[i] = torch.hip.memory_allocated(i)
|
||||
|
||||
# Apple Silicon
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
# MPS doesn't have per-device memory stats since there's only one device
|
||||
if hasattr(torch.mps, "current_allocated_memory"):
|
||||
memory_used[0] = torch.mps.current_allocated_memory()
|
||||
|
||||
# Intel GPUs
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
if hasattr(torch.xpu, "memory_allocated"):
|
||||
memory_used[i] = torch.xpu.memory_allocated(i)
|
||||
|
||||
# NPUs
|
||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
||||
for i in range(torch.npu.device_count()):
|
||||
if hasattr(torch.npu, "memory_allocated"):
|
||||
memory_used[i] = torch.npu.memory_allocated(i)
|
||||
|
||||
return memory_used
|
||||
|
||||
def update_memory_metrics(self):
|
||||
"""Update peak memory usage metrics."""
|
||||
# CPU memory
|
||||
cpu_memory = psutil.Process().memory_info().rss
|
||||
self.metrics.peak_cpu_memory = max(self.metrics.peak_cpu_memory, cpu_memory)
|
||||
|
||||
# GPU memory (if available)
|
||||
memory_used = self._get_allocated_memory()
|
||||
for i, memory in memory_used.items():
|
||||
self.metrics.peak_gpu_memory[i] = max(
|
||||
self.metrics.peak_gpu_memory.get(i, 0), memory
|
||||
)
|
||||
|
||||
def get_memory_metrics(self) -> dict[str, Any]:
|
||||
"""Get the current memory metrics as a dictionary."""
|
||||
memory_metrics = {
|
||||
"cpu_memory_bytes": psutil.Process().memory_info().rss,
|
||||
"peak_cpu_memory_bytes": self.metrics.peak_cpu_memory,
|
||||
}
|
||||
|
||||
# GPU memory (if available)
|
||||
memory_used = self._get_allocated_memory()
|
||||
for i, memory in memory_used.items():
|
||||
memory_metrics[f"gpu_{i}_memory_bytes"] = memory
|
||||
memory_metrics[
|
||||
f"gpu_{i}_peak_memory_bytes"
|
||||
] = self.metrics.peak_gpu_memory.get(i, 0)
|
||||
|
||||
return memory_metrics
|
||||
@@ -1,18 +0,0 @@
|
||||
organizations:
|
||||
- "axolotl-ai-co"
|
||||
- "meta-llama"
|
||||
- "huggingface"
|
||||
- "nvidia"
|
||||
- "facebook"
|
||||
- "google"
|
||||
- "microsoft"
|
||||
- "deepseek-ai"
|
||||
- "HuggingFaceTB"
|
||||
- "mistralai"
|
||||
- "Qwen"
|
||||
- "briaai"
|
||||
- "unsloth"
|
||||
- "NousResearch"
|
||||
- "allenai"
|
||||
- "amd"
|
||||
- "tiiuae"
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import signal
|
||||
@@ -14,6 +13,7 @@ import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import save_fsdp_model
|
||||
from peft import PeftModel
|
||||
from pkg_resources import get_distribution # type: ignore
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
@@ -22,8 +22,6 @@ from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-modu
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
@@ -41,10 +39,7 @@ sys.path.insert(0, src_dir)
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
|
||||
|
||||
@send_errors
|
||||
def train(
|
||||
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||
@@ -80,7 +75,7 @@ def train(
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
# Load model
|
||||
# Load the model and tokenizer
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
@@ -89,14 +84,6 @@ def train(
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
TELEMETRY_MANAGER.send_event(
|
||||
event_type="model-loaded", properties=model.config.to_dict()
|
||||
)
|
||||
if peft_config:
|
||||
TELEMETRY_MANAGER.send_event(
|
||||
event_type="peft-config-loaded", properties=peft_config.to_dict()
|
||||
)
|
||||
|
||||
model_ref = None
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
@@ -104,7 +91,7 @@ def train(
|
||||
LOG.debug("Passing model_ref: None to RL trainer")
|
||||
model_ref = None # explicit setting to None
|
||||
else:
|
||||
# load the model again for model_ref / baseline
|
||||
# load the model again for model_ref/baseline
|
||||
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
@@ -179,7 +166,7 @@ def train(
|
||||
|
||||
if getattr(cfg, "axolotl_config_path"):
|
||||
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
||||
version = importlib.metadata.version("axolotl")
|
||||
version = get_distribution("axolotl").version
|
||||
if raw_axolotl_cfg.is_file():
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
||||
|
||||
@@ -187,6 +174,8 @@ def train(
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
pretrain_hooks(cfg, trainer)
|
||||
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
@@ -197,6 +186,9 @@ def train(
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
post_train_hooks(cfg, trainer)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
# post training
|
||||
@@ -300,3 +292,21 @@ def train(
|
||||
trainer.push_to_hub()
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def pretrain_hooks(_cfg, _trainer):
|
||||
"""
|
||||
Run hooks right before kicking off the training
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
|
||||
|
||||
def post_train_hooks(_cfg, _trainer):
|
||||
"""
|
||||
Run hooks right after training completes
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -55,7 +55,6 @@ class ChatTemplate(str, Enum):
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
|
||||
jamba = "jamba" # pylint: disable=invalid-name
|
||||
jinja = "jinja" # pylint: disable=invalid-name
|
||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||
@@ -1683,7 +1682,7 @@ class AxolotlInputConfig(
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""Wrapper to validate GPU capabilities with the config options"""
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
capabilities: GPUCapabilities
|
||||
env_capabilities: EnvCapabilities
|
||||
|
||||
@@ -54,7 +54,6 @@ from axolotl.monkeypatch.multipack import (
|
||||
patch_for_multipack,
|
||||
)
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -166,7 +165,6 @@ def load_model_config(cfg):
|
||||
return model_config
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_tokenizer(cfg):
|
||||
model_config = load_model_config(cfg)
|
||||
tokenizer_kwargs = {}
|
||||
@@ -320,7 +318,6 @@ def load_tokenizer(cfg):
|
||||
return tokenizer
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
|
||||
|
||||
@@ -1195,17 +1192,18 @@ class ModelLoader:
|
||||
return self.model, lora_config
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_model(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
*,
|
||||
processor: ProcessorMixin = None,
|
||||
processor: ProcessorMixin = None, # pylint: disable=unused-argument
|
||||
inference: bool = False,
|
||||
reference_model: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[PreTrainedModel, PeftConfig | None]:
|
||||
"""Load a model for a given configuration and tokenizer"""
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
loader = ModelLoader(
|
||||
cfg,
|
||||
tokenizer,
|
||||
@@ -1217,7 +1215,6 @@ def load_model(
|
||||
return loader.load_model()
|
||||
|
||||
|
||||
@send_errors
|
||||
def load_adapter(model, cfg, adapter, inference=False):
|
||||
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
|
||||
194
styles.css
194
styles.css
@@ -1,193 +1,5 @@
|
||||
/* TYPOGRAPHY SECTION */
|
||||
/* css styles */
|
||||
|
||||
/* Import fonts */
|
||||
@import url('https://fonts.googleapis.com/css2?family=Be+Vietnam+Pro:wght@400;500&display=swap');
|
||||
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400&display=swap');
|
||||
|
||||
/* Typography hierarchy */
|
||||
:root {
|
||||
--font-title: 'Be Vietnam Pro', sans-serif;
|
||||
--font-body: 'JetBrains Mono', monospace;
|
||||
}
|
||||
|
||||
/* Title (h1) */
|
||||
h1 {
|
||||
font-family: var(--font-title);
|
||||
font-weight: 400;
|
||||
font-size: 6rem;
|
||||
line-height: 1.1;
|
||||
letter-spacing: -0.05em;
|
||||
font-feature-settings: "ss01" on;
|
||||
}
|
||||
|
||||
/* Heading (h2) */
|
||||
h2 {
|
||||
font-family: var(--font-title);
|
||||
font-weight: 500;
|
||||
font-size: 2rem;
|
||||
line-height: 1.2;
|
||||
letter-spacing: -0.03em;
|
||||
font-feature-settings: "ss01" on;
|
||||
}
|
||||
|
||||
/* Subtitle/Preamble */
|
||||
h3,
|
||||
h4 {
|
||||
font-family: var(--font-body);
|
||||
font-weight: 400;
|
||||
font-size: 1.5rem;
|
||||
line-height: 1.5;
|
||||
letter-spacing: -0.02em;
|
||||
}
|
||||
|
||||
/* Body text */
|
||||
body {
|
||||
font-family: var(--font-body);
|
||||
font-weight: 400;
|
||||
font-size: 1rem;
|
||||
line-height: 1.5;
|
||||
letter-spacing: -0.02em;
|
||||
}
|
||||
|
||||
/* Links */
|
||||
a {
|
||||
font-family: var(--font-body);
|
||||
font-weight: 400;
|
||||
font-size: 0.875rem;
|
||||
line-height: 1;
|
||||
letter-spacing: -0.02em;
|
||||
}
|
||||
|
||||
/* NAV BAR SECTION */
|
||||
|
||||
/* Navbar logo styling */
|
||||
.navbar-brand img {
|
||||
height: 32px;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
/* COLORS SECTION */
|
||||
|
||||
/* Brand colors */
|
||||
:root {
|
||||
--white: #ffffff;
|
||||
--greige-300: #EEEEE7;
|
||||
--greige-600: #CCCAC0;
|
||||
--black: #141310;
|
||||
--lime: #E3F8A8;
|
||||
--cyan: #A0F4EA;
|
||||
--purple: #C8D0F8;
|
||||
}
|
||||
|
||||
/* Base styles */
|
||||
body {
|
||||
background-color: var(--black);
|
||||
color: var(--greige-300);
|
||||
}
|
||||
|
||||
/* Navigation */
|
||||
.navbar {
|
||||
background-color: var(--black) !important;
|
||||
}
|
||||
|
||||
.navbar-dark .navbar-nav .nav-link {
|
||||
color: var(--greige-300);
|
||||
}
|
||||
|
||||
.navbar-dark .navbar-nav .nav-link:hover {
|
||||
color: var(--lime);
|
||||
}
|
||||
|
||||
/* Sidebar */
|
||||
.sidebar-navigation {
|
||||
background-color: var(--black);
|
||||
border-right: 1px solid var(--greige-600);
|
||||
}
|
||||
|
||||
.sidebar nav[role="doc-toc"] ul>li>a {
|
||||
color: var(--greige-300);
|
||||
}
|
||||
|
||||
.sidebar nav[role="doc-toc"] ul>li>a:hover {
|
||||
color: var(--lime);
|
||||
}
|
||||
|
||||
/* Links */
|
||||
a {
|
||||
color: var(--lime);
|
||||
}
|
||||
|
||||
a:hover {
|
||||
color: var(--cyan);
|
||||
}
|
||||
|
||||
/* Headers */
|
||||
h1,
|
||||
h2,
|
||||
h3,
|
||||
h4,
|
||||
h5,
|
||||
h6 {
|
||||
color: var(--white);
|
||||
}
|
||||
|
||||
/* Code blocks */
|
||||
pre {
|
||||
background-color: #1a1a1a !important;
|
||||
border: 1px solid var(--greige-600);
|
||||
}
|
||||
|
||||
/* Tables */
|
||||
.table {
|
||||
color: var(--greige-300);
|
||||
}
|
||||
|
||||
/* TOC */
|
||||
#toc-title {
|
||||
color: var(--white);
|
||||
}
|
||||
|
||||
.toc-active {
|
||||
color: var(--lime) !important;
|
||||
}
|
||||
|
||||
/* Buttons */
|
||||
.btn-primary {
|
||||
background-color: var(--lime);
|
||||
color: var(--black);
|
||||
border: none;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background-color: var(--cyan);
|
||||
color: var(--black);
|
||||
}
|
||||
|
||||
/* For inline code (single backtick) */
|
||||
code {
|
||||
background-color: #1a1a1a !important;
|
||||
color: var(--lime) !important;
|
||||
padding: 2px 4px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
/* For inline code that is also a link */
|
||||
a code {
|
||||
color: var(--cyan) !important;
|
||||
}
|
||||
|
||||
/* For code blocks (triple backtick) */
|
||||
pre.sourceCode {
|
||||
background-color: #1a1a1a !important;
|
||||
}
|
||||
|
||||
/* Make comments in bash/shell scripts green */
|
||||
code span.co {
|
||||
color: #5cb85c !important;
|
||||
}
|
||||
|
||||
/* Remove underlines from JSON comments and make them green */
|
||||
code span.er {
|
||||
color: #5cb85c !important;
|
||||
text-decoration: none !important;
|
||||
img[alt="Axolotl"] {
|
||||
content: url("https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg") !important;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Shared pytest fixtures"""
|
||||
|
||||
"""
|
||||
shared pytest fixtures
|
||||
"""
|
||||
import functools
|
||||
import importlib
|
||||
import shutil
|
||||
@@ -170,9 +171,3 @@ def cleanup_monkeypatches():
|
||||
module_globals = module_name_tuple[1]
|
||||
for module_global in module_globals:
|
||||
globals().pop(module_global, None)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_telemetry(monkeypatch):
|
||||
monkeypatch.setenv("AXOLOTL_DO_NOT_TRACK", "1")
|
||||
yield
|
||||
|
||||
@@ -90,6 +90,12 @@ class TestKnowledgeDistillation:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"load_in_8bit",
|
||||
@@ -121,3 +127,9 @@ class TestKnowledgeDistillation:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
|
||||
)
|
||||
|
||||
163
tests/e2e/integrations/test_kl_loss.py
Normal file
163
tests/e2e/integrations/test_kl_loss.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
sanity checks on kl loss and gradients
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Import both implementations
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
|
||||
|
||||
|
||||
def test_kl_loss_gradient():
|
||||
"""Test that the gradient of the Triton implementation matches the eager implementation."""
|
||||
|
||||
# Set the random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create random inputs
|
||||
batch_size = 2
|
||||
seq_len = 3
|
||||
vocab_size = 100
|
||||
top_k = 5
|
||||
|
||||
# Generate random student logits
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
|
||||
)
|
||||
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
||||
|
||||
# Generate random target token IDs, ensuring they're valid indices
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
|
||||
# Generate random target logprobs (before normalization)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
|
||||
# Normalize the target logprobs to ensure they form a valid distribution
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
|
||||
# Create a random mask with some tokens masked out
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Additional parameters
|
||||
num_items_in_batch = batch_size * seq_len
|
||||
kd_temperature = 1.0
|
||||
top_k_before_softmax = 0 # Test both modes
|
||||
|
||||
# Compute the loss and gradients with eager implementation
|
||||
loss_eager = eager_loss(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_eager.backward()
|
||||
grad_eager = student_logits.grad.clone()
|
||||
|
||||
# Reset gradients
|
||||
student_logits.grad.zero_()
|
||||
|
||||
# Compute the loss and gradients with Triton implementation
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_triton.backward()
|
||||
grad_triton = student_logits_triton.grad.clone()
|
||||
|
||||
# Compare loss values
|
||||
print(f"Eager loss: {loss_eager.item()}")
|
||||
print(f"Triton loss: {loss_triton.item()}")
|
||||
loss_diff = abs(loss_eager.item() - loss_triton.item())
|
||||
print(f"Loss difference: {loss_diff}")
|
||||
assert loss_diff < 1e-5, "Loss values differ significantly!"
|
||||
|
||||
# Compare gradients
|
||||
grad_diff = (grad_eager - grad_triton).abs().max().item()
|
||||
print(f"Max gradient difference: {grad_diff}")
|
||||
|
||||
# Print some sample gradients
|
||||
sample_idx = (0, 0, 0) # (batch, seq, vocab)
|
||||
print(f"Sample eager gradient: {grad_eager[sample_idx].item()}")
|
||||
print(f"Sample triton gradient: {grad_triton[sample_idx].item()}")
|
||||
|
||||
# Compute relative difference for non-zero gradients
|
||||
mask = grad_eager.abs() > 1e-10
|
||||
if mask.sum() > 0:
|
||||
rel_diff = (
|
||||
(
|
||||
(grad_eager[mask] - grad_triton[mask]).abs()
|
||||
/ (grad_eager[mask].abs() + 1e-10)
|
||||
)
|
||||
.max()
|
||||
.item()
|
||||
)
|
||||
print(f"Max relative gradient difference: {rel_diff}")
|
||||
assert rel_diff < 1e-3, "Gradients differ significantly!"
|
||||
|
||||
# Also test top_k_before_softmax = 1 mode
|
||||
top_k_before_softmax = 1
|
||||
|
||||
# Reset the gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
|
||||
)
|
||||
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
||||
|
||||
# Compute the loss and gradients with eager implementation
|
||||
loss_eager = eager_loss(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_eager.backward()
|
||||
grad_eager = student_logits.grad.clone()
|
||||
|
||||
# Compute the loss and gradients with Triton implementation
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_triton.backward()
|
||||
grad_triton = student_logits_triton.grad.clone()
|
||||
|
||||
# Compare gradients for top_k_before_softmax = 1
|
||||
grad_diff = (grad_eager - grad_triton).abs().max().item()
|
||||
print("\nWith top_k_before_softmax=1:")
|
||||
print(f"Max gradient difference: {grad_diff}")
|
||||
|
||||
# Compute relative difference for non-zero gradients
|
||||
mask = grad_eager.abs() > 1e-10
|
||||
if mask.sum() > 0:
|
||||
rel_diff = (
|
||||
(
|
||||
(grad_eager[mask] - grad_triton[mask]).abs()
|
||||
/ (grad_eager[mask].abs() + 1e-10)
|
||||
)
|
||||
.max()
|
||||
.item()
|
||||
)
|
||||
assert (
|
||||
rel_diff < 1e-3
|
||||
), f"Gradients differ significantly, Max relative gradient difference: {rel_diff}"
|
||||
204
tests/e2e/integrations/test_logsumexp.py
Normal file
204
tests/e2e/integrations/test_logsumexp.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
sanity checks on logsumexp kernel validity
|
||||
"""
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from axolotl.integrations.kd.topk_logprob.logsumexp import logsumexp_kernel
|
||||
|
||||
|
||||
# PyTorch implementation of logsumexp for reference
|
||||
def torch_logsumexp(logits):
|
||||
"""PyTorch implementation of logsumexp over last dimension"""
|
||||
return torch.logsumexp(logits, dim=-1)
|
||||
|
||||
|
||||
# Wrapper function for Triton logsumexp kernel
|
||||
def triton_logsumexp(logits):
|
||||
"""Triton implementation of logsumexp over last dimension"""
|
||||
B, S, V = logits.shape # pylint: disable=invalid-name
|
||||
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
|
||||
|
||||
grid = (B * S,)
|
||||
logsumexp_kernel[grid](
|
||||
logits.contiguous(),
|
||||
output,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
min(1024, triton.next_power_of_2(V)),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TritonLogSumExp(torch.autograd.Function):
|
||||
"""
|
||||
Wrap a custom autograd function to use the Triton logsumexp for gradient testing
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits):
|
||||
B, S, V = logits.shape # pylint: disable=invalid-name
|
||||
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
|
||||
|
||||
# Save inputs for backward pass
|
||||
ctx.save_for_backward(logits)
|
||||
ctx.shape = logits.shape
|
||||
|
||||
grid = (B * S,)
|
||||
logsumexp_kernel[grid](
|
||||
logits.contiguous(),
|
||||
output,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
min(1024, triton.next_power_of_2(V)),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(logits,) = ctx.saved_tensors
|
||||
|
||||
# For logsumexp, the gradient is softmax(input) * grad_output
|
||||
# First compute the logsumexp
|
||||
lse = TritonLogSumExp.apply(logits)
|
||||
|
||||
# Compute softmax by exponentiating differences
|
||||
softmax_output = torch.exp(logits - lse.unsqueeze(-1))
|
||||
|
||||
# Compute gradient of logsumexp by multiplying the softmax output by the gradient
|
||||
grad_input = softmax_output * grad_output.unsqueeze(-1)
|
||||
|
||||
return grad_input
|
||||
|
||||
|
||||
def test_logsumexp_values():
|
||||
"""Test that the Triton logsumexp implementation matches PyTorch's"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Test with various input shapes
|
||||
test_shapes = [
|
||||
(2, 3, 10), # small vocab
|
||||
(4, 5, 100), # medium vocab
|
||||
(2, 2, 32000), # large vocab (typical for LLMs)
|
||||
]
|
||||
|
||||
for shape in test_shapes:
|
||||
# Create random input tensors
|
||||
logits = torch.randn(shape, device="cuda", requires_grad=False)
|
||||
|
||||
# Compute logsumexp using both implementations
|
||||
torch_result = torch_logsumexp(logits)
|
||||
triton_result = triton_logsumexp(logits)
|
||||
|
||||
# Compare results
|
||||
max_diff = (torch_result - triton_result).abs().max().item()
|
||||
print(f"Shape {shape}, Max diff: {max_diff}")
|
||||
|
||||
# Assert that the results are very close
|
||||
assert max_diff < 1e-5, f"Results differ for shape {shape}: max diff {max_diff}"
|
||||
|
||||
|
||||
def test_logsumexp_edge_cases():
|
||||
"""Test edge cases for numerical stability"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Case 1: Very large values that might cause overflow
|
||||
logits_large = torch.ones(2, 3, 100, device="cuda") * 1000
|
||||
|
||||
# Case 2: Very small values that might cause underflow
|
||||
logits_small = torch.ones(2, 3, 100, device="cuda") * -1000
|
||||
|
||||
# Case 3: Mix of large and small values
|
||||
logits_mixed = torch.zeros(2, 3, 100, device="cuda")
|
||||
logits_mixed[:, :, 0] = 1000 # One very large value
|
||||
|
||||
# Case 4: All identical values
|
||||
logits_identical = torch.ones(2, 3, 100, device="cuda") * 5
|
||||
|
||||
# Case 5: Extreme values with NaN check
|
||||
logits_extreme = torch.cat(
|
||||
[
|
||||
torch.full((1, 3, 50), 1e10, device="cuda"),
|
||||
torch.full((1, 3, 50), -1e10, device="cuda"),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
for i, logits in enumerate(
|
||||
[logits_large, logits_small, logits_mixed, logits_identical, logits_extreme]
|
||||
):
|
||||
# Compute logsumexp using both implementations
|
||||
torch_result = torch_logsumexp(logits)
|
||||
triton_result = triton_logsumexp(logits)
|
||||
|
||||
# Check for NaNs
|
||||
assert not torch.isnan(
|
||||
torch_result
|
||||
).any(), f"PyTorch produced NaNs for case {i+1}"
|
||||
assert not torch.isnan(
|
||||
triton_result
|
||||
).any(), f"Triton produced NaNs for case {i+1}"
|
||||
|
||||
# Compare results
|
||||
max_diff = (torch_result - triton_result).abs().max().item()
|
||||
print(f"Edge case {i+1}, Max diff: {max_diff}")
|
||||
|
||||
# For very extreme values, allow a bit more tolerance
|
||||
if i == 4: # extreme case
|
||||
assert max_diff < 1e-2, f"Results differ too much for edge case {i+1}"
|
||||
else:
|
||||
assert max_diff < 1e-5, f"Results differ too much for edge case {i+1}"
|
||||
|
||||
|
||||
def test_logsumexp_gradients():
|
||||
"""Test that the gradients of Triton logsumexp match PyTorch's"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create input tensors with gradients enabled
|
||||
shapes = [(2, 3, 10), (4, 5, 100)]
|
||||
|
||||
for shape in shapes:
|
||||
# Create two identical tensors for PyTorch and Triton
|
||||
logits_torch = torch.randn(shape, device="cuda", requires_grad=True)
|
||||
logits_triton = logits_torch.clone().detach().requires_grad_(True)
|
||||
|
||||
# Forward pass
|
||||
torch_output = torch_logsumexp(logits_torch)
|
||||
triton_output = TritonLogSumExp.apply(logits_triton)
|
||||
|
||||
# Compare forward pass values
|
||||
max_diff_forward = (torch_output - triton_output).abs().max().item()
|
||||
assert max_diff_forward < 1e-5, f"Forward pass values differ for shape {shape}"
|
||||
|
||||
# Create random gradient
|
||||
grad_output = torch.randn_like(torch_output)
|
||||
|
||||
# Backward pass
|
||||
torch_output.backward(grad_output)
|
||||
triton_output.backward(grad_output)
|
||||
|
||||
# Compare gradients
|
||||
max_diff_grad = (logits_torch.grad - logits_triton.grad).abs().max().item()
|
||||
print(f"Shape {shape}, Max gradient diff: {max_diff_grad}")
|
||||
|
||||
# Assert that gradients are very close
|
||||
assert (
|
||||
max_diff_grad < 1e-5
|
||||
), f"Gradients differ for shape {shape}: max diff {max_diff_grad}"
|
||||
@@ -1,130 +0,0 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestDeepseekV3:
|
||||
"""
|
||||
Test case for DeepseekV3 models
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
[True, False],
|
||||
)
|
||||
def test_lora_deepseekv3(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/DeepSeek-V3-11M",
|
||||
"trust_remote_code": True,
|
||||
"sample_packing": sample_packing,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 2048,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversations",
|
||||
"message_property_mappings": {
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
"drop_system_message": True,
|
||||
"split": "train[:1%]",
|
||||
},
|
||||
],
|
||||
"special_tokens": {
|
||||
"bos_token": "<|begin▁of▁sentence|>",
|
||||
"eos_token": "<|end▁of▁sentence|>",
|
||||
},
|
||||
"chat_template": "deepseek_v3",
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
[True, False],
|
||||
)
|
||||
def test_fft_deepseekv3(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/DeepSeek-V3-11M",
|
||||
"trust_remote_code": True,
|
||||
"sample_packing": sample_packing,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
"split": "train[:1%]",
|
||||
},
|
||||
],
|
||||
"chat_template": "deepseek_v3",
|
||||
"special_tokens": {
|
||||
"bos_token": "<|begin▁of▁sentence|>",
|
||||
"eos_token": "<|end▁of▁sentence|>",
|
||||
},
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
@@ -102,7 +102,11 @@ def is_hopper():
|
||||
|
||||
|
||||
def check_tensorboard(
|
||||
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
||||
temp_run_dir: str,
|
||||
tag: str,
|
||||
comparison_val: float,
|
||||
assertion_err: str,
|
||||
lt: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
@@ -112,10 +116,20 @@ def check_tensorboard(
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
||||
if lt:
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < comparison_val, (
|
||||
assertion_err % df.value.values[-1]
|
||||
)
|
||||
else:
|
||||
assert df.value.values[-1] < comparison_val, assertion_err
|
||||
else:
|
||||
assert df.value.values[-1] < lt_val, assertion_err
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] > comparison_val, (
|
||||
assertion_err % df.value.values[-1]
|
||||
)
|
||||
else:
|
||||
assert df.value.values[-1] > comparison_val, assertion_err
|
||||
|
||||
|
||||
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""Shared pytest fixtures for telemetry tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_telemetry(monkeypatch):
|
||||
monkeypatch.delenv("AXOLOTL_DO_NOT_TRACK", raising=False)
|
||||
yield
|
||||
@@ -1,372 +0,0 @@
|
||||
"""Tests for telemetry callback module."""
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
|
||||
from axolotl.telemetry.callbacks import TIME_SINCE_LAST, TelemetryCallback
|
||||
|
||||
|
||||
def calc_expected_metrics(step, last_step, current_time, last_time, start_time=900.0):
|
||||
"""Calculate expected metrics values for tests"""
|
||||
time_diff = current_time - last_time
|
||||
step_diff = step - last_step
|
||||
return {
|
||||
"steps_per_second": step_diff / time_diff
|
||||
if time_diff > 0 and step_diff > 0
|
||||
else 0,
|
||||
"time_since_last_report": time_diff,
|
||||
"elapsed_time": current_time - start_time,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_time():
|
||||
"""Mock time.time() to have predictable values in tests"""
|
||||
with patch("axolotl.telemetry.callbacks.time") as mock_time:
|
||||
mock_time.time.return_value = 1000.0
|
||||
yield mock_time
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telemetry_manager():
|
||||
"""Create a mock TelemetryManager"""
|
||||
with patch("axolotl.telemetry.callbacks.TelemetryManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime_metrics_tracker():
|
||||
"""Create a mock RuntimeMetricsTracker"""
|
||||
with patch(
|
||||
"axolotl.telemetry.callbacks.RuntimeMetricsTracker"
|
||||
) as mock_tracker_class:
|
||||
mock_tracker = MagicMock()
|
||||
# Set up metrics property on the tracker
|
||||
mock_metrics = MagicMock()
|
||||
mock_metrics.to_dict.return_value = {
|
||||
"total_steps": 100,
|
||||
"peak_cpu_memory_bytes": 1024,
|
||||
}
|
||||
mock_tracker.metrics = mock_metrics
|
||||
|
||||
# Make the constructor return our mock
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
yield mock_tracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def training_args():
|
||||
"""Create a minimal TrainingArguments instance"""
|
||||
return TrainingArguments(output_dir="./output")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trainer_state():
|
||||
"""Create a mock TrainerState"""
|
||||
state = MagicMock(spec=TrainerState)
|
||||
state.global_step = 10
|
||||
state.epoch = 0.5 # halfway through first epoch
|
||||
state.log_history = [{"loss": 2.5, "learning_rate": 5e-5}]
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trainer_control():
|
||||
"""Create a mock TrainerControl"""
|
||||
return MagicMock(spec=TrainerControl)
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@pytest.fixture
|
||||
def callback(mock_telemetry_manager, mock_runtime_metrics_tracker):
|
||||
"""Create a TelemetryCallback instance with mocked dependencies"""
|
||||
return TelemetryCallback()
|
||||
|
||||
|
||||
class TestTelemetryCallback:
|
||||
"""Tests for the TelemetryCallback class."""
|
||||
|
||||
def test_initialization(self, callback, mock_runtime_metrics_tracker):
|
||||
"""Test callback initialization."""
|
||||
assert callback.current_epoch == -1
|
||||
assert callback.tracker == mock_runtime_metrics_tracker
|
||||
assert callback.last_report_step == 0
|
||||
assert hasattr(callback, "start_time")
|
||||
assert hasattr(callback, "last_report_time")
|
||||
assert callback.report_interval_steps == 100
|
||||
|
||||
def test_on_train_begin(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_train_begin sends expected event."""
|
||||
callback.on_train_begin(training_args, trainer_state, trainer_control)
|
||||
|
||||
mock_telemetry_manager.send_event.assert_called_once_with(
|
||||
event_type="train-started"
|
||||
)
|
||||
|
||||
def test_on_train_end(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_train_end sends expected event with metrics."""
|
||||
callback.on_train_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
|
||||
assert call_args["event_type"] == "train-ended"
|
||||
assert "loss" in call_args["properties"]
|
||||
assert call_args["properties"]["loss"] == 2.5
|
||||
assert "learning_rate" in call_args["properties"]
|
||||
assert call_args["properties"]["learning_rate"] == 5e-5
|
||||
|
||||
# Check that metrics from RuntimeMetricsTracker are included
|
||||
assert "total_steps" in call_args["properties"]
|
||||
assert call_args["properties"]["total_steps"] == 100
|
||||
assert "peak_cpu_memory_bytes" in call_args["properties"]
|
||||
assert call_args["properties"]["peak_cpu_memory_bytes"] == 1024
|
||||
|
||||
def test_on_epoch_begin(
|
||||
self,
|
||||
callback,
|
||||
mock_runtime_metrics_tracker,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_epoch_begin updates epoch counter and calls tracker."""
|
||||
initial_epoch = callback.current_epoch
|
||||
|
||||
callback.on_epoch_begin(training_args, trainer_state, trainer_control)
|
||||
|
||||
assert callback.current_epoch == initial_epoch + 1
|
||||
mock_runtime_metrics_tracker.start_epoch.assert_called_once_with(
|
||||
initial_epoch + 1
|
||||
)
|
||||
|
||||
def test_on_epoch_end(
|
||||
self,
|
||||
callback,
|
||||
mock_runtime_metrics_tracker,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_epoch_end calls tracker."""
|
||||
# Set current epoch
|
||||
callback.current_epoch = 2
|
||||
|
||||
callback.on_epoch_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
mock_runtime_metrics_tracker.end_epoch.assert_called_once_with(2)
|
||||
|
||||
def test_on_step_end_no_report(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end updates tracker but doesn't report if criteria not met."""
|
||||
# Set up state to avoid reporting
|
||||
trainer_state.global_step = 42 # Not divisible by report_interval_steps
|
||||
callback.last_report_step = 41 # Just 1 step since last report
|
||||
callback.last_report_time = time.time() # Just now
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should update tracker
|
||||
mock_runtime_metrics_tracker.update_step.assert_called_once_with(42)
|
||||
|
||||
# Should not send telemetry
|
||||
mock_telemetry_manager.send_event.assert_not_called()
|
||||
|
||||
# Should not update last report time/step
|
||||
assert callback.last_report_step == 41
|
||||
|
||||
def test_on_step_end_report_interval_steps(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker,
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end reports when step interval is reached."""
|
||||
# Set up state with clear values
|
||||
current_step = 100 # Exactly matches report_interval_steps
|
||||
last_step = 0
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
time_diff = current_time - start_time # 100 seconds
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
callback.report_interval_steps = 100
|
||||
callback.last_report_step = last_step
|
||||
callback.start_time = start_time
|
||||
callback.last_report_time = start_time
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should update tracker
|
||||
mock_runtime_metrics_tracker.update_step.assert_called_once_with(current_step)
|
||||
mock_runtime_metrics_tracker.update_memory_metrics.assert_called_once()
|
||||
|
||||
# Should send telemetry
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
assert call_args["event_type"] == "train-progress"
|
||||
|
||||
# Properties should include expected values
|
||||
props = call_args["properties"]
|
||||
assert props["step"] == current_step
|
||||
assert props["elapsed_time"] == time_diff # 1000 - 900 = 100
|
||||
assert props["time_since_last_report"] == time_diff # 1000 - 900 = 100
|
||||
assert props["steps_per_second"] == 1.0 # 100 steps / 100 seconds
|
||||
|
||||
# Should update last report time/step
|
||||
assert callback.last_report_step == current_step
|
||||
assert callback.last_report_time == current_time
|
||||
|
||||
def test_on_step_end_report_time_elapsed(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end reports when enough time has elapsed."""
|
||||
# Set up state with clear values
|
||||
current_step = 120
|
||||
last_step = 10
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
time_diff = TIME_SINCE_LAST + 1 # Just over the threshold
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
callback.report_interval_steps = 100
|
||||
callback.last_report_step = last_step
|
||||
callback.start_time = start_time
|
||||
callback.last_report_time = current_time - time_diff
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should send telemetry
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Properties should include expected values
|
||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
||||
expected_metrics = calc_expected_metrics(
|
||||
current_step, last_step, current_time, current_time - time_diff, start_time
|
||||
)
|
||||
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
|
||||
assert (
|
||||
props["time_since_last_report"]
|
||||
== expected_metrics["time_since_last_report"]
|
||||
)
|
||||
|
||||
def test_on_step_end_first_step(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test on_step_end always reports on first step."""
|
||||
# Set up state with clear values
|
||||
current_step = 1 # First step
|
||||
last_step = 0
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
last_report_time = 999.0 # Just 1 second ago
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
callback.report_interval_steps = 100
|
||||
callback.last_report_step = last_step
|
||||
callback.start_time = start_time
|
||||
callback.last_report_time = last_report_time
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should send telemetry even though not much time has passed
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Properties should include expected values for first step
|
||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
||||
assert props["step"] == current_step
|
||||
expected_metrics = calc_expected_metrics(
|
||||
current_step, last_step, current_time, last_report_time, start_time
|
||||
)
|
||||
assert props["steps_per_second"] == expected_metrics["steps_per_second"]
|
||||
|
||||
def test_log_history_empty(
|
||||
self,
|
||||
callback,
|
||||
mock_telemetry_manager,
|
||||
mock_runtime_metrics_tracker, # pylint: disable=unused-argument
|
||||
mock_time,
|
||||
training_args,
|
||||
trainer_state,
|
||||
trainer_control,
|
||||
):
|
||||
"""Test handling of empty log history."""
|
||||
# Set up state with clear values
|
||||
current_step = 1
|
||||
start_time = 900.0
|
||||
current_time = 1000.0
|
||||
|
||||
# Configure state and callback
|
||||
trainer_state.global_step = current_step
|
||||
trainer_state.log_history = []
|
||||
callback.start_time = start_time
|
||||
|
||||
# Mock time.time() to return consistent values
|
||||
mock_time.time.return_value = current_time
|
||||
|
||||
callback.on_step_end(training_args, trainer_state, trainer_control)
|
||||
|
||||
# Should still send telemetry
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Properties should have default values for missing log data
|
||||
props = mock_telemetry_manager.send_event.call_args[1]["properties"]
|
||||
assert props["loss"] == 0
|
||||
assert props["learning_rate"] == 0
|
||||
@@ -1,340 +0,0 @@
|
||||
"""Tests for telemetry error utilities"""
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.telemetry.errors import sanitize_stack_trace, send_errors
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_error_flag(monkeypatch):
|
||||
"""Reset ERROR_HANDLED flag using monkeypatch"""
|
||||
import axolotl.telemetry.errors
|
||||
|
||||
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
|
||||
yield
|
||||
monkeypatch.setattr(axolotl.telemetry.errors, "ERROR_HANDLED", False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_stack_trace():
|
||||
"""Provide a sample stack trace with mixed paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
|
||||
trainer = get_trainer(cfg)
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/train.py", line 214, in get_trainer
|
||||
model = get_model(cfg, tokenizer)
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/models.py", line 120, in get_model
|
||||
raise ValueError("Model path not found")
|
||||
ValueError: Model path not found
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def windows_stack_trace():
|
||||
"""Provide a sample stack trace with Windows paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\cli\\train.py", line 83, in main
|
||||
trainer = get_trainer(cfg)
|
||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\axolotl\\train.py", line 214, in get_trainer
|
||||
model = get_model(cfg, tokenizer)
|
||||
File "C:\\Users\\name\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\transformers\\models\\auto\\modeling_auto.py", line 482, in from_pretrained
|
||||
raise ValueError(f"Unrecognized configuration class {config.__class__}")
|
||||
ValueError: Unrecognized configuration class <class 'transformers.models.llama.configuration_llama.LlamaConfig'>
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixed_stack_trace():
|
||||
"""Provide a sample stack trace with both axolotl and non-axolotl paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/cli/train.py", line 83, in main
|
||||
trainer = get_trainer(cfg)
|
||||
File "/home/user/.local/lib/python3.9/site-packages/transformers/trainer.py", line 520, in train
|
||||
self._inner_training_loop()
|
||||
File "/home/user/.local/lib/python3.9/site-packages/axolotl/utils/trainer.py", line 75, in _inner_training_loop
|
||||
super()._inner_training_loop()
|
||||
File "/home/user/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
|
||||
data = self._next_data()
|
||||
RuntimeError: CUDA out of memory
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def venv_stack_trace():
|
||||
"""Provide a sample stack trace with virtual environment paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 1729, in train
|
||||
self._inner_training_loop()
|
||||
File "/home/user/venv/lib/python3.9/site-packages/transformers/trainer.py", line 2013, in _inner_training_loop
|
||||
self.accelerator.backward(loss)
|
||||
File "/home/user/venv/lib/python3.9/site-packages/accelerate/accelerator.py", line 1851, in backward
|
||||
self.scaler.scale(loss).backward(**kwargs)
|
||||
File "/home/user/venv/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
|
||||
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
|
||||
RuntimeError: CUDA out of memory
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dist_packages_stack_trace():
|
||||
"""Provide a sample stack trace with dist-packages paths"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
|
||||
data = self._next_data()
|
||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 675, in _next_data
|
||||
data = self._dataset_fetcher.fetch(index)
|
||||
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
|
||||
data = [self.dataset[idx] for idx in possibly_batched_index]
|
||||
File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2808, in __getitem__
|
||||
raise IndexError(f"Index {key} out of range for dataset of length {len(self)}.")
|
||||
IndexError: Index 10000 out of range for dataset of length 9832.
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_stack_trace():
|
||||
"""Provide a sample stack trace from a project directory (not a virtual env)"""
|
||||
return """Traceback (most recent call last):
|
||||
File "/home/user/projects/myproject/run.py", line 25, in <module>
|
||||
main()
|
||||
File "/home/user/projects/myproject/src/cli.py", line 45, in main
|
||||
app.run()
|
||||
File "/home/user/projects/myproject/src/app.py", line 102, in run
|
||||
raise ValueError("Configuration missing")
|
||||
ValueError: Configuration missing
|
||||
"""
|
||||
|
||||
|
||||
def test_sanitize_stack_trace(example_stack_trace):
|
||||
"""Test that sanitize_stack_trace properly preserves axolotl paths"""
|
||||
sanitized = sanitize_stack_trace(example_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "/home/user" not in sanitized
|
||||
assert ".local/lib/python3.9" not in sanitized
|
||||
|
||||
# Check that site-packages is preserved
|
||||
assert "site-packages/axolotl/cli/train.py" in sanitized
|
||||
assert "site-packages/axolotl/train.py" in sanitized
|
||||
assert "site-packages/axolotl/utils/models.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "ValueError: Model path not found" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_windows_paths(windows_stack_trace):
|
||||
"""Test that sanitize_stack_trace handles Windows paths"""
|
||||
sanitized = sanitize_stack_trace(windows_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "C:\\Users\\name" not in sanitized
|
||||
assert "AppData\\Local\\Programs\\Python" not in sanitized
|
||||
|
||||
# Check that both axolotl and transformers packages are preserved
|
||||
assert (
|
||||
"site-packages\\axolotl\\cli\\train.py" in sanitized
|
||||
or "site-packages/axolotl/cli/train.py" in sanitized
|
||||
)
|
||||
assert (
|
||||
"site-packages\\axolotl\\train.py" in sanitized
|
||||
or "site-packages/axolotl/train.py" in sanitized
|
||||
)
|
||||
assert (
|
||||
"site-packages\\transformers\\models\\auto\\modeling_auto.py" in sanitized
|
||||
or "site-packages/transformers/models/auto/modeling_auto.py" in sanitized
|
||||
)
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "ValueError: Unrecognized configuration class" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_mixed_paths(mixed_stack_trace):
|
||||
"""Test that sanitize_stack_trace preserves all package paths"""
|
||||
sanitized = sanitize_stack_trace(mixed_stack_trace)
|
||||
|
||||
# Check that all package paths are preserved
|
||||
assert "site-packages/axolotl/cli/train.py" in sanitized
|
||||
assert "site-packages/transformers/trainer.py" in sanitized
|
||||
assert "site-packages/axolotl/utils/trainer.py" in sanitized
|
||||
assert "site-packages/torch/utils/data/dataloader.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "RuntimeError: CUDA out of memory" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_venv_paths(venv_stack_trace):
|
||||
"""Test that sanitize_stack_trace preserves virtual environment package paths"""
|
||||
sanitized = sanitize_stack_trace(venv_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "/home/user/venv" not in sanitized
|
||||
|
||||
# Check that all package paths are preserved
|
||||
assert "site-packages/transformers/trainer.py" in sanitized
|
||||
assert "site-packages/accelerate/accelerator.py" in sanitized
|
||||
assert "site-packages/torch/_tensor.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "RuntimeError: CUDA out of memory" in sanitized
|
||||
|
||||
|
||||
def test_sanitize_dist_packages(dist_packages_stack_trace):
|
||||
"""Test that sanitize_stack_trace preserves dist-packages paths"""
|
||||
sanitized = sanitize_stack_trace(dist_packages_stack_trace)
|
||||
|
||||
# Check that system paths are removed
|
||||
assert "/usr/local/lib/python3.8" not in sanitized
|
||||
|
||||
# Check that all package paths are preserved
|
||||
assert "dist-packages/torch/utils/data/dataloader.py" in sanitized
|
||||
assert "dist-packages/torch/utils/data/_utils/fetch.py" in sanitized
|
||||
assert "dist-packages/datasets/arrow_dataset.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert (
|
||||
"IndexError: Index 10000 out of range for dataset of length 9832." in sanitized
|
||||
)
|
||||
|
||||
|
||||
def test_sanitize_project_paths(project_stack_trace):
|
||||
"""Test handling of project paths (non-virtual env)"""
|
||||
sanitized = sanitize_stack_trace(project_stack_trace)
|
||||
|
||||
# Check that personal paths are removed
|
||||
assert "/home/user/projects" not in sanitized
|
||||
|
||||
# For non-package paths, we should at least preserve the filename
|
||||
assert "run.py" in sanitized
|
||||
assert "cli.py" in sanitized
|
||||
assert "app.py" in sanitized
|
||||
|
||||
# Check that error message is preserved
|
||||
assert "ValueError: Configuration missing" in sanitized
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telemetry_manager():
|
||||
"""Create a mock TelemetryManager"""
|
||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = True
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
|
||||
def test_send_errors_successful_execution(mock_telemetry_manager):
|
||||
"""Test that send_errors doesn't send telemetry for successful function execution"""
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
result = test_func()
|
||||
assert result == "success"
|
||||
mock_telemetry_manager.send_event.assert_not_called()
|
||||
|
||||
|
||||
def test_send_errors_with_exception(mock_telemetry_manager):
|
||||
"""Test that send_errors sends telemetry when an exception occurs"""
|
||||
test_error = ValueError("Test error")
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise test_error
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
test_func()
|
||||
|
||||
assert excinfo.value == test_error
|
||||
mock_telemetry_manager.send_event.assert_called_once()
|
||||
|
||||
# Check that the error info was passed correctly
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
assert "test_func-errored" in call_args["event_type"]
|
||||
assert "Test error" in call_args["properties"]["exception"]
|
||||
assert "stack_trace" in call_args["properties"]
|
||||
|
||||
|
||||
def test_send_errors_nested_calls(mock_telemetry_manager):
|
||||
"""Test that send_errors only sends telemetry once for nested decorated functions"""
|
||||
|
||||
@send_errors
|
||||
def inner_func():
|
||||
raise ValueError("Inner error")
|
||||
|
||||
@send_errors
|
||||
def outer_func():
|
||||
return inner_func()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
outer_func()
|
||||
|
||||
# Telemetry should be sent only once for the inner function
|
||||
assert mock_telemetry_manager.send_event.call_count == 1
|
||||
call_args = mock_telemetry_manager.send_event.call_args[1]
|
||||
assert "inner_func-error" in call_args["event_type"]
|
||||
|
||||
|
||||
def test_send_errors_telemetry_disable():
|
||||
"""Test that send_errors doesn't attempt to send telemetry when disabled"""
|
||||
|
||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = False
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise ValueError("Test error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
mock_manager.send_event.assert_not_called()
|
||||
|
||||
|
||||
def test_error_handled_reset():
|
||||
"""Test that ERROR_HANDLED flag is properly reset"""
|
||||
with patch("axolotl.telemetry.errors.TelemetryManager") as mock_manager_class:
|
||||
# Create and configure the mock manager
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = True
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
|
||||
from axolotl.telemetry.errors import ERROR_HANDLED
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise ValueError("Test error")
|
||||
|
||||
assert not ERROR_HANDLED
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
from axolotl.telemetry.errors import ERROR_HANDLED
|
||||
|
||||
assert ERROR_HANDLED
|
||||
|
||||
|
||||
def test_module_path_resolution(mock_telemetry_manager):
|
||||
"""Test that the module path is correctly resolved for the event type"""
|
||||
import inspect
|
||||
|
||||
current_module = inspect.getmodule(test_module_path_resolution).__name__
|
||||
|
||||
@send_errors
|
||||
def test_func():
|
||||
raise ValueError("Test error")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
test_func()
|
||||
|
||||
assert mock_telemetry_manager.send_event.called
|
||||
event_type = mock_telemetry_manager.send_event.call_args[1]["event_type"]
|
||||
|
||||
expected_event_type = f"{current_module}.test_func-errored"
|
||||
assert expected_event_type == event_type
|
||||
@@ -1,245 +0,0 @@
|
||||
"""Tests for TelemetryManager class and utilities"""
|
||||
# pylint: disable=redefined-outer-name,protected-access
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from axolotl.telemetry.manager import TelemetryManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_whitelist(tmp_path):
|
||||
"""Create a temporary whitelist file for testing"""
|
||||
whitelist_content = {
|
||||
"organizations": ["meta-llama", "mistralai"],
|
||||
}
|
||||
whitelist_file = tmp_path / "whitelist.yaml"
|
||||
with open(whitelist_file, "w", encoding="utf-8") as f:
|
||||
yaml.dump(whitelist_content, f)
|
||||
|
||||
return str(whitelist_file)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def telemetry_manager_class():
|
||||
"""Reset the TelemetryManager singleton between tests"""
|
||||
original_instance = TelemetryManager._instance
|
||||
original_initialized = TelemetryManager._initialized
|
||||
TelemetryManager._instance = None
|
||||
TelemetryManager._initialized = False
|
||||
yield TelemetryManager
|
||||
TelemetryManager._instance = original_instance
|
||||
TelemetryManager._initialized = original_initialized
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(telemetry_manager_class, mock_whitelist):
|
||||
"""Create a TelemetryManager instance with mocked dependencies"""
|
||||
with patch("posthog.capture"), patch("posthog.flush"), patch("time.sleep"), patch(
|
||||
"axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist
|
||||
), patch.dict(os.environ, {"RANK": "0"}):
|
||||
manager = telemetry_manager_class()
|
||||
# Manually enable for most tests
|
||||
manager.enabled = True
|
||||
return manager
|
||||
|
||||
|
||||
def test_singleton_instance(telemetry_manager_class):
|
||||
"""Test that TelemetryManager is a singleton"""
|
||||
with patch("posthog.capture"), patch("time.sleep"), patch.dict(
|
||||
os.environ, {"RANK": "0"}
|
||||
):
|
||||
first = telemetry_manager_class()
|
||||
second = telemetry_manager_class()
|
||||
assert first is second
|
||||
assert telemetry_manager_class.get_instance() is first
|
||||
|
||||
|
||||
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
|
||||
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
|
||||
with patch.dict(os.environ, {"DO_NOT_TRACK": "1", "RANK": "0"}):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled for non-main processes"""
|
||||
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "1"}):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_enabled_by_default(telemetry_manager_class):
|
||||
"""Test that telemetry is enabled by default"""
|
||||
with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch("time.sleep"), patch(
|
||||
"logging.Logger.warning"
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert manager.enabled
|
||||
assert not manager.explicit_enable
|
||||
|
||||
|
||||
def test_explicit_enable_disables_warning(telemetry_manager_class):
|
||||
"""Test that explicit enabling prevents warning"""
|
||||
with patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}), patch(
|
||||
"logging.Logger.warning"
|
||||
) as mock_warning, patch("time.sleep"):
|
||||
manager = telemetry_manager_class()
|
||||
assert manager.enabled
|
||||
assert manager.explicit_enable
|
||||
for call in mock_warning.call_args_list:
|
||||
assert "Telemetry is enabled" not in str(call)
|
||||
|
||||
|
||||
def test_warning_displayed_for_implicit_enable(telemetry_manager_class):
|
||||
"""Test that warning is displayed when telemetry is implicitly enabled"""
|
||||
with patch.dict(os.environ, {"RANK": "0"}, clear=True), patch(
|
||||
"logging.Logger.warning"
|
||||
) as mock_warning, patch("time.sleep"):
|
||||
manager = telemetry_manager_class()
|
||||
assert manager.enabled
|
||||
assert not manager.explicit_enable
|
||||
warning_displayed = False
|
||||
for call in mock_warning.call_args_list:
|
||||
if "Telemetry is enabled" in str(call):
|
||||
warning_displayed = True
|
||||
break
|
||||
assert warning_displayed
|
||||
|
||||
|
||||
def test_is_whitelisted(manager, mock_whitelist):
|
||||
"""Test org whitelist functionality"""
|
||||
with patch("axolotl.telemetry.manager.WHITELIST_PATH", mock_whitelist):
|
||||
# Should match organizations from the mock whitelist
|
||||
assert manager._is_whitelisted("meta-llama/llama-7b")
|
||||
assert manager._is_whitelisted("mistralai/mistral-7b-instruct")
|
||||
# Should not match
|
||||
assert not manager._is_whitelisted("unknown/model")
|
||||
# Should handle case insensitively
|
||||
assert manager._is_whitelisted("META-LLAMA/Llama-7B")
|
||||
# Should handle empty input
|
||||
assert not manager._is_whitelisted("")
|
||||
assert not manager._is_whitelisted(None)
|
||||
|
||||
|
||||
def test_system_info_collection(manager):
|
||||
"""Test system information collection"""
|
||||
system_info = manager._get_system_info()
|
||||
|
||||
# Check essential keys
|
||||
assert "os" in system_info
|
||||
assert "python_version" in system_info
|
||||
assert "torch_version" in system_info
|
||||
assert "transformers_version" in system_info
|
||||
assert "axolotl_version" in system_info
|
||||
assert "cpu_count" in system_info
|
||||
assert "memory_total" in system_info
|
||||
assert "accelerator_count" in system_info
|
||||
|
||||
|
||||
def test_send_event(manager):
|
||||
"""Test basic event sending"""
|
||||
with patch("posthog.capture") as mock_capture:
|
||||
# Test with clean properties (no PII)
|
||||
manager.send_event("test_event", {"key": "value"})
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["event"] == "test_event"
|
||||
assert mock_capture.call_args[1]["properties"] == {"key": "value"}
|
||||
assert mock_capture.call_args[1]["distinct_id"] == manager.run_id
|
||||
|
||||
# Test with default properties (None)
|
||||
mock_capture.reset_mock()
|
||||
manager.send_event("simple_event")
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["properties"] == {}
|
||||
|
||||
|
||||
def test_send_system_info(manager):
|
||||
"""Test sending system info"""
|
||||
with patch("posthog.capture") as mock_capture:
|
||||
manager.send_system_info()
|
||||
assert mock_capture.called
|
||||
assert mock_capture.call_args[1]["event"] == "system-info"
|
||||
assert mock_capture.call_args[1]["properties"] == manager.system_info
|
||||
|
||||
|
||||
def test_redacted_properties(manager):
|
||||
"""Test path redaction in send_event method"""
|
||||
with patch("posthog.capture") as mock_capture:
|
||||
# Test with properties containing various paths and non-paths
|
||||
test_properties = {
|
||||
"filepath": "/home/user/sensitive/data.txt",
|
||||
"windows_path": "C:\\Users\\name\\Documents\\project\\file.py",
|
||||
"output_dir": "/var/lib/data",
|
||||
"path_to_model": "models/llama/7b",
|
||||
"message": "Training started", # Should not be redacted
|
||||
"metrics": {"loss": 0.5, "accuracy": 0.95}, # Should not be redacted
|
||||
"base_model": "models/local_model",
|
||||
"nested": {
|
||||
"model_path": "/models/my_model",
|
||||
"root_dir": "/home/user/projects",
|
||||
"stats": {"steps": 1000, "epochs": 3}, # Should not be redacted
|
||||
},
|
||||
}
|
||||
|
||||
manager.send_event("test_event", test_properties)
|
||||
|
||||
# Verify the call was made
|
||||
assert mock_capture.called
|
||||
|
||||
# Get the sanitized properties that were sent
|
||||
sanitized = mock_capture.call_args[1]["properties"]
|
||||
|
||||
# Check that path-like and base_model keys were redacted
|
||||
assert sanitized["filepath"] == "[REDACTED]"
|
||||
assert sanitized["windows_path"] == "[REDACTED]"
|
||||
assert sanitized["path_to_model"] == "[REDACTED]"
|
||||
assert sanitized["base_model"] == "[REDACTED]"
|
||||
|
||||
# Check that non-path values were preserved
|
||||
assert sanitized["message"] == "Training started"
|
||||
assert sanitized["metrics"] == {"loss": 0.5, "accuracy": 0.95}
|
||||
|
||||
# Check nested structure handling
|
||||
assert sanitized["nested"]["model_path"] == "[REDACTED]"
|
||||
assert sanitized["nested"]["root_dir"] == "[REDACTED]"
|
||||
assert sanitized["nested"]["stats"] == {"steps": 1000, "epochs": 3}
|
||||
|
||||
|
||||
def test_disable_telemetry(manager):
|
||||
"""Test that disabled telemetry doesn't send events"""
|
||||
with patch("posthog.capture") as mock_capture:
|
||||
manager.enabled = False
|
||||
manager.send_event("test_event")
|
||||
assert not mock_capture.called
|
||||
|
||||
|
||||
def test_exception_handling_during_send(manager):
|
||||
"""Test that exceptions in PostHog are handled gracefully"""
|
||||
with patch("posthog.capture", side_effect=Exception("Test error")), patch(
|
||||
"logging.Logger.warning"
|
||||
) as mock_warning:
|
||||
manager.send_event("test_event")
|
||||
warning_logged = False
|
||||
for call in mock_warning.call_args_list:
|
||||
if "Failed to send telemetry event" in str(call):
|
||||
warning_logged = True
|
||||
break
|
||||
assert warning_logged
|
||||
|
||||
|
||||
def test_shutdown(manager):
|
||||
"""Test shutdown behavior"""
|
||||
with patch("posthog.flush") as mock_flush:
|
||||
manager.shutdown()
|
||||
assert mock_flush.called
|
||||
@@ -1,356 +0,0 @@
|
||||
"""Tests for runtime metrics telemetry module"""
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.telemetry.runtime_metrics import RuntimeMetrics, RuntimeMetricsTracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_time():
|
||||
"""Mock time.time() to have predictable values in tests"""
|
||||
with patch("time.time") as mock_time:
|
||||
# Start with time 1000.0 and increment by 10 seconds on each call
|
||||
times = [1000.0 + i * 10 for i in range(10)]
|
||||
mock_time.side_effect = times
|
||||
yield mock_time
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_telemetry_manager():
|
||||
"""Create a mock TelemetryManager"""
|
||||
with patch(
|
||||
"axolotl.telemetry.runtime_metrics.TelemetryManager"
|
||||
) as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.enabled = True
|
||||
mock_manager_class.get_instance.return_value = mock_manager
|
||||
yield mock_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_psutil():
|
||||
"""Mock psutil for memory information"""
|
||||
with patch("axolotl.telemetry.runtime_metrics.psutil") as mock_psutil:
|
||||
mock_process = MagicMock()
|
||||
mock_memory_info = MagicMock()
|
||||
# Set initial memory to 1GB
|
||||
mock_memory_info.rss = 1024 * 1024 * 1024
|
||||
mock_process.memory_info.return_value = mock_memory_info
|
||||
mock_psutil.Process.return_value = mock_process
|
||||
yield mock_psutil
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_torch():
|
||||
"""Mock torch.cuda functions"""
|
||||
with patch("axolotl.telemetry.runtime_metrics.torch") as mock_torch:
|
||||
mock_torch.cuda.is_available.return_value = True
|
||||
mock_torch.cuda.device_count.return_value = 2
|
||||
|
||||
# Mock memory allocated per device (1GB for device 0, 2GB for device 1)
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 1) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
yield mock_torch
|
||||
|
||||
|
||||
class TestRuntimeMetrics:
|
||||
"""Tests for RuntimeMetrics class."""
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test RuntimeMetrics initialization."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
assert metrics.start_time == 1000.0
|
||||
assert metrics.epoch_start_times == {}
|
||||
assert metrics.epoch_end_times == {}
|
||||
assert metrics.peak_gpu_memory == {}
|
||||
assert metrics.total_steps == 0
|
||||
assert metrics.current_epoch == 0
|
||||
assert metrics.current_step == 0
|
||||
assert metrics.peak_cpu_memory == 0
|
||||
|
||||
def test_elapsed_time(self, mock_time):
|
||||
"""Test elapsed_time property."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# Mock time.time() to return 1050.0
|
||||
mock_time.side_effect = [1050.0]
|
||||
|
||||
assert metrics.elapsed_time == 50.0
|
||||
|
||||
def test_epoch_time(self):
|
||||
"""Test epoch_time method."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# No epoch data
|
||||
assert metrics.epoch_time(0) is None
|
||||
|
||||
# Add epoch start but no end
|
||||
metrics.epoch_start_times[0] = 1000.0
|
||||
assert metrics.epoch_time(0) is None
|
||||
|
||||
# Add epoch end
|
||||
metrics.epoch_end_times[0] = 1060.0
|
||||
assert metrics.epoch_time(0) == 60.0
|
||||
|
||||
def test_average_epoch_time(self):
|
||||
"""Test average_epoch_time method."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# No completed epochs
|
||||
assert metrics.average_epoch_time() is None
|
||||
|
||||
# Add one completed epoch
|
||||
metrics.epoch_start_times[0] = 1000.0
|
||||
metrics.epoch_end_times[0] = 1060.0
|
||||
assert metrics.average_epoch_time() == 60.0
|
||||
|
||||
# Add second completed epoch
|
||||
metrics.epoch_start_times[1] = 1060.0
|
||||
metrics.epoch_end_times[1] = 1140.0 # 80 seconds
|
||||
assert metrics.average_epoch_time() == 70.0 # Average of 60 and 80
|
||||
|
||||
# Add incomplete epoch (should not affect average)
|
||||
metrics.epoch_start_times[2] = 1140.0
|
||||
assert metrics.average_epoch_time() == 70.0
|
||||
|
||||
def test_steps_per_second(self, mock_time):
|
||||
"""Test steps_per_second method."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
|
||||
# No steps - first call to time.time()
|
||||
mock_time.side_effect = None
|
||||
mock_time.return_value = 1050.0
|
||||
assert metrics.steps_per_second() is None
|
||||
|
||||
# Add steps - second call to time.time()
|
||||
metrics.total_steps = 100
|
||||
mock_time.return_value = 1050.0 # Keep same time for consistent result
|
||||
assert metrics.steps_per_second() == 2.0 # 100 steps / 50 seconds
|
||||
|
||||
def test_to_dict_basic(self, mock_time):
|
||||
"""Test to_dict method with basic metrics."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
metrics.total_steps = 100
|
||||
metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024 # 2GB
|
||||
|
||||
# Mock elapsed_time
|
||||
mock_time.side_effect = None
|
||||
mock_time.return_value = 1050.0
|
||||
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert result["total_time_seconds"] == 50.0
|
||||
assert result["total_steps"] == 100
|
||||
assert result["steps_per_second"] == 2.0
|
||||
assert result["epochs_completed"] == 0
|
||||
assert result["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
assert "epoch_times" not in result
|
||||
assert "gpu_memory" not in result
|
||||
|
||||
def test_to_dict_with_epochs(self, mock_time):
|
||||
"""Test to_dict method with epoch data."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
metrics.total_steps = 100
|
||||
|
||||
# Add epoch data
|
||||
metrics.epoch_start_times[0] = 1000.0
|
||||
metrics.epoch_end_times[0] = 1060.0
|
||||
metrics.epoch_start_times[1] = 1060.0
|
||||
metrics.epoch_end_times[1] = 1140.0
|
||||
|
||||
# Mock elapsed_time
|
||||
mock_time.side_effect = None
|
||||
mock_time.return_value = 1150.0
|
||||
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert "epoch_times" in result
|
||||
assert result["epoch_times"]["epoch_0_seconds"] == 60.0
|
||||
assert result["epoch_times"]["epoch_1_seconds"] == 80.0
|
||||
assert result["average_epoch_time_seconds"] == 70.0
|
||||
|
||||
def test_to_dict_with_gpu_memory(self, mock_time):
|
||||
"""Test to_dict method with GPU memory data."""
|
||||
metrics = RuntimeMetrics(start_time=1000.0)
|
||||
metrics.peak_gpu_memory = {
|
||||
0: 1 * 1024 * 1024 * 1024, # 1GB
|
||||
1: 2 * 1024 * 1024 * 1024, # 2GB
|
||||
}
|
||||
|
||||
# Mock elapsed_time
|
||||
mock_time.side_effect = [1050.0]
|
||||
|
||||
result = metrics.to_dict()
|
||||
|
||||
assert "gpu_memory" in result
|
||||
assert result["gpu_memory"]["gpu_0_peak_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||
assert result["gpu_memory"]["gpu_1_peak_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
|
||||
|
||||
class TestRuntimeMetricsTracker:
|
||||
"""Tests for RuntimeMetricsTracker class."""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_initialization(self, mock_time, mock_telemetry_manager):
|
||||
"""Test RuntimeMetricsTracker initialization."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
assert isinstance(tracker.metrics, RuntimeMetrics)
|
||||
assert tracker.metrics.start_time == 1000.0 # First value from mock_time
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_start_epoch(
|
||||
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
|
||||
):
|
||||
"""Test start_epoch method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Reset mock_time to control next value
|
||||
mock_time.side_effect = [1010.0]
|
||||
|
||||
tracker.start_epoch(0)
|
||||
|
||||
assert tracker.metrics.current_epoch == 0
|
||||
assert tracker.metrics.epoch_start_times[0] == 1010.0
|
||||
|
||||
# Verify memory metrics were updated
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
assert 0 in tracker.metrics.peak_gpu_memory
|
||||
assert 1 in tracker.metrics.peak_gpu_memory
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_end_epoch(self, mock_time, mock_telemetry_manager):
|
||||
"""Test end_epoch method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Start epoch 0
|
||||
mock_time.side_effect = [1010.0]
|
||||
tracker.start_epoch(0)
|
||||
|
||||
# End epoch 0
|
||||
mock_time.side_effect = [1060.0]
|
||||
tracker.end_epoch(0)
|
||||
|
||||
assert 0 in tracker.metrics.epoch_end_times
|
||||
assert tracker.metrics.epoch_end_times[0] == 1060.0
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_update_step(
|
||||
self, mock_time, mock_psutil, mock_torch, mock_telemetry_manager
|
||||
):
|
||||
"""Test update_step method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Update step to a non-multiple of 100
|
||||
tracker.update_step(42)
|
||||
|
||||
assert tracker.metrics.current_step == 42
|
||||
assert tracker.metrics.total_steps == 1
|
||||
|
||||
# Memory metrics should not be updated for non-multiple of 100
|
||||
assert tracker.metrics.peak_cpu_memory == 0
|
||||
|
||||
# Update step to a multiple of 100
|
||||
tracker.update_step(100)
|
||||
|
||||
assert tracker.metrics.current_step == 100
|
||||
assert tracker.metrics.total_steps == 2
|
||||
|
||||
# Memory metrics should be updated for multiple of 100
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_update_memory_metrics(
|
||||
self, mock_psutil, mock_torch, mock_telemetry_manager
|
||||
):
|
||||
"""Test update_memory_metrics method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Initial memory state
|
||||
assert tracker.metrics.peak_cpu_memory == 0
|
||||
assert tracker.metrics.peak_gpu_memory == {}
|
||||
|
||||
# Update memory metrics
|
||||
tracker.update_memory_metrics()
|
||||
|
||||
# Verify CPU memory
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
|
||||
# Verify GPU memory
|
||||
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
|
||||
|
||||
# Change mocked memory values to be lower
|
||||
mock_process = mock_psutil.Process.return_value
|
||||
mock_memory_info = mock_process.memory_info.return_value
|
||||
mock_memory_info.rss = 0.5 * 1024 * 1024 * 1024 # 0.5GB
|
||||
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 0.5) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Update memory metrics again
|
||||
tracker.update_memory_metrics()
|
||||
|
||||
# Peak values should not decrease
|
||||
assert tracker.metrics.peak_cpu_memory == 1 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[0] == 1 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[1] == 2 * 1024 * 1024 * 1024
|
||||
|
||||
# Change mocked memory values to be higher
|
||||
mock_memory_info.rss = 2 * 1024 * 1024 * 1024 # 2GB
|
||||
|
||||
mock_torch.cuda.memory_allocated.side_effect = (
|
||||
lambda device: (device + 2) * 1024 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Update memory metrics again
|
||||
tracker.update_memory_metrics()
|
||||
|
||||
# Peak values should increase
|
||||
assert tracker.metrics.peak_cpu_memory == 2 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[0] == 2 * 1024 * 1024 * 1024
|
||||
assert tracker.metrics.peak_gpu_memory[1] == 3 * 1024 * 1024 * 1024
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def test_get_memory_metrics(self, mock_psutil, mock_torch, mock_telemetry_manager):
|
||||
"""Test get_memory_metrics method."""
|
||||
tracker = RuntimeMetricsTracker()
|
||||
|
||||
# Set peak memory values
|
||||
tracker.metrics.peak_cpu_memory = 2 * 1024 * 1024 * 1024
|
||||
tracker.metrics.peak_gpu_memory = {
|
||||
0: 3 * 1024 * 1024 * 1024,
|
||||
1: 4 * 1024 * 1024 * 1024,
|
||||
}
|
||||
|
||||
# Get memory metrics
|
||||
memory_metrics = tracker.get_memory_metrics()
|
||||
|
||||
# Verify CPU memory
|
||||
assert (
|
||||
memory_metrics["cpu_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||
) # Current value from mock
|
||||
assert (
|
||||
memory_metrics["peak_cpu_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
) # Peak value we set
|
||||
|
||||
# Verify GPU memory
|
||||
assert (
|
||||
memory_metrics["gpu_0_memory_bytes"] == 1 * 1024 * 1024 * 1024
|
||||
) # Current value from mock
|
||||
assert (
|
||||
memory_metrics["gpu_0_peak_memory_bytes"] == 3 * 1024 * 1024 * 1024
|
||||
) # Peak value we set
|
||||
assert (
|
||||
memory_metrics["gpu_1_memory_bytes"] == 2 * 1024 * 1024 * 1024
|
||||
) # Current value from mock
|
||||
assert (
|
||||
memory_metrics["gpu_1_peak_memory_bytes"] == 4 * 1024 * 1024 * 1024
|
||||
) # Peak value we set
|
||||
Reference in New Issue
Block a user