Compare commits

..

74 Commits

Author SHA1 Message Date
Wing Lian
f11227a35a various fixes 2025-01-30 10:39:18 -05:00
Wing Lian
c434951dd6 Always re-normalize teacher distribution 2025-01-29 08:36:40 -05:00
Wing Lian
42d4732aaf kd loss needs to be calculated in full precision 2025-01-28 19:40:35 -05:00
Wing Lian
2c9dfbed2e apply z-score scaling to kd 2025-01-27 14:27:35 -05:00
Wing Lian
4e4a16cd8a fix finding the top-k rather than assuming first position has the correct val 2025-01-21 13:09:20 -05:00
Wing Lian
67c1c8405e use iter instead of tuple 2025-01-21 11:23:38 -05:00
Wing Lian
bded6df509 change up logic so we always truncate to top_k 2025-01-21 11:20:01 -05:00
Wing Lian
bb5e6f4b72 make sure to truncate logprobs if there are more than top_k 2025-01-21 10:26:27 -05:00
Wing Lian
32258c247e no batching for kd chat templates 2025-01-15 08:22:29 -05:00
Wing Lian
04efcb102f don't shift student logits for kd 2025-01-15 01:07:48 -05:00
Wing Lian
483defb9ae try tests for kd on l40s 2025-01-14 23:56:00 -05:00
Wing Lian
35a84f2cb8 more fixes 2025-01-14 22:47:49 -05:00
Wing Lian
510cf45317 improve logprob masking and shift in trainer 2025-01-14 22:47:48 -05:00
Wing Lian
7232cbdeab chore: lint 2025-01-14 22:47:48 -05:00
Wing Lian
e8fceb7091 chore: lint 2025-01-14 22:47:48 -05:00
Wing Lian
a5e0671738 make sure to use tensorboard to capture loss for checks 2025-01-14 22:47:48 -05:00
Wing Lian
b9847553af fix adapter model check 2025-01-14 22:47:48 -05:00
Wing Lian
513ec9e03b make sure to use the correct tokenizer 2025-01-14 22:47:48 -05:00
Wing Lian
530347856d make sure to set tokenizer from l3 70b and save safetensors 2025-01-14 22:47:47 -05:00
Wing Lian
261e4fb619 lower lr 2025-01-14 22:47:47 -05:00
Wing Lian
158071e95f set lora_dropout explicitly 2025-01-14 22:47:47 -05:00
Wing Lian
432f65f5e6 make the kd e2e fit in vram for ci and add lora version 2025-01-14 22:47:47 -05:00
Wing Lian
1d039f5486 rename test files so it gets picked up 2025-01-14 22:47:47 -05:00
Wing Lian
b9a42b396f linting 2025-01-14 22:47:47 -05:00
Wing Lian
ff2fb0fc1b add kd trainer e2e test 2025-01-14 22:47:47 -05:00
Wing Lian
317f290186 reward model doesn't work well with batched 2025-01-14 22:47:46 -05:00
Wing Lian
ab690f3f01 improve check for batched 2025-01-14 22:47:46 -05:00
Wing Lian
47932f21c4 fix reward trainer calls for tokenization 2025-01-14 22:47:46 -05:00
Wing Lian
808328e041 reward can use same batch check 2025-01-14 22:47:46 -05:00
Wing Lian
6784822cfb tweak check for batched prompt data 2025-01-14 22:47:46 -05:00
Wing Lian
684b38291f ensure that batch vs single is done properly 2025-01-14 22:47:46 -05:00
Wing Lian
01896b1bde improve iterable support 2025-01-14 22:47:46 -05:00
Wing Lian
e659c01646 support streaming for processing sft datasts? 2025-01-14 22:47:45 -05:00
Wing Lian
204d6c43b4 make loss torch script compat 2025-01-14 22:47:45 -05:00
Wing Lian
d3c2b7ce9d kd sample packing 2025-01-14 22:47:45 -05:00
Wing Lian
93dfff92f1 be a bit pickier about loading dynamic prompt strategies 2025-01-14 22:47:45 -05:00
Wing Lian
6e409d2d88 more info on preprocess for kd and fix import 2025-01-14 22:47:45 -05:00
Wing Lian
d5bc214300 remove duplicate code 2025-01-14 22:47:45 -05:00
Wing Lian
92c6c1087e add copyrights 2025-01-14 22:47:45 -05:00
Wing Lian
feed96f95e increase logging around loading plugins 2025-01-14 22:47:44 -05:00
Wing Lian
cba6165ae1 make plugin setup concise 2025-01-14 22:47:44 -05:00
Wing Lian
cdfcd69afa remove moved class from import 2025-01-14 22:47:44 -05:00
Wing Lian
885653d52e move more things to kd plugin 2025-01-14 22:47:44 -05:00
Wing Lian
27faacbf5a refactor kd chat template loader 2025-01-14 22:47:44 -05:00
Wing Lian
c51b0337c1 support for custom trainer classes from plugins 2025-01-14 22:47:44 -05:00
Wing Lian
fa055f9f69 handle token/logprob shifting 2025-01-14 22:47:43 -05:00
Wing Lian
f60c623af0 remove references to triton kd for now 2025-01-14 22:47:43 -05:00
Wing Lian
746891eb5c add license block 2025-01-14 22:47:43 -05:00
Wing Lian
f09b5da60b refactor so we can easily add new loss functions 2025-01-14 22:47:43 -05:00
Wing Lian
689e1c10ba chore: lint 2025-01-14 22:47:43 -05:00
Wing Lian
a5c085e003 var naming and add todo 2025-01-14 22:47:43 -05:00
Wing Lian
63146300b7 fix kd loss so it's causal (fixes repeating tokens) 2025-01-14 22:47:43 -05:00
Wing Lian
ca5e397fc5 use kd_alpha in the correct loss method 2025-01-14 22:47:42 -05:00
Wing Lian
3416302b0d hash for temperature too 2025-01-14 22:47:42 -05:00
Wing Lian
7366efc4ca better rescaling for temperatures 2025-01-14 22:47:42 -05:00
Wing Lian
d8d817eaed don't use triton for now 2025-01-14 22:47:42 -05:00
Wing Lian
c0757e8a20 fix kwarg 2025-01-14 22:47:42 -05:00
Wing Lian
e565694914 v3 2025-01-14 22:47:42 -05:00
Wing Lian
081928e55b no torch.tensor 2025-01-14 22:47:42 -05:00
Wing Lian
dc90c93894 no log etc 2025-01-14 22:47:41 -05:00
Wing Lian
18a46c338a no torch.exp inside triton kernel 2025-01-14 22:47:41 -05:00
Wing Lian
119d586cf4 v2 trial 2025-01-14 22:47:41 -05:00
Wing Lian
c73acd7de0 no where support 2025-01-14 22:47:41 -05:00
Wing Lian
0b59a242d4 triton wip 2025-01-14 22:47:41 -05:00
Wing Lian
ed490517da chore: lint 2025-01-14 22:47:41 -05:00
Wing Lian
00ce77e7ef make sure to multiply against the correct loss 2025-01-14 22:47:41 -05:00
Wing Lian
ae545e0165 cross entropy loss coefficient during KD 2025-01-14 22:47:40 -05:00
Wing Lian
b592c05b93 flipped the slice 2025-01-14 22:47:40 -05:00
Wing Lian
7fe0ad088b make it work 2025-01-14 22:47:40 -05:00
Wing Lian
ddcf5c68b3 handle padding/collation for KD datasets 2025-01-14 22:47:40 -05:00
Wing Lian
e633a12dbe make batch smaller 2025-01-14 22:47:40 -05:00
Wing Lian
d584354ee4 filter bad rows 2025-01-14 22:47:40 -05:00
Wing Lian
303cfa71aa KD dataset loading and KD with logprobs 2025-01-14 22:47:40 -05:00
Wing Lian
88b3198894 refactor trainer to prevent circular dependencies later
fix loader default
2025-01-14 22:47:39 -05:00
142 changed files with 1590 additions and 10575 deletions

View File

@@ -15,7 +15,7 @@ First of all, thank you for your interest in contributing to axolotl! We appreci
- [Commit Messages](#commit-messages)
- [Additional Resources](#additional-resources)
## Code of Conduct
## Code of Conductcode
All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT.md). Please read it before participating in the axolotl community.

View File

@@ -22,6 +22,18 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "121"
cuda_version: 12.1.1
cudnn_version: 8
python_version: "3.10"
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.1
cudnn_version: 8
python_version: "3.11"
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""

View File

@@ -15,6 +15,16 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -72,6 +82,16 @@ jobs:
strategy:
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -125,10 +145,10 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.3.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -20,6 +20,12 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -12,6 +12,17 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -65,6 +76,17 @@ jobs:
strategy:
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -26,7 +26,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
@@ -98,6 +98,13 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -49,7 +49,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
@@ -228,7 +228,6 @@ jobs:
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
@@ -245,6 +244,12 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -269,7 +274,6 @@ jobs:
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |

View File

@@ -19,7 +19,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
rev: 6.0.0
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint

785
README.md
View File

@@ -1,8 +1,8 @@
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_white.svg">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg">
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
<source media="(prefers-color-scheme: dark)" srcset="image/axolotl_logo_digital_white.svg">
<source media="(prefers-color-scheme: light)" srcset="image/axolotl_logo_digital_black.svg">
<img alt="Axolotl" src="image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture>
</p>
@@ -19,99 +19,235 @@
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
</p>
Axolotl is a tool designed to streamline post-training for various AI models.
Post-training refers to any modifications or additional training performed on
pre-trained models - including full model fine-tuning, parameter-efficient tuning (like
LoRA and QLoRA), supervised fine-tuning (SFT), instruction tuning, and alignment
techniques. With support for multiple model architectures and training configurations,
Axolotl makes it easy to get started with these techniques.
Axolotl is designed to work with YAML config files that contain everything you need to
preprocess a dataset, train or fine-tune a model, run model inference or evaluation,
and much more.
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
Features:
- Train various Huggingface models such as llama, pythia, falcon, mpt
- Supports fullfinetune, lora, qlora, relora, and gptq
- Customize configurations using a simple yaml file or CLI overwrite
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
- Integrated with [xformers](https://github.com/facebookresearch/xformers), flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet
- And more!
## 🚀 Quick Start
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.10
- PyTorch ≥2.4.1
<table>
<tr>
<td>
### Installation
## Table of Contents
- [Axolotl](#axolotl)
- [Table of Contents](#table-of-contents)
- [Quickstart ⚡](#quickstart-)
- [Edge Builds](#edge-builds-)
- [Axolotl CLI Usage](#axolotl-cli-usage)
- [Badge ❤🏷️](#badge-)
- [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-)
- [Axolotl supports](#axolotl-supports)
- [Advanced Setup](#advanced-setup)
- [Environment](#environment)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu)
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [LambdaLabs](#lambdalabs)
- [GCP](#gcp)
- [Windows](#windows)
- [Mac](#mac)
- [Google Colab](#google-colab)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
- [Dataset](#dataset)
- [Config](#config)
- [All Config Options](#all-config-options)
- [Train](#train)
- [Preprocess dataset](#preprocess-dataset)
- [Multi-GPU](#multi-gpu)
- [DeepSpeed](#deepspeed)
- [FSDP](#fsdp)
- [FSDP + QLoRA](#fsdp--qlora)
- [Weights \& Biases Logging](#weights--biases-logging)
- [Special Tokens](#special-tokens)
- [Liger Kernel](#liger-kernel)
- [Inference Playground](#inference-playground)
- [Merge LORA to base](#merge-lora-to-base)
- [Common Errors 🧰](#common-errors-)
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
- [Need help? 🙋](#need-help-)
```shell
</td>
<td>
<div align="center">
<img src="image/axolotl_symbol_digital_white.svg" alt="axolotl" width="160">
<div>
<p>
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
</p>
<p>
Go ahead and Axolotl questions!!
</p>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
<img alt="PyTest Status" src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
</div>
</div>
</td>
</tr>
</table>
## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
```bash
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs
# download examples and optionally deepspeed configs to the local path
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
```
Other installation approaches are described [here](https://axolotl-ai-cloud.github.io/axolotl/docs/installation.html).
### Your First Fine-tune
```shell
# Fetch axolotl examples
axolotl fetch examples
# Or, specify a custom path
axolotl fetch examples --dest path/to/folder
# Train a model using LoRA
# finetune using lora
axolotl train examples/llama-3/lora-1b.yml
```
That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/getting-started.html) for a more detailed walkthrough.
### Edge Builds 🏎️
## ✨ Key Features
If you're looking for the latest features and updates between releases, you'll need to install
from source.
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, and more
- **Easy Configuration**: Simple YAML files to control your training setup
- **Performance Optimizations**: Flash Attention, xformers, multi-GPU training
- **Flexible Dataset Handling**: Use various formats and custom datasets
- **Cloud Ready**: Run on cloud platforms or local hardware
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
## 📚 Documentation
### Axolotl CLI Usage
We now support a new, more streamlined CLI using [click](https://click.palletsprojects.com/en/stable/).
- [Installation Options](https://axolotl-ai-cloud.github.io/axolotl/docs/installation.html) - Detailed setup instructions for different environments
- [Configuration Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html) - Full configuration options and examples
- [Dataset Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) - Supported formats and how to use them
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/llama-3/lora-1b.yml
## 🤝 Getting Help
# finetune lora
axolotl train examples/llama-3/lora-1b.yml
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support
- Check out our [Examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/) directory
- Read our [Debugging Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/debugging.html)
- Need dedicated support? Please contact [wing@axolotl.ai](mailto:wing@axolotl.ai) for options
# inference
axolotl inference examples/llama-3/lora-1b.yml \
--lora-model-dir="./outputs/lora-out"
## 🌟 Contributing
# gradio
axolotl inference examples/llama-3/lora-1b.yml \
--lora-model-dir="./outputs/lora-out" --gradio
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
```
## Supported Models
We've also added a new command for fetching `examples` and `deepspeed_configs` to your
local machine. This will come in handy when installing `axolotl` from PyPI.
```bash
# Fetch example YAML files (stores in "examples/" folder)
axolotl fetch examples
# Fetch deepspeed config files (stores in "deepspeed_configs/" folder)
axolotl fetch deepspeed_configs
# Optionally, specify a destination folder
axolotl fetch examples --dest path/to/folder
```
### Legacy Usage
<details>
<summary>Click to Expand</summary>
While the Axolotl CLI is the preferred method for interacting with axolotl, we
still support the legacy `-m axolotl.cli.*` usage.
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/llama-3/lora-1b.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/llama-3/lora-1b.yml
# inference
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
```
</details>
## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card.
```markdown
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
```
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
## Sponsors 🤝❤
If you love axolotl, consider sponsoring the project by reaching out directly to [wing@axolotl.ai](mailto:wing@axolotl.ai).
---
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
---
## Contributing 🤝
Please read the [contributing guide](./.github/CONTRIBUTING.md)
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
PRs are **greatly welcome**!
Please run the quickstart instructions followed by the below to setup env:
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
# optional: run against all files
pre-commit run --all-files
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>
## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
@@ -136,16 +272,523 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
❌: not supported
❓: untested
## ❤️ Sponsors
## Advanced Setup
Thank you to our sponsors who help make Axolotl possible:
### Environment
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) - Modal lets you run
jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale,
fine-tune large language models, run protein folding simulations, and much more.
#### Docker
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
```bash
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
## 📜 License
Or run on the current files for development:
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
```sh
docker compose up -d
```
>[!Tip]
> If you want to debug axolotl or prefer to use Docker as your development environment, see the [debugging guide's section on Docker](docs/debugging.qmd#debugging-with-docker).
<details>
<summary>Docker advanced</summary>
A more powerful Docker command to run would be this:
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
```
It additionally:
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
* The `--privileged` flag gives all capabilities to the container.
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
</details>
#### Conda/Pip venv
1. Install python >=**3.10**
2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install Axolotl along with python dependencies
```bash
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Huggingface to use gated models/datasets.
```bash
huggingface-cli login
```
Get the token at huggingface.co/settings/tokens
#### Cloud GPU
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### Bare Metal Cloud GPU
##### LambdaLabs
<details>
<summary>Click to Expand</summary>
1. Install python
```bash
sudo apt update
sudo apt install -y python3.10
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
sudo update-alternatives --config python # pick 3.10 if given option
python -V # should be 3.10
```
2. Install pip
```bash
wget https://bootstrap.pypa.io/get-pip.py
python get-pip.py
```
3. Install Pytorch https://pytorch.org/get-started/locally/
4. Follow instructions on quickstart.
5. Run
```bash
pip3 install protobuf==3.20.3
pip3 install -U --ignore-installed requests Pillow psutil scipy
```
6. Set path
```bash
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
```
</details>
##### GCP
<details>
<summary>Click to Expand</summary>
Use a Deeplearning linux OS with cuda and pytorch installed. Then follow instructions on quickstart.
Make sure to run the below to uninstall xla.
```bash
pip uninstall -y torch_xla[tpu]
```
</details>
#### Windows
Please use WSL or Docker!
#### Mac
Use the below instead of the install method in QuickStart.
```
pip3 install --no-build-isolation -e '.'
```
More info: [mac.md](/docs/mac.qmd)
#### Google Colab
Please use this example [notebook](examples/colab-notebooks/colab-axolotl-example.ipynb).
#### Launching on public clouds via SkyPilot
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
```bash
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
sky check
```
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
```
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/axolotl
```
Use one command to launch:
```bash
# On-demand
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
# Managed spot (auto-recovery on preemption)
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
```
#### Launching on public clouds via dstack
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
Write a job description in YAML as below:
```yaml
# dstack.yaml
type: task
image: axolotlai/axolotl-cloud:main-latest
env:
- HUGGING_FACE_HUB_TOKEN
- WANDB_API_KEY
commands:
- accelerate launch -m axolotl.cli.train config.yaml
ports:
- 6006
resources:
gpu:
memory: 24GB..
count: 2
```
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
```bash
pip install dstack
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
```
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
### Dataset
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
See [the documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
### Config
See [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
- model
```yaml
base_model: ./llama-7b-hf # local or huggingface repo
```
Note: The code will load the right architecture.
- dataset
```yaml
datasets:
# huggingface repo
- path: vicgalle/alpaca-gpt4
type: alpaca
# huggingface repo with specific configuration/subset
- path: EleutherAI/pile
name: enron_emails
type: completion # format from earlier
field: text # Optional[str] default: text, field to use for completion data
# huggingface repo with multiple named configurations/subsets
- path: bigcode/commitpackft
name:
- ruby
- python
- typescript
type: ... # unimplemented custom format
# chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template
- path: ...
type: chat_template
chat_template: chatml # defaults to tokenizer's chat_template
# local
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
# dataset with splits, but no train split
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
# Loading Data From a Public URL
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
ds_type: json # this is the default, see other options below.
```
- loading
```yaml
load_in_4bit: true
load_in_8bit: true
bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP
```
Note: Repo does not do 4-bit quantization.
- lora
```yaml
adapter: lora # 'qlora' or leave blank for full finetune
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
```
#### All Config Options
See [these docs](docs/config.qmd) for all config options.
### Train
Run
```bash
accelerate launch -m axolotl.cli.train your_config.yml
```
> [!TIP]
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
#### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning.
This is recommended for large datasets.
- Set `dataset_prepared_path:` to a local folder for saving and loading pre-tokenized dataset.
- (Optional): Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
- (Optional): Use `--debug` to see preprocessed examples.
```bash
python -m axolotl.cli.preprocess your_config.yml
```
#### Multi-GPU
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
is the recommended multi-GPU option currently because FSDP may experience
[loss instability](https://github.com/huggingface/transformers/issues/26498).
##### DeepSpeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```yaml
deepspeed: deepspeed_configs/zero1.json
```
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
```
##### FSDP
- llama FSDP
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
##### FSDP + QLoRA
Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.qmd) for more information.
##### Weights & Biases Logging
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
- wandb options
```yaml
wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
```
##### Comet Logging
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
- wandb options
```yaml
use_comet:
comet_api_key:
comet_workspace:
comet_project_name:
comet_experiment_key:
comet_mode:
comet_online:
comet_experiment_config:
```
##### Special Tokens
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
```yml
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
tokens: # these are delimiters
- "<|im_start|>"
- "<|im_end|>"
```
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
##### Liger Kernel
Liger Kernel: Efficient Triton Kernels for LLM Training
https://github.com/linkedin/Liger-Kernel
Liger (LinkedIn GPU Efficient Runtime) Kernel is a collection of Triton kernels designed specifically for LLM training.
It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The Liger Kernel
composes well and is compatible with both FSDP and Deepspeed.
```yaml
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```
### Inference Playground
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.
The config file is the same config file used for training.
Pass the appropriate flag to the inference command, depending upon what kind of model was trained:
- Pretrained LORA:
```bash
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
```
- Full weights finetune:
```bash
python -m axolotl.cli.inference examples/your_config.yml --base_model="./completed-model"
```
- Full weights finetune w/ a prompt from a text file:
```bash
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
--base_model="./completed-model" --prompter=None --load_in_8bit=True
```
-- With gradio hosting
```bash
python -m axolotl.cli.inference examples/your_config.yml --gradio
```
Please use `--sample_packing False` if you have it on and receive the error similar to below:
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
### Merge LORA to base
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
```bash
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
```
You may need to use the `gpu_memory_limit` and/or `lora_on_cpu` config options to avoid running out of memory. If you still run out of CUDA memory, you can try to merge in system RAM with
```bash
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
```
although this will be very slow, and using the config options above are recommended instead.
## Common Errors 🧰
See also the [FAQ's](./docs/faq.qmd) and [debugging guide](docs/debugging.qmd).
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
Please reduce any below
- `micro_batch_size`
- `eval_batch_size`
- `gradient_accumulation_steps`
- `sequence_len`
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
Using adamw_bnb_8bit might also save you some memory.
> `failed (exitcode: -9)`
Usually means your system has run out of system memory.
Similarly, you should consider reducing the same settings as when you run out of VRAM.
Additionally, look into upgrading your system RAM which should be simpler than GPU upgrades.
> RuntimeError: expected scalar type Float but found Half
Try set `fp16: true`
> NotImplementedError: No operator found for `memory_efficient_attention_forward` ...
Try to turn off xformers.
> accelerate config missing
It's safe to ignore it.
> NCCL Timeouts during training
See the [NCCL](docs/nccl.qmd) guide.
### Tokenization Mismatch b/w Inference & Training
For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks.
If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following:
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly.
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/finetuning/05_tokenizer_gotchas.html) for a concrete example.
## Debugging Axolotl
See [this debugging guide](docs/debugging.qmd) for tips on debugging Axolotl, along with an example configuration for debugging with VSCode.
## Need help? 🙋
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where our community members can help you.
Need dedicated support? Please contact us at [wing@axolotl.ai](ailto:wing@axolotl.ai) for dedicated support options.

View File

@@ -28,21 +28,16 @@ website:
- section: "How-To Guides"
contents:
# TODO Edit folder structure after we have more docs.
- docs/getting-started.qmd
- docs/installation.qmd
- docs/debugging.qmd
- docs/inference.qmd
- docs/multipack.qmd
- docs/fsdp_qlora.qmd
- docs/input_output.qmd
- docs/rlhf.qmd
- docs/nccl.qmd
- docs/mac.qmd
- docs/multi-gpu.qmd
- docs/multi-node.qmd
- docs/unsloth.qmd
- docs/amd_hpc.qmd
- docs/ray-integration.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Reference"
@@ -50,6 +45,7 @@ website:
- docs/config.qmd
- docs/faq.qmd
format:
html:
theme: materia

View File

@@ -32,9 +32,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -6,6 +6,5 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -23,8 +23,8 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),

View File

@@ -23,8 +23,8 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
@@ -38,12 +38,16 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
).env(df_args)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
app = App("Axolotl CI/CD", secrets=[])

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -20,8 +20,7 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
chmod +x /root/cloud-entrypoint.sh
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

View File

@@ -1,256 +0,0 @@
# Axolotl CLI Documentation
The Axolotl CLI provides a streamlined interface for training and fine-tuning large language models. This guide covers
the CLI commands, their usage, and common examples.
### Table of Contents
- Basic Commands
- Command Reference
- fetch
- preprocess
- train
- inference
- merge-lora
- merge-sharded-fsdp-weights
- evaluate
- lm-eval
- Legacy CLI Usage
- Remote Compute with Modal Cloud
- Cloud Configuration
- Running on Modal Cloud
- Cloud Configuration Options
### Basic Commands
All Axolotl commands follow this general structure:
```bash
axolotl <command> [config.yml] [options]
```
The config file can be local or a URL to a raw YAML file.
### Command Reference
#### fetch
Downloads example configurations and deepspeed configs to your local machine.
```bash
# Get example YAML files
axolotl fetch examples
# Get deepspeed config files
axolotl fetch deepspeed_configs
# Specify custom destination
axolotl fetch examples --dest path/to/folder
```
#### preprocess
Preprocesses and tokenizes your dataset before training. This is recommended for large datasets.
```bash
# Basic preprocessing
axolotl preprocess config.yml
# Preprocessing with one GPU
CUDA_VISIBLE_DEVICES="0" axolotl preprocess config.yml
# Debug mode to see processed examples
axolotl preprocess config.yml --debug
# Debug with limited examples
axolotl preprocess config.yml --debug --debug-num-examples 5
```
Configuration options:
```yaml
dataset_prepared_path: Local folder for saving preprocessed data
push_dataset_to_hub: HuggingFace repo to push preprocessed data (optional)
```
#### train
Trains or fine-tunes a model using the configuration specified in your YAML file.
```bash
# Basic training
axolotl train config.yml
# Train and set/override specific options
axolotl train config.yml \
--learning-rate 1e-4 \
--micro-batch-size 2 \
--num-epochs 3
# Training without accelerate
axolotl train config.yml --no-accelerate
# Resume training from checkpoint
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
```
#### inference
Runs inference using your trained model in either CLI or Gradio interface mode.
```bash
# CLI inference with LoRA
axolotl inference config.yml --lora-model-dir="./outputs/lora-out"
# CLI inference with full model
axolotl inference config.yml --base-model="./completed-model"
# Gradio web interface
axolotl inference config.yml --gradio \
--lora-model-dir="./outputs/lora-out"
# Inference with input from file
cat prompt.txt | axolotl inference config.yml \
--base-model="./completed-model"
```
#### merge-lora
Merges trained LoRA adapters into the base model.
```bash
# Basic merge
axolotl merge-lora config.yml
# Specify LoRA directory (usually used with checkpoints)
axolotl merge-lora config.yml --lora-model-dir="./lora-output/checkpoint-100"
# Merge using CPU (if out of GPU memory)
CUDA_VISIBLE_DEVICES="" axolotl merge-lora config.yml
```
Configuration options:
```yaml
gpu_memory_limit: Limit GPU memory usage
lora_on_cpu: Load LoRA weights on CPU
```
#### merge-sharded-fsdp-weights
Merges sharded FSDP model checkpoints into a single combined checkpoint.
```bash
# Basic merge
axolotl merge-sharded-fsdp-weights config.yml
```
#### evaluate
Evaluates a model's performance using metrics specified in the config.
```bash
# Basic evaluation
axolotl evaluate config.yml
```
#### lm-eval
Runs LM Evaluation Harness on your model.
```bash
# Basic evaluation
axolotl lm-eval config.yml
# Evaluate specific tasks
axolotl lm-eval config.yml --tasks arc_challenge,hellaswag
```
Configuration options:
```yaml
lm_eval_tasks: List of tasks to evaluate
lm_eval_batch_size: Batch size for evaluation
output_dir: Directory to save evaluation results
```
### Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
```bash
# Preprocess
python -m axolotl.cli.preprocess config.yml
# Train
accelerate launch -m axolotl.cli.train config.yml
# Inference
accelerate launch -m axolotl.cli.inference config.yml \
--lora_model_dir="./outputs/lora-out"
# Gradio interface
accelerate launch -m axolotl.cli.inference config.yml \
--lora_model_dir="./outputs/lora-out" --gradio
```
### Remote Compute with Modal Cloud
Axolotl supports running training and inference workloads on Modal cloud infrastructure. This is configured using a
cloud YAML file alongside your regular Axolotl config.
#### Cloud Configuration
Create a cloud config YAML with your Modal settings:
```yaml
# cloud_config.yml
provider: modal
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
gpu_count: 1 # Number of GPUs to use
timeout: 86400 # Maximum runtime in seconds (24 hours)
branch: main # Git branch to use (optional)
volumes: # Persistent storage volumes
- name: axolotl-cache
mount: /workspace/cache
env: # Environment variables
- WANDB_API_KEY
- HF_TOKEN
```
#### Running on Modal Cloud
Commands that support the --cloud flag:
```bash
# Preprocess on cloud
axolotl preprocess config.yml --cloud cloud_config.yml
# Train on cloud
axolotl train config.yml --cloud cloud_config.yml
# Train without accelerate on cloud
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
# Run lm-eval on cloud
axolotl lm-eval config.yml --cloud cloud_config.yml
```
#### Cloud Configuration Options
```yaml
provider: compute provider, currently only `modal` is supported
gpu: GPU type to use
gpu_count: Number of GPUs (default: 1)
memory: RAM in GB (default: 128)
timeout: Maximum runtime in seconds
timeout_preprocess: Preprocessing timeout
branch: Git branch to use
docker_tag: Custom Docker image tag
volumes: List of persistent storage volumes
env: Environment variables to pass
secrets: Secrets to inject
```

View File

@@ -187,12 +187,6 @@ rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
# reward modelling: `True` or `False`
reward_model:
# process reward modelling: `True` or `False`
process_reward_model:
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
@@ -250,8 +244,6 @@ total_num_tokens:
sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
# Use batch flattening for speedups when not using sample_packing
batch_flattening:
@@ -366,11 +358,10 @@ warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`.
save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`.
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
save_strategy: # Set to `"no"` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that

View File

@@ -8,12 +8,14 @@ order: 3
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
## pygmalion
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "value": "..."}]}
```
## chat_template
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.

View File

@@ -1,26 +0,0 @@
---
title: Stepwise Supervised Format
description: Format for datasets with stepwise completions and labels
order: 3
---
## Stepwise Supervised
The stepwise supervised format is designed for chain-of-thought (COT) reasoning
datasets where each example contains multiple completion steps and a preference label
for each step.
### Example
Here's a simple example of a stepwise supervised dataset entry:
```json
{
"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": [
"The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.",
"Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."
],
"labels": [true, false]
}
```

View File

@@ -1,155 +0,0 @@
---
title: "Getting Started with Axolotl"
format:
html:
toc: true
toc-depth: 3
number-sections: true
execute:
enabled: false
---
This guide will walk you through your first model fine-tuning project with Axolotl.
## Quick Example {#sec-quick-example}
Let's start by fine-tuning a small language model using LoRA. This example uses a 1B parameter model to ensure it runs on most GPUs.
Assuming `axolotl` is installed (if not, see our [Installation Guide](installation.qmd))
1. Download example configs:
```shell
axolotl fetch examples
```
2. Run the training:
```shell
axolotl train examples/llama-3/lora-1b.yml
```
That's it! Let's understand what just happened.
## Understanding the Process {#sec-understanding}
### The Configuration File {#sec-config}
The YAML configuration file controls everything about your training. Here's what (part of) our example config looks like:
```yaml
base_model: NousResearch/Llama-3.2-1B
# hub_model_id: username/custom_model_name
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
```
See our [Config options](config.qmd) for more details.
### Training {#sec-training}
When you run `axolotl train`, Axolotl:
1. Downloads the base model
2. (If specified) applies LoRA adapter layers
3. Loads and processes the dataset
4. Runs the training loop
5. Saves the trained model and / or LoRA weights
## Your First Custom Training {#sec-custom}
Let's modify the example for your own data:
1. Create a new config file `my_training.yml`:
```yaml
base_model: NousResearch/Nous-Hermes-llama-1b-v1
adapter: lora
# Training settings
micro_batch_size: 2
num_epochs: 3
learning_rate: 0.0003
# Your dataset
datasets:
- path: my_data.jsonl # Your local data file
type: alpaca # Or other format
```
This specific config is for LoRA fine-tuning a model with instruction tuning data using
the `alpaca` dataset format, which has the following format:
```json
{
"instruction": "Write a description of alpacas.",
"input": "",
"output": "Alpacas are domesticated South American camelids..."
}
```
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
format them.
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
format):
```json
{"instruction": "Classify this text", "input": "I love this!", "output": "positive"}
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
```
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
3. Run the training:
```shell
axolotl train my_training.yml
```
## Common Tasks {#sec-common-tasks}
### Testing Your Model {#sec-testing}
After training, test your model:
```shell
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
```
### Preprocessing Data {#sec-preprocessing}
For large datasets, preprocess first:
```shell
axolotl preprocess my_training.yml
```
### Using a UI {#sec-ui}
Launch a Gradio interface:
```shell
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
```
## Next Steps {#sec-next-steps}
Now that you have the basics, you might want to:
- Try different model architectures
- Experiment with hyperparameters
- Use more advanced training methods
- Scale up to larger models
Check our other guides for details on these topics:
- [Configuration Guide](config.qmd) - Full configuration options
- [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd)
- [Multi-Node Training](multi-node.qmd)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 292 KiB

View File

@@ -1,148 +0,0 @@
---
title: "Inference Guide"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
## Quick Start {#sec-quickstart}
### Basic Inference {#sec-basic}
::: {.panel-tabset}
## LoRA Models
```{.bash}
axolotl inference your_config.yml --lora-model-dir="./lora-output-dir"
```
## Full Fine-tuned Models
```{.bash}
axolotl inference your_config.yml --base-model="./completed-model"
```
:::
## Advanced Usage {#sec-advanced}
### Gradio Interface {#sec-gradio}
Launch an interactive web interface:
```{.bash}
axolotl inference your_config.yml --gradio
```
### File-based Prompts {#sec-file-prompts}
Process prompts from a text file:
```{.bash}
cat /tmp/prompt.txt | axolotl inference your_config.yml \
--base-model="./completed-model" --prompter=None
```
### Memory Optimization {#sec-memory}
For large models or limited memory:
```{.bash}
axolotl inference your_config.yml --load-in-8bit=True
```
## Merging LoRA Weights {#sec-merging}
Merge LoRA adapters with the base model:
```{.bash}
axolotl merge-lora your_config.yml --lora-model-dir="./completed-model"
```
### Memory Management for Merging {#sec-memory-management}
::: {.panel-tabset}
## Configuration Options
```{.yaml}
gpu_memory_limit: 20GiB # Adjust based on your GPU
lora_on_cpu: true # Process on CPU if needed
```
## Force CPU Merging
```{.bash}
CUDA_VISIBLE_DEVICES="" axolotl merge-lora ...
```
:::
## Tokenization {#sec-tokenization}
### Common Issues {#sec-tokenization-issues}
::: {.callout-warning}
Tokenization mismatches between training and inference are a common source of problems.
:::
To debug:
1. Check training tokenization:
```{.bash}
axolotl preprocess your_config.yml --debug
```
2. Verify inference tokenization by decoding tokens before model input
3. Compare token IDs between training and inference
### Special Tokens {#sec-special-tokens}
Configure special tokens in your YAML:
```{.yaml}
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
tokens:
- "<|im_start|>"
- "<|im_end|>"
```
## Troubleshooting {#sec-troubleshooting}
### Common Problems {#sec-common-problems}
::: {.panel-tabset}
## Memory Issues
- Use 8-bit loading
- Reduce batch sizes
- Try CPU offloading
## Token Issues
- Verify special tokens
- Check tokenizer settings
- Compare training and inference preprocessing
## Performance Issues
- Verify model loading
- Check prompt formatting
- Ensure temperature/sampling settings
:::
For more details, see our [debugging guide](debugging.qmd).

View File

@@ -1,119 +0,0 @@
---
title: "Installation Guide"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---
This guide covers all the ways you can install and set up Axolotl for your environment.
## Requirements {#sec-requirements}
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.10
- PyTorch ≥2.4.1
## Installation Methods {#sec-installation-methods}
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
installed) in order not to clobber it, and so that we set the correct version of
dependencies that are specific to the PyTorch version or other installed
co-dependencies.
### Edge/Development Build {#sec-edge-build}
For the latest features between releases:
```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
### Docker {#sec-docker}
```{.bash}
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
For development with Docker:
```{.bash}
docker compose up -d
```
::: {.callout-tip}
### Advanced Docker Configuration
```{.bash}
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
--name axolotl --ipc=host \
--ulimit memlock=-1 --ulimit stack=67108864 \
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
axolotlai/axolotl:main-latest
```
:::
## Cloud Environments {#sec-cloud}
### Cloud GPU Providers {#sec-cloud-gpu}
For providers supporting Docker:
- Use `axolotlai/axolotl-cloud:main-latest`
- Available on:
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
### Google Colab {#sec-colab}
Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb).
## Platform-Specific Instructions {#sec-platform-specific}
### macOS {#sec-macos}
```{.bash}
pip3 install --no-build-isolation -e '.'
```
See @sec-troubleshooting for Mac-specific issues.
### Windows {#sec-windows}
::: {.callout-important}
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
:::
## Environment Managers {#sec-env-managers}
### Conda/Pip venv {#sec-conda}
1. Install Python ≥3.10
2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl:
```{.bash}
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Hugging Face:
```{.bash}
huggingface-cli login
```
## Troubleshooting {#sec-troubleshooting}
If you encounter installation issues, see our [FAQ](faq.qmd) and [Debugging Guide](debugging.qmd).

View File

@@ -1,29 +0,0 @@
---
title: Learning Rate Groups
description: "Setting different learning rates by module name"
---
## Background
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
modules in a model.
## Example
```yaml
lr_groups:
- name: o_proj
modules:
- self_attn.o_proj.weight
lr: 1e-6
- name: q_proj
modules:
- model.layers.2.self_attn.q_proj.weight
lr: 1e-5
learning_rate: 2e-5
```
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
self attention `q_proj` module.

View File

@@ -1,118 +0,0 @@
---
title: "Multi-GPU Training Guide"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---
This guide covers advanced training configurations for multi-GPU setups using Axolotl.
## Overview {#sec-overview}
Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
DeepSpeed is the recommended approach for multi-GPU training due to its stability and performance. It provides various optimization levels through ZeRO stages.
### Configuration {#sec-deepspeed-config}
Add to your YAML config:
```{.yaml}
deepspeed: deepspeed_configs/zero1.json
```
### Usage {#sec-deepspeed-usage}
```{.bash}
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
```
### ZeRO Stages {#sec-zero-stages}
We provide default configurations for:
- ZeRO Stage 1 (`zero1.json`)
- ZeRO Stage 2 (`zero2.json`)
- ZeRO Stage 3 (`zero3.json`)
Choose based on your memory requirements and performance needs.
## FSDP {#sec-fsdp}
### Basic FSDP Configuration {#sec-fsdp-config}
```{.yaml}
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
## Performance Optimization {#sec-performance}
### Liger Kernel Integration {#sec-liger}
::: {.callout-note}
Liger Kernel provides efficient Triton kernels for LLM training, offering:
- 20% increase in multi-GPU training throughput
- 60% reduction in memory usage
- Compatibility with both FSDP and DeepSpeed
:::
Configuration:
```{.yaml}
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```
## Troubleshooting {#sec-troubleshooting}
### NCCL Issues {#sec-nccl}
For NCCL-related problems, see our [NCCL troubleshooting guide](nccl.qmd).
### Common Problems {#sec-common-problems}
::: {.panel-tabset}
## Memory Issues
- Reduce `micro_batch_size`
- Reduce `eval_batch_size`
- Adjust `gradient_accumulation_steps`
- Consider using a higher ZeRO stage
## Training Instability
- Start with DeepSpeed ZeRO-2
- Monitor loss values
- Check learning rates
:::
For more detailed troubleshooting, see our [debugging guide](debugging.qmd).

View File

@@ -1,93 +0,0 @@
---
title: Ray Train integration
description: How to use Axolotl with Ray Train
---
Axolotl supports using Ray as an alternative to `accelerate` for orchestrating training. This is especially useful for multi-node training since you only have to setup code and dependencies in a single node and launch training as if you were using a single node.
With the `--use-ray` CLI flag, Axolotl will use Ray Train's [`TorchTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray.train.torch.TorchTrainer) to run training.
## Ray cluster setup
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs here: https://docs.ray.io/en/latest/cluster/getting-started.html
Every Ray cluster has one _head_ node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this [doc](https://docs.ray.io/en/latest/cluster/key-concepts.html#cluster-key-concepts).
## Sanity check
To run a sanity check on whether your ray cluster is setup properly, execute the following on the head node:
```bash
ray status
```
The output should have a summary of your Ray cluster - list of all the nodes in your cluster, the number of CPUs and GPUs in your cluster, etc. For example, if you have a cluster with 1 CPU-only head node and 2 4xL40S worker nodes, the output can look like this:
```
Node status
---------------------------------------------------------------
Active:
1 head
Idle:
2 4xL40S:48CPU-384GB
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/96.0 CPU
0.0/8.0 GPU
0B/800.00GiB memory
0B/229.57GiB object_store_memory
Demands:
(no resource demands)
```
You should also be able to see the same on the [Ray dashboard](https://docs.ray.io/en/latest/ray-observability/getting-started.html).
## Configuring training with Ray Train
You can find an example configuration at `configs/llama-3/lora-1b-ray.yaml`.
The key parameters to note here are:
```yaml
...
use_ray: true
ray_num_workers: 4
# optional
resources_per_worker:
GPU: 1
...
```
- `use_ray`: This is the flag that enables the Ray Train integration. You can either use the corresponding `--use-ray` flag in the CLI or set `use_ray` in the config file.
- `ray_num_workers`: This is the number of workers/GPUs to use for training.
- `resources_per_worker`: This is the Ray [resource request](https://docs.ray.io/en/latest/ray-core/scheduling/resources.html) for each worker. This can be used to request a specific GPU type or a custom resource for each worker. For example, if your ray cluster has GPUs of different types, and you only want to use NVIDIA L40S GPUs, you can do
```yaml
resources_per_worker:
accelerator_type:L40S: 0.001
```
## Launching training
You can simply run the following command on the head node:
```bash
axolotl train examples/llama-3/lora-1b-ray.yml --use-ray
```
This will launch training on the head node and workers will be scheduled automatically by Ray Train to run on the appropriate head or worker nodes.
You can also monitor training progress on the Ray dashboard.
Coming back to the example on a Ray cluster with 1 head node and 2 4xL40S worker nodes, let's say you want to make use of all 8 GPUs. You would be able to just set `ray_num_workers: 8` and run the previous command. The Cluster tab will show the following:
![Ray dashboard](./images/ray-cluster-dashboard.png)

View File

@@ -1,47 +0,0 @@
---
title: "Reward Modelling"
description: "Reward models are used to guide models towards behaviors which is preferred by humans, by training over large datasets annotated with human preferences. "
---
### Overview
Reward modelling is a technique used to train models to predict the reward or value of a given input. This is particularly useful in reinforcement learning scenarios where the model needs to evaluate the quality of its actions or predictions.
We support the reward modelling techniques supported by `trl`.
### (Outcome) Reward Models
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
```yaml
base_model: google/gemma-2-2b
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
reward_model: true
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.1
eval_steps: 100
```
### Process Reward Models (PRM)
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
```yaml
base_model: Qwen/Qwen2.5-3B
model_type: AutoModelForTokenClassification
num_labels: 2
process_reward_model: true
datasets:
- path: trl-lib/math_shepherd
type: stepwise_supervised
split: train
val_set_size: 0.1
eval_steps: 100
```

View File

@@ -46,7 +46,7 @@ output_dir: ./outputs/btlm-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.000000001
max_grad_norm: 1.0

View File

@@ -1,28 +0,0 @@
project_name:
volumes:
- name: axolotl-data
mount: /workspace/data
- name: axolotl-artifacts
mount: /workspace/artifacts
# environment variables from local to set as secrets
secrets:
- HF_TOKEN
- WANDB_API_KEY
# Which branch of axolotl to use remotely
branch:
# additional custom commands when building the image
dockerfile_commands:
gpu: h100
gpu_count: 1
# Train specific configurations
memory: 128
timeout: 86400
# Preprocess specific configurations
memory_preprocess: 32
timeout_preprocess: 14400

View File

@@ -27,7 +27,7 @@ wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5

View File

@@ -47,7 +47,7 @@ peft_use_rslora: true
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5

View File

@@ -1,7 +1,6 @@
base_model: google/gemma-2-2b
# optionally might have model_type or tokenizer_type
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

View File

@@ -34,7 +34,7 @@ lora_target_linear: false
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -42,7 +42,7 @@ output_dir: ./outputs/model-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.00001
max_grad_norm: 1.0

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -37,7 +37,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5

View File

@@ -1,79 +0,0 @@
base_model: NousResearch/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero3.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"
use_ray: true
ray_num_workers: 4

View File

@@ -30,7 +30,7 @@ lora_target_linear: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -47,7 +47,7 @@ wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -41,7 +41,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -43,7 +43,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -38,7 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -38,7 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -38,7 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 12
num_epochs: 2
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -1,72 +0,0 @@
base_model: Qwen/Qwen2.5-3B
# optionally might have model_type or tokenizer_type
model_type: AutoModelForTokenClassification
num_labels: 2
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
process_reward_model: true
chat_template:
datasets:
- path: trl-lib/math_shepherd
type: stepwise_supervised
step_separator: "\n"
max_completion_length:
train_on_last_step_only: false
val_set_size: 0.2
output_dir: ./outputs/out
remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
eval_batch_size: 8
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32:
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
eval_steps: 100
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -37,7 +37,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -1,67 +0,0 @@
base_model: Qwen/Qwen2.5-0.5B
# optionally might have model_type or tokenizer_type
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
reward_model: true
chat_template: qwen_25
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -38,7 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.1
bitsandbytes==0.45.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
@@ -13,9 +13,9 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.48.1
transformers==4.47.1
tokenizers>=0.21.0
accelerate==1.3.0
accelerate==1.2.1
datasets==3.2.0
deepspeed==0.16.1
trl==0.13.0
@@ -25,7 +25,6 @@ hf_transfer
sentencepiece
gradio==3.50.2
modal==0.70.5
pydantic==2.6.3
addict
fire

View File

@@ -30,7 +30,7 @@ def parse_dataset(dataset=None, split="train"):
)
ds_cfg["field_messages"] = field_messages
message_fields = features[field_messages][0].keys()
message_fields = features["conversations"][0].keys()
message_field_role = None
for key in ["from", "role"]:
if key in message_fields:

View File

@@ -1,15 +1,10 @@
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
dP dP dP
88 88 88
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
88' `88 `8bd8' 88' `88 88 88' `88 88 88
88. .88 .d88b. 88. .88 88 88. .88 88 88
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:

View File

@@ -32,6 +32,8 @@ def parse_requirements():
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
triton_version = [req for req in _install_requires if "triton" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
@@ -85,8 +87,24 @@ def parse_requirements():
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(triton_version))
_install_requires.append("triton>=2.3.1")
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
raise ValueError("axolotl requires torch>=2.4")
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
pass
@@ -150,8 +168,5 @@ setup(
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
"ray": [
"ray[train]",
],
},
)

View File

@@ -31,8 +31,6 @@ class TrainerCliArgs:
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
num_processes: Optional[int] = field(default=None)
@dataclass

View File

@@ -1,56 +0,0 @@
"""
launch axolotl in supported cloud platforms
"""
from pathlib import Path
from typing import Union
import yaml
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
"""Load and validate cloud configuration."""
# Load cloud configuration.
with open(cloud_config, encoding="utf-8") as file:
cloud_cfg: DictDefault = DictDefault(yaml.safe_load(file))
return cloud_cfg
def do_cli_preprocess(
cloud_config: Union[Path, str],
config: Union[Path, str],
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.preprocess(config_yaml)
def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.train(config_yaml, accelerate=accelerate)
def do_cli_lm_eval(
cloud_config: Union[Path, str],
config: Union[Path, str],
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.lm_eval(config_yaml)

View File

@@ -1,18 +0,0 @@
"""
base class for cloud platforms from cli
"""
from abc import ABC, abstractmethod
class Cloud(ABC):
"""
Abstract base class for cloud platforms.
"""
@abstractmethod
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
pass
@abstractmethod
def train(self, config_yaml: str, accelerate: bool = True) -> str:
pass

View File

@@ -1,282 +0,0 @@
"""
Modal Cloud support from CLI
"""
import copy
import json
import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
import modal
from axolotl.cli.cloud.base import Cloud
def run_cmd(cmd: str, run_folder: str, volumes=None):
"""Run a command inside a folder, with Modal Volume reloading before and commit on success."""
# Ensure volumes contain latest files.
if volumes:
for _, vol in volumes.items():
vol.reload()
# modal workaround so it doesn't use the automounted axolotl
new_env = copy.deepcopy(os.environ)
if "PYTHONPATH" in new_env:
del new_env["PYTHONPATH"]
# Propagate errors from subprocess.
if exit_code := subprocess.call( # nosec B603
cmd.split(), cwd=run_folder, env=new_env
):
exit(exit_code) # pylint: disable=consider-using-sys-exit
# Commit writes to volume.
if volumes:
for _, vol in volumes.items():
vol.commit()
class ModalCloud(Cloud):
"""
Modal Cloud implementation.
"""
def __init__(self, config, app=None):
self.config = config
if not app:
app = modal.App()
self.app = app
self.volumes = {}
if config.volumes:
for volume_config in config.volumes:
_, mount, vol = self.create_volume(volume_config)
self.volumes[mount] = (vol, volume_config)
def get_env(self):
res = {
"HF_DATASETS_CACHE": "/workspace/data/huggingface-cache/datasets",
"HF_HUB_CACHE": "/workspace/data/huggingface-cache/hub",
}
for key in self.config.get("env", []):
if isinstance(key, str):
if val := os.environ.get(key, ""):
res[key] = val
elif isinstance(key, dict):
(key_, val) = list(key.items())[0]
res[key_] = val
return res
def get_image(self):
docker_tag = "main-py3.11-cu124-2.5.1"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"
# grab the sha256 hash from docker hub for this image+tag
# this ensures that we always get the latest image for this tag, even if it's already cached
try:
manifest = subprocess.check_output( # nosec B602
f"docker manifest inspect {docker_image}",
shell=True,
).decode("utf-8")
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
except subprocess.CalledProcessError:
sha256_hash = None
# create the image
if sha256_hash:
image = modal.Image.from_registry(f"axolotlai/axolotl@{sha256_hash}")
else:
image = modal.Image.from_registry(docker_image)
dockerfile_commands = []
if self.config.dockerfile_commands:
dockerfile_commands.extend(self.config.dockerfile_commands)
# branch
if self.config.branch:
dockerfile_commands.extend(
[
# Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
]
)
if dockerfile_commands:
image = image.dockerfile_commands(dockerfile_commands)
if env := self.get_env():
image = image.env(env)
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
return image
def get_secrets(self):
res = []
if self.config.secrets:
for key in self.config.get("secrets", []):
# pylint: disable=duplicate-code
if isinstance(key, str):
if val := os.environ.get(key, ""):
res.append(modal.Secret.from_dict({key: val}))
elif isinstance(key, dict):
(key_, val) = list(key.items())[0]
res.append(modal.Secret.from_dict({key_: val}))
return res
def create_volume(self, volume_config):
name = volume_config.name
mount = volume_config.mount
return name, mount, modal.Volume.from_name(name, create_if_missing=True)
def get_ephemeral_disk_size(self):
return 1000 * 525 # 1 TiB
def get_preprocess_timeout(self):
if self.config.timeout_preprocess:
return int(self.config.timeout_preprocess)
return 60 * 60 * 3 # 3 hours
def get_preprocess_memory(self):
memory = 128 # default to 128GiB
if self.config.memory:
memory = int(self.config.memory)
if self.config.memory_preprocess:
memory = int(self.config.memory_preprocess)
return 1024 * memory
def get_preprocess_env(self):
return self.app.function(
image=self.get_image(),
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=8.0,
ephemeral_disk=self.get_ephemeral_disk_size(),
memory=self.get_preprocess_memory(),
timeout=self.get_preprocess_timeout(),
secrets=self.get_secrets(),
)
def preprocess(self, config_yaml: str, *args, **kwargs):
modal_fn = self.get_preprocess_env()(_preprocess)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
*args,
**kwargs,
)
def get_train_timeout(self):
if self.config.timeout:
return int(self.config.timeout)
return 60 * 60 * 24 # 24 hours
def get_train_gpu(self): # pylint: disable=too-many-return-statements
count = self.config.gpu_count or 1
family = self.config.gpu.lower() or "l40s"
if family == "l40s":
return modal.gpu.L40S(count=count)
if family in ["a100", "a100-40gb"]:
return modal.gpu.A100(count=count, size="40GB")
if family == "a100-80gb":
return modal.gpu.A100(count=count, size="80GB")
if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count)
if family == "h100":
return modal.gpu.H100(count=count)
if family == "t4":
return modal.gpu.T4(count=count)
if family == "l4":
return modal.gpu.L4(count=count)
raise ValueError(f"Unsupported GPU family: {family}")
def get_train_memory(self):
memory = 128 # default to 128GiB
if self.config.memory:
memory = int(self.config.memory)
return 1024 * memory
def get_train_env(self):
return self.app.function(
image=self.get_image(),
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=16.0,
gpu=self.get_train_gpu(),
memory=self.get_train_memory(),
timeout=self.get_train_timeout(),
secrets=self.get_secrets(),
)
def train(self, config_yaml: str, accelerate: bool = True):
modal_fn = self.get_train_env()(_train)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
accelerate=accelerate,
volumes={k: v[0] for k, v in self.volumes.items()},
)
def lm_eval(self, config_yaml: str):
modal_fn = self.get_train_env()(_lm_eval)
with modal.enable_output():
with self.app.run(detach=True):
if self.config.get("spawn", False):
modal_fn_exec = modal_fn.spawn
else:
modal_fn_exec = modal_fn.remote
modal_fn_exec(
config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
)
def _preprocess(config_yaml: str, volumes=None):
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_cmd(
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
run_folder,
volumes,
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
if accelerate:
accelerate_args = "--accelerate"
else:
accelerate_args = "--no-accelerate"
run_cmd(
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)
def _lm_eval(config_yaml: str, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_cmd(
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)

View File

@@ -1,135 +0,0 @@
"""CLI to run training on a model."""
import logging
import os
from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.common.datasets import load_datasets
from axolotl.integrations.base import PluginManager
from axolotl.integrations.lolcats.linear_llama.configuration_linear_llama import (
LinearLlamaConfig,
)
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
LinearLlamaForCausalLM,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
from axolotl.utils.trainer import setup_trainer
LOG = logging.getLogger(__name__)
def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
Convert attention to linear attention and perform attention transfer via distillation.
"""
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
# ensure quantization and peft are turned off (due to how we need to re-apply peft later)
cfg.load_in_8bit = False
cfg.load_in_4bit = False
cfg.adapter = None
# load model
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
# freeze model
for p in model.parameters():
p.requires_grad = False
# convert to linear llama
linear_llama_config = LinearLlamaConfig.from_llama(
model.config, cfg.attention_config
)
model = LinearLlamaForCausalLM.from_llama(
model, config=linear_llama_config, train_attention=True
)
# set save_path, save tokenizer and model config.
save_path = str(os.path.join(cfg.output_dir, "distilled"))
tokenizer.save_pretrained(save_path)
if hasattr(model, "config"):
model.config.save_pretrained(save_path)
# Get datasets
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# toggle attention to be trainable
model.toggle_attention(train=True)
# Setup trainer
trainer = setup_trainer(
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=(model, None, None),
tokenizer=tokenizer,
processor=None,
total_num_steps=total_num_steps,
)
# train
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
# drop base_attention + remove training attn
model.toggle_attention(train=False)
model.remove_base_attention()
# NOTE: If in peft mode, consider whether to auto-merge
# save model
safe_serialization = cfg.save_safetensors is True
# NOTE: may need to consider other ways of saving due to multi-gpu etc
model.save_pretrained(save_path, safe_serialization=safe_serialization)
# cleanup
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_train`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# load cfg, force linearize and add plugin to linearize
parsed_cfg = load_cfg(
config,
linearize=True,
plugins=["axolotl.integrations.lolcats.LinearizePlugin"],
**kwargs,
)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_linearize(parsed_cfg, parsed_cli_args)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,17 +1,10 @@
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import logging
import random
import subprocess # nosec B404
import tempfile
from copy import deepcopy
from itertools import product
from pathlib import Path
from typing import Optional
import click
import yaml
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
@@ -22,81 +15,10 @@ from axolotl.cli.utils import (
fetch_from_github,
filter_none_kwargs,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
def generate_sweep_configs(base_config, sweeps_config):
"""
Recursively generates all possible configurations by applying sweeps to the base config.
Args:
base_config (dict): The original configuration dictionary
sweeps_config (dict): Dictionary where keys are parameters and values are either:
- lists of values to sweep independently
- or for paired values, a list of dicts under the '_' key
Returns:
list: List of all possible configuration dictionaries
Example:
sweeps_config = {
'learning_rate': [0.1, 0.01],
'_': [
{'load_in_8bit': True, 'adapter': 'lora'},
{'load_in_4bit': True, 'adapter': 'qlora'}
]
}
"""
# Separate paired values from regular sweeps
paired_values = sweeps_config.get("_", [])
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
# Process regular sweeps
param_names = list(regular_sweeps.keys())
param_values = list(regular_sweeps.values())
# Generate combinations for regular sweeps
regular_combinations = list(product(*param_values)) if param_values else [()]
# Combine regular sweeps with paired values
all_combinations = []
for reg_combo in regular_combinations:
if paired_values:
for paired_set in paired_values:
new_config = {}
# new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
for param_name, param_value in full_combo.items():
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
else:
# If no paired values, just use regular combinations
# new_config = deepcopy(base_config)
new_config = {}
for param_name, param_value in zip(param_names, reg_combo):
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
# randomize the order of trials
random.seed(42)
random.shuffle(all_combinations)
# Generate a new config for each combination
result_configs = []
for combination in all_combinations:
new_config = deepcopy(base_config)
for param_name, param_value in combination.items():
new_config[param_name] = param_value
result_configs.append(new_config)
return result_configs
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
@@ -105,28 +27,23 @@ def cli():
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
def preprocess(config: str, **kwargs) -> None:
"""
Preprocess datasets before training.
Args:
config: Path to `axolotl` config YAML file.
cloud: Path to a cloud accelerator configuration file.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
kwargs = {k: v for k, v in kwargs.items() if v is not None}
do_cli_preprocess(cloud_config=cloud, config=config)
else:
from axolotl.cli.preprocess import do_cli
from axolotl.cli.preprocess import do_cli
do_cli(config=config, **kwargs)
do_cli(config=config, **kwargs)
@cli.command()
@@ -136,99 +53,32 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
default=True,
help="Use accelerate launch for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@click.option(
"--sweep",
type=click.Path(exists=True, path_type=str),
help="YAML config for sweeping hyperparameters",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def train(
config: str,
accelerate: bool,
cloud: Optional[str] = None,
sweep: Optional[str] = None,
**kwargs,
) -> None:
def train(config: str, accelerate: bool, **kwargs) -> None:
"""
Train or fine-tune a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
cloud: Path to a cloud accelerator configuration file
sweep: Path to YAML config for sweeping hyperparameters.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
from axolotl.cli.cloud import do_cli_train
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
if sweep:
# load the sweep configuration yaml file
with open(sweep, "r", encoding="utf-8") as fin:
sweep_config: dict[str, list] = yaml.safe_load(fin)
with open(config, "r", encoding="utf-8") as fin:
base_config: dict[str, list] = yaml.safe_load(fin)
# generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
def iter_configs():
for perm in permutations:
# open temp directory for temporary configurations
with tempfile.TemporaryDirectory() as temp_dir:
with open(
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
) as fout:
yaml.dump(perm, fout)
yield str(Path(temp_dir) / "config.yaml")
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.train import do_cli
def iter_configs():
yield config
for cfg_file in iter_configs():
# handle errors from subprocess so we can continue rest of sweeps
try:
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
else:
from axolotl.cli.train import do_cli
do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc:
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc
do_cli(config=config, **kwargs)
@cli.command()
@@ -347,6 +197,7 @@ def merge_lora(config: str, **kwargs) -> None:
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
@@ -373,9 +224,6 @@ def fetch(directory: str, dest: Optional[str]) -> None:
fetch_from_github(f"{directory}/", dest)
cli.add_command(lm_eval)
def main():
cli()

View File

@@ -5,7 +5,6 @@ from pathlib import Path
from typing import Union
import fire
from accelerate import Accelerator
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
@@ -16,7 +15,6 @@ from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
@@ -65,47 +63,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
return_remaining_strings=True
)
if parsed_cfg.use_ray:
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
train_loop_config = {"cfg": parsed_cfg.to_dict(), "cli_args": parsed_cli_args}
trainer = TorchTrainer(
ray_train_func,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=parsed_cfg.ray_num_workers,
resources_per_worker=parsed_cfg.resources_per_worker.to_dict(),
use_gpu=True,
),
run_config=RunConfig(
name=parsed_cfg.ray_run_name,
storage_path=Path(parsed_cfg.output_dir).absolute().as_posix(),
),
)
return trainer.fit()
return do_train(parsed_cfg, parsed_cli_args)
def ray_train_func(kwargs: dict):
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
# also renormalize the config now that TorchTrainer has spawned distributed workers
cfg = DictDefault(kwargs["cfg"])
normalize_config(cfg)
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
resolve_dtype(cfg)
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
if cfg.deepspeed:
cfg.deepspeed = cfg.deepspeed.to_dict()
# initialize accelerator before model instantiation
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
kwargs["cfg"] = cfg
do_train(**kwargs)
do_train(parsed_cfg, parsed_cli_args)
if __name__ == "__main__":

View File

@@ -43,7 +43,6 @@ from axolotl.core.trainers.base import (
AxolotlKTOTrainer,
AxolotlMambaTrainer,
AxolotlORPOTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
AxolotlTrainer,
ReLoRATrainer,
@@ -53,7 +52,6 @@ from axolotl.core.training_args import (
AxolotlDPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
@@ -226,8 +224,7 @@ class TrainerBuilderBase(abc.ABC):
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for causal models
and reward modelling using TRL.
Build the HuggingFace training args/trainer for Causal models
"""
def get_callbacks(self):
@@ -307,8 +304,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
return AxolotlTrainer
def build(self, total_num_steps):
@@ -565,7 +560,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
@@ -709,17 +703,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"kd_zscore_base_temp"
] = self.cfg.kd_zscore_base_temp
if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs[
"kd_top_k_before_softmax"
] = self.cfg.kd_top_k_before_softmax
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
training_args_cls = AxolotlTrainingArguments
training_args_cls = (
AxolotlTrainingArguments
if not self.cfg.reward_model
else AxolotlRewardConfig
)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
@@ -753,9 +742,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not (self.cfg.reward_model or self.cfg.process_reward_model):
if not self.cfg.reward_model:
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not (self.cfg.reward_model or self.cfg.process_reward_model):
if not self.cfg.reward_model:
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
@@ -766,10 +755,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
and self.cfg.datasets is not None
):
if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
@@ -797,10 +784,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if self.cfg.micro_batch_size > 1:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
if self.cfg.model_config_type == "mamba":
@@ -869,7 +852,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
Trainer factory class for DPO Trainer
"""
def get_callbacks(self):

View File

@@ -21,14 +21,7 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sequential
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import (
CPOTrainer,
DPOTrainer,
KTOTrainer,
ORPOTrainer,
PRMTrainer,
RewardTrainer,
)
from trl import CPOTrainer, DPOTrainer, KTOTrainer, ORPOTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from axolotl.monkeypatch.relora import ReLoRAScheduler
@@ -201,95 +194,11 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
@@ -303,13 +212,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
@@ -326,7 +281,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
or self.args.lr_groups is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
@@ -858,7 +812,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
@@ -978,11 +931,3 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, RewardConfig
@dataclass
@@ -125,10 +125,6 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
@@ -199,13 +195,6 @@ class AxolotlTrainingMixins:
},
)
kd_top_k_before_softmax: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
},
)
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
@@ -255,10 +244,3 @@ class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""
Reward config for Reward training
"""
@dataclass
class AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig):
"""
PRM config for PRM training
"""

View File

@@ -52,17 +52,12 @@ class TokenizedPromptDataset(Dataset):
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 1_000
if (
hasattr(self.prompt_tokenizer, "filter_rows")
and self.prompt_tokenizer.filter_rows
):
if self.prompt_tokenizer.filter_rows:
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
num_proc=num_proc,
desc="Strategy Filtering Rows",
)
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,

View File

@@ -32,6 +32,3 @@ class KDArgs(BaseModel):
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[
bool
] = None # whether to sample top k before softmax during KD

View File

@@ -58,7 +58,6 @@ def loss(
target_mask: torch.Tensor,
num_items_in_batch: int = -1, # Use -1 to indicate "None"
kd_temperature: float = 1.0,
top_k_before_softmax: int = 0,
) -> torch.Tensor:
"""
A KD loss function that is TorchScript-friendly.
@@ -75,8 +74,6 @@ def loss(
num_items_in_batch (int, optional): The number of items in the batch.
kd_temperature (float, optional): The temperature for KD.
Default: 1.0
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
Default: 0
"""
target_logprobs = target_logprobs.float()
@@ -86,46 +83,26 @@ def loss(
# student_logits shape: [B, student_seq_len, vocab_size]
teacher_seq_len = target_token_ids.shape[1]
if top_k_before_softmax:
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size]
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size]
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
student_logits_topk = student_logits_topk.float()
student_logits_topk = student_logits_topk.float()
# Apply KD temperature to students logits
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
# Apply KD temperature to students logits
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
) # [B, teacher_seq_len, K]
else:
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = (
student_logits[:, :teacher_seq_len, :] / kd_temperature
) # [B, teacher_seq_len, vocab_size]
# keep in full precision for numerical stability of loss
student_logits_for_kd = student_logits_for_kd.float()
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Compute logsumexp across full vocabulary
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
# Convert just the top-k logits to logprobs
student_logprobs_topk = student_logits_topk - student_lse
# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
) # [B, teacher_seq_len, K]
# Convert teacher_mask to boolean for indexing
# In TorchScript, .bool() is sometimes unsupported, so we do:

View File

@@ -67,7 +67,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
outputs = model(**inputs)
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
@@ -92,7 +92,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
)
if self.args.kd_ce_alpha > 0:

View File

@@ -2,9 +2,9 @@
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from datetime import datetime
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
@@ -18,20 +18,25 @@ class LMEvalPlugin(BasePlugin):
return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg):
if cfg.lm_eval_post_train:
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention,
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
):
subprocess.run( # nosec
lm_eval_args,
check=True,
)
tasks = ",".join(cfg.lm_eval_tasks)
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
output_path = cfg.output_dir
output_path += "" if cfg.output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
subprocess.run( # nosec
[
"lm_eval",
"--model",
"hf",
"--model_args",
f"pretrained={cfg.output_dir}{fa2}{dtype}",
"--tasks",
tasks,
"--batch_size",
str(cfg.lm_eval_batch_size),
"--output_path",
output_path,
],
check=True,
)

View File

@@ -13,5 +13,3 @@ class LMEvalArgs(BaseModel):
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8
lm_eval_post_train: Optional[bool] = True
lm_eval_model: Optional[str] = None

View File

@@ -1,119 +0,0 @@
"""
axolotl CLI for running lm_eval tasks
"""
import subprocess # nosec
from collections import defaultdict
from datetime import datetime
from typing import Optional
import click
import yaml
from axolotl.utils.dict import DictDefault
def build_lm_eval_command(
tasks: list[str],
bfloat16=True,
flash_attention=False,
output_dir="./",
batch_size=8,
wandb_project=None,
wandb_entity=None,
wandb_name=None,
model=None,
revision=None,
apply_chat_template=None,
fewshot_as_multiturn=None,
):
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
if isinstance(tasks, str):
tasks = [tasks]
for task in tasks:
num_fewshot = "-1"
task_parts = task.split(":")
task_name = task_parts[0]
if len(task_parts) == 2:
task_name, num_fewshot = task_parts
tasks_by_num_fewshot[str(num_fewshot)].append(task_name)
for num_fewshot, tasks_list in tasks_by_num_fewshot.items():
tasks_str = ",".join(tasks_list)
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
pretrained = "pretrained="
pretrained += model if model else output_dir
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
revision = f",revision={revision}" if revision else ""
output_path = output_dir
output_path += "" if output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
lm_eval_args = [
"lm_eval",
"--model",
"hf",
"--model_args",
f"{pretrained}{fa2}{dtype}{revision}",
"--tasks",
tasks_str,
"--batch_size",
str(batch_size),
"--output_path",
output_path,
]
wandb_args = []
if wandb_project:
wandb_args.append(f"project={wandb_project}")
if wandb_entity:
wandb_args.append(f"entity={wandb_entity}")
if wandb_name:
wandb_args.append(f"name={wandb_name}")
if wandb_args:
lm_eval_args.append("--wandb_args")
lm_eval_args.append(",".join(wandb_args))
if apply_chat_template:
lm_eval_args.append("--apply_chat_template")
if num_fewshot_val:
lm_eval_args.append("--num_fewshot")
lm_eval_args.append(str(num_fewshot_val))
if apply_chat_template and fewshot_as_multiturn:
lm_eval_args.append("--fewshot_as_multiturn")
yield lm_eval_args
@click.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
def lm_eval(config: str, cloud: Optional[str] = None):
"""
use lm eval to evaluate a trained language model
"""
if cloud:
from axolotl.cli.cloud import do_cli_lm_eval
do_cli_lm_eval(cloud_config=cloud, config=config)
else:
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention,
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
):
subprocess.run( # nosec
lm_eval_args,
check=True,
)

View File

@@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,44 +0,0 @@
# Low-rank Linear Conversion via Attention Transfer (LoLCATs)
https://github.com/HazyResearch/lolcats/
### Usage
Install `causal_dot_product` CUDA kernel (check the README in the `csrc` directory):
```bash
cd src/axolotl/integrations/lolcats/linear_llama/csrc
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
# nano setup.py
# Build the CUDA kernel
python setup.py install
```
Step 1:
```yaml
plugins:
- axolotl.integrations.lolcats.LinearizePlugin
linearize: true
```
Run axolotl: `python -m axolotl.cli.convert_linear_attention config.yaml` TODO: change path CLI
Step 2: Remove the config `linearize: true` and finetune with lora with below possible targets.
```yaml
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
# with optional config below but this requires patching axolotl
# to allow this config to work with lora
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
```
`axolotl train config.yaml --base-model={output_dir}/distilled --trust-remote-code --learning-rate=0.0001 # --wandb-project="..."`
Step 3: Run inference on the finetuned model
`axolotl inference config.yaml --lora-model-dir="{output_dir}" --trust-remote-code # --prompter="AlpacaPrompter"`

View File

@@ -1,43 +0,0 @@
"""
Module for the Plugin for LoLCATs linear attention integration with Axolotl.
Low-rank Linear Conversion via Attention Transfer
"""
import logging
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import (
DistillAttentionXentMSETrainer,
)
from .args import LinearAttentionArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.lolcats")
class LinearizePlugin(BasePlugin):
"""
Plugin for lolcats integration with Axolotl.
"""
def __init__(self):
super().__init__()
# Register the Linear Llama model with transformers
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
register_linear_llama,
)
register_linear_llama()
def get_input_args(self):
return "axolotl.integrations.lolcats.LinearAttentionArgs"
def get_trainer_cls(self, cfg):
# defualt to XentMSE
# TODO: add check to allow MSE_linear
if cfg.linearize:
return DistillAttentionXentMSETrainer
return None

View File

@@ -1,47 +0,0 @@
"""
Module for handling linear attention input arguments.
"""
from typing import Optional
from pydantic import BaseModel
class FeatureMapKwargs(BaseModel):
"""Args for feature map"""
eps: float
mlp: Optional[None] = None
fullspace: bool
class LearnedKernelKwargs(BaseModel):
"""Args for learned kernel"""
feature_dim: int
skip_connection: bool
bias: bool
zero_init: bool
class AttentionConfig(BaseModel):
"""Args for attention config"""
attention_type: str
feature_map: str
feature_map_kwargs: FeatureMapKwargs
layer_idx: Optional[None] = None
learned_kernel: str
learned_kernel_kwargs: LearnedKernelKwargs
tie_qk_kernels: bool
train_qk: bool
class LinearAttentionArgs(BaseModel):
"""
Input args for linear attention
"""
attention_config: AttentionConfig
linearize: Optional[bool] = False

View File

@@ -1,90 +0,0 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Linear LLaMA model configuration"""
from typing import Optional
from transformers import LlamaConfig
class LinearLlamaConfig(LlamaConfig):
"""
This is the configuration class to store the configuration of a [`LinearLlamaModel`].
It is a modified LlamaConfig that includes additional parameters for linear attention.
Args:
attention_config (`dict`):
Dictionary containing the configuration for linear attention mechanism.
Expected contents:
`attention_type` (str):
The type of attention to convert to.
`feature_map` (`str`):
The type of feature map to use for linear attention.
`feature_map_kwargs` (`dict`):
Additional arguments for the feature map.
`learned_kernel` (`str`, *optional*):
Type of learned kernel to use, if any.
`learned_kernel_kwargs` (`dict`, *optional*):
Additional arguments for the learned kernel.
`tie_qk_kernels` (`bool`, *optional*, defaults to False):
Whether to tie query and key kernels.
`rotary_config` (`dict`, *optional*):
Configuration for rotary embeddings.
`train_attention` (`bool`, *optional*, defaults to False):
Whether to train attention to match softmax attention.
`remove_base_attn` (`bool`, *optional*, defaults to True):
Whether to remove base attention after initialization.
`mask_value` (`int`, *optional*, defaults to 0):
Value to use for masking.
`eps` (`float`, *optional*, defaults to 1e-12):
Epsilon value for numerical stability.
`fp32_attention` (`bool`, *optional*, defaults to False):
Whether to use fp32 precision for attention computation.
`track_state_grads` (`bool`, *optional*, defaults to False):
Whether to track gradients of attention states.
**kwargs:
Additional arguments inherited from LlamaConfig.
"""
model_type = "linear_llama"
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
super().__init__(**kwargs)
# Set auto_map
self.auto_map = {
"AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
"AutoModel": "modeling_linear_llama.LinearLlamaModel",
"AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
}
# Set default attention config if none provided
self.attention_config = attention_config or {"attention_type": "softmax"}
@classmethod
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
"""
Instantiate a LinearLlamaConfig from a LlamaConfig and additional attention config.
Args:
llama_config (:class:`~transformers.LlamaConfig`):
The LlamaConfig to inherit from.
attention_config (`dict`):
Dictionary containing the configuration for linear attention mechanism.
"""
return cls(attention_config=attention_config, **llama_config.to_dict())

View File

@@ -1,30 +0,0 @@
# Causal linear attention CUDA kernel
Usage:
```bash
cd src/axolotl/integrations/lolcats/linear_llama/csrc
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
# nano setup.py
# Build the CUDA kernel
python setup.py install
```
Reference: https://github.com/idiap/fast-transformers/
```bib
@inproceedings{katharopoulos_et_al_2020,
author = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
year = {2020}
}
@article{vyas_et_al_2020,
author={Vyas, A. and Katharopoulos, A. and Fleuret, F.},
title={Fast Transformers with Clustered Attention},
booktitle = {Proceedings of the International Conference on Neural Information Processing Systems (NeurIPS)},
year={2020}
}
```

View File

@@ -1,6 +0,0 @@
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
from .causal_attention import causal_dot_product

View File

@@ -1,225 +0,0 @@
//
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
// Apoorv Vyas <avyas@idiap.ch>
//
#include <torch/extension.h>
/**
* Compute a*b^T and save it into out.
*
* a \in R^A
* b \in R^B
*/
inline void vvt_dot(float *a, float *b, float *out, int A, int B) {
for (int i=0; i<A; i++) {
float * bi = b;
for (int j=0; j<B; j++) {
*out += (*a) * (*bi);
out++;
bi++;
}
a++;
}
}
/**
* Implement a vector matrix product v*m and save it into out.
*
* v \in R^A
* m \in R^{AxB}
*/
inline void vm_dot(float *v, float *m, float *out, int A, int B) {
// TODO: Consider removing the zeroing part and assuming out already
// contains 0s
for (int i=0; i<B; i++) {
out[i] = 0;
}
for (int i=0; i<A; i++) {
float *oi = out;
for (int j=0; j<B; j++) {
*oi += (*v) * (*m);
oi++;
m++;
}
v++;
}
}
/**
* Implement a vector transposed-matrix product and save it into out.
*
* v \in R^B
* m \in R^{AxB}
*/
inline void vmt_dot(float *v, float *m, float *out, int A, int B) {
for (int i=0; i<A; i++) {
float *vi = v;
float s = 0;
for (int j=0; j<B; j++) {
s += (*vi) * (*m);
vi++;
m++;
}
// TODO: Should we be aggregating? See the comment on vm_dot.
*out = s;
out++;
}
}
/**
* Compute the causally masked dot products of queries, keys and values.
*
* Basically compute V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} for all j. The
* computation is done efficiently by changing the order of the dot products.
*/
void causal_dot_product(
const torch::Tensor queries,
const torch::Tensor keys,
const torch::Tensor values,
torch::Tensor product
) {
// Extract some shapes
int N = queries.size(0);
int H = queries.size(1);
int L = queries.size(2);
int E = queries.size(3);
int M = values.size(3);
// Create accessors for all the arguments
auto qa = queries.accessor<float, 4>();
auto ka = keys.accessor<float, 4>();
auto va = values.accessor<float, 4>();
auto pa = product.accessor<float, 4>();
#pragma omp parallel for collapse(2)
for (int n=0; n<N; n++) {
for (int h=0; h<H; h++) {
auto kv = torch::zeros({E, M}, queries.options());
float *kvp = kv.data_ptr<float>();
for (int l=0; l<L; l++) {
vvt_dot(
&ka[n][h][l][0],
&va[n][h][l][0],
kvp,
E,
M
);
vm_dot(
&qa[n][h][l][0],
kvp,
&pa[n][h][l][0],
E,
M
);
}
}
}
}
/**
* Compute the gradients of queries, keys and values given the gradient of the
* causal_dot_product output.
*
* Make sure that everything is computed in O(N D^2) complexity.
*/
void causal_dot_backward(
const torch::Tensor queries,
const torch::Tensor keys,
const torch::Tensor values,
const torch::Tensor grad_out,
torch::Tensor grad_queries,
torch::Tensor grad_keys,
torch::Tensor grad_values
) {
// Extract some shapes
int N = queries.size(0);
int H = queries.size(1);
int L = queries.size(2);
int E = queries.size(3);
int M = values.size(3);
// Create accessors for all the arguments
auto qa = queries.accessor<float, 4>();
auto ka = keys.accessor<float, 4>();
auto va = values.accessor<float, 4>();
auto ga = grad_out.accessor<float, 4>();
auto gqa = grad_queries.accessor<float, 4>();
auto gka = grad_keys.accessor<float, 4>();
auto gva = grad_values.accessor<float, 4>();
#pragma omp parallel for collapse(2)
for (int n=0; n<N; n++) {
for (int h=0; h<H; h++) {
auto kv = torch::zeros({E, M}, queries.options());
float *kvp = kv.data_ptr<float>();
// Compute the gradient wrt the queries
for (int l=0; l<L; l++) {
vvt_dot(
&ka[n][h][l][0],
&va[n][h][l][0],
kvp,
E,
M
);
vmt_dot(
&ga[n][h][l][0],
kvp,
&gqa[n][h][l][0],
E,
M
);
}
// Compute the gradient wrt the keys and values
kv.zero_();
for (int l=L-1; l>=0; l--) {
vvt_dot(
&qa[n][h][l][0],
&ga[n][h][l][0],
kvp,
E,
M
);
vmt_dot(
&va[n][h][l][0],
kvp,
&gka[n][h][l][0],
E,
M
);
vm_dot(
&ka[n][h][l][0],
kvp,
&gva[n][h][l][0],
E,
M
);
}
}
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"causal_dot_product",
&causal_dot_product,
"Compute the weighted sum of values but attending only to previous "
"values."
);
m.def(
"causal_dot_backward",
&causal_dot_backward,
"Compute the gradient of queries, keys and values given the gradient "
"of causal_dot_product."
);
}

View File

@@ -1,67 +0,0 @@
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
import torch
try:
from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda
from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda
except ImportError as e:
print(e)
causal_dot_product_cuda = causal_dot_backward_cuda = None
class CausalDotProduct(torch.autograd.Function):
"""Compute the weighted sum of values but attending only to previous
values."""
dot = {
# "cpu": causal_dot_product_cpu,
"cuda": causal_dot_product_cuda
}
dot_backward = {
# "cpu": causal_dot_backward_cpu,
"cuda": causal_dot_backward_cuda
}
@staticmethod
def forward(ctx, Q, K, V):
# Save the inputs for the gradient computation
ctx.save_for_backward(Q, K, V)
# Create the output tensor
device = Q.device
N, H, L, _ = Q.shape
_, _, _, M = V.shape
product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device)
# Actually perform the dot product
CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
# breakpoint()
# CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
return product
@staticmethod
def backward(ctx, grad_out):
# Extract the saved tensors
Q, K, V = ctx.saved_tensors
# Allocate memory for the gradients
grad_Q = torch.zeros_like(Q)
grad_K = torch.zeros_like(K)
grad_V = torch.zeros_like(V)
# Actually compute the gradients
CausalDotProduct.dot_backward[Q.device.type](
Q.data, K.data, V.data, grad_out, grad_Q, grad_K, grad_V
)
return grad_Q, grad_K, grad_V
# Alias the autograd functions to python style snake case naming
causal_dot_product = CausalDotProduct.apply

View File

@@ -1,65 +0,0 @@
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
import subprocess # nosec
import torch
from setuptools import setup
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
def get_last_arch_torch():
arch = torch.cuda.get_arch_list()[-1]
print(f"Found arch: {arch} from existing torch installation")
return arch
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True # nosec
)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
arch = get_last_arch_torch()
sm_num = arch[-2:]
cc_flag = ["--generate-code=arch=compute_90,code=compute_90"] # for H100
# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100
# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090
# cc_flag = ['--generate-code=arch=compute_75,code=compute_75']
setup(
name="causal_attention_cuda_cpp",
ext_modules=[
CUDAExtension(
"causal_attention_cuda",
[
# 'causal_attention.cpp',
"causal_attention_cuda.cu",
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(
["-O3", "-lineinfo", "--use_fast_math", "-std=c++17"] + cc_flag
),
},
)
],
cmdclass={"build_ext": BuildExtension},
)

View File

@@ -1,856 +0,0 @@
"""
Linear attention classes
"""
import copy
from typing import Any, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
# Causal linear attention dot product CUDA kernel from fast-transformers
try:
from csrc import causal_dot_product as fast_causal_dot_product
except ImportError:
fast_causal_dot_product = None
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
# -------------------
# Attention functions
# -------------------
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
Causal linear attention dot product
- If available, use CUDA kernel from fast-transformers
"""
if fast_causal_dot_product is None:
kv = torch.einsum("bhlf,bhld->bhlfd", k, v)
return torch.einsum("bhlf,bhlfd->bhld", q, kv.cumsum(dim=2))
return fast_causal_dot_product(q, k, v)
def linear_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
fp32_attention: bool = False,
eps: float = 1e-12,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Compute linear attention with CUDA kernel implementation from fast-transformers
- https://github.com/idiap/fast-transformers
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
v is shape (b, h, l, head_dim)
"""
dtype = q.dtype
# Causal mask already applied
y = causal_dot_product(
q.contiguous().to(dtype=torch.float32),
k.contiguous().to(dtype=torch.float32),
v.contiguous().to(dtype=torch.float32),
)
if fp32_attention:
y = (
y
/ (
torch.einsum("bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)) + eps
)[..., None]
).to(dtype=dtype)
else:
y = y.to(dtype=dtype)
k = k.float().cumsum(dim=2).to(dtype=dtype)
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
return y, None, None
def softmax_attention(
q: torch.Tensor,
k: torch.Tensor,
v: Optional[torch.Tensor] = None,
causal: bool = True,
fp32_attention: bool = True,
):
"""
Standard softmax attention; only compute outputs if v is not None
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
"""
y = None
a = torch.einsum("bhmd,bhnd->bhmn", q, k) * (k.shape[-1] ** -0.5)
if causal: # Apply causal mask
m, n = a.shape[-2:]
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
n - m + 1
)
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
if fp32_attention:
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
else:
a = torch.softmax(a, dim=-1)
if v is not None:
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
return y, a, None
def quadratic_attention(
q: torch.Tensor,
k: torch.Tensor,
v: Optional[torch.Tensor] = None,
causal: bool = True,
fp32_attention: bool = False,
eps: float = 1e-12,
):
"""
Compute attention with feature maps by instantiating L x L matrix of attention weights
-> Use for attention distillation
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
"""
y = None
dtype = q.dtype
if fp32_attention:
q, k = q.float(), k.float()
a = torch.einsum("bhmd,bhnd->bhmn", q, k) # note we don't scale, tho we could
if causal: # Apply causal mask
m, n = a.shape[-2:]
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
n - m + 1
)
a = a.masked_fill(causal_mask, 0)
# Normalize to compute attention
a = a / (a.sum(dim=-1, keepdim=True) + eps)
a = a.to(dtype=dtype) if fp32_attention else a
if torch.isnan(a).sum() > 0:
breakpoint()
if v is not None:
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
return y, a, None
# ---------------------
# Attention layer class
# ---------------------
class LolcatsLinearAttention(nn.Module):
"""
LoLCATs attention implementation initialized from a
`LlamaAttention` or `MistralAttention` object (base_attn)
Most of the arguments are directly tied to argparse args
- For now we don't support padding.
"""
def __init__(
self,
base_attn: nn.Module, # like LlamaAttention
feature_map: str,
feature_map_kwargs: dict,
layer_idx: Optional[int] = None,
max_layer_idx: Optional[int] = None,
learned_kernel: Optional[str] = None,
learned_kernel_kwargs: Optional[dict] = None,
tie_qk_kernels: Optional[bool] = False,
rotary_config: Optional[dict] = None,
train_attention: Optional[bool] = False,
remove_base_attn: bool = True,
attention_type: Optional[str] = "lolcats_llama",
mask_value: int = 0,
eps: float = 1e-12,
fp32_attention: bool = False,
track_state_grads: bool = False,
rank: Optional[int] = 0,
**kwargs,
) -> None:
super().__init__()
self.base_config = getattr(base_attn, "config", None)
if self.base_config is not None:
self.base_config = self.base_config.to_dict()
self.attention_type = attention_type
self.mask_value = mask_value
self.eps = eps
self.layer_idx = layer_idx if layer_idx is not None else base_attn.layer_idx
self.max_layer_idx = max_layer_idx
self.tie_qk_kernels = tie_qk_kernels
self.train_attention = train_attention
self.base_inference = False
self.fp32_attention = fp32_attention
self.track_state_grads = track_state_grads
if rank == 0: # multi-gpu
if fp32_attention and layer_idx == 0:
print(f"-> fp32_attention is {fp32_attention}")
if layer_idx == 0 and feature_map_kwargs is not None:
for k, v in feature_map_kwargs.items():
print(f"-> {k}: {v}")
if layer_idx == 0 and learned_kernel_kwargs is not None:
for k, v in learned_kernel_kwargs.items():
print(f"-> {k}: {v}")
self.remove_base_attn = remove_base_attn
self.init_weights_(base_attn, remove_base_attn)
self.init_feature_map_(
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
)
def init_feature_map_(
self,
feature_map: str,
feature_map_kwargs: dict,
learned_kernel: Optional[str] = None,
learned_kernel_kwargs: Optional[dict] = None,
):
"""
Initialize MLP-based feature map
"""
self.fmap_gqa = False # Turn True if specified below
if learned_kernel is not None and learned_kernel_kwargs is not None:
# Ensure dict
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
learned_kernel_kwargs["num_heads"] = self.num_heads
learned_kernel_kwargs["head_dim"] = self.head_dim
learned_kernel_kwargs["dtype"] = self.q_proj.weight.dtype
learned_kernel_kwargs["device"] = self.q_proj.weight.device
# Create MLP
mlp_learned_kernel = init_learned_kernel(
learned_kernel, **learned_kernel_kwargs
)
# Add "activation"; see src.models.feature_map.py
self.feature_map_q = init_feature_map(
name=feature_map, mlp=mlp_learned_kernel, **feature_map_kwargs
)
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
self.feature_map_k = self.feature_map_q
else:
self.feature_map_k = copy.deepcopy(self.feature_map_q)
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
"""
Initialize module layers, weights, positional dependencies, etc.
from original softmax attention layer (base_attn)
"""
# Make other attributes accessible
self.attention_dropout = 0 # We don't use dropout
self.hidden_size = base_attn.config.hidden_size
self.num_heads = base_attn.config.num_attention_heads
self.head_dim = base_attn.head_dim
self.num_key_value_heads = base_attn.config.num_key_value_heads
self.num_key_value_groups = base_attn.num_key_value_groups
self.q_shape = [self.num_heads, self.head_dim]
self.k_shape = [self.num_key_value_heads, self.head_dim]
self.v_shape = [self.num_key_value_heads, self.head_dim]
# Copy original model projection layers
self.q_proj = base_attn.q_proj
self.k_proj = base_attn.k_proj
self.v_proj = base_attn.v_proj
self.o_proj = base_attn.o_proj
try: # If wanting to use FA2 for ground-truth inference
self._flash_attn_uses_top_left_mask = (
base_attn._flash_attn_uses_top_left_mask
)
except AttributeError:
pass
if self.remove_base_attn or remove_base_attn:
del base_attn # We don't need to keep these around
else:
self.base_attn = base_attn # For some training runs helpful to just call
def process_qkv(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[Any] = None,
):
"""
Compute queries, keys, and values
"""
b, l, _ = hidden_states.size()
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
kv_seq_len = k.shape[-2]
# Shape is (batch_size, seq_len, num_heads, head_dim)
q = q.view(b, l, *self.q_shape).transpose(1, 2)
k = k.view(b, l, *self.k_shape).transpose(1, 2)
v = v.view(b, l, *self.v_shape).transpose(1, 2)
if (
past_key_value is not None
): # and k.shape[2] > q.shape[2]: # e.g., when generating
past_key_value.window_size = getattr(
self, "decode_window_size", None
) # self.decode_window_size
if isinstance(
past_key_value, Cache
): # In Transformers v4.36+ this is a DynamicCache object
kv_seq_len += past_key_value.get_usable_length(
kv_seq_len, self.layer_idx
)
else:
kv_seq_len += past_key_value[0].shape[-2]
# Apply rotary embeddings
if position_embeddings is not None:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin)
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
return q, k, v, kv_seq_len
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[Any] = None, # "legacy" cache approach
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
- Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_embeddings, past_key_value
)
if self.base_inference:
with torch.no_grad():
# 1. Compute "ground-truth" attention output and weights
y_true, _, _ = softmax_attention(q, k, v, causal=True)
y_true = (
y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
)
y_true = self.o_proj(y_true)
attn_weights = (None, None)
elif self.train_attention: # Distilling / learning attentions
# Note for now we assume no padding when distilling; attention masks only enforce causality
assert (
output_attentions is True
), f"When training feature maps, output_attentions should be True but is {output_attentions}"
with torch.no_grad():
# 1. Compute "ground-truth" attention output and weights
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
y_true = (
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
)
y_true = self.o_proj(y_true)
# 2. Compute "predicted" attention (just weights)
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
attn_weights = ( # type: ignore
(attn_pred, attn_true),
(y_pred, _y_true),
) # Save both attention weights so we can supervise.
else: # Finetuning
q, k = self.feature_map_q(q), self.feature_map_k(k)
# Apply prefill mask
if attention_mask is not None and q.shape[2] > 1:
if len(attention_mask.shape) == 4:
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][
..., None
] # b, 1, k_len, 1
else:
lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
k = k.masked_fill(~lin_attn_mask, 0)
if past_key_value is not None: # Initialize states
if len(past_key_value.kv_states) == self.layer_idx:
b, h, _, f = k.shape
past_key_value.kv_states.append(
torch.zeros(
b, h, f, self.head_dim, dtype=q.dtype, device=q.device
)
)
past_key_value.k_states.append(
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
)
# Generating
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
assert use_cache is True
kv_state, k_state = past_key_value.update(
k, v, self.layer_idx, accumulate_in_fp32=self.fp32_attention
)
if self.fp32_attention:
q = q.float()
y_true = (
torch.einsum("bhlf,bhfd->bhld", q, kv_state.float())
/ torch.einsum("bhlf,bhlf->bhl", q, k_state.float())[
..., None
]
).to(dtype=k.dtype)
else:
y_true = (
torch.einsum("bhlf,bhfd->bhld", q, kv_state)
/ torch.einsum("bhlf,bhlf->bhl", q, k_state)[..., None]
)
else:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
y_true, _, _ = linear_attention(
q, k, v, self.fp32_attention, self.eps
) # Ordinarily the states are ignored
past_key_value.update(
k.detach(),
v.detach(),
self.layer_idx,
accumulate_in_fp32=self.fp32_attention,
)
# doing some unnecessary recomputation here
else:
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
# Concatenate heads and apply output projection
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
attn_weights = None
return y_true, attn_weights
class LinearAttentionState(Cache):
"""
Handle the KV and K states for linear attention
- Adopts HF Transformers `past_key_values` convention
- Inherits from `Cache` class
- Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self) -> None:
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""
Returns the sequence length of the cached states. A layer index can be optionally passed.
"""
if layer_idx is None:
raise ValueError("Layer index must not be None")
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
self._seen_tokens_by_layer.append(0)
return self._seen_tokens_by_layer[layer_idx]
def get_max_length(self) -> Optional[int]:
"""
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
"""
return None
def get_usable_length(
self, new_seq_length: int, layer_idx: Optional[int] = 0
) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_length()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: Optional[int] = None,
cache_kwargs: Optional[Any] = None,
accumulate_in_fp32: bool = True,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
if layer_idx is None:
raise ValueError("Layer index must not be None")
with torch.no_grad():
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
key_states, value_states = key_states.float(), value_states.float()
kv_state = torch.einsum(
"bhlf,bhld->bhfd", key_states, value_states
).detach()
k_state = key_states.sum(
dim=-2, keepdim=True
).detach() # b, h, 1, f; note the 1
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
print(
"if len(self.k_states) <= layer_idx: # Initializing kv and k states"
)
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
else:
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
dtype
)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
dtype
)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def to_legacy_cache(self):
"""Hack, but just return self"""
return self
def reorder_cache(self, beam_idx: torch.LongTensor):
"""
Reorders the cache for beam search, given the selected beam indices.
-> Copied from transformers/src/transformers/cache_utils.py
"""
raise NotImplementedError(
"Reordering cache not implemented for LinearAttentionState"
)
# -------------------
# feature map functions
# -------------------
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
"""
Initialize feature map final activation for linear attention
"""
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
"""
Initialize feature map final activation for linear attention
"""
if name == "softmax_dim" and fullspace:
return SoftmaxDim(**kwargs)
elif name == "softmax_dim" and not fullspace:
return SoftmaxDimHalfspace(**kwargs)
elif name == "exp_dim" and fullspace:
return Exp(**kwargs)
elif name == "exp_dim" and not fullspace:
return ExpHalfspace(**kwargs)
elif name == "pos_elu":
return PosELU(**kwargs)
elif name == "relu":
return ReLU(**kwargs)
else:
raise NotImplementedError
def init_learned_kernel(name: str, **kwargs):
"""
Initialize feature map MLP for linear attention
"""
if name == "untied_head_einsum":
return FeatureMapMLP(**kwargs)
elif name == "untied_head_adapter":
return FeatureMapAdapter(**kwargs)
else:
raise NotImplementedError
class FeatureMap(nn.Module):
"""
Final 'activation' of feature map. Can probably be combined with
`FeatureMapMLP` below
Full feature map is like f(xW + b)
-> This is the `f` part
"""
def __init__(
self,
activation_name: str,
head_dim_idx: int = -1,
eps: float = 1e-12,
mlp: Optional[nn.Module] = None,
fullspace: bool = True,
):
super().__init__()
self.head_dim_idx = head_dim_idx
self.eps = eps
self.mlp = mlp if mlp is not None else nn.Identity()
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
"""
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
"""
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
def q_map(self, *args, **kwargs):
"""
Use for inference in case q and k feature maps differ
"""
return self.forward(*args, **kwargs)
def k_map(self, *args, **kwargs):
"""
Use for inference in case q and k feature maps differ
"""
return self.forward(*args, **kwargs)
# -----------------------
# Feature map activations
# -----------------------
class FeatureMapAct(nn.Module):
"""
Base class for feature map activations
"""
def __init__(self, eps: float = 1e-12):
super().__init__()
self.eps = eps
def forward(self, x: torch.Tensor, *args, **kwargs):
"""
x.shape is (batch_size, n_heads, seq_len, head_dim)
"""
return x
class PosELU(FeatureMapAct):
"""
1 + ELU activation as in https://arxiv.org/abs/2006.16236
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return (1 + F.elu(x)).clamp(min=self.eps)
class ReLU(FeatureMapAct):
"""
ReLU activation as in https://arxiv.org/abs/2103.13076
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return F.relu(x).clamp(min=self.eps)
class SoftmaxDim(FeatureMapAct):
"""
Softmax activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return torch.cat(
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
).clamp(min=self.eps)
class SoftmaxDimHalfspace(FeatureMapAct):
"""
Softmax activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return torch.softmax(x, dim=-1).clamp(min=self.eps)
class Exp(FeatureMapAct):
"""
Exp activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
x_max = torch.amax(x, dim=-1, keepdim=True)
x_min = torch.amin(x, dim=-1, keepdim=True)
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
min=self.eps
)
class ExpHalfspace(FeatureMapAct):
"""
Exp activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
x_max = torch.amax(x, dim=-1, keepdim=True)
return torch.exp(x - x_max).clamp(min=self.eps)
# ----------------
# Feature map MLPs
# ----------------
class FeatureMapMLP(nn.Module):
"""
Learnable MLP in feature map.
Full feature map is like f(xW + b)
-> This is the `W` and (optional) `b` part
"""
def __init__(
self,
num_heads: int,
head_dim: int, # input dim
feature_dim: int, # output dim
dtype: torch.dtype,
device: torch.device,
skip_connection: bool = False,
bias: bool = False,
zero_init: bool = False,
normal_init: bool = False,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.feature_dim = feature_dim
self.dtype = dtype
self.device = device
self.skip_connection = skip_connection
self.bias = bias
self.zero_init = zero_init
self.normal_init = normal_init
self.init_weights_()
if self.zero_init: # Zero-out weights or set as identity post-initialization
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
if self.normal_init:
with torch.no_grad():
nn.init.normal_(self.layer)
if self.skip_connection:
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
assert self.head_dim == self.feature_dim, assertion_fail
def init_weights_(self):
"""
Initialize (W)eights and (b)iases
"""
self.layer = nn.Parameter(
torch.zeros(
(self.num_heads, self.head_dim, self.feature_dim),
dtype=self.dtype,
device=self.device,
)
)
nn.init.kaiming_uniform_(self.layer)
if self.bias:
self.bias = nn.Parameter(
torch.zeros(
(1, self.num_heads, 1, 1), # self.feature_dim),
dtype=self.dtype,
device=self.device,
)
)
nn.init.kaiming_uniform_(self.bias)
else:
self.bias = 0.0 # hack
def zero_init_with_skip_(self):
"""
Initialize weights to zero matrix if skip connection
"""
with torch.no_grad():
nn.init.zeros_(self.layer)
def zero_init_(self):
"""
Initialize weights to identity matrix if no skip connection
"""
with torch.no_grad():
for i in range(self.layer.shape[0]):
try:
nn.init.eye_(self.layer[i])
except RuntimeError:
with torch.no_grad():
dtype = self.layer[i].dtype
weight = torch.eye(
*self.layer[i].shape,
requires_grad=self.layer[i].requires_grad,
device=self.layer[i].device,
)
self.layer[i] = weight.to(dtype=dtype)
def forward(self, x: torch.Tensor):
"""
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
"""
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
return x + _x if self.skip_connection else _x
class FeatureMapAdapter(FeatureMapMLP):
"""
Learnable Feature map with bottleneck adapter
as in https://arxiv.org/abs/1902.00751
We don't use but could be fun to try
"""
def __init__(self, hidden_dim: int, *args, **kwargs):
kwargs["skip_connection"] = True
kwargs["bias"] = True
kwargs["zero_init"] = True
self.hidden_dim = hidden_dim
super().__init__(*args, **kwargs)
def init_weights_(self):
"""
Initialize (W)eights and (b)iases
"""
kwargs = {"dtype": self.dtype, "device": self.device}
self.layer0 = nn.Parameter(
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
)
self.layer1 = nn.Parameter(
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
)
nn.init.kaiming_uniform_(self.layer0)
nn.init.kaiming_uniform_(self.layer1)
self.bias0 = nn.Parameter(
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
)
self.bias1 = nn.Parameter(
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
)
nn.init.kaiming_uniform_(self.bias0)
nn.init.kaiming_uniform_(self.bias1)
def zero_init_with_skip_(self):
with torch.no_grad():
nn.init.zeros_(self.layer0)
nn.init.zeros_(self.layer1)
nn.init.zeros_(self.bias0)
nn.init.zeros_(self.bias1)
def zero_init_(self):
raise NotImplementedError
def forward(self, x: torch.Tensor):
"""
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
-> Down-project, apply nonlinearity, up-project; add skip connection
"""
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
_x = F.relu(_x)
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
return x + _x if self.skip_connection else _x

View File

@@ -1,460 +0,0 @@
"""
Subquadratic attention combining sliding window and linear attentions
- Using "standard" sliding windows
- Didactically computes outputs with n^2 attention weights for now
- Copied + adapted from linear_window_attention_tk.py for single-file reference
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
from .linear_attention import (
LinearAttentionState,
LolcatsLinearAttention,
softmax_attention,
)
# ----------------------
# Sliding window helpers
# ----------------------
def get_masks(
window_size: int, q_len: int, k_len: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return masks for softmax and linear attention terms
-> 1 is include, 0 is ignore
"""
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
k_len - q_len
)
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
k_len - q_len - window_size
)
window_mask = causal_mask - linear_mask
# Return softmax mask (window), linear attention mask
# -> shapes broadcast over (b, h, q_len, k_len)
return window_mask[None, None, ...], linear_mask[None, None, ...]
def hybrid_attention_quadratic(
q: torch.Tensor,
k: torch.Tensor,
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_factor: torch.Tensor,
linear_factor: torch.Tensor,
window_size: int,
kv_state: Optional[torch.Tensor] = None,
k_state: Optional[torch.Tensor] = None,
eps: float = 1e-12,
mask_value: float = -1e8,
):
"""
Hybrid attention combining sliding window and linear attentions
"""
mask_window, mask_linear = get_masks(
window_size, q.shape[-2], k.shape[-2], q.device
)
# 1. Sliding window (softmax attention)
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# 2. Under window (linear attention)
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
sum_ln = a_ln.sum(dim=-1, keepdim=True)
# 3. Combine
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
# Allow outputs to also depend on prior kv_state and k_state
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
if (
kv_state is not None and k_state is not None
): # Combine with prior kv_state and k_state
y += linear_factor * torch.einsum(
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
)
sum_ln += (
linear_factor
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
)
y = (y / (sum_sm + sum_ln)).to(q.dtype)
return y, a # attention weights only for the last chunk
# ---------------------
# Attention layer class
# ---------------------
class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(
self,
window_size: int = 64,
decode_window_size: Optional[int] = None,
affine_attention_factors: bool = False,
init_window_factor: float = 0,
train_window_factor: bool = True,
state_grad_enabled: bool = False,
**kwargs,
):
self.window_size = window_size
self.decode_window_size = (
decode_window_size if decode_window_size is not None else window_size
)
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
super().__init__(**kwargs)
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_sw'
# Determine how we compute attentions
self.quadratic_attention = hybrid_attention_quadratic
self.attention_type = kwargs[
"attention_type"
] # 'hedgehog_long_llama_window_sw'
# Learnable factor for combining attentions
self.affine_attention_factors = affine_attention_factors
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
if train_window_factor:
self.window_factors = nn.Parameter(
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
)
else:
self.register_buffer(
"window_factors",
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
)
# Whether we use original flash attention 2 inference (use during attention transfer)
self.base_inference = False
self.state_grad_enabled = state_grad_enabled
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
k
) # Have to do after repeat for grouped-query attn if we use same fmap
if self.train_attention:
# 1. Compute "ground-truth" attention output and weights
with torch.no_grad():
_y_true, a_true = softmax_attention(q, k, v)[:2]
y_true = (
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
)
y_true = self.o_proj(y_true)
# 2. Compute "predicted" attention outputs
# compute attn weights under sliding window
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
y_pred, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
else:
attn_weights = None
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = a_pred
else:
past_key_value.window_size = self.decode_window_size
if (
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
): # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
# Softmax attention terms
a_sm = torch.einsum(
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
) * (k.shape[-1] ** -0.5)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum(
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
)[..., None]
)
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, _ = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state,
)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(
k,
v,
self.layer_idx,
fmap_key_states=f_k,
accumulate_in_fp32=True,
)
# Concatenate heads and apply output projection
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
return y_true, attn_weights, past_key_value
class LinearAttentionSlidingWindowCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
# Account for sliding windows
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
self.window_size = window_size
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: Optional[int] = None,
cache_kwargs: Optional[Any] = None,
accumulate_in_fp32: bool = False,
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
grad_enabled: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update KV, K states; and KV cache during training
- For decoding, use `self.decode_kv_states` to keep track of KV states
up to sliding window terms
- For (chunked) training, use `self.kv_states` to keep track of KV states
up to end of sequence
- Likewise for `self.decode_k_states` and `self.k_states`
"""
if fmap_key_states is None:
raise ValueError("fmap_key_states must not be None")
if layer_idx is None:
raise ValueError("Layer index must not be None")
with torch.set_grad_enabled(grad_enabled):
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
# key_states = key_states.float()
fmap_key_states = fmap_key_states.float()
value_states = value_states.float()
# Decoding KV state (KV terms up to last window_size)
decode_kv_state = torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, : -self.window_size],
value_states[:, :, : -self.window_size],
)
# KV state
kv_state = decode_kv_state + torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, -self.window_size :],
value_states[:, :, -self.window_size :],
)
# shape is b, h, 1, f; note the 1
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
dim=-2, keepdim=True
)
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
dim=-2, keepdim=True
)
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
self.decode_kv_states.append(decode_kv_state.to(dtype))
self.decode_k_states.append(decode_k_state.to(dtype))
self.k_cache.append(key_states[:, :, -self.window_size :, :])
self.v_cache.append(
value_states[:, :, -self.window_size :, :].to(dtype)
)
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
else:
# Update kv and k states recurrently
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
dtype
)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
dtype
)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
decode_kv_state = (
self.decode_kv_states[layer_idx].to(kv_state.dtype)
+ decode_kv_state
).to(dtype)
decode_k_state = (
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
).to(dtype)
self.decode_kv_states[layer_idx] = decode_kv_state
self.decode_k_states[layer_idx] = decode_k_state
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def update_for_decoding(
self,
keys: torch.Tensor,
values: torch.Tensor,
layer_idx: int,
feature_map_k: Callable,
dtype: torch.dtype,
):
"""
Update the decoding KV and K states, and KV cache, during decodeing
"""
with torch.no_grad():
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
if k_cache.shape[-2] < self.window_size: # build window-size cache
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
else:
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
# else:
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum(
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
).to(
dtype
) # b, h, f, d
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat(
[k_cache[:, :, 1:, :], keys], dim=-2
)
self.v_cache[layer_idx] = torch.cat(
[v_cache[:, :, 1:, :], values], dim=-2
)
if layer_idx == 0:
self._seen_tokens += keys.shape[-2]
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
return (
self.k_cache[layer_idx],
self.v_cache[layer_idx],
self.decode_kv_states[layer_idx],
self.decode_k_states[layer_idx],
)

View File

@@ -1,685 +0,0 @@
"""
Subquadratic attention combining sliding window and linear attentions
- Using "standard" sliding windows
- Didactically computes outputs with n^2 attention weights for now
- Copied + adapted from linear_window_attention_tk.py for single-file reference
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
import logging
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
try:
from transformers.modeling_flash_attention_utils import _flash_attention_forward
except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
# Causal linear attention dot product CUDA kernel from fast-transformers
from .linear_attention import (
LinearAttentionState,
LolcatsLinearAttention,
causal_dot_product,
)
LOG = logging.getLogger(__name__)
# ----------------------
# Sliding window helpers
# ----------------------
def get_masks(
window_size: int, q_len: int, k_len: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return masks for softmax and linear attention terms
-> 1 is include, 0 is ignore
"""
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
max(k_len - q_len, 0)
)
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
max(k_len - q_len, 0) - window_size
)
window_mask = causal_mask - linear_mask
# Return softmax mask (window), linear attention mask
# -> shapes broadcast over (b, h, q_len, k_len)
return window_mask[None, None, ...], linear_mask[None, None, ...]
def hybrid_attention_quadratic(
q: torch.Tensor,
k: torch.Tensor,
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_factor: torch.Tensor,
linear_factor: torch.Tensor,
window_size: int,
kv_state: Optional[torch.Tensor] = None,
k_state: Optional[torch.Tensor] = None,
eps: float = 1e-12,
mask_value: float = -1e8,
):
"""
Hybrid attention combining sliding window and linear attentions
"""
mask_window, mask_linear = get_masks(
window_size, q.shape[-2], k.shape[-2], q.device
)
# 1. Sliding window (softmax attention)
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# 2. Under window (linear attention)
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
sum_ln = a_ln.sum(dim=-1, keepdim=True)
# 3. Combine
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
# Allow outputs to also depend on prior kv_state and k_state
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
if (
kv_state is not None and k_state is not None
): # Combine with prior kv_state and k_state
y += linear_factor * torch.einsum(
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
)
sum_ln += (
linear_factor
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
)
y = (y / (sum_sm + sum_ln)).to(q.dtype)
return y, a # attention weights only for the last chunk
# ------------------------------
# Hybrid window attention linear
# ------------------------------
def under_window_linear_attention(
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_size: int,
linear_factor: torch.Tensor,
eps: float = 1e-12,
):
"""Compute hybrid window attention dot product with linear complexity in q_len"""
dtype = f_q.dtype
w = window_size
f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :]
v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :]
qkv = linear_factor * causal_dot_product(
f_q.contiguous().to(dtype=torch.float32),
f_k.contiguous().to(dtype=torch.float32),
v.contiguous().to(dtype=torch.float32),
).to(dtype=dtype)
sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype)
sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None]
sum_qk[sum_qk == 0] += eps
return qkv, sum_qk
def sliding_window_softmax_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
window_size: int,
window_factor: torch.Tensor,
mask_value: float = -1e8,
):
"""
Compute sliding window softmax attention without materializing
O(seq_len^2) attention weights
"""
d = q.shape[-1]
# Compute windows for keys
window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
# Compute windowed_softmax(qk); causal in its construction
a_sm = torch.einsum("bhld,bhldw->bhlw", q, k) * (d**-0.5)
a_sm[a_sm == 0] = -torch.finfo(
q.dtype
).max # heuristic for zeroing out padding above
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
return torch.einsum("bhlw,bhldw->bhld", a_sm, v), sum_sm
# return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v)
def hybrid_attention_linear(
q: torch.Tensor,
k: torch.Tensor,
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_factor: Optional[torch.Tensor] = None,
linear_factor: Optional[torch.Tensor] = None,
window_size: int = 64,
kv_state: Optional[torch.Tensor] = None,
k_state: Optional[torch.Tensor] = None,
eps: float = 1e-12,
mask_value: float = -1e8,
):
"""
Alternative hybrid attention combining sliding window and linear attentions
-> Uses O(n) memory if n is sequence length by padding and unfolding windows
"""
# window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
if window_factor is None:
raise ValueError("window_factor must be provided")
if linear_factor is None:
raise ValueError("linear_factor must be provided")
# 1. Sliding window (softmax attention)
with torch.no_grad():
qkv_sm, sum_qk_sm = sliding_window_softmax_attention(
q, k, v, window_size, window_factor, mask_value
)
# 2. Under window (linear attention)
qkv_ln, sum_qk_ln = under_window_linear_attention(
f_q, f_k, v, window_size, linear_factor, eps
)
# 3. Combine
y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln)
return y, None
# ---------------------
# Attention layer class
# ---------------------
class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(
self,
window_size: int = 64,
decode_window_size: Optional[int] = None,
affine_attention_factors: bool = False,
init_window_factor: float = 0,
train_window_factor: bool = True,
state_grad_enabled: bool = False,
**kwargs,
):
self.window_size = window_size
self.decode_window_size = (
decode_window_size if decode_window_size is not None else window_size
)
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
super().__init__(**kwargs)
# Determine how we compute attentions
self.linear_attention = hybrid_attention_linear
self.attention_type = "lolcats_llama_window_sw"
# Learnable factor for combining attentions
self.affine_attention_factors = affine_attention_factors
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
if train_window_factor:
self.window_factors = nn.Parameter(
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
)
else:
self.register_buffer(
"window_factors",
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
)
# Whether we use original flash attention 2 inference (use during attention transfer)
self.base_inference = False
self.state_grad_enabled = state_grad_enabled
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
if self.train_attention and self.base_inference:
with torch.no_grad():
_y_true = flash_attention_2(
self, # self.base_attn,
hidden_states=hidden_states,
attention_mask=None,
position_ids=position_ids,
past_key_value=None,
output_attentions=False,
use_cache=False,
)[0]
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
y_true = _y_true.reshape(b, l, -1).contiguous()
y_true = self.o_proj(y_true)
# layer_io = (hidden_states, _y_true) # hack
layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
return y_true, layer_io, None
else:
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
k
) # Have to do after repeat for grouped-query attn if we use same fmap
attn_weights = None
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, a_pred = self.linear_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = a_pred
else:
past_key_value.window_size = self.decode_window_size
if (
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
): # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
# Softmax attention terms
a_sm = torch.einsum(
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
) * (k.shape[-1] ** -0.5)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum(
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
)[..., None]
)
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, _ = self.linear_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state,
)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(
k,
v,
self.layer_idx,
fmap_key_states=f_k,
accumulate_in_fp32=True,
)
# Concatenate heads and apply output projection
_y_true = y_true.transpose(1, 2).contiguous()
y_true = self.o_proj(_y_true.view(b, l, self.hidden_size))
if self.train_attention:
attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d)
return y_true, attn_weights, past_key_value
class LinearAttentionSlidingWindowCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
# Account for sliding windows
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
self.window_size = window_size
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: Optional[int] = None,
cache_kwargs: Optional[Any] = None,
accumulate_in_fp32: bool = False,
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
grad_enabled: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update KV, K states; and KV cache during training
- For decoding, use `self.decode_kv_states` to keep track of KV states
up to sliding window terms
- For (chunked) training, use `self.kv_states` to keep track of KV states
up to end of sequence
- Likewise for `self.decode_k_states` and `self.k_states`
"""
if fmap_key_states is None:
raise ValueError("fmap_key_states must not be None")
if layer_idx is None:
raise ValueError("Layer index must not be None")
with torch.set_grad_enabled(grad_enabled):
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
# key_states = key_states.float()
fmap_key_states = fmap_key_states.float()
value_states = value_states.float()
# Decoding KV state (KV terms up to last window_size)
decode_kv_state = torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, : -self.window_size],
value_states[:, :, : -self.window_size],
)
# KV state
kv_state = decode_kv_state + torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, -self.window_size :],
value_states[:, :, -self.window_size :],
)
# shape is b, h, 1, f; note the 1
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
dim=-2, keepdim=True
)
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
dim=-2, keepdim=True
)
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
self.decode_kv_states.append(decode_kv_state.to(dtype))
self.decode_k_states.append(decode_k_state.to(dtype))
self.k_cache.append(key_states[:, :, -self.window_size :, :])
self.v_cache.append(
value_states[:, :, -self.window_size :, :].to(dtype)
)
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
else:
# Update kv and k states recurrently
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
dtype
)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
dtype
)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
decode_kv_state = (
self.decode_kv_states[layer_idx].to(kv_state.dtype)
+ decode_kv_state
).to(dtype)
decode_k_state = (
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
).to(dtype)
self.decode_kv_states[layer_idx] = decode_kv_state
self.decode_k_states[layer_idx] = decode_k_state
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def update_for_decoding(
self,
keys: torch.Tensor,
values: torch.Tensor,
layer_idx: int,
feature_map_k: Callable,
dtype: torch.dtype,
):
"""
Update the decoding KV and K states, and KV cache, during decodeing
"""
with torch.no_grad():
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
if k_cache.shape[-2] < self.window_size: # build window-size cache
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
else:
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
# else:
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum(
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
).to(
dtype
) # b, h, f, d
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat(
[k_cache[:, :, 1:, :], keys], dim=-2
)
self.v_cache[layer_idx] = torch.cat(
[v_cache[:, :, 1:, :], values], dim=-2
)
if layer_idx == 0:
self._seen_tokens += keys.shape[-2]
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
return (
self.k_cache[layer_idx],
self.v_cache[layer_idx],
self.decode_kv_states[layer_idx],
self.decode_k_states[layer_idx],
)
# -----------------
# Flash Attention 2
# -----------------
def flash_attention_2(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
"""
Wrapper for LlamaFlashAttention2
Copied and modified from HF Transformers v4.36 and v4.43 implementations
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
try: # As in Transformers v4.36
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
except Exception: # As in Transformers v4.39
cos, sin = self.rotary_emb(key_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
LOG.debug(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self, "_flash_attention_forward", False):
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=True,
)
else:
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0, # dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=True,
)
return attn_output, past_key_value

View File

@@ -1,24 +0,0 @@
"""
LoLCATs attention combining sliding window and linear attentions
- Using standard sliding window arrangement
- Training over long sequences with fixed memory with recurrent view
- During attention transfer, use Flash Attention to compute softmax attention outputs
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
from .linear_window_attention_sw import hybrid_attention_quadratic
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(self, remove_base_attn=True, **kwargs):
# keep self.base_attn for Flash Attention inference
super().__init__(remove_base_attn=True, **kwargs)
self.quadratic_attention = hybrid_attention_quadratic

View File

@@ -1,466 +0,0 @@
"""
Subquadratic attention combining sliding window and linear attentions
- Using the TK "terracing" arrangement
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
import math
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
from .linear_attention import (
LinearAttentionState,
LolcatsLinearAttention,
softmax_attention,
)
# ----------------------
# Sliding window helpers
# ----------------------
def get_masks(
window_size: int, q_len: int, k_len: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return masks for softmax and linear attention terms
-> 1 is include, 0 is ignore
"""
win_len = window_size
m = math.ceil(max(q_len, k_len) / window_size)
# Creates an n x n mask where n = window_size^2
mask = torch.block_diag(
*[
torch.ones(
(win_len, win_len),
)
]
* m
)
mask += torch.roll(mask, -win_len, -1) # this adds the terracing
if mask.shape[0] > q_len:
mask = mask[-q_len:]
if mask.shape[1] > k_len:
mask = mask[:, -k_len:]
# Return softmax mask (window), linear attention mask
mask = mask[None, None, ...] # b, h, q_len, k_len
return (
torch.tril(mask).to(device=device, dtype=torch.int),
torch.tril(1 - mask).to(device=device, dtype=torch.int),
)
def hybrid_attention_quadratic(
q: torch.Tensor,
k: torch.Tensor,
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_factor: torch.Tensor,
linear_factor: torch.Tensor,
window_size: int,
kv_state: Optional[torch.Tensor] = None,
k_state: Optional[torch.Tensor] = None,
eps: float = 1e-12,
mask_value: float = -1e8,
):
"""
Hybrid attention combining sliding window and linear attentions
"""
mask_window, mask_linear = get_masks(
window_size, q.shape[-2], k.shape[-2], q.device
)
# 1. Sliding window (softmax attention)
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# 2. Under window (linear attention)
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
sum_ln = a_ln.sum(dim=-1, keepdim=True)
# 3. Combine
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
# Allow outputs to also depend on prior kv_state and k_state
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
if (
kv_state is not None and k_state is not None
): # Combine with prior kv_state and k_state
y += linear_factor * torch.einsum(
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
)
sum_ln += (
linear_factor
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
)
y = (y / (sum_sm + sum_ln)).to(q.dtype)
return y, a # attention weights only for the last chunk
# ---------------------
# Attention layer class
# ---------------------
class LolcatsTKWindowAttention(LolcatsLinearAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(
self,
window_size: int = 64,
decode_window_size: Optional[int] = None,
affine_attention_factors: bool = False,
init_window_factor: float = 0,
train_window_factor: bool = True,
state_grad_enabled: bool = False,
**kwargs,
):
self.window_size = window_size
self.decode_window_size = (
decode_window_size if decode_window_size is not None else window_size
)
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
super().__init__(**kwargs)
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_tk'
# Determine how we compute attentions
self.quadratic_attention = hybrid_attention_quadratic
self.attention_type = kwargs[
"attention_type"
] # 'hedgehog_long_llama_window_tk'
# Learnable factor for combining attentions
self.affine_attention_factors = affine_attention_factors
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
if train_window_factor:
self.window_factors = nn.Parameter(
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
)
else:
self.register_buffer(
"window_factors",
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
)
# Whether we use original flash attention 2 inference (use during attention transfer)
self.base_inference = False
self.state_grad_enabled = state_grad_enabled
self.window_factor = self.window_factors # legacy naming support
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
k
) # Have to do after repeat for grouped-query attn if we use same fmap
if self.train_attention:
# 1. Compute "ground-truth" attention output and weights
with torch.no_grad():
_y_true, a_true = softmax_attention(q, k, v)[:2]
y_true = (
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
)
y_true = self.o_proj(y_true)
# 2. Compute "predicted" attention outputs
# compute attn weights under sliding window
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
y_pred, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
else:
attn_weights = None
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = a_pred
else:
past_key_value.window_size = self.decode_window_size
if (
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
): # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
# Softmax attention terms
a_sm = torch.einsum(
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
) * (k.shape[-1] ** -0.5)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum(
"bhld,bhnd->bhl", f_q.float(), f_k_state.float()
)[..., None]
)
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, _ = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state,
)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(
k,
v,
self.layer_idx,
fmap_key_states=f_k,
accumulate_in_fp32=True,
)
# Concatenate heads and apply output projection
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
return y_true, attn_weights, past_key_value
class LinearAttentionTKWindowCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
# Account for sliding windows
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
self.window_size = window_size
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: Optional[int] = None,
cache_kwargs: Optional[Any] = None,
accumulate_in_fp32: bool = False,
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
grad_enabled: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update KV, K states; and KV cache during training
- For decoding, use `self.decode_kv_states` to keep track of KV states
up to sliding window terms
- For (chunked) training, use `self.kv_states` to keep track of KV states
up to end of sequence
- Likewise for `self.decode_k_states` and `self.k_states`
"""
if fmap_key_states is None:
raise ValueError("fmap_key_states should not be None")
if layer_idx is None:
raise ValueError("layer_idx should not be None")
with torch.set_grad_enabled(grad_enabled):
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
# key_states = key_states.float()
fmap_key_states = fmap_key_states.float()
value_states = value_states.float()
# Decoding KV state (KV terms up to last window_size)
decode_kv_state = torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, : -self.window_size],
value_states[:, :, : -self.window_size],
)
# KV state
kv_state = decode_kv_state + torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, -self.window_size :],
value_states[:, :, -self.window_size :],
)
# shape is b, h, 1, f; note the 1
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
dim=-2, keepdim=True
)
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
dim=-2, keepdim=True
)
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
self.decode_kv_states.append(decode_kv_state.to(dtype))
self.decode_k_states.append(decode_k_state.to(dtype))
self.k_cache.append(key_states[:, :, -self.window_size :, :])
self.v_cache.append(
value_states[:, :, -self.window_size :, :].to(dtype)
)
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
else:
# Update kv and k states recurrently
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
dtype
)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
dtype
)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
decode_kv_state = (
self.decode_kv_states[layer_idx].to(kv_state.dtype)
+ decode_kv_state
).to(dtype)
decode_k_state = (
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
).to(dtype)
self.decode_kv_states[layer_idx] = decode_kv_state
self.decode_k_states[layer_idx] = decode_k_state
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def update_for_decoding(
self,
keys: torch.Tensor,
values: torch.Tensor,
layer_idx: int,
feature_map_k: Callable,
dtype: torch.dtype,
):
"""
Update the decoding KV and K states, and KV cache, during decodeing
"""
with torch.no_grad():
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
if k_cache.shape[-2] < self.window_size: # build window-size cache
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
else:
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum(
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
).to(
dtype
) # b, h, f, d
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat(
[k_cache[:, :, 1:, :], keys], dim=-2
)
self.v_cache[layer_idx] = torch.cat(
[v_cache[:, :, 1:, :], values], dim=-2
)
if layer_idx == 0:
self._seen_tokens += keys.shape[-2]
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
return (
self.k_cache[layer_idx],
self.v_cache[layer_idx],
self.decode_kv_states[layer_idx],
self.decode_k_states[layer_idx],
)

View File

@@ -1,219 +0,0 @@
"""
LoLCATs + ThunderKittens linear attention + sliding window for generation
"""
import logging
from typing import Any, Callable, List, Optional
import torch
import torch.nn.functional as F
from .linear_attention import LinearAttentionState
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
LOG = logging.getLogger(__name__)
try:
from thunderkittens import hedgehog as tk_window_hedgehog_attention
LOG.debug("Successfully imported ThunderKittens for TK window attention")
except ImportError:
LOG.debug("Failed to import ThunderKittens for TK window attention")
class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention):
def __init__(self, *args, window_size: int = 64, **kwargs):
super().__init__(*args, **kwargs)
self.train_attention = False
self.base_inference = False
self.window_size = 64 # hard-coded support for TK kernel
self.decode_window_size = 64
b, h, l, d = 1, 32, 8192, 128
self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device="cuda")
self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device="cuda")
self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device="cuda")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Any] = None, # “legacy” cache approach
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
assert (
past_key_value is not None
), "past_key_value must be provided for generation"
assert (
self.train_attention is False
), "train_attention is not supported for generation"
assert (
self.base_inference is False
), "base_inference is not supported for generation"
assert use_cache is True, "use_cache must be True for generation"
past_key_value.window_size = self.decode_window_size
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
if q.shape[2] == 1 and kv_seq_len > 1: # Generating after prefill
f_q = self.feature_map_q(q)
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k
)
k_cache, v_cache, kv_state, k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
# Softmax attention terms
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
k.shape[-1] ** -0.5
)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[
..., None
]
)
self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Process prefill
# Use TK-implemented linear + terrace window attention
b, h, l, d = q.shape
device = q.device
# tk.hedgehog arguments
# y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device)
# kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device)
# k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device)
betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32))
alphas = (
1 - betas
if self.affine_attention_factors
else torch.ones(betas.shape, dtype=torch.float32, device=device)
)
q_map = self.feature_map_q.mlp.layer
k_map = self.feature_map_k.mlp.layer
# Saves outputs to y_pred, k_state, kv_state, where we fuse:
# 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
# 2. y_pred = attention(q, k, f_q, f_k, v) # b, h, l, d
# 3. kv_state = torch.einsum(bhlf,bhld->bhfd,
# f_k[:, :, :-self.window_size],
# v[:, :, :-self.window_size]) # b, h, f, d
# 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d
tk_window_hedgehog_attention(
q.contiguous(),
k.contiguous(),
v.contiguous(),
self.y_true,
self.k_state,
self.kv_state,
q_map,
k_map,
alphas,
betas,
)
past_key_value.update_with_kv(
self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx
)
# Concatenate heads and apply output projection
y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
return y_true, None, past_key_value
class LinearAttentionTKWindowGenerationCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a “KV state” and “K state”
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.window_size = window_size
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
def update_with_kv(
self,
kv_state: torch.Tensor,
k_state: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_idx: int,
):
"""
Update the cache with new KV and K states
"""
if layer_idx == 0:
self._seen_tokens += k.shape[2]
self._seen_tokens_by_layer.append(k.shape[2])
# Initialize KV and K states
if len(self.decode_k_states) <= layer_idx:
self.decode_kv_states.append(kv_state)
self.decode_k_states.append(k_state)
else: # Update KV and K states
self.decode_kv_states[layer_idx] = (
self.decode_kv_states[layer_idx] + kv_state
)
self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state
self.k_cache.append(k[:, :, -self.window_size :, :])
self.v_cache.append(v[:, :, -self.window_size :, :])
def update_for_decoding(
self, k: torch.Tensor, v: torch.Tensor, layer_idx: int, feature_map_k: Callable
):
"""
Update the cache for decoding
"""
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum("bhlf,bhld->bhfd", k_state.float(), v_state.float()).to(
k.dtype
)
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2)
if layer_idx == 0:
self._seen_tokens += k.shape[-2]
self._seen_tokens_by_layer[layer_idx] += k.shape[-2]
return (
self.k_cache[layer_idx],
self.v_cache[layer_idx],
self.decode_kv_states[layer_idx],
self.decode_k_states[layer_idx],
)

View File

@@ -1,306 +0,0 @@
"""
LoLCATs attention combining sliding window and linear attentions
- Using the TK "terracing" arrangement
- Training over long sequences with fixed memory with recurrent view
- During attention transfer, use Flash Attention to compute softmax attention outputs
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
import logging
from typing import Optional
import torch
import torch.nn.functional as F
from transformers.cache_utils import Cache
try:
from transformers.modeling_flash_attention_utils import _flash_attention_forward
except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from .linear_attention import softmax_attention
from .linear_window_attention_tk import LolcatsTKWindowAttention
LOG = logging.getLogger(
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"
)
class LolcatsTKWindowLongAttention(LolcatsTKWindowAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(self, remove_base_attn=True, **kwargs):
# keep self.base_attn for Flash Attention inference
super().__init__(remove_base_attn=True, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
if self.train_attention and self.base_inference:
with torch.no_grad():
# LOG.debug(hidden_states.shape)
_y_true = flash_attention_2(
self, # self.base_attn,
hidden_states=hidden_states,
attention_mask=None,
position_ids=position_ids,
past_key_value=None,
output_attentions=False,
# output_hidden_states=False,
use_cache=False,
)[0]
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
y_true = _y_true.reshape(b, l, -1).contiguous()
y_true = self.o_proj(y_true)
layer_io = (hidden_states, _y_true) # hack
# layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
return y_true, layer_io, None
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
y_pred, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
else:
past_key_value.window_size = self.decode_window_size
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
k.shape[-1] ** -0.5
)
# a_sm = torch.softmax(a_sm, dim=-1)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
y_pred = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum("bhlf,bhnf->bhl", f_q.float(), f_k_state.float())[
..., None
]
)
y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
if (
self.state_grad_enabled
and self.layer_idx == 0
and position_ids is not None
):
LOG.debug(
f"\n position_ids: [{position_ids[0, 0]}, {position_ids[0, -1]}]"
)
LOG.debug(
f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}"
)
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_pred, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state,
)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(
k, v, self.layer_idx, fmap_key_states=f_k, accumulate_in_fp32=True
)
# Concatenate heads and apply output projection
_y_pred = y_pred.transpose(1, 2).contiguous()
y_pred = self.o_proj(_y_pred.view(b, l, self.hidden_size))
if self.train_attention:
with torch.no_grad():
a_true = softmax_attention(q, k, None, causal=True)[1]
attn_weights = (_y_pred, (a_pred, a_true))
else:
attn_weights = _y_pred # flash_attn outputs are shape (b, l, h, d)
return y_pred, attn_weights, past_key_value
# -----------------
# Flash Attention 2
# -----------------
def flash_attention_2(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
"""
Wrapper for LlamaFlashAttention2
Copied and modified from HF Transformers v4.36 and v4.43 implementations
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
try: # As in Transformers v4.36
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
except Exception: # As in Transformers v4.39
cos, sin = self.rotary_emb(key_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
LOG.debug(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self, "_flash_attention_forward", False):
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=True,
)
else:
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0, # dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=True,
)
return attn_output, past_key_value

View File

@@ -1,361 +0,0 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""Linear LLaMA model implementation."""
import logging
from functools import partial
from typing import Any, Optional
from torch import nn
from tqdm import tqdm
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
from .configuration_linear_llama import LinearLlamaConfig
LOG = logging.getLogger(__name__)
class LinearLlamaDecoderLayer(LlamaDecoderLayer):
"""
Modified LlamaDecoderLayer that uses LinearAttention instead of standard attention.
"""
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)
# Replace the attention layer with our custom attention
self.self_attn = convert_llama_attention(
layer=self, attention_config=config.attention_config
)
class LinearLlamaModel(LlamaModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LinearLlamaDecoderLayer`]
Args:
config: LinearLlamaConfig
"""
config_class = LinearLlamaConfig
base_model_prefix = "linear_llama"
def __init__(self, config: LinearLlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
LinearLlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
class LinearLlamaForCausalLM(LlamaForCausalLM):
"""
Linear LLaMA model for causal language modeling.
"""
config_class = LinearLlamaConfig
base_model_prefix = "linear_llama"
def __init__(self, config):
super().__init__(config)
self.model = LinearLlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@classmethod
def from_llama(
cls,
model: LlamaForCausalLM,
config: LinearLlamaConfig,
train_attention: bool = False,
remove_base_attn: bool = True,
) -> "LinearLlamaForCausalLM":
"""
Initialize a LinearLlamaForCausalLM from a LlamaModel
"""
if config is None:
raise ValueError("Missing config")
# initialize a new model with config
new_model = cls(config=config)
# remove the default model and lm_head
del new_model.model
del new_model.lm_head
# load converted model, lm_head, and vocab_size from llama model
new_model.model = convert_attention(
model.model,
attention_config=config.attention_config,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
new_model.lm_head = model.lm_head
new_model.vocab_size = model.vocab_size
return new_model
def toggle_attention(self, train: bool = True):
"""
Toggle attention to be trainable or not
"""
toggle_attention(self.model, train=train)
def remove_base_attention(self):
"""
Remove base attention after distillation
"""
remove_base_attention(self.model)
def convert_attention(
model: nn.Module,
attention_config: dict,
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Call to convert all attention layers
"""
# Get the layers to convert if provided
softmax_attns = attention_config.get("softmax_attentions", [])
# Get the attention to convert to
attention_type = attention_config.get("attention_type")
if attention_type != "softmax":
layers = traverse_layers(model)
for layer_idx, layer in enumerate(
tqdm(layers, desc="Converting attentions...")
):
if layer_idx not in softmax_attns:
layer.self_attn = convert_llama_attention(
layer,
attention_config,
layers,
train_attention,
remove_base_attn,
)
layer.self_attn.converted = True
else:
# Freeze any preserved softmax attention layers
for p in layer.parameters():
p.requires_grad = False
else:
LOG.info(
f"-> attention_config.attention_type is {attention_type}; not converting attentions"
)
return model
def toggle_attention(llama_model: nn.Module, train: bool = False):
"""
Make attentions trainable if train is True
-> Set train_attention = False when finetuning
"""
for layer in traverse_layers(llama_model):
layer.self_attn.train_attention = train
return llama_model
def remove_base_attention(llama_model: nn.Module):
"""
Remove teacher attention after distillation (if we keep it)
"""
for layer in traverse_layers(llama_model):
if getattr(layer.self_attn, "base_attn", False):
del layer.self_attn.base_attn
return llama_model
def traverse_layers(model: nn.Module, verbose: bool = False):
"""
Return list of model layers
"""
try:
layers = model.model.layers
if verbose:
LOG.info("-> Loading from model.model.layers")
except AttributeError as e: # if base model
if verbose:
LOG.info(e)
try:
layers = model.layers
if verbose:
LOG.info("-> Loading from model.layers")
except AttributeError as e1: # If we make a PEFT model
if verbose:
LOG.info(e1)
layers = model.base_model.model.model.layers
if verbose:
LOG.info("-> Loading from model.base_model.model.model.layers")
return layers
def convert_llama_attention(
layer: nn.Module,
attention_config: dict,
layers: Optional[list[nn.Module]] = None, # list of layers
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Converts a single layer's attention layer as specified by attention_config
"""
return get_attention(**attention_config)(
base_attn=layer.self_attn,
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
max_layer_idx=len(layers) - 1 if layers else None,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
def get_attention(attention_type: str, **kwargs):
"""
Get the linear attention class; either purely linear or linear with sliding window
-> 'linear' == 'lolcats_llama'
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
"""
kwargs["attention_type"] = attention_type
if attention_type == "lolcats_llama":
from .linear_attention import LolcatsLinearAttention
return partial(LolcatsLinearAttention, **kwargs)
elif attention_type == "lolcats_llama_window_tk":
from .linear_window_attention_tk import LolcatsTKWindowAttention
return partial(LolcatsTKWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw":
from .linear_window_attention_sw import LolcatsSlidingWindowAttention
return partial(LolcatsSlidingWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw_linear":
from .linear_window_attention_sw_linear import (
LolcatsLinearSlidingWindowAttention,
)
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
# Experimental chunked linear attentions below
elif attention_type == "lolcats_long_llama_window_tk":
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
return partial(LolcatsTKWindowLongAttention, **kwargs)
elif attention_type == "lolcats_long_llama_window_sw":
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
return partial(LolcatsWindowAttentionTKGen, **kwargs)
else:
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
return None
def get_attention_cache(attention_type: str, past_key_values: Any = None):
"""
Determine how we store past keys and values when generating
"""
if attention_type is None:
return past_key_values
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
elif "lolcats_llama_window_tk_gen" in attention_type:
from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
)
return LinearAttentionTKWindowGenerationCache()
elif "llama_window_tk" in attention_type:
from .linear_window_attention_tk import LinearAttentionTKWindowCache
return LinearAttentionTKWindowCache()
elif "llama_window_sw" in attention_type:
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
elif "llama_window_sw_linear" in attention_type:
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
)
return LinearAttentionTKWindowGenerationCache()
elif "softmax" in attention_type:
return past_key_values
else:
from .linear_attention import LinearAttentionState
return LinearAttentionState()
def register_linear_llama():
"""
Register Linear LLaMA model with the Transformers library.
"""
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
AutoConfig.register("linear_llama", LinearLlamaConfig)
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
# registering for auto classes to save files
LinearLlamaConfig.register_for_auto_class("AutoConfig")
LinearLlamaModel.register_for_auto_class("AutoModel")
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")

View File

@@ -1,118 +0,0 @@
"""
Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer.
In this implementation we support using either just the softmax attention outputs, or the softmax attention weights.
"""
from typing import Any
from torch import Tensor, nn, tensor
from axolotl.core.trainers.base import AxolotlTrainer
class DistillAttentionXentMSETrainer(AxolotlTrainer):
"""
Custom trainer class for distilling attentions.
- We compute and store the attention outputs and/or weights for each head and layer,
for both the "teacher" softmax attentions and "student" learnable subquadratic attentions
- We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights)
"""
def __init__(
self,
model: nn.Module,
mse_factor: float = 1e3,
xent_factor: float = 0,
**kwargs: Any,
):
super().__init__(model=model, **kwargs)
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
self.criterion_mse = nn.MSELoss(reduction="mean")
self.mse_factor = mse_factor
self.xent_factor = xent_factor
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
self.model_accepts_loss_kwargs = False # added to combat explosive loss
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, Tensor],
return_outputs=False,
num_items_in_batch=None,
) -> tuple[Tensor, dict]:
"""
Attention distillation ("attention transfer")
- For each layer and head, get attentions and train to
minimize some combo of MSE and cross-entropy loss
"""
# alias inputs to data
data = inputs
device = model.device
# Filter out labels
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
# set num_items_in_batch
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
# Forward pass
outputs = model(**inputs, output_attentions=True, use_cache=False)
outputs = outputs.get("attentions")
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
# n_layers x (predicted_attns, true_attns)
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
loss_mse = tensor(0.0, device=device)
loss_xent = tensor(0.0, device=device)
n_layers = 0 # Number of layers to distill
softmax_layers = []
for layer_idx, attns in enumerate(outputs):
if attns is not None:
if len(attns) != 2:
attns = attns.cpu()
else:
if self.xent_factor > 0:
# Cross-entropy loss
a_pred, a_true = attns[0]
a_pred = a_pred.clamp(
min=1e-12
).log() # nn.CrossEntropy assumes unnormalized logits
k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len
# Compute mean cross-entropy over all queries
a_pred = a_pred.contiguous().view(-1, k_len)
a_true = a_true.contiguous().view(-1, k_len)
loss_xent += self.criterion_xent(a_pred, a_true)
if self.mse_factor > 0:
loss_mse += self.criterion_mse(*attns[1])
n_layers += 1
else:
softmax_layers.append(layer_idx)
if n_layers > 0:
loss_xent = loss_xent / n_layers * self.xent_factor
loss_mse = loss_mse / n_layers * self.mse_factor
loss = loss_xent + loss_mse
if "position_ids" in data:
outputs = {
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
"input_len": data["position_ids"].shape[1],
"position_ids": data["position_ids"][0].detach().cpu().numpy(),
"mse_factor": self.mse_factor,
"xent_factor": self.xent_factor,
}
else:
outputs = {
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
"mse_factor": self.mse_factor,
"xent_factor": self.xent_factor,
}
return (loss, outputs) if return_outputs else loss

View File

@@ -0,0 +1,308 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging
from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""
PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
else:
loss = self.compute_loss(model, inputs)
"""
ORIGINAL_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
"""
PATCHED_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
"""
def get_training_step_code() -> str:
training_step = inspect.getsource(
Trainer.training_step # pylint: disable=protected-access
)
return training_step
def check_training_step_is_patchable() -> bool:
training_step = get_training_step_code()
training_step, _ = detab_code(training_step)
return ORIGINAL_CONTEXT_CODE in training_step
def patch_training_step_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step)
if ORIGINAL_CONTEXT_CODE not in training_step:
return
# assert (
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace(
"def training_step(",
"def _fixed_training_step(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_step:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
def get_model_forward_code() -> str:
forward = inspect.getsource(
LlamaForCausalLM.forward # pylint: disable=protected-access
)
return forward
def check_forward_is_patchable() -> bool:
forward = get_model_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_LLAMA_FCLM_CODE in forward
def patch_forward_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace(
"def forward(",
"def _fixed_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)
ORIGINAL_TRAINER_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
PATCHED_TRAINER_CODE = """
disable_deepspeed_no_sync = (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
)
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_deepspeed_0_16_x():
"""
monkeypatch for fixing the training loop for deepspeed GA
see https://github.com/huggingface/transformers/pull/35157
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)
def patch_flash_attention_forward():
"""
monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch
"""
import transformers.modeling_flash_attention_utils
def proxy_flash_attention_forward(*args, **kwargs):
kwargs.pop("num_items_in_batch", None)
return _flash_attention_forward(*args, **kwargs)
transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)
transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)

View File

@@ -1,67 +0,0 @@
"""
see https://github.com/huggingface/transformers/pull/35834
"""
import logging
from functools import partial
from typing import Optional
import torch
logger = logging.getLogger(__name__)
def fixed_fa_peft_integration_check(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
target_dtype: Optional[torch.dtype] = None,
preferred_dtype: Optional[torch.dtype] = None,
):
"""
PEFT usually casts the layer norms in float32 for training stability reasons
therefore the input hidden states gets silently casted in float32. Hence, we need
cast them back in float16 / bfloat16 just to be sure everything works as expected.
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
Args:
query (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value (`torch.Tensor`):
Input value states to be passed to Flash Attention API
target_dtype (`torch.dtype`, *optional*):
The dtype to convert the attention tensors to. Conversion can be ignored by
not providing the target dtype.
preferred_dtype (`torch.dtype`, *optional*):
The preferred dtype to convert the attention tensors to regardless of the
target dtype.
"""
if target_dtype is None and preferred_dtype is None:
return query, key, value
if preferred_dtype and target_dtype != preferred_dtype:
target_dtype = preferred_dtype
# check if any of query, key, or value are in float32. If so, cast them back to target dtype.
if any(module.dtype == torch.float32 for module in [query, key, value]):
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
return query, key, value
def patch_fa_peft_integration():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils.fa_peft_integration_check = partial(
fixed_fa_peft_integration_check, preferred_dtype=None
)

View File

@@ -1,116 +0,0 @@
"""
Module for stepwise datasets, typically including a prompt and reasoning traces,
and (optionally) per-step, or per-prompt-trace labels for reward modelling.
"""
from itertools import chain
from typing import Dict, List, Optional, Union
from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.prompt_tokenizers import IGNORE_INDEX
from axolotl.utils.dict import DictDefault
class StepwiseSupervisedPromptTokenizingStrategy:
"""
Tokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning.
These datasets should include the following columns:
- prompt: the prompt text
- completions: a list of `n` completion steps
- labels: a list of `n` labels indicating the "correctness" of each step
"""
def __init__(
self,
tokenizer,
sequence_len: int = 2048,
step_separator: str = "\n",
max_completion_length: Optional[int] = None,
train_on_last_step_only: bool = False,
):
self.tokenizer = tokenizer
self.sequence_len = sequence_len
self.step_separator = step_separator
self.max_completion_length = max_completion_length
self.train_on_last_step_only = train_on_last_step_only
def tokenize_prompt(
self, prompt: Dict[str, Union[str, List[str]]]
) -> BatchEncoding:
# Inspired by TRL's PRMTRainer
# https://github.com/huggingface/trl/blob/ed7de87dc766478c024b68f12530d1b0e7c3ff23/trl/trainer/prm_trainer.py#L206
prompt_ids = self.tokenizer(prompt["prompt"], add_special_tokens=False)[
"input_ids"
]
completions_ids = [
self.tokenizer(completion, add_special_tokens=False)["input_ids"]
for completion in prompt["completions"]
]
# Handle labels
if self.train_on_last_step_only:
labels = [IGNORE_INDEX] * (len(prompt["labels"]) - 1) + [
int(prompt["labels"][-1])
]
else:
labels = [int(label) for label in prompt["labels"]]
# Add step separators
separator_ids = self.tokenizer.encode(
self.step_separator, add_special_tokens=False
)
completions_ids = [completion + separator_ids for completion in completions_ids]
# Create step-wise labels
labels = [
[IGNORE_INDEX] * (len(completion) - 1) + [label] # type: ignore
for completion, label in zip(completions_ids, labels)
]
# Join all steps
completion_ids = list(chain(*completions_ids))
labels = list(chain(*labels)) # type: ignore
# Handle max lengths
if self.max_completion_length:
completion_ids = completion_ids[: self.max_completion_length]
labels = labels[: self.max_completion_length]
# Add BOS token if model has one
if self.tokenizer.bos_token_id is not None:
prompt_ids = [self.tokenizer.bos_token_id] + prompt_ids
# Combine prompt and completion
input_ids = prompt_ids + completion_ids
full_labels = [IGNORE_INDEX] * len(prompt_ids) + labels
# Apply max sequence length
if self.sequence_len:
input_ids = input_ids[: self.sequence_len]
full_labels = full_labels[: self.sequence_len]
return {
"input_ids": input_ids,
"labels": full_labels,
"attention_mask": [1] * len(input_ids),
}
@property
def supports_batched(self):
return False
def load(
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
ds_cfg: DictDefault,
) -> StepwiseSupervisedPromptTokenizingStrategy:
return StepwiseSupervisedPromptTokenizingStrategy(
tokenizer,
cfg.sequence_len,
step_separator=ds_cfg.get("step_separator", "\n"),
max_completion_length=ds_cfg.max_completion_length,
train_on_last_step_only=ds_cfg.get("train_on_last_step_only", False),
)

View File

@@ -141,9 +141,7 @@ def train(
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if (
cfg.local_rank == 0 and not cfg.use_ray
): # ray workers don't have access to this signal
if cfg.local_rank == 0:
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
@@ -261,7 +259,7 @@ def train(
.decode("utf-8")
}
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model:
if cfg.rl is not None or cfg.reward_model:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import gc
import logging
import math
import os
import traceback
from shutil import copyfile
@@ -829,6 +830,13 @@ class SaveModelCallback(TrainerCallback):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
elif (
args.save_strategy == IntervalStrategy.STEPS
and state.save_steps < 1.0
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
):
# workaround to save model on fractional save_steps
control.should_save = True
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs
@@ -846,12 +854,6 @@ class GCCallback(TrainerCallback):
def on_step_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
if state.global_step % self.gc_steps == 0:
torch.cuda.empty_cache()
gc.collect()
def on_epoch_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
torch.cuda.empty_cache()
gc.collect()

Some files were not shown because too many files have changed in this diff Show More