Compare commits

..

10 Commits

Author SHA1 Message Date
Wing Lian
68e97d032a chunk to prevent overflows in kernel 2025-02-26 04:44:24 -05:00
Wing Lian
23f029a89c lint and additional train metric checks for kd 2025-02-26 03:19:42 -05:00
Wing Lian
afbb44f08b more optims 2025-02-26 01:49:47 -05:00
Wing Lian
d753ead033 optimize and include bench util 2025-02-26 01:17:50 -05:00
Wing Lian
c011405117 fix gradients 2025-02-25 23:34:27 -05:00
Wing Lian
a2e52a29e9 chore: lint 2025-02-25 07:29:46 -05:00
Wing Lian
e82268e580 use triton for kd-loss in trainer 2025-02-24 22:58:35 -05:00
Wing Lian
75e1480c10 chunking not necessary 2025-02-24 22:56:15 -05:00
Wing Lian
45e1548d59 fix the kernels 2025-02-24 22:38:55 -05:00
Wing Lian
165088e7c1 triton kernel for top-k logprob kd 2025-02-24 22:13:26 -05:00
85 changed files with 2486 additions and 2733 deletions

View File

@@ -88,11 +88,6 @@ jobs:
pytorch: 2.5.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -80,11 +80,6 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -19,6 +19,9 @@
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
</p>
Axolotl is a tool designed to streamline post-training for various AI models.
@@ -47,14 +50,13 @@ Features:
## 🚀 Quick Start
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.4.1
### Installation
```bash
```shell
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs
@@ -66,7 +68,7 @@ Other installation approaches are described [here](https://axolotl-ai-cloud.gith
### Your First Fine-tune
```bash
```shell
# Fetch axolotl examples
axolotl fetch examples

View File

@@ -3,12 +3,10 @@ project:
website:
title: "Axolotl"
description: "We make fine-tuning accessible, scalable, and fun"
description: "Fine-tuning"
favicon: favicon.jpg
navbar:
logo: image/axolotl_logo_digital_white.svg
title: false
title: Axolotl
background: dark
pinned: false
collapse: false
@@ -27,59 +25,33 @@ website:
contents:
- text: Home
href: index.qmd
- section: "Getting Started"
- section: "How-To Guides"
contents:
# TODO Edit folder structure after we have more docs.
- docs/getting-started.qmd
- docs/installation.qmd
- docs/cli.qmd
- docs/debugging.qmd
- docs/inference.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Deployments"
contents:
- docs/docker.qmd
- docs/multipack.qmd
- docs/fsdp_qlora.qmd
- docs/input_output.qmd
- docs/rlhf.qmd
- docs/nccl.qmd
- docs/mac.qmd
- docs/multi-gpu.qmd
- docs/multi-node.qmd
- docs/ray-integration.qmd
- docs/amd_hpc.qmd
- docs/mac.qmd
- section: "How To Guides"
contents:
- docs/multimodal.qmd
- docs/rlhf.qmd
- docs/reward_modelling.qmd
- docs/lr_groups.qmd
- docs/lora_optims.qmd
- section: "Core Concepts"
contents:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/multipack.qmd
- section: "Advanced Features"
contents:
- docs/fsdp_qlora.qmd
- docs/unsloth.qmd
- docs/torchao.qmd
- docs/custom_integrations.qmd
- section: "Troubleshooting"
contents:
- docs/faq.qmd
- docs/debugging.qmd
- docs/nccl.qmd
- docs/amd_hpc.qmd
- docs/ray-integration.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Reference"
contents:
- docs/config.qmd
- docs/faq.qmd
format:
html:
theme: darkly
theme: materia
css: styles.css
toc: true

View File

@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \

View File

@@ -1,5 +1,5 @@
---
title: AMD GPUs on HPC Systems
title: Training with AMD GPUs on HPC Systems
description: A comprehensive guide for using Axolotl on distributed systems with AMD GPUs
---

View File

@@ -1,19 +1,28 @@
---
title: "CLI Reference"
format:
html:
toc: true
toc-expand: 2
toc-depth: 3
execute:
enabled: false
---
# Axolotl CLI Documentation
The Axolotl CLI provides a streamlined interface for training and fine-tuning large language models. This guide covers
the CLI commands, their usage, and common examples.
### Table of Contents
## Basic Commands
- Basic Commands
- Command Reference
- fetch
- preprocess
- train
- inference
- merge-lora
- merge-sharded-fsdp-weights
- evaluate
- lm-eval
- Legacy CLI Usage
- Remote Compute with Modal Cloud
- Cloud Configuration
- Running on Modal Cloud
- Cloud Configuration Options
### Basic Commands
All Axolotl commands follow this general structure:
@@ -23,9 +32,9 @@ axolotl <command> [config.yml] [options]
The config file can be local or a URL to a raw YAML file.
## Command Reference
### Command Reference
### fetch
#### fetch
Downloads example configurations and deepspeed configs to your local machine.
@@ -40,7 +49,7 @@ axolotl fetch deepspeed_configs
axolotl fetch examples --dest path/to/folder
```
### preprocess
#### preprocess
Preprocesses and tokenizes your dataset before training. This is recommended for large datasets.
@@ -65,7 +74,7 @@ dataset_prepared_path: Local folder for saving preprocessed data
push_dataset_to_hub: HuggingFace repo to push preprocessed data (optional)
```
### train
#### train
Trains or fine-tunes a model using the configuration specified in your YAML file.
@@ -86,38 +95,7 @@ axolotl train config.yml --no-accelerate
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
```
It is possible to run sweeps over multiple hyperparameters by passing in a sweeps config.
```bash
# Basic training with sweeps
axolotl train config.yml --sweep path/to/sweep.yaml
```
Example sweep config:
```yaml
_:
# This section is for dependent variables we need to fix
- load_in_8bit: false
load_in_4bit: false
adapter: lora
- load_in_8bit: true
load_in_4bit: false
adapter: lora
# These are independent variables
learning_rate: [0.0003, 0.0006]
lora_r:
- 16
- 32
lora_alpha:
- 16
- 32
- 64
```
### inference
#### inference
Runs inference using your trained model in either CLI or Gradio interface mode.
@@ -137,7 +115,7 @@ cat prompt.txt | axolotl inference config.yml \
--base-model="./completed-model"
```
### merge-lora
#### merge-lora
Merges trained LoRA adapters into the base model.
@@ -159,7 +137,7 @@ gpu_memory_limit: Limit GPU memory usage
lora_on_cpu: Load LoRA weights on CPU
```
### merge-sharded-fsdp-weights
#### merge-sharded-fsdp-weights
Merges sharded FSDP model checkpoints into a single combined checkpoint.
@@ -168,7 +146,7 @@ Merges sharded FSDP model checkpoints into a single combined checkpoint.
axolotl merge-sharded-fsdp-weights config.yml
```
### evaluate
#### evaluate
Evaluates a model's performance using metrics specified in the config.
@@ -177,27 +155,27 @@ Evaluates a model's performance using metrics specified in the config.
axolotl evaluate config.yml
```
### lm-eval
#### lm-eval
Runs LM Evaluation Harness on your model.
```bash
# Basic evaluation
axolotl lm-eval config.yml
# Evaluate specific tasks
axolotl lm-eval config.yml --tasks arc_challenge,hellaswag
```
Configuration options:
```yaml
# List of tasks to evaluate
lm_eval_tasks:
- arc_challenge
- hellaswag
lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
lm_eval_tasks: List of tasks to evaluate
lm_eval_batch_size: Batch size for evaluation
output_dir: Directory to save evaluation results
```
## Legacy CLI Usage
### Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
@@ -217,18 +195,12 @@ accelerate launch -m axolotl.cli.inference config.yml \
--lora_model_dir="./outputs/lora-out" --gradio
```
::: {.callout-important}
When overriding CLI parameters in the legacy CLI, use same notation as in yaml file (e.g., `--lora_model_dir`).
**Note:** This differs from the new Click-based CLI, which uses dash notation (e.g., `--lora-model-dir`). Keep this in mind if you're referencing newer documentation or switching between CLI versions.
:::
## Remote Compute with Modal Cloud
### Remote Compute with Modal Cloud
Axolotl supports running training and inference workloads on Modal cloud infrastructure. This is configured using a
cloud YAML file alongside your regular Axolotl config.
### Cloud Configuration
#### Cloud Configuration
Create a cloud config YAML with your Modal settings:
@@ -243,17 +215,13 @@ branch: main # Git branch to use (optional)
volumes: # Persistent storage volumes
- name: axolotl-cache
mount: /workspace/cache
- name: axolotl-data
mount: /workspace/data
- name: axolotl-artifacts
mount: /workspace/artifacts
env: # Environment variables
- WANDB_API_KEY
- HF_TOKEN
```
### Running on Modal Cloud
#### Running on Modal Cloud
Commands that support the --cloud flag:
@@ -271,18 +239,18 @@ axolotl train config.yml --cloud cloud_config.yml --no-accelerate
axolotl lm-eval config.yml --cloud cloud_config.yml
```
### Cloud Configuration Options
#### Cloud Configuration Options
```yaml
provider: # compute provider, currently only `modal` is supported
gpu: # GPU type to use
gpu_count: # Number of GPUs (default: 1)
memory: # RAM in GB (default: 128)
timeout: # Maximum runtime in seconds
timeout_preprocess: # Preprocessing timeout
branch: # Git branch to use
docker_tag: # Custom Docker image tag
volumes: # List of persistent storage volumes
env: # Environment variables to pass
secrets: # Secrets to inject
provider: compute provider, currently only `modal` is supported
gpu: GPU type to use
gpu_count: Number of GPUs (default: 1)
memory: RAM in GB (default: 128)
timeout: Maximum runtime in seconds
timeout_preprocess: Preprocessing timeout
branch: Git branch to use
docker_tag: Custom Docker image tag
volumes: List of persistent storage volumes
env: Environment variables to pass
secrets: Secrets to inject
```

View File

@@ -154,6 +154,8 @@ datasets:
content: value
# ...
message_property_mappings:
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
user: ["human", "user"]
@@ -161,16 +163,10 @@ datasets:
system: ["system"]
tool: ["tool"]
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
# This does not drop the default system message from chat_template if it exists. If you wish to,
# we recommend using a custom jinja template with the default system message removed or
# adding a system turn with empty content.
drop_system_message:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`
# Note: If the below 4 fields are set to empty, defaults to training only on the last message.
# Note: If the below 4 fields are empty, defaults to training only on the last message.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["assistant"] # default
@@ -178,7 +174,6 @@ datasets:
# - all: train on all EOS tokens
# - turn (default): train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
@@ -226,8 +221,8 @@ process_reward_model:
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message. Currently only supports chatml.
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
@@ -449,7 +444,7 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -532,8 +527,6 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
sdp_attention:
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention:
# Optional[bool]. Whether to use low_cpu_mem_usage
low_cpu_mem_usage:
# Resume from a specific checkpoint dir
resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
@@ -554,13 +547,6 @@ special_tokens:
# Add extra tokens.
tokens:
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
# Can be checked if they exist in tokenizer.json added_tokens.
added_tokens_overrides: # Dict[int, str]
# 128041: "<|im_start|>"
# 128042: "<|im_end|>"
# FSDP
fsdp:
fsdp_config:

View File

@@ -1,57 +0,0 @@
---
title: Custom Integrations
toc: true
toc-depth: 3
---
```{python}
#| echo: false
import re
def process_readme(integration_name):
try:
path = f'../src/axolotl/integrations/{integration_name}/README.md'
with open(path, 'r') as f:
txt = f.read()
# Remove h1 headings
txt = re.sub(r'^# .*\n?', '', txt, flags=re.MULTILINE)
# Convert h2 to h3
txt = re.sub(r'^## ', '### ', txt, flags=re.MULTILINE)
return txt
except FileNotFoundError:
return None
def print_section(name, folder_name):
output = f"\n## {name}\n"
content = process_readme(folder_name)
if content:
output += content
output += f"\nPlease see reference [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/{folder_name})\n"
return output
```
```{python}
#| output: asis
#| echo: false
# Introduction text
print("""
Axolotl adds custom features through `integrations`. They are located within the `src/axolotl/integrations` directory.
To enable them, please check the respective documentations.
""")
# Sections
sections = [
("Cut Cross Entropy", "cut_cross_entropy"),
("Grokfast", "grokfast"),
("Knowledge Distillation (KD)", "kd"),
("Liger Kernels", "liger"),
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
("Spectrum", "spectrum")
]
for section_name, folder_name in sections:
print(print_section(section_name, folder_name))
```

View File

@@ -6,9 +6,7 @@ order: 3
## sharegpt
::: {.callout-important}
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
:::
IMPORTANT: ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
## pygmalion
@@ -74,10 +72,6 @@ datasets:
train_on_eos:
```
::: {.callout-tip}
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
:::
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
@@ -108,10 +102,6 @@ datasets:
type: chat_template
```
::: {.callout-important}
Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
:::
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
@@ -159,6 +149,4 @@ datasets:
message_field_training_detail: train_detail
```
::: {.callout-tip}
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.

View File

@@ -13,7 +13,7 @@ As there are a lot of available options in Axolotl, this guide aims to provide a
Axolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.
## Pre-training
## [Pre-training](pretraining.qmd)
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
@@ -96,10 +96,6 @@ One step is equal to `sequence_len * micro_batch_size * gradient_accumulation_st
It is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.
### Reference
Please see docs [here](pretraining.qmd).
## Supervised fine-tuning (SFT)
Supervised fine-tuning is the process of training models to respond to an instruction or chat input.
@@ -124,12 +120,11 @@ If you went through the flow chart and did not find one that matches, it is reco
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
:::
### Pre-Tokenized Dataset
### [Pre-Tokenized Dataset](tokenized.qmd)
We suggest this approach when you want to bring your own tokenized dataset.
Axolotl expects the dataset to have three keys:
- `input_ids`: from tokenizing formatted prompt
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.
@@ -150,9 +145,7 @@ datasets:
`type: ` is empty!
:::
Reference: [Pre-Tokenized Dataset Documentation](tokenized.qmd).
### Template Free Dataset
### [Template Free Dataset](template_free.qmd)
We reccomend this approach when you want granular control over the prompt formatting, special tokens, and masking, whilst letting Axolotl handle the tokenization. This is very useful if your dataset has unique prompts that differ across samples and where one single general template wouldn't suffice.
@@ -189,9 +182,7 @@ datasets:
type: input_output
```
Reference: [Template Free Documentation](template_free.qmd).
### Conversation Dataset
### [Conversation Dataset](conversation.qmd)
`conversation` messages are a list of messages which usually contain a `role` and `content` key.
@@ -267,7 +258,7 @@ Newer conversation datasets usually follow the OpenAI format.
Axolotl supports both as well as allowing customization of any kind of key.
#### Chat Template Usage
#### [Chat Template Usage](conversation.qmd#chat_template)
To properly use this method, it is important to identify three things:
@@ -349,19 +340,9 @@ datasets:
narrator: ["narrator"]
```
::: {.callout-tip}
As chat_templates may use hardcoded EOS/EOT tokens that are different from the tokenizer's EOS, it is highly recommended to set them. For example, `ChatML` uses `<|im_end|>` to end turns.
#### Applying `chat_template`
```yaml
special_tokens:
eos_token: <|im_end|>
```
:::
##### Applying `chat_template`
Once all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset.
Once all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset. The final step would be to correctly set the EOS token in your config:
```yaml
datasets:
@@ -410,17 +391,7 @@ If this config were to be applied to the sample dataset above, the output would
The first number refers to the label, the second refers to the `token_id`. For example, `-100` labels appear on non-assistant portions, meaning that they are masked during. For assistant portions, the label is the same as the `token_id`.
::: {.callout-note}
If during `preprocess`, there are a lot of warnings of `Could not find content __ boundary`, please check the FAQ section for [chat_templates](../faq.qmd#chat-templates).
:::
#### Reference
Please see docs [here](conversation.qmd).
### Instruction Dataset
### [Instruction Dataset](inst_tune.qmd)
Instruction datasets are used to train instruction-following models and comprise a prompt, containing an instruction, and a single response. In contrast to chat datasets which may be multi-turn, instruct datasets are typically single-turn.
@@ -452,9 +423,6 @@ datasets:
Axolotl supports many kinds of instruction dataset. All of them can be found here (https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/inst_tune.html) with their respective type and sample row format.
Reference: [Instruction Dataset Documentation](inst_tune.qmd).
#### Custom Instruct Prompt Format
Due to the myriad possibilities of instruction formats, Axolotl allows customizing your own instruction format without having to dive into the code directly.
@@ -485,8 +453,6 @@ datasets:
The config sets that the `field_instruction` is actually named `input`, and the `field_input` is empty as we don't have an `input` in this sample. Generally, `instruction` can be thought as the question to the model, and `input` as the additional information with `output` being the response. It is not necessary to have an `input` nor `system`. In the end, the most important part is to understand what format you want it to look like and how you can customize this to your use case.
Reference: [Custom Instruct Prompt Format Documentation](inst_tune.qmd#how-to-add-custom-prompt-format).
## Reinforcement Learning from Human Feedback (RLHF)
As there are multiple RLHF methods with their own dataset requirements. Please see [RLHF documentation](../rlhf.qmd) for more detail.
As there are multiple RLHF methods with their own dataset requirements. Please see [RLHF datasets](../rlhf.qmd) documentation for more detail.

View File

@@ -27,6 +27,7 @@ pretraining_dataset:
type: pretrain
trust_remote_code:
skip: # number of rows of data to skip over from the beginning
...
```
:::

View File

@@ -1,239 +1,7 @@
---
title: Template-Free
description: Construct prompts without a template.
toc: true
toc-depth: 3
order: 4
---
## Background {#sec-background}
### Masking Inputs {#masking-inputs}
One of the most popular features of
[axolotl](https://github.com/axolotl-ai-cloud/axolotl) is
setting the following configuration value:
```yaml
train_on_inputs: false
```
If you declare a [dataset formats](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#dataset)
such as `alpaca` or `chatml`, axolotl knows what is an input
(i.e. human) vs. an output (i.e. the assistant) and masks the input
labels so that your model can focus on predicting the outputs only.
### You may not want prompt templates {#sec-you-may-not-want-prompt-templates}
However, there are many situations where you don't want to use one of
these formats or templates. This is because they can:
- Add unnecessary boilerplate to your prompts.
- Create artifacts like special delimiters `<|im_start|>` that can
quickly become footguns if you don't include them correctly at
inference time.
- Enforce a *chat* interface when you do not want one. Sometimes you
just want to fine-tune a model to a very specific task and do NOT
want multi-turn conversations, roles, etc.
- Limit you to only certain roles that the template allows.
### The `input_output` format {#sec-the-inputoutput-format}
You can construct your prompts without a template by using the
`input_output` format, by setting `type: input_output` in your
configuration file like this:
**config.yml**
```yaml
train_on_inputs: false # Mask segments of your data
datasets:
- path: output.jsonl
type: input_output # use template free prompt construction
```
Unlike `type: completion`, which is also template-free,
`type: input_output` allows you to mask segments of your text. More
details on how this works are described below.
## Usage {#sec-usage}
This is how you can use the `input_output` format:
### 1. Prepare Data {#sec-1-prepare-data}
To use the `input_output` format, collect your data in the following
format into a jsonl file (below is the first row from the file
`output`.jsonl` pretty printed):
```bash
$ head -n1 output.jsonl | python -m json.tool
```
:::{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
:::
Set `label:false` when you want to mask a segment of text so that the
model isn't trained on it. Some things to keep in mind:
> [!IMPORTANT]
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
concatenates all the segments as-is.** The tokenizer doesn't add
anything additional. Notice how I added spaces, newlines, `<s>`
(BOS), and `</s>` (EOS) myself.
> 2. Make sure you check the materialized output to validate that the
prompt is getting assembled how you like.
### 2. Use `type: input_output` {#sec-2-use-type-inputoutput}
Let's materialize data with our `output.jsonl` file by setting
`type: input_output` in our axolotl config:
```yaml
# training_config.yaml
base_model: mistralai/Mistral-7B-v0.1
data_seed: 49
seed: 49
datasets:
- path: output.jsonl
type: input_output
val_set_size: 0.1
sequence_len: 896
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 3
eval_batch_size: 2
num_epochs: 1
learning_rate: 0.0002
train_on_inputs: false
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
```
You can use the following command to materialize your data. The
`--debug` flag will print the tokens, along with the labels so you can
verify that the correct items are being ignored:
```bash
axolotl preprocess training_config.yaml --debug
...
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
```
The format is `decoded_token`(`label`, `token_id`), for example,
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
token_id is `1`. When the label is `-100` then that token is ignored for
training.
### 3. Check the prompts {#sec-3-check-the-prompts}
Here is another way to check the materialized output:
```python
from transformers import AutoTokenizer
from datasets import load_from_disk
import yaml
directory = !ls last_run_prepared/
with open('training_config.yaml', 'r') as f:
cfg = yaml.safe_load(f)
model_id = cfg['base_model']
tok = AutoTokenizer.from_pretrained(model_id)
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
```
```python
>>> row = ds[0]
>>> print(tok.decode(row['input_ids']))
<s> Hello
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ignored by comparing the labels
to each token:
```python
import pandas as pd
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
zip(row['input_ids'], row['labels'])])
```
| token | label | id |
|-------|-------|-------|
| 0 | \<s\> | 1 |
| 1 | Hello | 22557 |
| 2 | \\n | 13 |
| 3 | hi | 12014 |
| 4 | there | 736 |
| 5 | ! | 28808 |
| 6 | . | 28723 |
| 7 | | 28705 |
| 8 | good | -100 |
| 9 | bye | -100 |
| 10 | | -100 |
| 11 | fare | 19111 |
| 12 | well | 5458 |
| 13 | \</s\>| 2 |
If we look at the input data, the above table seems correct! (The jsonl
version is repeated below for reference):
```bash
$ head -n1 output.jsonl | python -m json.tool
```
:::{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
:::
See [these docs](../input_output.qmd).

View File

@@ -3,11 +3,8 @@ title: Dataset Preprocessing
description: How datasets are processed
---
## Overview
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
the [dataset format](docs/dataset-formats) and prompt strategies to:
the (dataset format)[../dataset-formats/] and prompt strategies to:
- parse the dataset based on the *dataset format*
- transform the dataset to how you would interact with the model based on the *prompt strategy*
- tokenize the dataset based on the configured model & tokenizer
@@ -15,12 +12,10 @@ the [dataset format](docs/dataset-formats) and prompt strategies to:
The processing of the datasets can happen one of two ways:
1. Before kicking off training by calling `axolotl preprocess config.yaml --debug`
1. Before kicking off training by calling `python -m axolotl.cli.preprocess /path/to/your.yaml --debug`
2. When training is started
### What are the benefits of pre-processing?
When training interactively or for sweeps
What are the benefits of pre-processing? When training interactively or for sweeps
(e.g. you are restarting the trainer often), processing the datasets can oftentimes be frustratingly
slow. Pre-processing will cache the tokenized/formatted datasets according to a hash of dependent
training parameters so that it will intelligently pull from its cache when possible.
@@ -33,12 +28,8 @@ default path of `./last_run_prepared/`, but will ignore anything already cached
setting `dataset_prepared_path: ./last_run_prepared`, the trainer will use whatever pre-processed
data is in the cache.
### What are the edge cases?
Let's say you are writing a custom prompt strategy or using a user-defined
What are the edge cases? Let's say you are writing a custom prompt strategy or using a user-defined
prompt template. Because the trainer cannot readily detect these changes, we cannot change the
calculated hash value for the pre-processed dataset.
If you have `dataset_prepared_path: ...` set
calculated hash value for the pre-processed dataset. If you have `dataset_prepared_path: ...` set
and change your prompt templating logic, it may not pick up the changes you made and you will be
training over the old prompt.

View File

@@ -31,13 +31,11 @@ While debugging it's helpful to simplify your test scenario as much as possible.
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
```yaml
datasets:
dataset:
...
shards: 20
```
3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings.
- `micro_batch_size: 1`
@@ -87,7 +85,7 @@ The easiest way to get started is to modify the [.vscode/launch.json](../.vscode
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
```json
```jsonc
// .vscode/launch.json
{
"version": "0.2.0",
@@ -134,7 +132,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
Below is the [./vscode/tasks.json](../.vscode/tasks.json) file that defines the `cleanup-for-dataprep` task. This task is run before each debugging session when you use the above configuration. Note how there are two tasks that delete the two folders mentioned above. The third task `cleanup-for-dataprep` is a composite task that combines the two tasks. A composite task is necessary because VSCode does not allow you to specify multiple tasks in the `preLaunchTask` argument of the `launch.json` file.
```json
```jsonc
// .vscode/tasks.json
// this file is used by launch.json
{

View File

@@ -1,140 +0,0 @@
---
title: "Docker"
format:
html:
toc: true
toc-depth: 4
---
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
## Base
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
#### Image
```
axolotlai/axolotl-base
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
#### Tags format
```bash
main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
```
Tags examples:
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
## Main
The main image is the image that is used to run Axolotl. It is based on the `axolotlai/axolotl-base` image and includes the Axolotl codebase, dependencies, and more.
#### Image
```
axolotlai/axolotl
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
#### Tags format {#sec-main-tags}
```bash
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
main-latest
# nightly build
{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}
# tagged release
{version}
```
:::{.callout-tip}
There may be some extra tags appended to the image, like `-vllm` which installs those packages.
:::
Tags examples:
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu124-2.4.1`
- `0.7.1`
## Cloud
The cloud image is the image that is used to run Axolotl in the cloud. It is based on the `axolotlai/axolotl` image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.
:::{.callout-tip}
Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variables to disable it.
:::
#### Image
```
axolotlai/axolotl-cloud
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
#### Tags format
This uses the same tags as the [`main` image](#sec-main-tags).
#### Environment variables
- `JUPYTER_DISABLE`: Disable Jupyter lab.
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
- `PUBLIC_KEY`: Add a public key for the SSH service.
- `SSH_KEY`: Add a private key for the SSH service.
#### Volume mounts
:::{.callout-tip}
We recommend mounting volumes to `/workspace/data` for data persistence. `/workspace/axolotl` contains the source code and is ephemeral.
:::
- `/workspace/data/axolotl-artifacts`: Directory to store Axolotl artifacts.
- `/workspace/data/huggingface-cache`: Directory to store HuggingFace cache.
## Cloud-no-tmux
This is the same as the [`cloud` image](#sec-cloud) but without tmux.
#### Image
```
axolotlai/axolotl-cloud-term
```
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud-term)
:::{.callout-note}
The naming may be a bit confusing as it has `-term` appended to the end.
:::
#### Tags format
This uses the same tags as the [`cloud` image](#sec-cloud-tags).

View File

@@ -3,7 +3,6 @@ title: FAQ
description: Frequently asked questions
---
### General
**Q: The trainer stopped and hasn't progressed in several minutes.**
@@ -19,40 +18,12 @@ description: Frequently asked questions
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
**Q: ModuleNotFoundError: No module named 'mpi4py' using single GPU with deepspeed**
> A: You may be using deepspeed with single gpu. Please remove the `deepspeed:` section in the yaml file or `--deepspeed` CLI flag.
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
**Q: The codes is stuck on saving preprocessed datasets.**
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.
**Q: `Empty template generated for turn ___`**
> A: The `content` is empty for that turn.
**Q: `Could not find content start/end boundary for turn __`**
> A: The specific turn's start/end could not be detected. Please ensure you have set the `eos_token` following your `chat_template`. Otherwise, this could be a `chat_template` which doesn't use proper boundaries for each turn (like system). On the rare occurrence, make sure your content is not `[[dummy_message]]`. Please let us know about this.
**Q: `Content end boundary is before start boundary for turn ___`**
> A: This is an edge case which should not occur. Please create an Issue if this happens.
**Q: `Content end boundary is the same as start boundary for turn ___. This is likely an empty turn.`**
> A: This is likely an empty turn.
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.

View File

@@ -1,5 +1,5 @@
---
title: "Quickstart"
title: "Getting Started with Axolotl"
format:
html:
toc: true
@@ -17,12 +17,12 @@ Let's start by fine-tuning a small language model using LoRA. This example uses
Assuming `axolotl` is installed (if not, see our [Installation Guide](installation.qmd))
1. Download example configs:
```bash
```shell
axolotl fetch examples
```
2. Run the training:
```bash
```shell
axolotl train examples/llama-3/lora-1b.yml
```
@@ -108,7 +108,7 @@ Please consult the supported [Dataset Formats](dataset-formats/) for more detail
3. Run the training:
```bash
```shell
axolotl train my_training.yml
```
@@ -118,7 +118,7 @@ axolotl train my_training.yml
After training, test your model:
```bash
```shell
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
```
@@ -126,7 +126,7 @@ axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
For large datasets, preprocess first:
```bash
```shell
axolotl preprocess my_training.yml
```
@@ -134,7 +134,7 @@ axolotl preprocess my_training.yml
Launch a Gradio interface:
```bash
```shell
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
```

View File

@@ -1,10 +1,11 @@
---
title: "Inference"
title: "Inference Guide"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---

View File

@@ -3,4 +3,263 @@ title: Template-free prompt construction
description: "Template-free prompt construction with the `input_output` format"
---
The documentation moved to [here](dataset-formats/template_free.qmd).
<!-- TOC -->
- [Background](#background)
- [Masking Inputs](#masking-inputs)
- [You may not want prompt templates](#you-may-not-want-prompt-templates)
- [The `input_output` format](#the-input_output-format)
- [Usage](#usage)
- [1. Prepare Data](#1-prepare-data)
- [2. Use `type: input_output`](#2-use-type-input_output)
- [3. Check the prompts](#3-check-the-prompts)
<!-- /TOC -->
<a id="markdown-background" name="background"></a>
## Background
<a id="markdown-masking-inputs" name="masking-inputs"></a>
### Masking Inputs
One of the most popular features of
[axolotl](https://github.com/axolotl-ai-cloud/axolotl) is
setting the following configuration value:
```yaml
train_on_inputs: false
```
If you declare a [dataset formats](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#dataset)
such as `alpaca` or `chatml`, axolotl knows what is an input
(i.e. human) vs. an output (i.e. the assistant) and masks the input
labels so that your model can focus on predicting the outputs only.
<a id="markdown-you-may-not-want-prompt-templates" name="you-may-not-want-prompt-templates"></a>
### You may not want prompt templates
However, there are many situations where you don't want to use one of
these formats or templates. This is because they can:
- Add unnecessary boilerplate to your prompts.
- Create artifacts like special delimiters `<|im_start|>` that can
quickly become footguns if you don't include them correctly at
inference time.
- Enforce a *chat* interface when you do not want one. Sometimes you
just want to fine-tune a model to a very specific task and do NOT
want multi-turn conversations, roles, etc.
- Limit you to only certain roles that the template allows.
<a id="markdown-the-inputoutput-format" name="the-inputoutput-format"></a>
### The `input_output` format
You can construct your prompts without a template by using the
`input_output` format, by setting `type: input_output` in your
configuration file like this:
**config.yml**
```yaml
train_on_inputs: false # Mask segments of your data
datasets:
- path: output.jsonl
type: input_output # use template free prompt construction
```
Unlike `type: completion`, which is also template-free,
`type: input_output` allows you to mask segments of your text. More
details on how this works are described below.
<a id="markdown-usage" name="usage"></a>
## Usage
This is how you can use the `input_output` format:
<a id="markdown-1-prepare-data" name="1-prepare-data"></a>
### 1. Prepare Data
To use the `input_output` format, collect your data in the following
format into a jsonl file (below is the first row from the file
`output`.jsonl` pretty printed):
```bash
$ head -n1 output.jsonl | python -m json.tool
```
:::{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
:::
Set `label:false` when you want to mask a segment of text so that the
model isn't trained on it. Some things to keep in mind:
> [!IMPORTANT]
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
concatenates all the segments as-is.** The tokenizer doesn't add
anything additional. Notice how I added spaces, newlines, `<s>`
(BOS), and `</s>` (EOS) myself.
> 2. Make sure you check the materialized output to validate that the
prompt is getting assembled how you like.
<a id="markdown-2-use-type-inputoutput" name="2-use-type-inputoutput"></a>
### 2. Use `type: input_output`
Let's materialize data with our `output.jsonl` file by setting
`type: input_output` in our axolotl config:
```yaml
# training_config.yaml
base_model: mistralai/Mistral-7B-v0.1
data_seed: 49
seed: 49
datasets:
- path: output.jsonl
type: input_output
val_set_size: 0.1
sequence_len: 896
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 3
eval_batch_size: 2
num_epochs: 1
learning_rate: 0.0002
train_on_inputs: false
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
```
You can use the following command to materialize your data. The
`--debug` flag will print the tokens, along with the labels so you can
verify that the correct items are being ignored:
```bash
$ python -m axolotl.cli.preprocess training_config.yaml --debug
...
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
```
The format is `decoded_token`(`label`, `token_id`), for example,
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
token_id is `1`. When the label is `-100` then that token is ignored for
training.
<a id="markdown-3-check-the-prompts" name="3-check-the-prompts"></a>
### 3. Check the prompts
Here is another way to check the materialized output:
```python
from transformers import AutoTokenizer
from datasets import load_from_disk
import yaml
directory = !ls last_run_prepared/
with open('training_config.yaml', 'r') as f:
cfg = yaml.safe_load(f)
model_id = cfg['base_model']
tok = AutoTokenizer.from_pretrained(model_id)
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
```
```python
>>> row = ds[0]
>>> print(tok.decode(row['input_ids']))
<s> Hello
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ignored by comparing the labels
to each token:
```python
import pandas as pd
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
zip(row['input_ids'], row['labels'])])
```
| token | label | id |
|-------|-------|-------|
| 0 | \<s\> | 1 |
| 1 | Hello | 22557 |
| 2 | \\n | 13 |
| 3 | hi | 12014 |
| 4 | there | 736 |
| 5 | ! | 28808 |
| 6 | . | 28723 |
| 7 | | 28705 |
| 8 | good | -100 |
| 9 | bye | -100 |
| 10 | | -100 |
| 11 | fare | 19111 |
| 12 | well | 5458 |
| 13 | \</s\>| 2 |
If we look at the input data, the above table seems correct! (The jsonl
version is repeated below for reference):
```bash
$ head -n1 output.jsonl | python -m json.tool
```
:::{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
:::

View File

@@ -1,10 +1,11 @@
---
title: "Installation"
title: "Installation Guide"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---
@@ -65,8 +66,6 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
```
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
## Cloud Environments {#sec-cloud}
### Cloud GPU Providers {#sec-cloud-gpu}

View File

@@ -1,6 +1,7 @@
---
title: "LoRA Optimizations"
description: "Custom autograd functions and Triton kernels in Axolotl for optimized LoRA fine-tuning"
description: "Custom autograd functions and Triton kernels in Axolotl for optimized
LoRA fine-tuning"
---
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two

View File

@@ -19,5 +19,4 @@ Current support:
- [ ] DeepSpeed
Untested:
- FSDP

View File

@@ -1,5 +1,5 @@
---
title: "Multi-GPU"
title: "Multi-GPU Training Guide"
format:
html:
toc: true
@@ -35,11 +35,7 @@ deepspeed: deepspeed_configs/zero1.json
### Usage {#sec-deepspeed-usage}
```{.bash}
# Passing arg via config
axolotl train config.yml
# Passing arg via cli
axolotl train config.yml --deepspeed deepspeed_configs/zero1.json
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
```
### ZeRO Stages {#sec-zero-stages}
@@ -74,7 +70,25 @@ For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
### Liger Kernel Integration {#sec-liger}
Please see [docs](custom_integrations.qmd#liger) for more info.
::: {.callout-note}
Liger Kernel provides efficient Triton kernels for LLM training, offering:
- 20% increase in multi-GPU training throughput
- 60% reduction in memory usage
- Compatibility with both FSDP and DeepSpeed
:::
Configuration:
```{.yaml}
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -13,7 +13,7 @@ You will also need to have the same configuration file for your model on each ma
Make sure the main machine is reachable by other machines.
:::
## Accelerate
# Accelerate
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
@@ -51,17 +51,17 @@ fsdp_config:
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
## Raytrain
# Raytrain
Please see ray train doc [here](ray-integration.qmd).
## Torchrun
# Torchrun
If you are using Infiniband, we recommend torchrun to utilize the full bandwidth.
Set the following env (change buffersize/socketname depending on your system):
```bash
```yaml
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
export NCCL_BUFFSIZE=2097152

View File

@@ -13,13 +13,13 @@ Often, this timeout will happen after 30 minutes (the default setting) and is ac
Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:
```bash
```shell
nvidia-smi nvlink --status
```
To force NCCL to use NVLink, simply set this in the environment:
```bash
```shell
export NCCL_P2P_LEVEL=NVL
```
@@ -33,13 +33,13 @@ If NVLink is not available in your environment there are other options for ``NCC
To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:
```bash
```shell
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
```
It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:
```bash
```shell
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
export TORCH_DISTRIBUTED_DEBUG=INFO

View File

@@ -1,5 +1,5 @@
---
title: Ray Train
title: Ray Train integration
description: How to use Axolotl with Ray Train
---
@@ -9,7 +9,7 @@ With the `--use-ray` CLI flag, Axolotl will use Ray Train's [`TorchTrainer`](htt
## Ray cluster setup
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs [here](https://docs.ray.io/en/latest/cluster/getting-started.html).
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs here: https://docs.ray.io/en/latest/cluster/getting-started.html
Every Ray cluster has one _head_ node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this [doc](https://docs.ray.io/en/latest/cluster/key-concepts.html#cluster-key-concepts).
@@ -58,11 +58,13 @@ You can find an example configuration at `configs/llama-3/lora-1b-ray.yaml`.
The key parameters to note here are:
```yaml
...
use_ray: true
ray_num_workers: 4
# optional
resources_per_worker:
GPU: 1
...
```
- `use_ray`: This is the flag that enables the Ray Train integration. You can either use the corresponding `--use-ray` flag in the CLI or set `use_ray` in the config file.

View File

@@ -28,17 +28,6 @@ val_set_size: 0.1
eval_steps: 100
```
Bradley-Terry chat templates expect single-turn conversations in the following format:
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### Process Reward Models (PRM)
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
@@ -56,5 +45,3 @@ datasets:
val_set_size: 0.1
eval_steps: 100
```
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.

View File

@@ -3,23 +3,22 @@ title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
toc-depth: 3
---
## Overview
# Overview
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
feedback. Various methods include, but not limited to:
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
- [Direct Preference Optimization (DPO)](#dpo)
- [Identity Preference Optimization (IPO)](#ipo)
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
## RLHF using Axolotl
# RLHF using Axolotl
::: {.callout-important}
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
@@ -31,7 +30,7 @@ We rely on the [TRL](https://github.com/huggingface/trl) library for implementat
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
:::
### DPO
## DPO
Example config:
@@ -48,7 +47,7 @@ datasets:
DPO supports the following types with the following dataset format:
#### chatml.argilla
### chatml.argilla
```json
{
@@ -59,7 +58,7 @@ DPO supports the following types with the following dataset format:
}
```
#### chatml.argilla_chat
### chatml.argilla_chat
```json
{
@@ -74,7 +73,7 @@ DPO supports the following types with the following dataset format:
}
```
#### chatml.icr
### chatml.icr
```json
{
@@ -85,7 +84,7 @@ DPO supports the following types with the following dataset format:
}
```
#### chatml.intel
### chatml.intel
```json
{
@@ -96,7 +95,7 @@ DPO supports the following types with the following dataset format:
}
```
#### chatml.prompt_pairs
### chatml.prompt_pairs
```json
{
@@ -107,7 +106,7 @@ DPO supports the following types with the following dataset format:
}
```
#### chatml.ultra
### chatml.ultra
```json
{
@@ -124,7 +123,7 @@ DPO supports the following types with the following dataset format:
}
```
#### llama3.argilla
### llama3.argilla
```json
{
@@ -135,7 +134,7 @@ DPO supports the following types with the following dataset format:
}
```
#### llama3.argilla_chat
### llama3.argilla_chat
```json
{
@@ -150,7 +149,7 @@ DPO supports the following types with the following dataset format:
}
```
#### llama3.icr
### llama3.icr
```json
{
@@ -161,7 +160,7 @@ DPO supports the following types with the following dataset format:
}
```
#### llama3.intel
### llama3.intel
```json
{
@@ -172,7 +171,7 @@ DPO supports the following types with the following dataset format:
}
```
#### llama3.prompt_pairs
### llama3.prompt_pairs
```json
{
@@ -183,7 +182,7 @@ DPO supports the following types with the following dataset format:
}
```
#### llama3.ultra
### llama3.ultra
```json
{
@@ -200,7 +199,7 @@ DPO supports the following types with the following dataset format:
}
```
#### zephyr.nectar
### zephyr.nectar
```json
{
@@ -219,7 +218,7 @@ DPO supports the following types with the following dataset format:
}
```
#### chat_template.default
### chat_template.default
```yaml
rl: dpo
@@ -265,7 +264,7 @@ Sample input format:
}
```
#### user_defined.default
### user_defined.default
For custom behaviors,
@@ -296,7 +295,7 @@ The input format is a simple JSON input with customizable fields based on the ab
}
```
### IPO
## IPO
As IPO is just DPO with a different loss function, all supported options for DPO works here.
@@ -304,7 +303,7 @@ As IPO is just DPO with a different loss function, all supported options for DPO
rl: ipo
```
### ORPO
## ORPO
Paper: https://arxiv.org/abs/2403.07691
@@ -321,7 +320,7 @@ datasets:
ORPO supports the following types with the following dataset format:
#### chat_template.argilla
### chat_template.argilla
```json
{
@@ -340,7 +339,7 @@ ORPO supports the following types with the following dataset format:
}
```
### KTO
## KTO
```yaml
rl: kto
@@ -361,7 +360,7 @@ gradient_checkpointing_kwargs:
KTO supports the following types with the following dataset format:
#### chatml.argilla
### chatml.argilla
```json
{
@@ -371,7 +370,7 @@ KTO supports the following types with the following dataset format:
}
```
#### chatml.argilla_chat
### chatml.argilla_chat
```json
{
@@ -384,7 +383,7 @@ KTO supports the following types with the following dataset format:
}
```
#### chatml.intel
### chatml.intel
```json
{
@@ -394,7 +393,7 @@ KTO supports the following types with the following dataset format:
}
```
#### chatml.prompt_pairs
### chatml.prompt_pairs
```json
{
@@ -404,7 +403,7 @@ KTO supports the following types with the following dataset format:
}
```
#### chatml.ultra
### chatml.ultra
```json
{
@@ -414,7 +413,7 @@ KTO supports the following types with the following dataset format:
}
```
#### llama3.argilla
### llama3.argilla
```json
{
@@ -424,7 +423,7 @@ KTO supports the following types with the following dataset format:
}
```
#### llama3.argilla_chat
### llama3.argilla_chat
```json
{
@@ -435,7 +434,7 @@ KTO supports the following types with the following dataset format:
}
```
#### llama3.intel
### llama3.intel
```json
{
@@ -445,7 +444,7 @@ KTO supports the following types with the following dataset format:
}
```
#### llama3.prompt_pairs
### llama3.prompt_pairs
```json
{
@@ -455,7 +454,7 @@ KTO supports the following types with the following dataset format:
}
```
#### llama3.ultra
### llama3.ultra
```json
{
@@ -465,7 +464,7 @@ KTO supports the following types with the following dataset format:
}
```
#### user_defined.default
### user_defined.default
For custom behaviors,
@@ -495,52 +494,7 @@ The input format is a simple JSON input with customizable fields based on the ab
}
```
### GRPO
GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions:
```python
# rewards.py
import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
```
```yaml
rl: grpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: True
vllm_device: auto
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}'
```
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### Using local dataset files
## Using local dataset files
```yaml
datasets:
@@ -551,7 +505,7 @@ datasets:
type: chatml.intel
```
### TRL auto-unwrapping for PEFT
## TRL auto-unwrapping for PEFT
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:

View File

@@ -3,12 +3,6 @@ title: "PyTorch ao"
description: "Custom data types and layouts for training and inference"
---
To use experimental optimizers (`AdamWFp8`, `AdamW4bit`, `AdamW8bit`) from Pytorch Ao, please install the package as shown below.
::: {.callout-tip}
Some experimental optimizers are already present in regular Pytorch, so please re-check if you actually need this package!
:::
### Installation
Stable Release from the PyTorch index

View File

@@ -8,12 +8,6 @@ description: "Hyper-optimized QLoRA finetuning for single GPUs"
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
standard industry baselines.
::: {.callout-important}
Due to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.
This will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).
:::
### Installation
@@ -23,7 +17,7 @@ The following will install the correct unsloth and extras from source.
python scripts/unsloth_install.py | sh
```
### Usage
### Using unsloth w Axolotl
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.

View File

@@ -1,7 +1,7 @@
---
# toc-location: right-body
# toc-title: Table Of Contents
# toc-expand: 2
toc-location: right-body
toc-title: Table Of Contents
toc-expand: 2
---
```{python}

View File

@@ -7,7 +7,7 @@ mamba-ssm==1.2.0.post1
flash-attn==2.7.4.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.3
liger-kernel==0.5.2
# END section
packaging==23.2
@@ -62,5 +62,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
axolotl-contribs-lgpl==0.0.3

View File

@@ -24,5 +24,5 @@ if cce_spec:
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
)

View File

@@ -113,7 +113,7 @@ class ModalCloud(Cloud):
[
# Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
]
)
@@ -258,22 +258,25 @@ class ModalCloud(Cloud):
def _preprocess(config_yaml: str, volumes=None):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
run_folder = "/workspace/artifacts/axolotl"
run_cmd(
"axolotl preprocess /workspace/mounts/config.yaml --dataset-processes=8",
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
run_folder,
volumes,
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
run_folder = "/workspace/artifacts/axolotl"
if accelerate:
accelerate_args = "--accelerate"
else:
@@ -282,19 +285,20 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
if num_processes := kwargs.pop("num_processes", None):
num_processes_args = f"--num-processes {num_processes}"
run_cmd(
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)
def _lm_eval(config_yaml: str, volumes=None):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
run_folder = "/workspace/artifacts/axolotl"
run_cmd(
"axolotl lm-eval /workspace/mounts/config.yaml",
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)

View File

@@ -1,7 +1,6 @@
"""CLI to run training on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -35,20 +34,18 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
print_axolotl_text_art()
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
del trainer
plugin_manager.post_train_unload(cfg)

View File

@@ -24,8 +24,8 @@ class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset
eval_dataset: Dataset | None = None
total_num_steps: int | None = None
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:

View File

@@ -43,7 +43,7 @@ class TokenizedChatDataset(Dataset):
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(32, process_or_cpu_count)
num_proc = min(64, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,

View File

@@ -35,7 +35,6 @@ from transformers import (
EarlyStoppingCallback,
TrainerCallback,
)
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
@@ -85,7 +84,6 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
from axolotl.utils.models import ensure_dtype
try:
@@ -93,11 +91,13 @@ try:
except ImportError:
pass
LOG = logging.getLogger(__name__)
LOG = logging.getLogger("axolotl.core.trainer_builder")
class TrainerBuilderBase(abc.ABC):
"""Base class for trainer builder."""
"""
Base class for trainer builder
"""
_train_dataset = None
_eval_dataset = None
@@ -110,9 +110,9 @@ class TrainerBuilderBase(abc.ABC):
self.tokenizer = tokenizer
self.processor = processor
# If the model supports tagging, add the axolotl tag.
# in case the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
# model.push_to_hub instead of trainer.push_to_hub.
# model.push_to_hub instad of trainer.push_to_hub.
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
@@ -227,8 +227,8 @@ class TrainerBuilderBase(abc.ABC):
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for causal models and reward modeling
using TRL.
Build the HuggingFace training args/trainer for causal models
and reward modelling using TRL.
"""
def get_callbacks(self):
@@ -551,8 +551,30 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
"alternate_lr_scheduler_type"
@@ -636,114 +658,46 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# Handle custom optimizer
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
if self.cfg.optimizer in custom_supported_optimizers:
# Common optimizer kwargs
optimizer_kwargs = {
"lr": training_arguments_kwargs.get("learning_rate"),
"weight_decay": training_arguments_kwargs.get("weight_decay"),
}
# pylint: disable=duplicate-code
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
# Adam-specific kwargs
adam_kwargs = {}
if training_arguments_kwargs.get(
"adam_beta1"
) and training_arguments_kwargs.get("adam_beta2"):
adam_kwargs["betas"] = (
training_arguments_kwargs.get("adam_beta1"),
training_arguments_kwargs.get("adam_beta2"),
)
if training_arguments_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
MuonOptimizerFactory,
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
if "weight_decay" in training_arguments_kwargs:
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
)
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
trainer_kwargs["optimizers"] = (
Lion(params=self.model.parameters(), **lion_kwargs),
None,
)
else:
# Use transformers' optimizer
training_arguments_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config:
training_arguments_kwargs[
"accelerator_config"
@@ -751,12 +705,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_ce_alpha_end is not None:
training_arguments_kwargs["kd_ce_alpha_end"] = self.cfg.kd_ce_alpha_end
if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_alpha_end is not None:
training_arguments_kwargs["kd_alpha_end"] = self.cfg.kd_alpha_end
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
@@ -922,7 +872,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase):
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
"""
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
"""
def get_callbacks(self):
callbacks = super().get_callbacks()

View File

@@ -14,7 +14,6 @@ from typing import Dict, Literal, Optional
import torch
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer
@@ -23,11 +22,9 @@ from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
@@ -118,17 +115,6 @@ class SchedulerMixin(Trainer):
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
@@ -168,18 +154,47 @@ class SchedulerMixin(Trainer):
return self.lr_scheduler
class OptimizerMixin(Trainer):
class AxolotlTrainer(SchedulerMixin, Trainer):
"""
Mixin class for shared handling of building custom optimizers
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def create_optimizer_grouped_parameters(
self, opt_model, optimizer_kwargs
) -> list[dict]:
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
params: dict = {
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
@@ -266,30 +281,23 @@ class OptimizerMixin(Trainer):
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if (
not self.optimizer
and self.optimizer_cls_and_kwargs is not None
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
@@ -306,47 +314,50 @@ class OptimizerMixin(Trainer):
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
else:
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
)
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
or self.args.lr_groups is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -355,45 +366,6 @@ class OptimizerMixin(Trainer):
return self.optimizer
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:

View File

@@ -9,7 +9,6 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl")
@@ -32,44 +31,30 @@ class GRPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if not hasattr(cfg, "trl") or not cfg.trl:
return grpo_args_kwargs
trl: TRLConfig = cfg.trl # type: ignore
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_device"] = (
trl.vllm_device if trl.vllm_device else "auto"
)
if trl.vllm_gpu_memory_utilization:
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = trl.vllm_gpu_memory_utilization
if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
if trl.num_generations:
grpo_args_kwargs["num_generations"] = trl.num_generations
if trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
if trl.ref_model_mixup_alpha:
grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha
if trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
grpo_args_kwargs["log_completions"] = trl.log_completions
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
return grpo_args_kwargs
@classmethod

View File

@@ -23,8 +23,6 @@ import importlib
import logging
from typing import OrderedDict
import torch
class BasePlugin:
"""
@@ -471,14 +469,3 @@ class PluginManager:
"""
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)
class BaseOptimizerFactory:
"""
Base class for factories to create custom optimizers
"""
def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> "torch.optim.Optimizer":
pass

View File

@@ -1,26 +1,6 @@
# Cut Cross Entropy
Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
See https://github.com/apple/ml-cross-entropy
## Requirements
- PyTorch 2.4.0 or higher
## Installation
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
```bash
# if you are in dev environment
python scripts/cutcrossentropy_install.py | sh
# if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
```
## Usage
### Usage
```yaml
plugins:
@@ -28,19 +8,3 @@ plugins:
cut_cross_entropy: true
```
## Citation
```bib
@article{wijmans2024cut,
author = {Erik Wijmans and
Brody Huval and
Alexander Hertzberg and
Vladlen Koltun and
Philipp Kr\"ahenb\"uhl},
title = {Cut Your Losses in Large-Vocabulary Language Models},
journal = {arXiv},
year = {2024},
url = {https://arxiv.org/abs/2411.09009},
}
```

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
)

View File

@@ -2,7 +2,7 @@
See https://github.com/ironjr/grokfast
## Usage
### Usage
```yaml
plugins:
@@ -11,14 +11,3 @@ plugins:
grokfast_alpha: 2.0
grokfast_lamb: 0.98
```
## Citation
```bib
@article{lee2024grokfast,
title={{Grokfast}: Accelerated Grokking by Amplifying Slow Gradients},
author={Lee, Jaerin and Kang, Bong Gyun and Kim, Kihoon and Lee, Kyoung Mu},
journal={arXiv preprint arXiv:2405.20233},
year={2024}
}
```

View File

@@ -1,23 +0,0 @@
# Knowledge Distillation
## Usage
```yaml
plugins:
- "axolotl.integrations.kd.KDPlugin"
kd_trainer: True
kd_ce_alpha: 0.1
kd_alpha: 0.9
kd_temperature: 1.0
torch_compile: True # torch>=2.5.1, recommended to reduce vram
datasets:
- path: ...
type: "axolotl.integrations.kd.chat_template"
field_messages: "messages_combined"
logprobs_field: "llm_text_generation_vllm_logprobs" # for kd only, field of logprobs
```
An example dataset can be found at [`axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample`](https://huggingface.co/datasets/axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample)

View File

@@ -34,12 +34,3 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer
return None
def add_callbacks_post_trainer(self, cfg, trainer):
callbacks = []
if cfg.kd_trainer:
from .callbacks import KDAlphaSchedulerCallback
callbacks.append(KDAlphaSchedulerCallback())
return callbacks

View File

@@ -30,8 +30,6 @@ class KDArgs(BaseModel):
float
] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_ce_alpha_end: Optional[float] = None # end value for kd_ce_alpha
kd_alpha_end: Optional[float] = None # end value for kd_alpha
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[

View File

@@ -1,28 +0,0 @@
from transformers import TrainerCallback
class KDAlphaSchedulerCallback(TrainerCallback):
"""Callback to for scheduling KD alpha during training."""
def on_epoch_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if int(state.epoch) == 0:
state.kd_alpha = args.kd_alpha
state.kd_ce_alpha = args.kd_ce_alpha
elif int(state.epoch) == state.num_train_epochs - 1:
if args.kd_alpha_end is not None:
control.kd_alpha = args.kd_alpha_end
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = args.kd_ce_alpha_end
else:
epoch_steps = state.num_train_epochs - 1
scale = int(state.epoch) / epoch_steps
if args.kd_alpha_end is not None:
control.kd_alpha = (
args.kd_alpha + (args.kd_alpha_end - args.kd_alpha) * scale
)
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = (
args.kd_ce_alpha + (args.kd_ce_alpha_end - args.kd_ce_alpha) * scale
)

View File

@@ -62,16 +62,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
Transform logprobs to target format for KD training
"""
if "target_logprobs" in sample.keys() and "target_token_ids" in sample.keys():
logprobs = sample.pop("target_logprobs")
token_ids = sample.pop("target_token_ids")
else:
logprobs = sample.pop(self.logprobs_field)
token_ids = [None] * len(logprobs)
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
target_padding_len = input_seq_len - target_seq_len
input_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
@@ -88,11 +82,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids = []
target_mask = []
if target_padding_len < 0:
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
target_padding_len = 0
input_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
@@ -104,37 +98,33 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, target_padding_len):
for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for position in range(target_padding_len, input_seq_len):
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for token_pos_logprobs, token_pos_token_ids in zip(logprobs, token_ids):
for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids
position_logprobs = []
position_token_ids = []
# Process each token probability entry
if token_pos_token_ids is None:
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
else:
position_logprobs = token_pos_logprobs
position_token_ids = token_pos_token_ids
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
# Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor(
@@ -153,7 +143,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
teacher_probs_t2 = teacher_probs_t1**exponent
else:
teacher_probs_t2 = teacher_probs_t1
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True

View File

@@ -0,0 +1,391 @@
"""
benchmark utility helper for benchmarking the KL divergence triton kernel
"""
import gc
import time
import torch
from torch.utils.benchmark import Timer
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
# pylint: disable=cell-var-from-loop
def benchmark_kl_div_loss_with_backward():
# Test configurations
batch_sizes = [1, 4]
seq_lens = [64, 512, 2048, 4096, 8192]
vocab_size = 32000
top_k = 64
# Store results
results = []
# Run benchmarks
for batch_size in batch_sizes:
for seq_len in seq_lens:
# Generate random test data
torch.manual_seed(42)
# Create tensors with gradients
student_logits = torch.randn(
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
)
# pylint: disable=duplicate-code
target_token_ids = torch.randint(
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
)
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
target_mask = torch.randint(
0, 2, (batch_size, seq_len, top_k), device="cuda"
).float()
# Clone student_logits for the two implementations
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
# Define functions for timing that include both forward and backward passes
def run_reference():
# Forward pass
loss_ref = eager_loss(
student_logits_ref, target_token_ids, target_logprobs, target_mask
)
# Backward pass
loss_ref.backward()
def run_triton():
# Forward pass
# pylint: disable=duplicate-code
loss_triton = triton_loss(
student_logits_triton,
target_token_ids,
target_logprobs,
target_mask,
)
# Backward pass
loss_triton.backward()
# Benchmark reference implementation (forward + backward)
t0 = Timer(
stmt="run_reference()",
globals={
"run_reference": run_reference,
},
)
# Reset gradients before timing
student_logits_ref.grad = None
ref_time = t0.timeit(10).median * 1000 # Convert to ms
# Benchmark Triton implementation (forward + backward)
t1 = Timer(
stmt="run_triton()",
globals={
"run_triton": run_triton,
},
)
# Reset gradients before timing
student_logits_triton.grad = None
triton_time = t1.timeit(10).median * 1000 # Convert to ms
# Compute speedup
speedup = ref_time / triton_time if triton_time > 0 else float("inf")
# Store results
results.append(
{
"batch_size": batch_size,
"seq_len": seq_len,
"reference_time_ms": ref_time,
"triton_time_ms": triton_time,
"speedup": speedup,
}
)
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
print(f" Reference time (fwd+bwd): {ref_time:.2f} ms")
print(f" Triton time (fwd+bwd): {triton_time:.2f} ms")
print(f" Speedup: {speedup:.2f}x")
return results
def benchmark_forward_backward_separately():
"""
Benchmark forward and backward passes separately to identify where the speedup comes from.
"""
# Test configurations
batch_sizes = [1, 4, 8]
seq_lens = [64, 512, 2048]
vocab_size = 32000
top_k = 64
# Store results
detailed_results = []
# Run benchmarks
for batch_size in batch_sizes:
for seq_len in seq_lens:
# Generate random test data
torch.manual_seed(42)
# Create tensors with gradients
student_logits = torch.randn(
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
)
# pylint: disable=duplicate-code
target_token_ids = torch.randint(
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
)
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
target_mask = torch.randint(
0, 2, (batch_size, seq_len, top_k), device="cuda"
).float()
# Clone student_logits for the two implementations
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
# Forward-only reference
def run_reference_forward():
with torch.no_grad():
return eager_loss(
student_logits_ref,
target_token_ids,
target_logprobs,
target_mask,
)
# Forward-only triton
def run_triton_forward():
with torch.no_grad():
return triton_loss(
student_logits_triton,
target_token_ids,
target_logprobs,
target_mask,
)
# Benchmark forward pass only
t0_fwd = Timer(
stmt="run_reference_forward()",
globals={
"run_reference_forward": run_reference_forward,
},
)
ref_fwd_time = t0_fwd.timeit(10).median * 1000 # Convert to ms
t1_fwd = Timer(
stmt="run_triton_forward()",
globals={
"run_triton_forward": run_triton_forward,
},
)
triton_fwd_time = t1_fwd.timeit(10).median * 1000 # Convert to ms
# Pre-compute losses for backward pass benchmarking
loss_ref = eager_loss(
student_logits_ref, target_token_ids, target_logprobs, target_mask
)
loss_triton = triton_loss(
student_logits_triton, target_token_ids, target_logprobs, target_mask
)
# Backward-only reference
def run_reference_backward():
student_logits_ref.grad = None
loss_ref.backward()
# Backward-only triton
def run_triton_backward():
student_logits_triton.grad = None
loss_triton.backward()
# Benchmark backward pass only
t0_bwd = Timer(
stmt="run_reference_backward()",
globals={
"run_reference_backward": run_reference_backward,
},
)
ref_bwd_time = t0_bwd.timeit(10).median * 1000 # Convert to ms
t1_bwd = Timer(
stmt="run_triton_backward()",
globals={
"run_triton_backward": run_triton_backward,
},
)
triton_bwd_time = t1_bwd.timeit(10).median * 1000 # Convert to ms
# Compute speedups
fwd_speedup = (
ref_fwd_time / triton_fwd_time if triton_fwd_time > 0 else float("inf")
)
bwd_speedup = (
ref_bwd_time / triton_bwd_time if triton_bwd_time > 0 else float("inf")
)
total_ref_time = ref_fwd_time + ref_bwd_time
total_triton_time = triton_fwd_time + triton_bwd_time
total_speedup = (
total_ref_time / total_triton_time
if total_triton_time > 0
else float("inf")
)
# Store results
detailed_results.append(
{
"batch_size": batch_size,
"seq_len": seq_len,
"ref_forward_ms": ref_fwd_time,
"triton_forward_ms": triton_fwd_time,
"forward_speedup": fwd_speedup,
"ref_backward_ms": ref_bwd_time,
"triton_backward_ms": triton_bwd_time,
"backward_speedup": bwd_speedup,
"total_ref_ms": total_ref_time,
"total_triton_ms": total_triton_time,
"total_speedup": total_speedup,
}
)
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
print(
f" Forward: Reference={ref_fwd_time:.2f}ms, Triton={triton_fwd_time:.2f}ms, Speedup={fwd_speedup:.2f}x"
)
print(
f" Backward: Reference={ref_bwd_time:.2f}ms, Triton={triton_bwd_time:.2f}ms, Speedup={bwd_speedup:.2f}x"
)
print(
f" Total: Reference={total_ref_time:.2f}ms, Triton={total_triton_time:.2f}ms, Speedup={total_speedup:.2f}x"
)
return detailed_results
def benchmark_memory_usage_with_backward():
# Test configurations
batch_sizes = [1, 2]
seq_len = 8192
vocab_size = 128000
top_k = 64
# Store results
mem_results = []
# Run benchmarks
for batch_size in batch_sizes:
# Generate random test data
torch.manual_seed(42)
student_logits = torch.randn(
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
)
target_token_ids = torch.randint(
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
)
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
target_mask = torch.randint(
0, 2, (batch_size, seq_len, top_k), device="cuda"
).float()
# Clone student_logits for the implementations
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
# Measure PyTorch memory usage (forward + backward)
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
loss_ref = eager_loss(
student_logits_ref, target_token_ids, target_logprobs, target_mask
)
loss_ref.backward()
torch.cuda.synchronize()
pytorch_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
# Measure Triton memory usage (forward + backward)
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
loss_triton = triton_loss(
student_logits_triton, target_token_ids, target_logprobs, target_mask
)
loss_triton.backward()
torch.cuda.synchronize()
triton_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
# Measure Triton memory usage with different chunk sizes (forward + backward)
for n_chunks in [1, 2, 4, 8]:
student_logits_chunk = student_logits.clone().detach().requires_grad_(True)
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
loss_chunk = triton_loss(
student_logits_chunk,
target_token_ids,
target_logprobs,
target_mask,
)
loss_chunk.backward()
torch.cuda.synchronize()
chunk_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
mem_results.append(
{
"batch_size": batch_size,
"implementation": f"Triton (chunks={n_chunks})",
"memory_mb": chunk_mem,
}
)
# Store results
mem_results.append(
{
"batch_size": batch_size,
"implementation": "PyTorch",
"memory_mb": pytorch_mem,
}
)
mem_results.append(
{
"batch_size": batch_size,
"implementation": "Triton (default)",
"memory_mb": triton_mem,
}
)
print(f"Batch size: {batch_size} (with backward pass)")
print(f" PyTorch memory: {pytorch_mem:.2f} MB")
print(f" Triton memory: {triton_mem:.2f} MB")
print(f" Memory reduction: {(1 - triton_mem/pytorch_mem)*100:.2f}%")
return mem_results
def main():
print("Running benchmarks with forward and backward passes...")
benchmark_kl_div_loss_with_backward()
clean()
print("\nRunning detailed forward/backward benchmarks...")
# benchmark_forward_backward_separately()
# clean()
print("\nRunning memory usage benchmarks with backward passes...")
benchmark_memory_usage_with_backward()
clean()
def clean():
for _ in range(5):
gc.collect()
torch.cuda.empty_cache()
time.sleep(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,750 @@
"""
Optimized Triton kernel for KL divergence loss between teacher and student models.
"""
# pylint: disable=invalid-name,unused-argument
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
@triton.jit
def fused_logsumexp_logprobs_kernel(
student_logits_ptr, # Input logits in original dtype
student_logprobs_ptr, # Output logprobs (float32)
token_ids_ptr, # Token IDs for top-k
B,
S,
V,
K, # batch size, seq len, vocab size, top-k
temperature,
stride_l_b,
stride_l_s,
stride_l_v,
stride_lp_b,
stride_lp_s,
stride_lp_k,
stride_t_b,
stride_t_s,
stride_t_k,
BLOCK_SIZE: tl.constexpr,
):
"""
Fused kernel that computes logsumexp and logprobs for topk tokens.
All computations are done in float32 for numerical stability.
"""
# Program ID
pid = tl.program_id(0)
batch_idx = pid // S
seq_idx = pid % S
# Bounds check
if batch_idx >= B or seq_idx >= S:
return
# Compute logsumexp over the vocabulary
max_val = -float("inf")
# Phase 1: Find max value across vocabulary
for v_offset in range(0, V, BLOCK_SIZE):
# Create block indices and mask
block_size = min(BLOCK_SIZE, V - v_offset)
block_idx = tl.arange(0, BLOCK_SIZE)
mask = block_idx < block_size
# Load logits block and convert to float32 in-place
ptrs = (
student_logits_ptr
+ batch_idx * stride_l_b
+ seq_idx * stride_l_s
+ (v_offset + block_idx) * stride_l_v
)
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
# Apply temperature scaling if needed
if temperature != 1.0:
block_logits = block_logits / temperature
# Update max value
block_max = tl.max(block_logits, axis=0)
max_val = tl.maximum(max_val, block_max)
# Phase 2: Compute sum of exp(logits - max_val)
sum_exp = 0.0
for v_offset in range(0, V, BLOCK_SIZE):
# Create block indices and mask
block_size = min(BLOCK_SIZE, V - v_offset)
block_idx = tl.arange(0, BLOCK_SIZE)
mask = block_idx < block_size
# Load logits block and convert to float32 in-place
ptrs = (
student_logits_ptr
+ batch_idx * stride_l_b
+ seq_idx * stride_l_s
+ (v_offset + block_idx) * stride_l_v
)
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
# Apply temperature scaling if needed
if temperature != 1.0:
block_logits = block_logits / temperature
# Compute exp(logits - max_val) and add to sum
block_exp = tl.exp(block_logits - max_val)
sum_exp += tl.sum(block_exp * mask, axis=0)
# Compute final logsumexp
logsumexp = max_val + tl.log(sum_exp)
# Phase 3: Compute and store logprobs for the top-k tokens
token_ids_base = token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
logprobs_base = (
student_logprobs_ptr + batch_idx * stride_lp_b + seq_idx * stride_lp_s
)
for k in range(K):
# Load token ID for position k
token_id = tl.load(token_ids_base + k * stride_t_k)
# Load the corresponding logit and convert to float32
token_logit_ptr = (
student_logits_ptr
+ batch_idx * stride_l_b
+ seq_idx * stride_l_s
+ token_id * stride_l_v
)
token_logit = tl.load(token_logit_ptr).to(tl.float32)
# Apply temperature scaling if needed
if temperature != 1.0:
token_logit = token_logit / temperature
# Compute logprob directly: logit - logsumexp
token_logprob = token_logit - logsumexp
# Store the result
tl.store(logprobs_base + k * stride_lp_k, token_logprob)
@triton.jit
def grad_softmax_kernel(
grad_student_logits_ptr,
target_token_ids_ptr,
teacher_probs_ptr,
student_probs_ptr,
mask_ptr,
B,
S,
V,
K, # batch size, seq len, vocab size, top-k
scale,
stride_gl_b,
stride_gl_s,
stride_gl_v,
stride_t_b,
stride_t_s,
stride_t_k,
stride_p_b,
stride_p_s,
stride_p_k,
stride_sp_b,
stride_sp_s,
stride_sp_k,
stride_m_b,
stride_m_s,
stride_m_k,
BLOCK_SIZE: tl.constexpr,
):
# Program ID
pid = tl.program_id(0)
batch_idx = pid // S
seq_idx = pid % S
# Bounds check
if batch_idx >= B or seq_idx >= S:
return
# Base pointers for this (batch, seq) pair
grad_logits_base = (
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
)
token_ids_base = (
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
)
teacher_probs_base = (
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
)
student_probs_base = (
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
)
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
# Process each teacher probability one at a time, computing all gradients for it
for k in range(0, K):
# Load data for current position k
teacher_prob = tl.load(teacher_probs_base + k * stride_p_k)
student_prob_k = tl.load(student_probs_base + k * stride_sp_k)
mask_val = tl.load(mask_base + k * stride_m_k)
# Precompute the self-influence term (multiplied by scale)
self_term = teacher_prob * (1.0 - student_prob_k) * scale
# Calculate gradient contributions for all positions j
for j in range(0, K):
token_id_j = tl.load(token_ids_base + j * stride_t_k)
student_prob_j = tl.load(student_probs_base + j * stride_sp_k)
mask_j = tl.load(mask_base + j * stride_m_k)
# Calculate the masking factor
combined_mask = mask_val * mask_j
# Determine if this is a diagonal or off-diagonal term
is_k_equals_j = tl.where(k == j, 1.0, 0.0)
# Compute the gradient contribution
# For diagonal (k==j): -teacher_prob * (1-student_prob_k) * scale * mask
# For off-diagonal: -(-teacher_prob * student_prob_j) * scale * mask
grad_contribution = (
-(
self_term * is_k_equals_j
- teacher_prob * student_prob_j * scale * (1.0 - is_k_equals_j)
)
* combined_mask
)
# Atomically update the gradient for this token
tl.atomic_add(
grad_logits_base + token_id_j * stride_gl_v, grad_contribution
)
@triton.jit
def grad_topk_softmax_kernel(
grad_student_logits_ptr,
student_logits_ptr,
target_token_ids_ptr,
teacher_probs_ptr,
student_probs_ptr,
mask_ptr,
B,
S,
V,
K, # batch size, seq len, vocab size, top-k
scale,
stride_gl_b,
stride_gl_s,
stride_gl_v,
stride_l_b,
stride_l_s,
stride_l_v,
stride_t_b,
stride_t_s,
stride_t_k,
stride_p_b,
stride_p_s,
stride_p_k,
stride_sp_b,
stride_sp_s,
stride_sp_k,
stride_m_b,
stride_m_s,
stride_m_k,
BLOCK_SIZE: tl.constexpr,
):
# Program ID
pid = tl.program_id(0)
batch_idx = pid // S
seq_idx = pid % S
# Bounds check
if batch_idx >= B or seq_idx >= S:
return
# Base pointers for this (batch, seq) pair
grad_logits_base = (
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
)
# logits_base = student_logits_ptr + batch_idx * stride_l_b + seq_idx * stride_l_s
token_ids_base = (
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
)
teacher_probs_base = (
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
)
student_probs_base = (
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
)
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
# Load all token IDs, probs and masks for this position
token_ids = tl.zeros([K], dtype=tl.int32)
teacher_probs = tl.zeros([K], dtype=tl.float32)
student_probs = tl.zeros([K], dtype=tl.float32)
masks = tl.zeros([K], dtype=tl.float32)
for k in range(K):
token_ids[k] = tl.load(token_ids_base + k * stride_t_k)
teacher_probs[k] = tl.load(teacher_probs_base + k * stride_p_k)
student_probs[k] = tl.load(student_probs_base + k * stride_sp_k)
masks[k] = tl.load(mask_base + k * stride_m_k)
# Process gradients for all tokens in this position
for k in range(K):
# token_id = token_ids[k]
mask_k = masks[k]
# Skip computation if mask is zero by multiplying gradient by mask
for j in range(K):
other_token_id = token_ids[j]
mask_j = masks[j]
combined_mask = mask_k * mask_j
# Compute gradient differently for diagonal vs off-diagonal entries
# Using * 1.0 to convert boolean to float
is_diagonal = tl.where(j == k, 1.0, 0.0)
# Self influence: gradient = teacher_prob * (1 - student_prob)
self_grad = teacher_probs[k] * (1.0 - student_probs[k]) * is_diagonal
# Cross influence: gradient = -teacher_prob[k] * student_prob[j]
cross_grad = -teacher_probs[k] * student_probs[j] * (1.0 - is_diagonal)
# Combined gradient scaled by mask
grad_val = (self_grad + cross_grad) * scale * combined_mask
tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val)
# Triton-accelerated implementation of KL divergence loss for top-k tokens
# Chunking helper functions for handling long sequences
def chunk_tensor(
tensor: torch.Tensor, max_seq_len: int
) -> Tuple[torch.Tensor, Optional[int]]:
"""Split a tensor along sequence dimension if needed."""
_, seq_len, *__ = tensor.shape
if seq_len <= max_seq_len:
return tensor, None
num_chunks = (seq_len + max_seq_len - 1) // max_seq_len
chunks = []
for i in range(num_chunks):
start_idx = i * max_seq_len
end_idx = min((i + 1) * max_seq_len, seq_len)
chunks.append(tensor[:, start_idx:end_idx, ...])
return chunks, num_chunks
def merge_chunks(chunks: list, original_shape: torch.Size):
"""Merge chunks back into a single tensor with original shape."""
return torch.cat(chunks, dim=1)
# Triton-accelerated implementation of KL divergence loss for top-k tokens
class TopKKLDivergence(torch.autograd.Function):
"""
Autograd function for KL divergence loss between top-k logprobs
with support for chunking to handle very long sequences.
"""
# Max sequence length to process in a single kernel launch
# This is a tunable parameter that might need adjustment based on GPU memory
MAX_SEQ_LEN = 8192
@staticmethod
def forward(
ctx,
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch=-1,
kd_temperature=1.0,
top_k_before_softmax=0,
):
"""
Forward pass for KL divergence loss between top-k logprobs with chunking.
"""
# Only convert target_logprobs to float, leave student_logits as is
target_logprobs = target_logprobs.float()
# Get dimensions
batch_size, _, vocab_size = student_logits.shape
_, teacher_seq_len, top_k = target_token_ids.shape
# Slice student logits to match teacher sequence length
student_logits_for_kd = student_logits[:, :teacher_seq_len, :]
# Store original values for backward pass
ctx.original_seq_len = teacher_seq_len
ctx.original_dtype = student_logits.dtype
# Apply chunking for long sequences
if teacher_seq_len > TopKKLDivergence.MAX_SEQ_LEN:
# Chunk the inputs
student_logits_chunks, num_chunks = chunk_tensor(
student_logits_for_kd, TopKKLDivergence.MAX_SEQ_LEN
)
target_token_ids_chunks, _ = chunk_tensor(
target_token_ids, TopKKLDivergence.MAX_SEQ_LEN
)
# target_logprobs_chunks, _ = chunk_tensor(
# target_logprobs, TopKKLDivergence.MAX_SEQ_LEN
# )
# target_mask_chunks, _ = chunk_tensor(
# target_mask, TopKKLDivergence.MAX_SEQ_LEN
# )
# Process each chunk
student_logprobs_chunks = []
student_probs_chunks = []
for i in range(num_chunks):
chunk_logits = student_logits_chunks[i]
chunk_token_ids = target_token_ids_chunks[i]
chunk_seq_len = chunk_logits.shape[1]
if top_k_before_softmax:
# Apply temperature to student logits
if kd_temperature != 1.0:
chunk_logits = chunk_logits / kd_temperature
# Gather student logits for top-k tokens
chunk_logits_topk = torch.gather(
chunk_logits, dim=-1, index=chunk_token_ids
)
# Compute softmax over gathered logits
chunk_logprobs_topk = torch.log_softmax(chunk_logits_topk, dim=-1)
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
else:
# Allocate output tensor for logprobs directly (always in float32)
chunk_logprobs_topk = torch.empty(
(batch_size, chunk_seq_len, top_k),
dtype=torch.float32,
device=chunk_logits.device,
)
# Launch fused kernel directly
grid = (batch_size * chunk_seq_len,)
fused_logsumexp_logprobs_kernel[grid](
chunk_logits.contiguous(),
chunk_logprobs_topk,
chunk_token_ids.contiguous(),
batch_size,
chunk_seq_len,
vocab_size,
top_k,
kd_temperature,
chunk_logits.stride(0),
chunk_logits.stride(1),
chunk_logits.stride(2),
chunk_logprobs_topk.stride(0),
chunk_logprobs_topk.stride(1),
chunk_logprobs_topk.stride(2),
chunk_token_ids.stride(0),
chunk_token_ids.stride(1),
chunk_token_ids.stride(2),
min(1024, triton.next_power_of_2(vocab_size)),
)
# Calculate probs from logprobs
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
# Store results
student_logprobs_chunks.append(chunk_logprobs_topk)
student_probs_chunks.append(chunk_probs_topk)
# Merge results
student_logprobs_topk = torch.cat(student_logprobs_chunks, dim=1)
student_probs_topk = torch.cat(student_probs_chunks, dim=1)
# Save chunking info for backward pass
ctx.used_chunking = True
ctx.num_chunks = num_chunks
else:
# Original code path for shorter sequences
if top_k_before_softmax:
# Apply temperature to student logits
if kd_temperature != 1.0:
student_logits_for_kd = student_logits_for_kd / kd_temperature
# Gather student logits for top-k tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
)
# Compute softmax over gathered logits
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
student_probs_topk = torch.exp(student_logprobs_topk)
else:
# Allocate output tensor for logprobs directly (always in float32)
student_logprobs_topk = torch.empty(
(batch_size, teacher_seq_len, top_k),
dtype=torch.float32,
device=student_logits.device,
)
# Launch fused kernel directly
grid = (batch_size * teacher_seq_len,)
fused_logsumexp_logprobs_kernel[grid](
student_logits_for_kd.contiguous(),
student_logprobs_topk,
target_token_ids.contiguous(),
batch_size,
teacher_seq_len,
vocab_size,
top_k,
kd_temperature,
student_logits_for_kd.stride(0),
student_logits_for_kd.stride(1),
student_logits_for_kd.stride(2),
student_logprobs_topk.stride(0),
student_logprobs_topk.stride(1),
student_logprobs_topk.stride(2),
target_token_ids.stride(0),
target_token_ids.stride(1),
target_token_ids.stride(2),
min(1024, triton.next_power_of_2(vocab_size)),
)
# Calculate probs from logprobs
student_probs_topk = torch.exp(student_logprobs_topk)
# No chunking used
ctx.used_chunking = False
# Save tensors for backward pass
ctx.save_for_backward(
student_logits_for_kd,
target_token_ids,
target_logprobs,
target_mask,
student_probs_topk,
)
ctx.kd_temperature = kd_temperature
ctx.top_k_before_softmax = top_k_before_softmax
ctx.num_items_in_batch = num_items_in_batch
# Convert mask to boolean
valid_mask = target_mask.bool()
# Extract valid tokens only - this is where the error was happening
# Use cloned contiguous tensors and explicit indexing for safety
student_logprobs_flat = student_logprobs_topk.view(-1, top_k)
target_logprobs_flat = target_logprobs.view(-1, top_k)
valid_mask_flat = valid_mask.view(-1, top_k)
# Gather valid indices explicitly to avoid illegal memory access
valid_indices = torch.nonzero(valid_mask_flat.view(-1)).squeeze(-1)
student_logprobs_valid = torch.index_select(
student_logprobs_flat.view(-1), 0, valid_indices
)
target_logprobs_valid = torch.index_select(
target_logprobs_flat.view(-1), 0, valid_indices
)
# Convert teacher logprobs to probabilities
teacher_probs_valid = torch.exp(target_logprobs_valid)
# Compute KL divergence loss
token_losses = teacher_probs_valid * (
target_logprobs_valid - student_logprobs_valid
)
kd_loss = token_losses.sum()
# Apply temperature scaling
# pylint: disable=duplicate-code
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items or valid tokens
if num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
num_valid_tokens = valid_indices.numel()
kd_loss = kd_loss / float(num_valid_tokens if num_valid_tokens > 0 else 1)
return kd_loss
@staticmethod
def backward(ctx, grad_output):
"""
Optimized backward pass for KL divergence loss with proper dtype handling and chunking.
"""
(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
student_probs,
) = ctx.saved_tensors
kd_temperature = ctx.kd_temperature
num_items_in_batch = ctx.num_items_in_batch
original_dtype = ctx.original_dtype
# Get dimensions
batch_size, _, vocab_size = student_logits.shape
_, teacher_seq_len, top_k = target_token_ids.shape
# Initialize gradient tensor in float32 to support atomic operations
grad_student_logits = torch.zeros_like(student_logits, dtype=torch.float32)
# Compute scaling factor
scale = grad_output.item()
# Apply temperature scaling from forward pass
if kd_temperature != 1.0:
scale = scale * (kd_temperature**2)
# Normalize by number of items or valid tokens
if num_items_in_batch > 0:
scale = scale / float(num_items_in_batch)
else:
scale = scale / float(target_mask.sum().item())
# Apply chain rule for temperature scaling (1/temperature)
if kd_temperature != 1.0:
scale = scale / kd_temperature
# Convert teacher logprobs to probabilities
teacher_probs = torch.exp(target_logprobs)
# Use chunking for the backward pass if used in forward
if getattr(ctx, "used_chunking", False):
num_chunks = ctx.num_chunks
max_seq = TopKKLDivergence.MAX_SEQ_LEN
# Process each chunk
for i in range(num_chunks):
start_idx = i * max_seq
end_idx = min((i + 1) * max_seq, teacher_seq_len)
chunk_len = end_idx - start_idx
# Get chunk slices
# student_logits_chunk = student_logits[:, start_idx:end_idx, :]
target_token_ids_chunk = target_token_ids[:, start_idx:end_idx, :]
teacher_probs_chunk = teacher_probs[:, start_idx:end_idx, :]
student_probs_chunk = student_probs[:, start_idx:end_idx, :]
target_mask_chunk = target_mask[:, start_idx:end_idx, :]
grad_student_logits_chunk = grad_student_logits[:, start_idx:end_idx, :]
# Launch gradient computation kernel for this chunk
grid = (batch_size * chunk_len,)
grad_softmax_kernel[grid](
grad_student_logits_chunk.contiguous(),
target_token_ids_chunk.contiguous(),
teacher_probs_chunk.contiguous(),
student_probs_chunk.contiguous(),
target_mask_chunk.contiguous(),
batch_size,
chunk_len,
vocab_size,
top_k,
scale,
grad_student_logits_chunk.stride(0),
grad_student_logits_chunk.stride(1),
grad_student_logits_chunk.stride(2),
target_token_ids_chunk.stride(0),
target_token_ids_chunk.stride(1),
target_token_ids_chunk.stride(2),
teacher_probs_chunk.stride(0),
teacher_probs_chunk.stride(1),
teacher_probs_chunk.stride(2),
student_probs_chunk.stride(0),
student_probs_chunk.stride(1),
student_probs_chunk.stride(2),
target_mask_chunk.stride(0),
target_mask_chunk.stride(1),
target_mask_chunk.stride(2),
min(1024, triton.next_power_of_2(top_k)),
)
# Update the gradient tensor (already in-place)
else:
# Original code path for shorter sequences
# Launch gradient computation kernel
grid = (batch_size * teacher_seq_len,)
grad_softmax_kernel[grid](
grad_student_logits.contiguous(),
target_token_ids.contiguous(),
teacher_probs.contiguous(),
student_probs.contiguous(),
target_mask.contiguous(),
batch_size,
teacher_seq_len,
vocab_size,
top_k,
scale,
grad_student_logits.stride(0),
grad_student_logits.stride(1),
grad_student_logits.stride(2),
target_token_ids.stride(0),
target_token_ids.stride(1),
target_token_ids.stride(2),
teacher_probs.stride(0),
teacher_probs.stride(1),
teacher_probs.stride(2),
student_probs.stride(0),
student_probs.stride(1),
student_probs.stride(2),
target_mask.stride(0),
target_mask.stride(1),
target_mask.stride(2),
min(1024, triton.next_power_of_2(top_k)),
)
# Convert gradient back to original dtype if needed
if original_dtype != torch.float32:
grad_student_logits = grad_student_logits.to(original_dtype)
# Return gradients for student_logits and None for other inputs
return grad_student_logits, None, None, None, None, None, None
# Wrapper function for chunked computation
def loss(
student_logits: torch.Tensor,
target_token_ids: torch.Tensor,
target_logprobs: torch.Tensor,
target_mask: torch.Tensor,
num_items_in_batch: int = -1,
kd_temperature: float = 1.0,
top_k_before_softmax: int = 0,
max_seq_len: Optional[int] = None,
):
"""
Triton-accelerated Memory-efficient KL divergence loss computation for knowledge distillation
with support for very long sequences.
Args:
student_logits: Student logits [B, seq_len, vocab_size]
target_token_ids: Teacher token IDs [B, seq_len, top_k]
target_logprobs: Teacher logprobs [B, seq_len, top_k]
target_mask: Token mask [B, seq_len, top_k]
num_items_in_batch: Number of items for normalization (-1 for auto)
kd_temperature: Temperature for KD
top_k_before_softmax: Flag for softmax application order
max_seq_len: Override default MAX_SEQ_LEN value for chunking
"""
# Allow overriding the max sequence length
if max_seq_len is not None and max_seq_len > 0:
TopKKLDivergence.MAX_SEQ_LEN = max_seq_len
total_loss = TopKKLDivergence.apply(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
-1 if num_items_in_batch <= 0 else num_items_in_batch,
kd_temperature,
top_k_before_softmax,
)
return total_loss

View File

@@ -0,0 +1,67 @@
"""
Optimized Triton kernels for logsumexp
"""
# pylint: disable=invalid-name,unused-argument
import triton
import triton.language as tl
# Helper function for computing logsumexp
@triton.jit
def logsumexp_kernel(
logits_ptr,
output_ptr,
B,
S,
V, # batch size, seq len, vocab size
stride_b,
stride_s,
stride_v,
out_stride_b,
out_stride_s,
BLOCK_SIZE: tl.constexpr,
):
# Program ID
# pylint: disable=duplicate-code
pid = tl.program_id(0)
batch_idx = pid // S
seq_idx = pid % S
# Bounds check
if batch_idx >= B or seq_idx >= S:
return
# Pointers
logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s
# Find maximum for numerical stability
max_val = -float("inf")
for v_offset in range(0, V, BLOCK_SIZE):
v_size = min(BLOCK_SIZE, V - v_offset)
mask = tl.arange(0, BLOCK_SIZE) < v_size
logits_block = tl.load(
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
mask=mask,
other=-float("inf"),
)
max_val = tl.maximum(max_val, tl.max(logits_block, axis=0))
# Compute sum of exp(logit - max_val)
sum_exp = 0.0
for v_offset in range(0, V, BLOCK_SIZE):
v_size = min(BLOCK_SIZE, V - v_offset)
mask = tl.arange(0, BLOCK_SIZE) < v_size
logits_block = tl.load(
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
mask=mask,
other=-float("inf"),
)
sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0)
# Compute logsumexp
result = max_val + tl.log(sum_exp)
# Store result
tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result)

View File

@@ -16,22 +16,11 @@
KD trainer
"""
from transformers import TrainerControl
from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainerControl(TrainerControl):
kd_alpha: float = 1.0
kd_ce_alpha: float = 0.0
def state(self) -> dict:
state_val = super().state()
state_val["args"]["kd_alpha"] = self.kd_alpha
state_val["args"]["kd_ce_alpha"] = self.kd_ce_alpha
from .topk_logprob.forward_kl_triton import loss as topk_kd_loss_triton
class AxolotlKDTrainer(AxolotlTrainer):
@@ -39,12 +28,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
Custom trainer subclass for Knowledge Distillation (KD)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.kd_alpha = self.args.kd_alpha
self.kd_ce_alpha = self.args.kd_ce_alpha
self.control = AxolotlKDTrainerControl()
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
@@ -103,7 +86,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
num_items_in_batch=num_items_in_batch,
)
else:
loss_kd = topk_kd_loss(
loss_fn = (
topk_kd_loss
if self.args.kd_top_k_before_softmax
else topk_kd_loss_triton
)
loss_kd = loss_fn(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
@@ -113,8 +101,9 @@ class AxolotlKDTrainer(AxolotlTrainer):
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
)
if self.kd_ce_alpha > 0:
loss = self.kd_ce_alpha * outputs["loss"] + self.kd_alpha * loss_kd
if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists

View File

@@ -1,36 +0,0 @@
# Liger Kernel Integration
Liger Kernel provides efficient Triton kernels for LLM training, offering:
- 20% increase in multi-GPU training throughput
- 60% reduction in memory usage
- Compatibility with both FSDP and DeepSpeed
See https://github.com/linkedin/Liger-Kernel
## Usage
```yaml
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```
## Citation
```bib
@article{hsu2024ligerkernelefficienttriton,
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
year={2024},
eprint={2410.10989},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.10989},
journal={arXiv preprint arXiv:2410.10989},
}
```

View File

@@ -1,10 +1,6 @@
# LM Eval Harness
Run evaluation on model using the popular lm-evaluation-harness library.
See https://github.com/EleutherAI/lm-evaluation-harness
## Usage
### Usage
```yaml
plugins:
@@ -14,22 +10,4 @@ lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
## Citation
```bib
@misc{eval-harness,
author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},
title = {A framework for few-shot language model evaluation},
month = 07,
year = 2024,
publisher = {Zenodo},
version = {v0.4.3},
doi = {10.5281/zenodo.12608602},
url = {https://zenodo.org/records/12608602}
}
```

View File

@@ -1,17 +1,15 @@
# Spectrum: Targeted Training on Signal to Noise Ratio
## Spectrum: Targeted Training on Signal to Noise Ratio
by Eric Hartford, Lucas Atkins, Fernando Fernandes, David Golchinfar
This plugin contains code to freeze the bottom fraction of modules in a model, based on the Signal-to-Noise Ratio (SNR).
See https://github.com/cognitivecomputations/spectrum
## Overview
### Overview
Spectrum is a tool for scanning and evaluating the Signal-to-Noise Ratio (SNR) of layers in large language models.
By identifying the top n% of layers with the highest SNR, you can optimize training efficiency.
## Usage
### Usage
```yaml
plugins:
@@ -21,17 +19,3 @@ spectrum_top_fraction: 0.5
# Optional if using a pre-scanned model as your base_model. Useful if using a model mirror
spectrum_model_name: meta-llama/Meta-Llama-3.1-8B
```
## Citation
```bib
@misc{hartford2024spectrumtargetedtrainingsignal,
title={Spectrum: Targeted Training on Signal to Noise Ratio},
author={Eric Hartford and Lucas Atkins and Fernando Fernandes Neto and David Golchinfar},
year={2024},
eprint={2406.06623},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2406.06623},
}
```

View File

@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
"""
from typing import Optional
from pydantic import BaseModel, model_validator
from pydantic import BaseModel
class SpectrumArgs(BaseModel):
@@ -27,20 +27,3 @@ class SpectrumArgs(BaseModel):
spectrum_top_fraction: Optional[float] = 0.5
spectrum_model_name: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_fsdp_use_orig_params(cls, data):
if (
data.get("fsdp")
and data.get("fsdp_config")
and not data["fsdp_config"].get("use_orig_params")
and data.get("plugins")
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
):
# would otherwise raise
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
raise ValueError(
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
)
return data

View File

@@ -25,7 +25,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"gemmoe",
"starcoder2",
"deepseek_v2",
"deepseek_v3",
]

View File

@@ -1,29 +1,26 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import inspect
import os
import signal
import sys
import weakref
from pathlib import Path
from typing import Any, Dict
from typing import Tuple, Union
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftConfig, PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
@@ -35,25 +32,17 @@ try:
except ImportError:
BetterTransformer = None
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = get_logger(__name__)
def setup_model_and_tokenizer(
cfg: DictDefault,
) -> tuple[
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
]:
"""
Load the tokenizer, processor (for multimodal models), and model based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else
`None`), and processor (if multimodal, else `None`).
"""
def train(
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
@@ -66,58 +55,11 @@ def setup_model_and_tokenizer(
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
# Load the model and peft_config
msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
# Get datasets
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
# Apply freezing if specified
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
return model, tokenizer, peft_config, processor
def setup_reference_model(
cfg: DictDefault, tokenizer: PreTrainedTokenizer
) -> PreTrainedModel | None:
"""
Set up the reference model for RL training if needed.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: The tokenizer to use for the reference model.
Returns:
Reference model if needed for RL training, `None` otherwise.
"""
model_ref = None
if cfg.rl and cfg.rl != "orpo":
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
return model_ref
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
"""
Determine the checkpoint to resume from based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Path to the checkpoint to resume from, or `None` if not resuming.
"""
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
@@ -131,22 +73,77 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
return cfg.resume_from_checkpoint
resume_from_checkpoint = cfg.resume_from_checkpoint
# Load the model and tokenizer
msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
def setup_signal_handler(
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
):
"""
Set up signal handler for graceful termination.
model_ref = None
if cfg.rl and cfg.rl != "orpo":
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to save on termination
safe_serialization: Whether to use safe serialization when saving
"""
# ray workers don't have access to this signal
if cfg.local_rank == 0 and not cfg.use_ray:
safe_serialization = cfg.save_safetensors is True
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
trainer = setup_trainer(
cfg,
train_dataset,
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
processor,
total_num_steps,
)
if cfg.fix_untrained_tokens:
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# additionally presave the tokenizer and model configs
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
if hasattr(model, "config"):
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if (
cfg.local_rank == 0 and not cfg.use_ray
): # ray workers don't have access to this signal
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
@@ -164,22 +161,21 @@ def setup_signal_handler(
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
)
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
def execute_training(
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
):
"""
Execute the training process with appropriate backend configurations.
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = get_distribution("axolotl").version
if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
@@ -191,30 +187,15 @@ def execute_training(
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
post_train_hooks(cfg, trainer)
def save_trained_model(
cfg: DictDefault,
trainer: Any,
model: PreTrainedModel,
safe_serialization: bool,
):
"""
Save the trained model according to configuration and training setup.
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object.
model: The trained model to save.
safe_serialization: Whether to use safe serialization.
"""
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
# Post training module hooks
# post training
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
# Handle FSDP state dict type
state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled:
if cfg.fsdp_final_state_dict_type:
@@ -222,18 +203,16 @@ def save_trained_model(
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
# Handle ReLoRA early return case
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
return model, tokenizer
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
# TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple
# processes attempt to write the same file
if (
state_dict_type == "SHARDED_STATE_DICT"
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
@@ -265,6 +244,7 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
@@ -275,241 +255,58 @@ def save_trained_model(
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def create_model_card(cfg: DictDefault, trainer: Trainer):
"""
Create a model card for the trained model if needed.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
trainer: The trainer object with model card creation capabilities.
"""
if not cfg.hub_model_id:
# Guard since create_model_card may fail if dataset_tags is empty list
try:
model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./")
.encode("utf-8")
.decode("utf-8")
}
# We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.
rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model
if cfg.datasets is not None and not rl:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
if dataset_tags:
model_card_kwarg["dataset_tags"] = dataset_tags
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [
d for d in dataset_tags if not d.startswith("https://")
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_name"] = dataset_tags
else:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [
d for d in dataset_tags if not d.startswith("https://")
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError):
pass
elif cfg.hub_model_id:
# Defensively push to the hub to ensure the model card is updated
# defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()
return model, tokenizer
def save_initial_configs(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
peft_config: PeftConfig | None,
):
def pretrain_hooks(_cfg, _trainer):
"""
Save initial configurations before training.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: The tokenizer to save.
model: The model to save configuration for.
peft_config: The PEFT configuration to save if applicable.
Run hooks right before kicking off the training
:param cfg:
:param trainer:
:return:
"""
# Create output_dir if it doesn't already exist
output_dir = Path(cfg.output_dir)
if not output_dir.is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
# Pre-save adapter config so it's available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...")
peft_config.save_pretrained(cfg.output_dir)
# Pre-save the tokenizer and model configs
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
tokenizer.save_pretrained(str(output_dir))
if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
def setup_model_card(cfg: DictDefault):
def post_train_hooks(_cfg, _trainer):
"""
Set up the Axolotl badge and add the Axolotl config to the model card if available.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Run hooks right after training completes
:param cfg:
:param trainer:
:return:
"""
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = importlib.metadata.version("axolotl")
if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
def handle_untrained_tokens_fix(
cfg: DictDefault,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
train_dataset: Dataset,
safe_serialization: bool,
):
"""
Apply fixes for untrained tokens if configured.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The model to apply fixes to.
tokenizer: The tokenizer for token identification.
train_dataset: The training dataset to use.
safe_serialization: Whether to use safe serialization when saving.
"""
if not cfg.fix_untrained_tokens:
return
is_ds_zero3: bool = False
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
is_ds_zero3 = True
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
fix_kwargs: Dict[str, Any] = {}
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
if "is_ds_zero3" in sig.parameters:
fix_kwargs["is_ds_zero3"] = is_ds_zero3
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
def setup_model_and_trainer(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
]:
"""
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
trainer setup.
Args:
cfg: The configuration dictionary with training parameters.
dataset_meta: Object with training, validation datasets and metadata.
Returns:
Tuple of:
- Trainer (Causal or RLHF)
- Model
- Tokenizer
- PEFT config
"""
# Load tokenizer, processor and model
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
# Set up reference model for RL if needed
model_ref = setup_reference_model(cfg, tokenizer)
# Get datasets from metadata
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Set up trainer
trainer = setup_trainer(
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=model,
tokenizer=tokenizer,
processor=processor,
total_num_steps=total_num_steps,
model_ref=model_ref,
peft_config=peft_config,
)
return (
trainer,
model,
tokenizer,
peft_config,
)
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
"""
Train a model on the given dataset.
Args:
cfg: The configuration dictionary with training parameters
dataset_meta: Object with training, validation datasets and metadata
Returns:
Tuple of (model, tokenizer) after training
"""
# Setup model, tokenizer, (causal or RLHF) trainer etc.
(
trainer,
model,
tokenizer,
peft_config,
) = setup_model_and_trainer(cfg, dataset_meta)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Handle untrained tokens if configured
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
# Save initial configs
save_initial_configs(cfg, tokenizer, model, peft_config)
# Set up signal handler for graceful termination
setup_signal_handler(cfg, model, safe_serialization)
# Set up badges and config info for model card
setup_model_card(cfg)
# Execute the training
execute_training(cfg, trainer, resume_from_checkpoint)
# Save the trained model
save_trained_model(cfg, trainer, model, safe_serialization)
# Create model card
create_model_card(cfg, trainer)
return model, tokenizer, trainer

View File

@@ -813,15 +813,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
# TODO if using deepspeed and it's a file, save deepspeed config too
if args.deepspeed and os.path.isfile(args.deepspeed):
LOG.info(f"DeepSpeed config has been saved to the WandB run.")
artifact = wandb.Artifact(
f"deepspeed-{wandb.run.id}", type="deepspeed-config"
)
artifact.add_file(args.deepspeed)
wandb.log_artifact(artifact)
wandb.save(args.deepspeed)
return control

File diff suppressed because one or more lines are too long

View File

@@ -173,16 +173,10 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
]
out_features[i][feature] = np.concatenate(arrays)
else:
try:
arrays = [
np.array(item[feature])
for item in features_
if feature in item
]
if arrays[0].dtype != "object":
out_features[i][feature] = np.concatenate(arrays)
except ValueError:
pass
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)

View File

@@ -55,7 +55,6 @@ class ChatTemplate(str, Enum):
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
@@ -64,17 +63,6 @@ class ChatTemplate(str, Enum):
metharme = "metharme" # pylint: disable=invalid-name
class CustomSupportedOptimizers(str, Enum):
"""Custom supported optimizers"""
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
@@ -505,7 +493,17 @@ class HyperparametersConfig(BaseModel):
embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[OptimizerNames, CustomSupportedOptimizers]
Union[
OptimizerNames,
Literal[
"lion_pytorch",
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
],
]
] = OptimizerNames.ADAMW_HF
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None,
@@ -519,7 +517,7 @@ class HyperparametersConfig(BaseModel):
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
Union[SchedulerType, Literal["one_cycle"]]
] = SchedulerType.COSINE
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
@@ -728,7 +726,7 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var]
dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None
@@ -779,9 +777,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[
Union[Literal["unsloth", "offload"], bool]
] = Field(default=False)
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
default=False
)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None
@@ -856,7 +854,6 @@ class AxolotlInputConfig(
special_tokens: Optional[SpecialTokensConfig] = None
tokens: Optional[List[str]] = None
added_tokens_overrides: Optional[Dict[int, str]] = None
torch_compile: Optional[Union[Literal["auto"], bool]] = None
torch_compile_backend: Optional[str] = None
@@ -1155,15 +1152,6 @@ class AxolotlInputConfig(
raise ValueError("gradient_checkpointing is not supported for MPT models")
return self
@model_validator(mode="after")
def check_offload_grad_checkpointing(self):
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
LOG.warning(
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
)
self.gradient_checkpointing = "offload"
return self
@model_validator(mode="after")
def check_better_transformers(self):
if self.flash_optimum is True:
@@ -1188,13 +1176,6 @@ class AxolotlInputConfig(
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
@model_validator(mode="before")
@classmethod
def check_lr_groups(cls, data):
if data.get("lr_groups") and data.get("loraplus_lr_ratio"):
raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.")
return data
@model_validator(mode="before")
@classmethod
def check_saves(cls, data):

View File

@@ -1,8 +1,7 @@
"""
GRPO specific configuration args
"""
from typing import Optional
from typing import List, Optional
from pydantic import BaseModel, Field
@@ -12,10 +11,7 @@ class TRLConfig(BaseModel):
Input args for TRL.
"""
beta: Optional[float] = Field(
default=None,
json_schema_extra={"description": "Beta for RL training"},
)
beta: Optional[float] = None
max_completion_length: Optional[int] = Field(
default=None,
json_schema_extra={
@@ -24,68 +20,16 @@ class TRLConfig(BaseModel):
)
# GRPO specific args
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
use_vllm: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"},
)
vllm_device: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
vllm_gpu_memory_utilization: Optional[float] = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
vllm_dtype: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
use_vllm: Optional[bool] = False
vllm_device: Optional[str] = "auto"
vllm_gpu_memory_utilization: Optional[float] = 0.9
vllm_max_model_len: Optional[int] = None
vllm_dtype: Optional[str] = "auto"
reward_funcs: Optional[list[str]] = Field(
default=None,
json_schema_extra={"description": "List of reward functions to load"},
)
reward_weights: Optional[list[float]] = Field(
default=None,
json_schema_extra={
"description": "Weights for each reward function. Must match the number of reward functions."
},
)
num_generations: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
},
)
log_completions: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to log completions"},
)
sync_ref_model: Optional[bool] = Field(
default=False,
json_schema_extra={
"description": (
"Whether to sync the reference model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
)
},
)
ref_model_mixup_alpha: Optional[float] = Field(
default=0.9,
json_schema_extra={
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
},
)
ref_model_sync_steps: Optional[int] = Field(
default=64,
json_schema_extra={
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
)
reward_funcs: Optional[List[str]] = None
num_generations: Optional[int] = None
log_completions: Optional[bool] = False
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64

View File

@@ -79,7 +79,7 @@ def is_main_process():
def is_local_main_process():
return PartialState().is_local_main_process
return PartialState().is_main_process
def get_world_size():

View File

@@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import (
)
def hf_grad_checkpoint_offload_wrapper(
def hf_grad_checkpoint_unsloth_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(

View File

@@ -24,6 +24,7 @@ from peft import (
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from peft.tuners.lora import QuantLinear
from torch import nn
from transformers import ( # noqa: F401
AddedToken,
@@ -56,14 +57,8 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
barrier,
get_device_count,
get_device_type,
is_local_main_process,
zero_only,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -170,95 +165,7 @@ def load_model_config(cfg):
return model_config
def modify_tokenizer_files(
tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
) -> str:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer.
This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab.
Args:
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
Returns:
Path to the modified tokenizer directory
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
"""
import json
# Create the tokenizer directory in output_dir if it doesn't exist
tokenizer_dir = os.path.join(output_dir, "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
# Get the token IDs and map them to their new values
token_id_mappings = {
int(token_id): new_value for token_id, new_value in token_mappings.items()
}
# 1. Update tokenizer_config.json - added_tokens_decoder
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config_data = json.load(f)
# Update added_tokens_decoder
if "added_tokens_decoder" in config_data:
for token_id, new_value in token_id_mappings.items():
token_id_str = str(token_id)
if token_id_str in config_data["added_tokens_decoder"]:
config_data["added_tokens_decoder"][token_id_str][
"content"
] = new_value
else:
raise ValueError(
f"Token ID {token_id_str} not found in added_tokens_decoder"
)
# Write the updated config back
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
# 2. Update tokenizer.json - added_tokens
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
if os.path.exists(tokenizer_path):
with open(tokenizer_path, "r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
# Update added_tokens
if "added_tokens" in tokenizer_data:
for token_id, new_value in token_id_mappings.items():
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
if token_entry["id"] == token_id:
tokenizer_data["added_tokens"][i]["content"] = new_value
break
else:
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
raise ValueError(
f"Token ID {token_id} not found in added_tokens"
)
# Write the updated tokenizer data back
with open(tokenizer_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_data, f, indent=2)
barrier()
return tokenizer_dir
def load_tokenizer(cfg):
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
use_fast = True # this is the default
@@ -273,18 +180,8 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
# Set base tokenizer path
tokenizer_path = cfg.tokenizer_config
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
)
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_path,
cfg.tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
@@ -492,8 +389,8 @@ class ModelLoader:
patch_fa_peft_integration()
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention:
self.patch_attention()
@@ -1359,7 +1256,7 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set()
for name, module in model.named_modules():
if (

View File

@@ -6,80 +6,6 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
class RexLR(LRScheduler):
"""
Reflected Exponential (REX) learning rate scheduler.
- Original implementation: https://github.com/IvanVassi/REX_LR
- Original license: Apache 2.0
- Based on: https://arxiv.org/abs/2107.04197
Args:
optimizer (torch.optim.Optimizer): The optimizer to schedule the learning rate for.
max_lr (float): The maximum learning rate.
min_lr (float): The minimum learning rate.
total_steps (int): The total number of training steps.
num_warmup_steps (int): The number of warmup steps.
last_step (int): The index of last step.
"""
def __init__(
self, optimizer, max_lr, min_lr, total_steps=0, num_warmup_steps=0, last_step=0
):
if min_lr > max_lr:
raise ValueError(
f'Value of "min_lr" should be less than value of "max_lr". Got min_lr={min_lr} and max_lr={max_lr}'
)
if num_warmup_steps > total_steps:
raise ValueError(
f"num_warmup_steps ({num_warmup_steps}) must be less than or equal to total_steps ({total_steps})."
)
self.min_lr = min_lr
self.max_lr = max_lr
self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps
self.last_step = last_step - 1
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
# Pass self.last_step as last_epoch to the parent.
super().__init__(optimizer, last_epoch=self.last_step)
@property
def last_step(self):
return self.last_epoch
@last_step.setter
def last_step(self, value):
self.last_epoch = value
def get_lr(self):
# Warmup phase: if defined, increase lr linearly from 0 to max_lr.
if 1 <= self.last_step <= self.num_warmup_steps:
return [
base_lr * self.last_step / self.num_warmup_steps
for base_lr in self.base_lrs
]
# Post-warmup phase: adjust step relative to the end of warmup.
step_after = self.last_step - self.num_warmup_steps
remaining_steps = self.total_steps - self.num_warmup_steps
# Avoid LR spiking
if step_after >= remaining_steps or step_after == -1 or remaining_steps <= 0:
return [self.min_lr for _ in self.base_lrs]
mod_iter = step_after % remaining_steps
z = (remaining_steps - mod_iter) / remaining_steps
rex_factor = self.min_lr / self.max_lr + (1.0 - self.min_lr / self.max_lr) * (
z / (0.1 + 0.9 * z)
)
return [base_lr * rex_factor for base_lr in self.base_lrs]
class InterpolatingLogScheduler(LRScheduler):
"""
A scheduler that interpolates learning rates in a logarithmic fashion

View File

@@ -574,40 +574,14 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg,
train_dataset,
eval_dataset,
model,
tokenizer,
processor,
total_num_steps,
model_ref=None,
peft_config=None,
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
"""
Helper method for instantiating and building a (causal or RLHF) trainer.
Args:
cfg: Axolotl config object containing training parameters.
train_dataset: Dataset to use for training.
eval_dataset: Dataset to use for evaluation.
model: The model to train.
tokenizer: Tokenizer for processing text input.
processor: Processor for data preparation.
total_num_steps: The total number of training steps.
model_ref: Optional reference model for RLHF training. Default is None.
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
Returns:
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model_ref
trainer_builder.peft_config = peft_config
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]
else:
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset

View File

@@ -1,193 +1,5 @@
/* TYPOGRAPHY SECTION */
/* css styles */
/* Import fonts */
@import url('https://fonts.googleapis.com/css2?family=Be+Vietnam+Pro:wght@400;500&display=swap');
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400&display=swap');
/* Typography hierarchy */
:root {
--font-title: 'Be Vietnam Pro', sans-serif;
--font-body: 'JetBrains Mono', monospace;
}
/* Title (h1) */
h1 {
font-family: var(--font-title);
font-weight: 400;
font-size: 5rem;
line-height: 1.1;
letter-spacing: -0.05em;
font-feature-settings: "ss01" on;
}
/* Heading (h2) */
h2 {
font-family: var(--font-title);
font-weight: 500;
font-size: 2rem;
line-height: 1.2;
letter-spacing: -0.03em;
font-feature-settings: "ss01" on;
}
/* Subtitle/Preamble */
h3,
h4 {
font-family: var(--font-body);
font-weight: 400;
font-size: 1.5rem;
line-height: 1.5;
letter-spacing: -0.02em;
}
/* Body text */
body {
font-family: var(--font-body);
font-weight: 400;
font-size: 1rem;
line-height: 1.5;
letter-spacing: -0.02em;
}
/* Links */
a {
font-family: var(--font-body);
font-weight: 400;
font-size: 0.875rem;
line-height: 1;
letter-spacing: -0.02em;
}
/* NAV BAR SECTION */
/* Navbar logo styling */
.navbar-brand img {
height: 32px;
margin-right: 10px;
}
/* COLORS SECTION */
/* Brand colors */
:root {
--white: #ffffff;
--greige-300: #EEEEE7;
--greige-600: #CCCAC0;
--black: #141310;
--lime: #E3F8A8;
--cyan: #A0F4EA;
--purple: #C8D0F8;
}
/* Base styles */
body {
background-color: var(--black);
color: var(--greige-300);
}
/* Navigation */
.navbar {
background-color: var(--black) !important;
}
.navbar-dark .navbar-nav .nav-link {
color: var(--greige-300);
}
.navbar-dark .navbar-nav .nav-link:hover {
color: var(--lime);
}
/* Sidebar */
.sidebar-navigation {
background-color: var(--black);
border-right: 1px solid var(--greige-600);
}
.sidebar nav[role="doc-toc"] ul>li>a {
color: var(--greige-300);
}
.sidebar nav[role="doc-toc"] ul>li>a:hover {
color: var(--lime);
}
/* Links */
a {
color: var(--lime);
}
a:hover {
color: var(--cyan);
}
/* Headers */
h1,
h2,
h3,
h4,
h5,
h6 {
color: var(--white);
}
/* Code blocks */
pre {
background-color: #1a1a1a !important;
border: 1px solid var(--greige-600);
}
/* Tables */
.table {
color: var(--greige-300);
}
/* TOC */
#toc-title {
color: var(--white);
}
.toc-active {
color: var(--lime) !important;
}
/* Buttons */
.btn-primary {
background-color: var(--lime);
color: var(--black);
border: none;
}
.btn-primary:hover {
background-color: var(--cyan);
color: var(--black);
}
/* For inline code (single backtick) */
code {
background-color: #1a1a1a !important;
color: var(--lime) !important;
padding: 2px 4px;
border-radius: 4px;
}
/* For inline code that is also a link */
a code {
color: var(--cyan) !important;
}
/* For code blocks (triple backtick) */
pre.sourceCode {
background-color: #1a1a1a !important;
}
/* Make comments in bash/shell scripts green */
code span.co {
color: #5cb85c !important;
}
/* Remove underlines from JSON comments and make them green */
code span.er {
color: #5cb85c !important;
text-decoration: none !important;
img[alt="Axolotl"] {
content: url("https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg") !important;
}

View File

@@ -28,7 +28,7 @@ class TestTrainCommand(BaseCliTest):
config_path.write_text(valid_test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,
@@ -48,7 +48,7 @@ class TestTrainCommand(BaseCliTest):
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
mock_train.return_value = (MagicMock(), MagicMock())
result = cli_runner.invoke(
cli,

View File

@@ -25,8 +25,8 @@ def fixture_cfg():
"optimizer": "adamw_torch_fused",
"sequence_len": 2048,
"rl": True,
"adam_beta1": 0.91,
"adam_beta2": 0.998,
"adam_beta1": 0.998,
"adam_beta2": 0.9,
"adam_epsilon": 0.00001,
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
@@ -60,8 +60,8 @@ class TestHFRLTrainerBuilder:
def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.91
assert training_arguments.adam_beta2 == 0.998
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True

View File

@@ -69,51 +69,6 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# pylint: disable=redefined-outer-name
def test_qwen2_w_cce(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"plugins": [
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
],
"cut_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
}
)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize(
"attention_type",
[

View File

@@ -90,6 +90,12 @@ class TestKnowledgeDistillation:
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
)
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
)
@pytest.mark.parametrize(
"load_in_8bit",
@@ -121,3 +127,9 @@ class TestKnowledgeDistillation:
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
)
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
)

View File

@@ -0,0 +1,163 @@
"""
sanity checks on kl loss and gradients
"""
import torch
# Import both implementations
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
def test_kl_loss_gradient():
"""Test that the gradient of the Triton implementation matches the eager implementation."""
# Set the random seed for reproducibility
torch.manual_seed(42)
# Create random inputs
batch_size = 2
seq_len = 3
vocab_size = 100
top_k = 5
# Generate random student logits
student_logits = torch.randn(
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
)
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
# Generate random target token IDs, ensuring they're valid indices
# pylint: disable=duplicate-code
target_token_ids = torch.randint(
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
)
# Generate random target logprobs (before normalization)
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
# Normalize the target logprobs to ensure they form a valid distribution
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
# Create a random mask with some tokens masked out
target_mask = torch.randint(
0, 2, (batch_size, seq_len, top_k), device="cuda"
).float()
# Additional parameters
num_items_in_batch = batch_size * seq_len
kd_temperature = 1.0
top_k_before_softmax = 0 # Test both modes
# Compute the loss and gradients with eager implementation
loss_eager = eager_loss(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch,
kd_temperature,
top_k_before_softmax,
)
loss_eager.backward()
grad_eager = student_logits.grad.clone()
# Reset gradients
student_logits.grad.zero_()
# Compute the loss and gradients with Triton implementation
loss_triton = triton_loss(
student_logits_triton,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch,
kd_temperature,
top_k_before_softmax,
)
loss_triton.backward()
grad_triton = student_logits_triton.grad.clone()
# Compare loss values
print(f"Eager loss: {loss_eager.item()}")
print(f"Triton loss: {loss_triton.item()}")
loss_diff = abs(loss_eager.item() - loss_triton.item())
print(f"Loss difference: {loss_diff}")
assert loss_diff < 1e-5, "Loss values differ significantly!"
# Compare gradients
grad_diff = (grad_eager - grad_triton).abs().max().item()
print(f"Max gradient difference: {grad_diff}")
# Print some sample gradients
sample_idx = (0, 0, 0) # (batch, seq, vocab)
print(f"Sample eager gradient: {grad_eager[sample_idx].item()}")
print(f"Sample triton gradient: {grad_triton[sample_idx].item()}")
# Compute relative difference for non-zero gradients
mask = grad_eager.abs() > 1e-10
if mask.sum() > 0:
rel_diff = (
(
(grad_eager[mask] - grad_triton[mask]).abs()
/ (grad_eager[mask].abs() + 1e-10)
)
.max()
.item()
)
print(f"Max relative gradient difference: {rel_diff}")
assert rel_diff < 1e-3, "Gradients differ significantly!"
# Also test top_k_before_softmax = 1 mode
top_k_before_softmax = 1
# Reset the gradients
student_logits = torch.randn(
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
)
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
# Compute the loss and gradients with eager implementation
loss_eager = eager_loss(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch,
kd_temperature,
top_k_before_softmax,
)
loss_eager.backward()
grad_eager = student_logits.grad.clone()
# Compute the loss and gradients with Triton implementation
loss_triton = triton_loss(
student_logits_triton,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch,
kd_temperature,
top_k_before_softmax,
)
loss_triton.backward()
grad_triton = student_logits_triton.grad.clone()
# Compare gradients for top_k_before_softmax = 1
grad_diff = (grad_eager - grad_triton).abs().max().item()
print("\nWith top_k_before_softmax=1:")
print(f"Max gradient difference: {grad_diff}")
# Compute relative difference for non-zero gradients
mask = grad_eager.abs() > 1e-10
if mask.sum() > 0:
rel_diff = (
(
(grad_eager[mask] - grad_triton[mask]).abs()
/ (grad_eager[mask].abs() + 1e-10)
)
.max()
.item()
)
assert (
rel_diff < 1e-3
), f"Gradients differ significantly, Max relative gradient difference: {rel_diff}"

View File

@@ -0,0 +1,204 @@
"""
sanity checks on logsumexp kernel validity
"""
import torch
import triton
from axolotl.integrations.kd.topk_logprob.logsumexp import logsumexp_kernel
# PyTorch implementation of logsumexp for reference
def torch_logsumexp(logits):
"""PyTorch implementation of logsumexp over last dimension"""
return torch.logsumexp(logits, dim=-1)
# Wrapper function for Triton logsumexp kernel
def triton_logsumexp(logits):
"""Triton implementation of logsumexp over last dimension"""
B, S, V = logits.shape # pylint: disable=invalid-name
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
grid = (B * S,)
logsumexp_kernel[grid](
logits.contiguous(),
output,
B,
S,
V,
logits.stride(0),
logits.stride(1),
logits.stride(2),
output.stride(0),
output.stride(1),
min(1024, triton.next_power_of_2(V)),
)
return output
class TritonLogSumExp(torch.autograd.Function):
"""
Wrap a custom autograd function to use the Triton logsumexp for gradient testing
"""
@staticmethod
def forward(ctx, logits):
B, S, V = logits.shape # pylint: disable=invalid-name
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
# Save inputs for backward pass
ctx.save_for_backward(logits)
ctx.shape = logits.shape
grid = (B * S,)
logsumexp_kernel[grid](
logits.contiguous(),
output,
B,
S,
V,
logits.stride(0),
logits.stride(1),
logits.stride(2),
output.stride(0),
output.stride(1),
min(1024, triton.next_power_of_2(V)),
)
return output
@staticmethod
def backward(ctx, grad_output):
(logits,) = ctx.saved_tensors
# For logsumexp, the gradient is softmax(input) * grad_output
# First compute the logsumexp
lse = TritonLogSumExp.apply(logits)
# Compute softmax by exponentiating differences
softmax_output = torch.exp(logits - lse.unsqueeze(-1))
# Compute gradient of logsumexp by multiplying the softmax output by the gradient
grad_input = softmax_output * grad_output.unsqueeze(-1)
return grad_input
def test_logsumexp_values():
"""Test that the Triton logsumexp implementation matches PyTorch's"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Test with various input shapes
test_shapes = [
(2, 3, 10), # small vocab
(4, 5, 100), # medium vocab
(2, 2, 32000), # large vocab (typical for LLMs)
]
for shape in test_shapes:
# Create random input tensors
logits = torch.randn(shape, device="cuda", requires_grad=False)
# Compute logsumexp using both implementations
torch_result = torch_logsumexp(logits)
triton_result = triton_logsumexp(logits)
# Compare results
max_diff = (torch_result - triton_result).abs().max().item()
print(f"Shape {shape}, Max diff: {max_diff}")
# Assert that the results are very close
assert max_diff < 1e-5, f"Results differ for shape {shape}: max diff {max_diff}"
def test_logsumexp_edge_cases():
"""Test edge cases for numerical stability"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Case 1: Very large values that might cause overflow
logits_large = torch.ones(2, 3, 100, device="cuda") * 1000
# Case 2: Very small values that might cause underflow
logits_small = torch.ones(2, 3, 100, device="cuda") * -1000
# Case 3: Mix of large and small values
logits_mixed = torch.zeros(2, 3, 100, device="cuda")
logits_mixed[:, :, 0] = 1000 # One very large value
# Case 4: All identical values
logits_identical = torch.ones(2, 3, 100, device="cuda") * 5
# Case 5: Extreme values with NaN check
logits_extreme = torch.cat(
[
torch.full((1, 3, 50), 1e10, device="cuda"),
torch.full((1, 3, 50), -1e10, device="cuda"),
],
dim=0,
)
for i, logits in enumerate(
[logits_large, logits_small, logits_mixed, logits_identical, logits_extreme]
):
# Compute logsumexp using both implementations
torch_result = torch_logsumexp(logits)
triton_result = triton_logsumexp(logits)
# Check for NaNs
assert not torch.isnan(
torch_result
).any(), f"PyTorch produced NaNs for case {i+1}"
assert not torch.isnan(
triton_result
).any(), f"Triton produced NaNs for case {i+1}"
# Compare results
max_diff = (torch_result - triton_result).abs().max().item()
print(f"Edge case {i+1}, Max diff: {max_diff}")
# For very extreme values, allow a bit more tolerance
if i == 4: # extreme case
assert max_diff < 1e-2, f"Results differ too much for edge case {i+1}"
else:
assert max_diff < 1e-5, f"Results differ too much for edge case {i+1}"
def test_logsumexp_gradients():
"""Test that the gradients of Triton logsumexp match PyTorch's"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Create input tensors with gradients enabled
shapes = [(2, 3, 10), (4, 5, 100)]
for shape in shapes:
# Create two identical tensors for PyTorch and Triton
logits_torch = torch.randn(shape, device="cuda", requires_grad=True)
logits_triton = logits_torch.clone().detach().requires_grad_(True)
# Forward pass
torch_output = torch_logsumexp(logits_torch)
triton_output = TritonLogSumExp.apply(logits_triton)
# Compare forward pass values
max_diff_forward = (torch_output - triton_output).abs().max().item()
assert max_diff_forward < 1e-5, f"Forward pass values differ for shape {shape}"
# Create random gradient
grad_output = torch.randn_like(torch_output)
# Backward pass
torch_output.backward(grad_output)
triton_output.backward(grad_output)
# Compare gradients
max_diff_grad = (logits_torch.grad - logits_triton.grad).abs().max().item()
print(f"Shape {shape}, Max gradient diff: {max_diff_grad}")
# Assert that gradients are very close
assert (
max_diff_grad < 1e-5
), f"Gradients differ for shape {shape}: max diff {max_diff_grad}"

View File

@@ -750,66 +750,3 @@ class TestMultiGPULlama:
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
)

View File

@@ -1,130 +0,0 @@
"""
E2E tests for lora llama
"""
import logging
import os
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestDeepseekV3:
"""
Test case for DeepseekV3 models
"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_lora_deepseekv3(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/DeepSeek-V3-11M",
"trust_remote_code": True,
"sample_packing": sample_packing,
"flash_attention": True,
"sequence_len": 2048,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0,
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"field_messages": "conversations",
"message_property_mappings": {
"role": "from",
"content": "value",
},
"drop_system_message": True,
"split": "train[:1%]",
},
],
"special_tokens": {
"bos_token": "<begin▁of▁sentence>",
"eos_token": "<end▁of▁sentence>",
},
"chat_template": "deepseek_v3",
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_fft_deepseekv3(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/DeepSeek-V3-11M",
"trust_remote_code": True,
"sample_packing": sample_packing,
"flash_attention": True,
"sequence_len": 2048,
"val_set_size": 0,
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
"split": "train[:1%]",
},
],
"chat_template": "deepseek_v3",
"special_tokens": {
"bos_token": "<begin▁of▁sentence>",
"eos_token": "<end▁of▁sentence>",
},
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -66,54 +66,6 @@ class TestLlama:
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"bos_token": "<|custom_im_start|>",
"eos_token": "<|custom_im_end|>",
},
"datasets": [
{
"chat_template": "jinja",
"chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens_already_trained(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{

View File

@@ -75,7 +75,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -190,7 +190,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
@@ -249,7 +249,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32

View File

@@ -65,9 +65,8 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert trainer.optimizer.optimizer.__class__.__name__ == "AdamW"
@with_temp_dir
@require_torch_2_5_1
@@ -112,57 +111,8 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "ADOPT" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
@require_torch_2_5_1
def test_muon(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "muon",
"lr_scheduler": "cosine",
"weight_decay": 0.01,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):

View File

@@ -1,71 +0,0 @@
"""
E2E tests for custom schedulers using Llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomSchedulers(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_rex_scheduler(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_hf",
"max_steps": 20,
"lr_scheduler": "rex",
"warmup_steps": 5,
"cosine_min_lr_ratio": 0.05,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -102,7 +102,11 @@ def is_hopper():
def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
temp_run_dir: str,
tag: str,
comparison_val: float,
assertion_err: str,
lt: bool = True,
) -> None:
"""
helper function to parse and check tensorboard logs
@@ -112,10 +116,20 @@ def check_tensorboard(
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
if "%s" in assertion_err:
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
if lt:
if "%s" in assertion_err:
assert df.value.values[-1] < comparison_val, (
assertion_err % df.value.values[-1]
)
else:
assert df.value.values[-1] < comparison_val, assertion_err
else:
assert df.value.values[-1] < lt_val, assertion_err
if "%s" in assertion_err:
assert df.value.values[-1] > comparison_val, (
assertion_err % df.value.values[-1]
)
else:
assert df.value.values[-1] > comparison_val, assertion_err
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:

View File

@@ -1,7 +1,6 @@
"""
Test cases for the tokenizer loading
"""
import unittest
import pytest
@@ -10,7 +9,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer
class TestTokenizers:
class TestTokenizers(unittest.TestCase):
"""
test class for the load_tokenizer fn
"""
@@ -76,48 +75,12 @@ class TestTokenizers:
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
assert len(tokenizer) == 32001
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
self.assertEqual(len(tokenizer), 32001)
# ensure reloading the tokenizer again from cfg results in same vocab length
tokenizer = load_tokenizer(cfg)
assert len(tokenizer) == 32001
def test_added_tokens_overrides(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {
128041: "RANDOM_OVERRIDE_1",
128042: "RANDOM_OVERRIDE_2",
},
"output_dir": temp_dir,
}
)
tokenizer = load_tokenizer(cfg)
assert tokenizer.encode("RANDOM_OVERRIDE_1", add_special_tokens=False) == [
128041
]
assert tokenizer.encode("RANDOM_OVERRIDE_2", add_special_tokens=False) == [
128042
]
def test_added_tokens_overrides_with_toolargeid(self, temp_dir):
cfg = DictDefault(
{
# use with tokenizer that has reserved_tokens in added_tokens
"tokenizer_config": "NousResearch/Llama-3.2-1B",
"added_tokens_overrides": {1000000: "BROKEN_RANDOM_OVERRIDE_1"},
"output_dir": temp_dir,
}
)
with pytest.raises(
ValueError, match=r".*Token ID 1000000 not found in added_tokens.*"
):
load_tokenizer(cfg)
self.assertEqual(len(tokenizer), 32001)
if __name__ == "__main__":