Compare commits
51 Commits
chat-datas
...
feat/linea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13d458d0ae | ||
|
|
ebd406af1d | ||
|
|
caa49a9d7d | ||
|
|
c15ea6b956 | ||
|
|
578fa764c8 | ||
|
|
0e6efaa10c | ||
|
|
c4cb622590 | ||
|
|
0f82bd2d18 | ||
|
|
49746b184f | ||
|
|
9e1c4de13c | ||
|
|
2d5f692fc0 | ||
|
|
2fd5c45c2e | ||
|
|
8294e6218f | ||
|
|
253dcdd0cf | ||
|
|
4cc60df876 | ||
|
|
2bc7833a4e | ||
|
|
1fb8d86396 | ||
|
|
adeefc1991 | ||
|
|
fb88269dcb | ||
|
|
433cf4a8c7 | ||
|
|
0b7b58c8be | ||
|
|
81731adc1d | ||
|
|
a1715aa317 | ||
|
|
ce0cd470f7 | ||
|
|
311d6eb5da | ||
|
|
158330ab60 | ||
|
|
80e1468b8d | ||
|
|
a20f17689b | ||
|
|
78ce268848 | ||
|
|
d425d5d3c3 | ||
|
|
cf17649ef3 | ||
|
|
6f294c3d8d | ||
|
|
6f713226dd | ||
|
|
1063d82b51 | ||
|
|
ac471a697a | ||
|
|
8779997ba5 | ||
|
|
268543a3be | ||
|
|
54dd7abfc1 | ||
|
|
c071a530f7 | ||
|
|
c015a76a23 | ||
|
|
067b442596 | ||
|
|
0b52f06227 | ||
|
|
887513285d | ||
|
|
20620771f1 | ||
|
|
6086162488 | ||
|
|
b2774af66c | ||
|
|
74f9782fc3 | ||
|
|
8a7a0b07dc | ||
|
|
8fb72cbc0b | ||
|
|
bb9d4102c4 | ||
|
|
af727eedf7 |
2
.github/CONTRIBUTING.md
vendored
2
.github/CONTRIBUTING.md
vendored
@@ -15,7 +15,7 @@ First of all, thank you for your interest in contributing to axolotl! We appreci
|
||||
- [Commit Messages](#commit-messages)
|
||||
- [Additional Resources](#additional-resources)
|
||||
|
||||
## Code of Conductcode
|
||||
## Code of Conduct
|
||||
|
||||
All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT.md). Please read it before participating in the axolotl community.
|
||||
|
||||
|
||||
12
.github/workflows/base.yml
vendored
12
.github/workflows/base.yml
vendored
@@ -22,18 +22,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "121"
|
||||
cuda_version: 12.1.1
|
||||
cudnn_version: 8
|
||||
python_version: "3.10"
|
||||
pytorch: 2.3.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "121"
|
||||
cuda_version: 12.1.1
|
||||
cudnn_version: 8
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
|
||||
26
.github/workflows/main.yml
vendored
26
.github/workflows/main.yml
vendored
@@ -15,16 +15,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.10"
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras: mamba-ssm
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras: mamba-ssm
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
@@ -82,16 +72,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.10"
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
@@ -145,10 +125,10 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
6
.github/workflows/multi-gpu-e2e.yml
vendored
6
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -20,12 +20,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
|
||||
22
.github/workflows/nightlies.yml
vendored
22
.github/workflows/nightlies.yml
vendored
@@ -12,17 +12,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.10"
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
@@ -76,17 +65,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.10"
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
|
||||
9
.github/workflows/tests-nightly.yml
vendored
9
.github/workflows/tests-nightly.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||
pytorch_version: ["2.4.1", "2.5.1"]
|
||||
exclude:
|
||||
- python_version: "3.10"
|
||||
pytorch_version: "2.4.1"
|
||||
@@ -98,13 +98,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.10"
|
||||
pytorch: 2.3.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
|
||||
90
.github/workflows/tests.yml
vendored
90
.github/workflows/tests.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||
pytorch_version: ["2.4.1", "2.5.1"]
|
||||
exclude:
|
||||
- python_version: "3.10"
|
||||
pytorch_version: "2.4.1"
|
||||
@@ -204,52 +204,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.tests
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.10"
|
||||
pytorch: 2.3.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
@@ -274,6 +228,48 @@ jobs:
|
||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.tests
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
|
||||
@@ -19,7 +19,7 @@ repos:
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 6.0.0
|
||||
rev: 6.1.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/pylint
|
||||
|
||||
775
README.md
775
README.md
@@ -1,8 +1,8 @@
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="image/axolotl_logo_digital_white.svg">
|
||||
<source media="(prefers-color-scheme: light)" srcset="image/axolotl_logo_digital_black.svg">
|
||||
<img alt="Axolotl" src="image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_white.svg">
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg">
|
||||
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
|
||||
</picture>
|
||||
</p>
|
||||
|
||||
@@ -19,235 +19,99 @@
|
||||
<br/>
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
||||
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
||||
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
|
||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||
Post-training refers to any modifications or additional training performed on
|
||||
pre-trained models - including full model fine-tuning, parameter-efficient tuning (like
|
||||
LoRA and QLoRA), supervised fine-tuning (SFT), instruction tuning, and alignment
|
||||
techniques. With support for multiple model architectures and training configurations,
|
||||
Axolotl makes it easy to get started with these techniques.
|
||||
|
||||
Axolotl is designed to work with YAML config files that contain everything you need to
|
||||
preprocess a dataset, train or fine-tune a model, run model inference or evaluation,
|
||||
and much more.
|
||||
|
||||
Features:
|
||||
|
||||
- Train various Huggingface models such as llama, pythia, falcon, mpt
|
||||
- Supports fullfinetune, lora, qlora, relora, and gptq
|
||||
- Customize configurations using a simple yaml file or CLI overwrite
|
||||
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
|
||||
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
|
||||
- Integrated with [xformers](https://github.com/facebookresearch/xformers), flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
|
||||
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
||||
- Easily run with Docker locally or on the cloud
|
||||
- Log results and optionally checkpoints to wandb, mlflow or Comet
|
||||
- And more!
|
||||
|
||||
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
||||
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
|
||||
</a>
|
||||
## 🚀 Quick Start
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
**Requirements**:
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.10
|
||||
- PyTorch ≥2.4.1
|
||||
|
||||
## Table of Contents
|
||||
- [Axolotl](#axolotl)
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [Quickstart ⚡](#quickstart-)
|
||||
- [Edge Builds](#edge-builds-)
|
||||
- [Axolotl CLI Usage](#axolotl-cli-usage)
|
||||
- [Badge ❤🏷️](#badge-️)
|
||||
- [Contributing 🤝](#contributing-)
|
||||
- [Sponsors 🤝❤](#sponsors-)
|
||||
- [Axolotl supports](#axolotl-supports)
|
||||
- [Advanced Setup](#advanced-setup)
|
||||
- [Environment](#environment)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [Cloud GPU](#cloud-gpu)
|
||||
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
||||
- [LambdaLabs](#lambdalabs)
|
||||
- [GCP](#gcp)
|
||||
- [Windows](#windows)
|
||||
- [Mac](#mac)
|
||||
- [Google Colab](#google-colab)
|
||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
|
||||
- [Dataset](#dataset)
|
||||
- [Config](#config)
|
||||
- [All Config Options](#all-config-options)
|
||||
- [Train](#train)
|
||||
- [Preprocess dataset](#preprocess-dataset)
|
||||
- [Multi-GPU](#multi-gpu)
|
||||
- [DeepSpeed](#deepspeed)
|
||||
- [FSDP](#fsdp)
|
||||
- [FSDP + QLoRA](#fsdp--qlora)
|
||||
- [Weights \& Biases Logging](#weights--biases-logging)
|
||||
- [Special Tokens](#special-tokens)
|
||||
- [Liger Kernel](#liger-kernel)
|
||||
- [Inference Playground](#inference-playground)
|
||||
- [Merge LORA to base](#merge-lora-to-base)
|
||||
- [Common Errors 🧰](#common-errors-)
|
||||
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
|
||||
- [Debugging Axolotl](#debugging-axolotl)
|
||||
- [Need help? 🙋](#need-help-)
|
||||
### Installation
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
<div align="center">
|
||||
<img src="image/axolotl_symbol_digital_white.svg" alt="axolotl" width="160">
|
||||
<div>
|
||||
<p>
|
||||
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
|
||||
</p>
|
||||
<p>
|
||||
Go ahead and Axolotl questions!!
|
||||
</p>
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
||||
<img alt="PyTest Status" src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Quickstart ⚡
|
||||
|
||||
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
||||
|
||||
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
|
||||
|
||||
```bash
|
||||
```shell
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# download examples and optionally deepspeed configs to the local path
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
axolotl fetch examples
|
||||
axolotl fetch deepspeed_configs # OPTIONAL
|
||||
|
||||
# finetune using lora
|
||||
axolotl train examples/llama-3/lora-1b.yml
|
||||
```
|
||||
|
||||
### Edge Builds 🏎️
|
||||
Other installation approaches are described [here](https://axolotl-ai-cloud.github.io/axolotl/docs/installation.html).
|
||||
|
||||
If you're looking for the latest features and updates between releases, you'll need to install
|
||||
from source.
|
||||
### Your First Fine-tune
|
||||
|
||||
```bash
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip3 install packaging ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
### Axolotl CLI Usage
|
||||
We now support a new, more streamlined CLI using [click](https://click.palletsprojects.com/en/stable/).
|
||||
|
||||
```bash
|
||||
# preprocess datasets - optional but recommended
|
||||
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/llama-3/lora-1b.yml
|
||||
|
||||
# finetune lora
|
||||
axolotl train examples/llama-3/lora-1b.yml
|
||||
|
||||
# inference
|
||||
axolotl inference examples/llama-3/lora-1b.yml \
|
||||
--lora-model-dir="./outputs/lora-out"
|
||||
|
||||
# gradio
|
||||
axolotl inference examples/llama-3/lora-1b.yml \
|
||||
--lora-model-dir="./outputs/lora-out" --gradio
|
||||
|
||||
# remote yaml files - the yaml config can be hosted on a public URL
|
||||
# Note: the yaml config must directly link to the **raw** yaml
|
||||
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
|
||||
```
|
||||
|
||||
We've also added a new command for fetching `examples` and `deepspeed_configs` to your
|
||||
local machine. This will come in handy when installing `axolotl` from PyPI.
|
||||
|
||||
```bash
|
||||
# Fetch example YAML files (stores in "examples/" folder)
|
||||
```shell
|
||||
# Fetch axolotl examples
|
||||
axolotl fetch examples
|
||||
|
||||
# Fetch deepspeed config files (stores in "deepspeed_configs/" folder)
|
||||
axolotl fetch deepspeed_configs
|
||||
|
||||
# Optionally, specify a destination folder
|
||||
# Or, specify a custom path
|
||||
axolotl fetch examples --dest path/to/folder
|
||||
|
||||
# Train a model using LoRA
|
||||
axolotl train examples/llama-3/lora-1b.yml
|
||||
```
|
||||
|
||||
### Legacy Usage
|
||||
<details>
|
||||
That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/getting-started.html) for a more detailed walkthrough.
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
## ✨ Key Features
|
||||
|
||||
While the Axolotl CLI is the preferred method for interacting with axolotl, we
|
||||
still support the legacy `-m axolotl.cli.*` usage.
|
||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, and more
|
||||
- **Easy Configuration**: Simple YAML files to control your training setup
|
||||
- **Performance Optimizations**: Flash Attention, xformers, multi-GPU training
|
||||
- **Flexible Dataset Handling**: Use various formats and custom datasets
|
||||
- **Cloud Ready**: Run on cloud platforms or local hardware
|
||||
|
||||
```bash
|
||||
# preprocess datasets - optional but recommended
|
||||
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/llama-3/lora-1b.yml
|
||||
## 📚 Documentation
|
||||
|
||||
# finetune lora
|
||||
accelerate launch -m axolotl.cli.train examples/llama-3/lora-1b.yml
|
||||
- [Installation Options](https://axolotl-ai-cloud.github.io/axolotl/docs/installation.html) - Detailed setup instructions for different environments
|
||||
- [Configuration Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html) - Full configuration options and examples
|
||||
- [Dataset Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) - Supported formats and how to use them
|
||||
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
||||
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
||||
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
||||
|
||||
# inference
|
||||
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
|
||||
--lora_model_dir="./outputs/lora-out"
|
||||
## 🤝 Getting Help
|
||||
|
||||
# gradio
|
||||
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
|
||||
--lora_model_dir="./outputs/lora-out" --gradio
|
||||
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support
|
||||
- Check out our [Examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/) directory
|
||||
- Read our [Debugging Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/debugging.html)
|
||||
- Need dedicated support? Please contact [✉️wing@axolotl.ai](mailto:wing@axolotl.ai) for options
|
||||
|
||||
# remote yaml files - the yaml config can be hosted on a public URL
|
||||
# Note: the yaml config must directly link to the **raw** yaml
|
||||
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
|
||||
```
|
||||
## 🌟 Contributing
|
||||
|
||||
</details>
|
||||
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
|
||||
|
||||
## Badge ❤🏷️
|
||||
|
||||
Building something cool with Axolotl? Consider adding a badge to your model card.
|
||||
|
||||
```markdown
|
||||
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
```
|
||||
|
||||
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
|
||||
## Sponsors 🤝❤
|
||||
|
||||
If you love axolotl, consider sponsoring the project by reaching out directly to [wing@axolotl.ai](mailto:wing@axolotl.ai).
|
||||
|
||||
---
|
||||
|
||||
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
|
||||
|
||||
---
|
||||
|
||||
## Contributing 🤝
|
||||
|
||||
Please read the [contributing guide](./.github/CONTRIBUTING.md)
|
||||
|
||||
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
|
||||
|
||||
PRs are **greatly welcome**!
|
||||
|
||||
Please run the quickstart instructions followed by the below to setup env:
|
||||
```bash
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
pre-commit install
|
||||
|
||||
# test
|
||||
pytest tests/
|
||||
|
||||
# optional: run against all files
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
||||
|
||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
||||
</a>
|
||||
|
||||
## Axolotl supports
|
||||
## Supported Models
|
||||
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||
@@ -272,523 +136,16 @@ Thanks to all of our contributors to date. Help drive open source AI progress fo
|
||||
❌: not supported
|
||||
❓: untested
|
||||
|
||||
## Advanced Setup
|
||||
## ❤️ Sponsors
|
||||
|
||||
### Environment
|
||||
Thank you to our sponsors who help make Axolotl possible:
|
||||
|
||||
#### Docker
|
||||
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) - Modal lets you run
|
||||
jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale,
|
||||
fine-tune large language models, run protein folding simulations, and much more.
|
||||
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
```
|
||||
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
||||
|
||||
Or run on the current files for development:
|
||||
## 📜 License
|
||||
|
||||
```sh
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
>[!Tip]
|
||||
> If you want to debug axolotl or prefer to use Docker as your development environment, see the [debugging guide's section on Docker](docs/debugging.qmd#debugging-with-docker).
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Docker advanced</summary>
|
||||
|
||||
A more powerful Docker command to run would be this:
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
|
||||
```
|
||||
|
||||
It additionally:
|
||||
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
||||
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
||||
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
||||
* The `--privileged` flag gives all capabilities to the container.
|
||||
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
|
||||
|
||||
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
||||
|
||||
</details>
|
||||
|
||||
#### Conda/Pip venv
|
||||
1. Install python >=**3.10**
|
||||
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
3. Install Axolotl along with python dependencies
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
4. (Optional) Login to Huggingface to use gated models/datasets.
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
Get the token at huggingface.co/settings/tokens
|
||||
|
||||
#### Cloud GPU
|
||||
|
||||
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
|
||||
|
||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
||||
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
#### Bare Metal Cloud GPU
|
||||
|
||||
##### LambdaLabs
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
|
||||
1. Install python
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install -y python3.10
|
||||
|
||||
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
|
||||
sudo update-alternatives --config python # pick 3.10 if given option
|
||||
python -V # should be 3.10
|
||||
|
||||
```
|
||||
|
||||
2. Install pip
|
||||
```bash
|
||||
wget https://bootstrap.pypa.io/get-pip.py
|
||||
python get-pip.py
|
||||
```
|
||||
|
||||
3. Install Pytorch https://pytorch.org/get-started/locally/
|
||||
|
||||
4. Follow instructions on quickstart.
|
||||
|
||||
5. Run
|
||||
```bash
|
||||
pip3 install protobuf==3.20.3
|
||||
pip3 install -U --ignore-installed requests Pillow psutil scipy
|
||||
```
|
||||
|
||||
6. Set path
|
||||
```bash
|
||||
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
|
||||
```
|
||||
</details>
|
||||
|
||||
##### GCP
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
|
||||
Use a Deeplearning linux OS with cuda and pytorch installed. Then follow instructions on quickstart.
|
||||
|
||||
Make sure to run the below to uninstall xla.
|
||||
```bash
|
||||
pip uninstall -y torch_xla[tpu]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### Windows
|
||||
Please use WSL or Docker!
|
||||
|
||||
#### Mac
|
||||
|
||||
Use the below instead of the install method in QuickStart.
|
||||
```
|
||||
pip3 install --no-build-isolation -e '.'
|
||||
```
|
||||
More info: [mac.md](/docs/mac.qmd)
|
||||
|
||||
#### Google Colab
|
||||
|
||||
Please use this example [notebook](examples/colab-notebooks/colab-axolotl-example.ipynb).
|
||||
|
||||
#### Launching on public clouds via SkyPilot
|
||||
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
|
||||
|
||||
```bash
|
||||
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
|
||||
sky check
|
||||
```
|
||||
|
||||
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
|
||||
```
|
||||
git clone https://github.com/skypilot-org/skypilot.git
|
||||
cd skypilot/llm/axolotl
|
||||
```
|
||||
|
||||
Use one command to launch:
|
||||
```bash
|
||||
# On-demand
|
||||
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
||||
|
||||
# Managed spot (auto-recovery on preemption)
|
||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
||||
```
|
||||
|
||||
#### Launching on public clouds via dstack
|
||||
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
|
||||
|
||||
Write a job description in YAML as below:
|
||||
|
||||
```yaml
|
||||
# dstack.yaml
|
||||
type: task
|
||||
|
||||
image: axolotlai/axolotl-cloud:main-latest
|
||||
|
||||
env:
|
||||
- HUGGING_FACE_HUB_TOKEN
|
||||
- WANDB_API_KEY
|
||||
|
||||
commands:
|
||||
- accelerate launch -m axolotl.cli.train config.yaml
|
||||
|
||||
ports:
|
||||
- 6006
|
||||
|
||||
resources:
|
||||
gpu:
|
||||
memory: 24GB..
|
||||
count: 2
|
||||
```
|
||||
|
||||
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
|
||||
|
||||
```bash
|
||||
pip install dstack
|
||||
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
|
||||
```
|
||||
|
||||
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
|
||||
|
||||
### Dataset
|
||||
|
||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||
|
||||
See [the documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
|
||||
|
||||
### Config
|
||||
|
||||
See [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
|
||||
|
||||
- model
|
||||
```yaml
|
||||
base_model: ./llama-7b-hf # local or huggingface repo
|
||||
```
|
||||
Note: The code will load the right architecture.
|
||||
|
||||
- dataset
|
||||
```yaml
|
||||
datasets:
|
||||
# huggingface repo
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
|
||||
# huggingface repo with specific configuration/subset
|
||||
- path: EleutherAI/pile
|
||||
name: enron_emails
|
||||
type: completion # format from earlier
|
||||
field: text # Optional[str] default: text, field to use for completion data
|
||||
|
||||
# huggingface repo with multiple named configurations/subsets
|
||||
- path: bigcode/commitpackft
|
||||
name:
|
||||
- ruby
|
||||
- python
|
||||
- typescript
|
||||
type: ... # unimplemented custom format
|
||||
|
||||
# chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template
|
||||
- path: ...
|
||||
type: chat_template
|
||||
chat_template: chatml # defaults to tokenizer's chat_template
|
||||
|
||||
# local
|
||||
- path: data.jsonl # or json
|
||||
ds_type: json # see other options below
|
||||
type: alpaca
|
||||
|
||||
# dataset with splits, but no train split
|
||||
- path: knowrohit07/know_sql
|
||||
type: context_qa.load_v2
|
||||
train_on_split: validation
|
||||
|
||||
# loading from s3 or gcs
|
||||
# s3 creds will be loaded from the system default and gcs only supports public access
|
||||
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
||||
...
|
||||
|
||||
# Loading Data From a Public URL
|
||||
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
|
||||
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
|
||||
ds_type: json # this is the default, see other options below.
|
||||
```
|
||||
|
||||
- loading
|
||||
```yaml
|
||||
load_in_4bit: true
|
||||
load_in_8bit: true
|
||||
|
||||
bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
|
||||
fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
|
||||
tf32: true # require >=ampere
|
||||
|
||||
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
|
||||
float16: true # use instead of fp16 when you don't want AMP
|
||||
```
|
||||
Note: Repo does not do 4-bit quantization.
|
||||
|
||||
- lora
|
||||
```yaml
|
||||
adapter: lora # 'qlora' or leave blank for full finetune
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
```
|
||||
|
||||
#### All Config Options
|
||||
|
||||
See [these docs](docs/config.qmd) for all config options.
|
||||
|
||||
### Train
|
||||
|
||||
Run
|
||||
```bash
|
||||
accelerate launch -m axolotl.cli.train your_config.yml
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
|
||||
|
||||
#### Preprocess dataset
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||
This is recommended for large datasets.
|
||||
|
||||
- Set `dataset_prepared_path:` to a local folder for saving and loading pre-tokenized dataset.
|
||||
- (Optional): Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
||||
- (Optional): Use `--debug` to see preprocessed examples.
|
||||
|
||||
```bash
|
||||
python -m axolotl.cli.preprocess your_config.yml
|
||||
```
|
||||
|
||||
#### Multi-GPU
|
||||
|
||||
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
|
||||
is the recommended multi-GPU option currently because FSDP may experience
|
||||
[loss instability](https://github.com/huggingface/transformers/issues/26498).
|
||||
|
||||
##### DeepSpeed
|
||||
|
||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||
|
||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed_configs/zero1.json
|
||||
```
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
|
||||
```
|
||||
|
||||
##### FSDP
|
||||
|
||||
- llama FSDP
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_offload_params: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
##### FSDP + QLoRA
|
||||
|
||||
Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.qmd) for more information.
|
||||
|
||||
##### Weights & Biases Logging
|
||||
|
||||
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
||||
|
||||
- wandb options
|
||||
```yaml
|
||||
wandb_mode:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
##### Comet Logging
|
||||
|
||||
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
|
||||
|
||||
- wandb options
|
||||
```yaml
|
||||
use_comet:
|
||||
comet_api_key:
|
||||
comet_workspace:
|
||||
comet_project_name:
|
||||
comet_experiment_key:
|
||||
comet_mode:
|
||||
comet_online:
|
||||
comet_experiment_config:
|
||||
```
|
||||
|
||||
##### Special Tokens
|
||||
|
||||
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
||||
|
||||
```yml
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
tokens: # these are delimiters
|
||||
- "<|im_start|>"
|
||||
- "<|im_end|>"
|
||||
```
|
||||
|
||||
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
|
||||
|
||||
##### Liger Kernel
|
||||
|
||||
Liger Kernel: Efficient Triton Kernels for LLM Training
|
||||
|
||||
https://github.com/linkedin/Liger-Kernel
|
||||
|
||||
Liger (LinkedIn GPU Efficient Runtime) Kernel is a collection of Triton kernels designed specifically for LLM training.
|
||||
It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The Liger Kernel
|
||||
composes well and is compatible with both FSDP and Deepspeed.
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
### Inference Playground
|
||||
|
||||
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.
|
||||
The config file is the same config file used for training.
|
||||
|
||||
Pass the appropriate flag to the inference command, depending upon what kind of model was trained:
|
||||
|
||||
- Pretrained LORA:
|
||||
```bash
|
||||
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
|
||||
```
|
||||
- Full weights finetune:
|
||||
```bash
|
||||
python -m axolotl.cli.inference examples/your_config.yml --base_model="./completed-model"
|
||||
```
|
||||
- Full weights finetune w/ a prompt from a text file:
|
||||
```bash
|
||||
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||
```
|
||||
-- With gradio hosting
|
||||
```bash
|
||||
python -m axolotl.cli.inference examples/your_config.yml --gradio
|
||||
```
|
||||
|
||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
||||
|
||||
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
|
||||
|
||||
### Merge LORA to base
|
||||
|
||||
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
|
||||
|
||||
```bash
|
||||
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
|
||||
```
|
||||
|
||||
You may need to use the `gpu_memory_limit` and/or `lora_on_cpu` config options to avoid running out of memory. If you still run out of CUDA memory, you can try to merge in system RAM with
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
|
||||
```
|
||||
|
||||
although this will be very slow, and using the config options above are recommended instead.
|
||||
|
||||
## Common Errors 🧰
|
||||
|
||||
See also the [FAQ's](./docs/faq.qmd) and [debugging guide](docs/debugging.qmd).
|
||||
|
||||
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
|
||||
|
||||
Please reduce any below
|
||||
- `micro_batch_size`
|
||||
- `eval_batch_size`
|
||||
- `gradient_accumulation_steps`
|
||||
- `sequence_len`
|
||||
|
||||
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
|
||||
|
||||
Using adamw_bnb_8bit might also save you some memory.
|
||||
|
||||
> `failed (exitcode: -9)`
|
||||
|
||||
Usually means your system has run out of system memory.
|
||||
Similarly, you should consider reducing the same settings as when you run out of VRAM.
|
||||
Additionally, look into upgrading your system RAM which should be simpler than GPU upgrades.
|
||||
|
||||
> RuntimeError: expected scalar type Float but found Half
|
||||
|
||||
Try set `fp16: true`
|
||||
|
||||
> NotImplementedError: No operator found for `memory_efficient_attention_forward` ...
|
||||
|
||||
Try to turn off xformers.
|
||||
|
||||
> accelerate config missing
|
||||
|
||||
It's safe to ignore it.
|
||||
|
||||
> NCCL Timeouts during training
|
||||
|
||||
See the [NCCL](docs/nccl.qmd) guide.
|
||||
|
||||
|
||||
### Tokenization Mismatch b/w Inference & Training
|
||||
|
||||
For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks.
|
||||
|
||||
If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following:
|
||||
|
||||
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
|
||||
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
|
||||
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly.
|
||||
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
|
||||
|
||||
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/finetuning/05_tokenizer_gotchas.html) for a concrete example.
|
||||
|
||||
## Debugging Axolotl
|
||||
|
||||
See [this debugging guide](docs/debugging.qmd) for tips on debugging Axolotl, along with an example configuration for debugging with VSCode.
|
||||
|
||||
## Need help? 🙋
|
||||
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where our community members can help you.
|
||||
|
||||
Need dedicated support? Please contact us at [✉️wing@axolotl.ai](ailto:wing@axolotl.ai) for dedicated support options.
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
@@ -28,16 +28,21 @@ website:
|
||||
- section: "How-To Guides"
|
||||
contents:
|
||||
# TODO Edit folder structure after we have more docs.
|
||||
- docs/getting-started.qmd
|
||||
- docs/installation.qmd
|
||||
- docs/debugging.qmd
|
||||
- docs/inference.qmd
|
||||
- docs/multipack.qmd
|
||||
- docs/fsdp_qlora.qmd
|
||||
- docs/input_output.qmd
|
||||
- docs/rlhf.qmd
|
||||
- docs/nccl.qmd
|
||||
- docs/mac.qmd
|
||||
- docs/multi-gpu.qmd
|
||||
- docs/multi-node.qmd
|
||||
- docs/unsloth.qmd
|
||||
- docs/amd_hpc.qmd
|
||||
- docs/ray-integration.qmd
|
||||
- section: "Dataset Formats"
|
||||
contents: docs/dataset-formats/*
|
||||
- section: "Reference"
|
||||
@@ -45,7 +50,6 @@ website:
|
||||
- docs/config.qmd
|
||||
- docs/faq.qmd
|
||||
|
||||
|
||||
format:
|
||||
html:
|
||||
theme: materia
|
||||
|
||||
@@ -32,9 +32,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
fi
|
||||
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
@@ -6,5 +6,6 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -23,8 +23,8 @@ df_template = template_env.get_template("Dockerfile.jinja")
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
|
||||
"CUDA": os.environ.get("CUDA", "121"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
|
||||
@@ -23,8 +23,8 @@ df_template = template_env.get_template("Dockerfile.jinja")
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
|
||||
"CUDA": os.environ.get("CUDA", "121"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
@@ -38,16 +38,12 @@ temp_dir = tempfile.mkdtemp()
|
||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
f.write(dockerfile_contents)
|
||||
|
||||
cicd_image = (
|
||||
Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
context_mount=None,
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
)
|
||||
.env(df_args)
|
||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
)
|
||||
cicd_image = Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
context_mount=None,
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
).env(df_args)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
@@ -59,7 +55,7 @@ VOLUME_CONFIG = {
|
||||
}
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
|
||||
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
|
||||
@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
@@ -20,7 +20,8 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||
chmod +x /root/cloud-entrypoint.sh
|
||||
chmod +x /root/cloud-entrypoint.sh && \
|
||||
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
|
||||
|
||||
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
|
||||
CMD ["sleep", "infinity"]
|
||||
|
||||
256
docs/cli.qmd
Normal file
256
docs/cli.qmd
Normal file
@@ -0,0 +1,256 @@
|
||||
# Axolotl CLI Documentation
|
||||
|
||||
The Axolotl CLI provides a streamlined interface for training and fine-tuning large language models. This guide covers
|
||||
the CLI commands, their usage, and common examples.
|
||||
|
||||
### Table of Contents
|
||||
|
||||
- Basic Commands
|
||||
- Command Reference
|
||||
- fetch
|
||||
- preprocess
|
||||
- train
|
||||
- inference
|
||||
- merge-lora
|
||||
- merge-sharded-fsdp-weights
|
||||
- evaluate
|
||||
- lm-eval
|
||||
- Legacy CLI Usage
|
||||
- Remote Compute with Modal Cloud
|
||||
- Cloud Configuration
|
||||
- Running on Modal Cloud
|
||||
- Cloud Configuration Options
|
||||
|
||||
|
||||
### Basic Commands
|
||||
|
||||
All Axolotl commands follow this general structure:
|
||||
|
||||
```bash
|
||||
axolotl <command> [config.yml] [options]
|
||||
```
|
||||
|
||||
The config file can be local or a URL to a raw YAML file.
|
||||
|
||||
### Command Reference
|
||||
|
||||
#### fetch
|
||||
|
||||
Downloads example configurations and deepspeed configs to your local machine.
|
||||
|
||||
```bash
|
||||
# Get example YAML files
|
||||
axolotl fetch examples
|
||||
|
||||
# Get deepspeed config files
|
||||
axolotl fetch deepspeed_configs
|
||||
|
||||
# Specify custom destination
|
||||
axolotl fetch examples --dest path/to/folder
|
||||
```
|
||||
|
||||
#### preprocess
|
||||
|
||||
Preprocesses and tokenizes your dataset before training. This is recommended for large datasets.
|
||||
|
||||
```bash
|
||||
# Basic preprocessing
|
||||
axolotl preprocess config.yml
|
||||
|
||||
# Preprocessing with one GPU
|
||||
CUDA_VISIBLE_DEVICES="0" axolotl preprocess config.yml
|
||||
|
||||
# Debug mode to see processed examples
|
||||
axolotl preprocess config.yml --debug
|
||||
|
||||
# Debug with limited examples
|
||||
axolotl preprocess config.yml --debug --debug-num-examples 5
|
||||
```
|
||||
|
||||
Configuration options:
|
||||
|
||||
```yaml
|
||||
dataset_prepared_path: Local folder for saving preprocessed data
|
||||
push_dataset_to_hub: HuggingFace repo to push preprocessed data (optional)
|
||||
```
|
||||
|
||||
#### train
|
||||
|
||||
Trains or fine-tunes a model using the configuration specified in your YAML file.
|
||||
|
||||
```bash
|
||||
# Basic training
|
||||
axolotl train config.yml
|
||||
|
||||
# Train and set/override specific options
|
||||
axolotl train config.yml \
|
||||
--learning-rate 1e-4 \
|
||||
--micro-batch-size 2 \
|
||||
--num-epochs 3
|
||||
|
||||
# Training without accelerate
|
||||
axolotl train config.yml --no-accelerate
|
||||
|
||||
# Resume training from checkpoint
|
||||
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
|
||||
```
|
||||
|
||||
#### inference
|
||||
|
||||
Runs inference using your trained model in either CLI or Gradio interface mode.
|
||||
|
||||
```bash
|
||||
# CLI inference with LoRA
|
||||
axolotl inference config.yml --lora-model-dir="./outputs/lora-out"
|
||||
|
||||
# CLI inference with full model
|
||||
axolotl inference config.yml --base-model="./completed-model"
|
||||
|
||||
# Gradio web interface
|
||||
axolotl inference config.yml --gradio \
|
||||
--lora-model-dir="./outputs/lora-out"
|
||||
|
||||
# Inference with input from file
|
||||
cat prompt.txt | axolotl inference config.yml \
|
||||
--base-model="./completed-model"
|
||||
```
|
||||
|
||||
#### merge-lora
|
||||
|
||||
Merges trained LoRA adapters into the base model.
|
||||
|
||||
```bash
|
||||
# Basic merge
|
||||
axolotl merge-lora config.yml
|
||||
|
||||
# Specify LoRA directory (usually used with checkpoints)
|
||||
axolotl merge-lora config.yml --lora-model-dir="./lora-output/checkpoint-100"
|
||||
|
||||
# Merge using CPU (if out of GPU memory)
|
||||
CUDA_VISIBLE_DEVICES="" axolotl merge-lora config.yml
|
||||
```
|
||||
|
||||
Configuration options:
|
||||
|
||||
```yaml
|
||||
gpu_memory_limit: Limit GPU memory usage
|
||||
lora_on_cpu: Load LoRA weights on CPU
|
||||
```
|
||||
|
||||
#### merge-sharded-fsdp-weights
|
||||
|
||||
Merges sharded FSDP model checkpoints into a single combined checkpoint.
|
||||
|
||||
```bash
|
||||
# Basic merge
|
||||
axolotl merge-sharded-fsdp-weights config.yml
|
||||
```
|
||||
|
||||
#### evaluate
|
||||
|
||||
Evaluates a model's performance using metrics specified in the config.
|
||||
|
||||
```bash
|
||||
# Basic evaluation
|
||||
axolotl evaluate config.yml
|
||||
```
|
||||
|
||||
#### lm-eval
|
||||
|
||||
Runs LM Evaluation Harness on your model.
|
||||
|
||||
```bash
|
||||
# Basic evaluation
|
||||
axolotl lm-eval config.yml
|
||||
|
||||
# Evaluate specific tasks
|
||||
axolotl lm-eval config.yml --tasks arc_challenge,hellaswag
|
||||
```
|
||||
|
||||
Configuration options:
|
||||
|
||||
```yaml
|
||||
lm_eval_tasks: List of tasks to evaluate
|
||||
lm_eval_batch_size: Batch size for evaluation
|
||||
output_dir: Directory to save evaluation results
|
||||
```
|
||||
|
||||
### Legacy CLI Usage
|
||||
|
||||
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
|
||||
|
||||
```bash
|
||||
# Preprocess
|
||||
python -m axolotl.cli.preprocess config.yml
|
||||
|
||||
# Train
|
||||
accelerate launch -m axolotl.cli.train config.yml
|
||||
|
||||
# Inference
|
||||
accelerate launch -m axolotl.cli.inference config.yml \
|
||||
--lora_model_dir="./outputs/lora-out"
|
||||
|
||||
# Gradio interface
|
||||
accelerate launch -m axolotl.cli.inference config.yml \
|
||||
--lora_model_dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
### Remote Compute with Modal Cloud
|
||||
|
||||
Axolotl supports running training and inference workloads on Modal cloud infrastructure. This is configured using a
|
||||
cloud YAML file alongside your regular Axolotl config.
|
||||
|
||||
#### Cloud Configuration
|
||||
|
||||
Create a cloud config YAML with your Modal settings:
|
||||
|
||||
```yaml
|
||||
# cloud_config.yml
|
||||
provider: modal
|
||||
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
|
||||
gpu_count: 1 # Number of GPUs to use
|
||||
timeout: 86400 # Maximum runtime in seconds (24 hours)
|
||||
branch: main # Git branch to use (optional)
|
||||
|
||||
volumes: # Persistent storage volumes
|
||||
- name: axolotl-cache
|
||||
mount: /workspace/cache
|
||||
|
||||
env: # Environment variables
|
||||
- WANDB_API_KEY
|
||||
- HF_TOKEN
|
||||
```
|
||||
|
||||
#### Running on Modal Cloud
|
||||
|
||||
Commands that support the --cloud flag:
|
||||
|
||||
```bash
|
||||
# Preprocess on cloud
|
||||
axolotl preprocess config.yml --cloud cloud_config.yml
|
||||
|
||||
# Train on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml
|
||||
|
||||
# Train without accelerate on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
|
||||
|
||||
# Run lm-eval on cloud
|
||||
axolotl lm-eval config.yml --cloud cloud_config.yml
|
||||
```
|
||||
|
||||
#### Cloud Configuration Options
|
||||
|
||||
```yaml
|
||||
provider: compute provider, currently only `modal` is supported
|
||||
gpu: GPU type to use
|
||||
gpu_count: Number of GPUs (default: 1)
|
||||
memory: RAM in GB (default: 128)
|
||||
timeout: Maximum runtime in seconds
|
||||
timeout_preprocess: Preprocessing timeout
|
||||
branch: Git branch to use
|
||||
docker_tag: Custom Docker image tag
|
||||
volumes: List of persistent storage volumes
|
||||
env: Environment variables to pass
|
||||
secrets: Secrets to inject
|
||||
```
|
||||
@@ -187,6 +187,12 @@ rl:
|
||||
# whether to perform weighting if doing DPO training. Boolean.
|
||||
dpo_use_weighting:
|
||||
|
||||
# reward modelling: `True` or `False`
|
||||
reward_model:
|
||||
|
||||
# process reward modelling: `True` or `False`
|
||||
process_reward_model:
|
||||
|
||||
# The name of the chat template to use for training, following values are supported:
|
||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
||||
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
||||
@@ -244,6 +250,8 @@ total_num_tokens:
|
||||
sample_packing_group_size: 100000
|
||||
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
||||
sample_packing_bin_size: 200
|
||||
# whether to concatenate samples during pretraining
|
||||
pretraining_sample_concatenation:
|
||||
|
||||
# Use batch flattening for speedups when not using sample_packing
|
||||
batch_flattening:
|
||||
@@ -358,10 +366,11 @@ warmup_ratio: 0.05 # cannot use with warmup_steps
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||
eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps
|
||||
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
||||
save_strategy: # Set to `"no"` to skip checkpoint saves
|
||||
save_steps: # Leave empty to save at each epoch
|
||||
eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`.
|
||||
save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`.
|
||||
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
|
||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
|
||||
@@ -8,14 +8,12 @@ order: 3
|
||||
|
||||
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
|
||||
|
||||
|
||||
## pygmalion
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
|
||||
## chat_template
|
||||
|
||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||
|
||||
26
docs/dataset-formats/stepwise_supervised.qmd
Normal file
26
docs/dataset-formats/stepwise_supervised.qmd
Normal file
@@ -0,0 +1,26 @@
|
||||
---
|
||||
title: Stepwise Supervised Format
|
||||
description: Format for datasets with stepwise completions and labels
|
||||
order: 3
|
||||
---
|
||||
|
||||
## Stepwise Supervised
|
||||
|
||||
The stepwise supervised format is designed for chain-of-thought (COT) reasoning
|
||||
datasets where each example contains multiple completion steps and a preference label
|
||||
for each step.
|
||||
|
||||
### Example
|
||||
|
||||
Here's a simple example of a stepwise supervised dataset entry:
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "Which number is larger, 9.8 or 9.11?",
|
||||
"completions": [
|
||||
"The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.",
|
||||
"Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."
|
||||
],
|
||||
"labels": [true, false]
|
||||
}
|
||||
```
|
||||
155
docs/getting-started.qmd
Normal file
155
docs/getting-started.qmd
Normal file
@@ -0,0 +1,155 @@
|
||||
---
|
||||
title: "Getting Started with Axolotl"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
number-sections: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
This guide will walk you through your first model fine-tuning project with Axolotl.
|
||||
|
||||
## Quick Example {#sec-quick-example}
|
||||
|
||||
Let's start by fine-tuning a small language model using LoRA. This example uses a 1B parameter model to ensure it runs on most GPUs.
|
||||
Assuming `axolotl` is installed (if not, see our [Installation Guide](installation.qmd))
|
||||
|
||||
1. Download example configs:
|
||||
```shell
|
||||
axolotl fetch examples
|
||||
```
|
||||
|
||||
2. Run the training:
|
||||
```shell
|
||||
axolotl train examples/llama-3/lora-1b.yml
|
||||
```
|
||||
|
||||
That's it! Let's understand what just happened.
|
||||
|
||||
## Understanding the Process {#sec-understanding}
|
||||
|
||||
### The Configuration File {#sec-config}
|
||||
|
||||
The YAML configuration file controls everything about your training. Here's what (part of) our example config looks like:
|
||||
|
||||
```yaml
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
```
|
||||
|
||||
See our [Config options](config.qmd) for more details.
|
||||
|
||||
### Training {#sec-training}
|
||||
|
||||
When you run `axolotl train`, Axolotl:
|
||||
|
||||
1. Downloads the base model
|
||||
2. (If specified) applies LoRA adapter layers
|
||||
3. Loads and processes the dataset
|
||||
4. Runs the training loop
|
||||
5. Saves the trained model and / or LoRA weights
|
||||
|
||||
## Your First Custom Training {#sec-custom}
|
||||
|
||||
Let's modify the example for your own data:
|
||||
|
||||
1. Create a new config file `my_training.yml`:
|
||||
|
||||
```yaml
|
||||
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
||||
adapter: lora
|
||||
|
||||
# Training settings
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
learning_rate: 0.0003
|
||||
|
||||
# Your dataset
|
||||
datasets:
|
||||
- path: my_data.jsonl # Your local data file
|
||||
type: alpaca # Or other format
|
||||
```
|
||||
|
||||
This specific config is for LoRA fine-tuning a model with instruction tuning data using
|
||||
the `alpaca` dataset format, which has the following format:
|
||||
|
||||
```json
|
||||
{
|
||||
"instruction": "Write a description of alpacas.",
|
||||
"input": "",
|
||||
"output": "Alpacas are domesticated South American camelids..."
|
||||
}
|
||||
```
|
||||
|
||||
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
|
||||
format them.
|
||||
|
||||
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
|
||||
format):
|
||||
|
||||
```json
|
||||
{"instruction": "Classify this text", "input": "I love this!", "output": "positive"}
|
||||
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
||||
```
|
||||
|
||||
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
|
||||
|
||||
3. Run the training:
|
||||
|
||||
```shell
|
||||
axolotl train my_training.yml
|
||||
```
|
||||
|
||||
## Common Tasks {#sec-common-tasks}
|
||||
|
||||
### Testing Your Model {#sec-testing}
|
||||
|
||||
After training, test your model:
|
||||
|
||||
```shell
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
```
|
||||
|
||||
### Preprocessing Data {#sec-preprocessing}
|
||||
|
||||
For large datasets, preprocess first:
|
||||
|
||||
```shell
|
||||
axolotl preprocess my_training.yml
|
||||
```
|
||||
|
||||
### Using a UI {#sec-ui}
|
||||
|
||||
Launch a Gradio interface:
|
||||
|
||||
```shell
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
## Next Steps {#sec-next-steps}
|
||||
|
||||
Now that you have the basics, you might want to:
|
||||
|
||||
- Try different model architectures
|
||||
- Experiment with hyperparameters
|
||||
- Use more advanced training methods
|
||||
- Scale up to larger models
|
||||
|
||||
Check our other guides for details on these topics:
|
||||
|
||||
- [Configuration Guide](config.qmd) - Full configuration options
|
||||
- [Dataset Formats](dataset-formats) - Working with different data formats
|
||||
- [Multi-GPU Training](multi-gpu.qmd)
|
||||
- [Multi-Node Training](multi-node.qmd)
|
||||
BIN
docs/images/ray-cluster-dashboard.png
Normal file
BIN
docs/images/ray-cluster-dashboard.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 292 KiB |
148
docs/inference.qmd
Normal file
148
docs/inference.qmd
Normal file
@@ -0,0 +1,148 @@
|
||||
---
|
||||
title: "Inference Guide"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
number-sections: true
|
||||
code-tools: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
|
||||
|
||||
## Quick Start {#sec-quickstart}
|
||||
|
||||
### Basic Inference {#sec-basic}
|
||||
|
||||
::: {.panel-tabset}
|
||||
|
||||
## LoRA Models
|
||||
|
||||
```{.bash}
|
||||
axolotl inference your_config.yml --lora-model-dir="./lora-output-dir"
|
||||
```
|
||||
|
||||
## Full Fine-tuned Models
|
||||
|
||||
```{.bash}
|
||||
axolotl inference your_config.yml --base-model="./completed-model"
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
## Advanced Usage {#sec-advanced}
|
||||
|
||||
### Gradio Interface {#sec-gradio}
|
||||
|
||||
Launch an interactive web interface:
|
||||
|
||||
```{.bash}
|
||||
axolotl inference your_config.yml --gradio
|
||||
```
|
||||
|
||||
### File-based Prompts {#sec-file-prompts}
|
||||
|
||||
Process prompts from a text file:
|
||||
|
||||
```{.bash}
|
||||
cat /tmp/prompt.txt | axolotl inference your_config.yml \
|
||||
--base-model="./completed-model" --prompter=None
|
||||
```
|
||||
|
||||
### Memory Optimization {#sec-memory}
|
||||
|
||||
For large models or limited memory:
|
||||
|
||||
```{.bash}
|
||||
axolotl inference your_config.yml --load-in-8bit=True
|
||||
```
|
||||
|
||||
## Merging LoRA Weights {#sec-merging}
|
||||
|
||||
Merge LoRA adapters with the base model:
|
||||
|
||||
```{.bash}
|
||||
axolotl merge-lora your_config.yml --lora-model-dir="./completed-model"
|
||||
```
|
||||
|
||||
### Memory Management for Merging {#sec-memory-management}
|
||||
|
||||
::: {.panel-tabset}
|
||||
|
||||
## Configuration Options
|
||||
|
||||
```{.yaml}
|
||||
gpu_memory_limit: 20GiB # Adjust based on your GPU
|
||||
lora_on_cpu: true # Process on CPU if needed
|
||||
```
|
||||
|
||||
## Force CPU Merging
|
||||
|
||||
```{.bash}
|
||||
CUDA_VISIBLE_DEVICES="" axolotl merge-lora ...
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
## Tokenization {#sec-tokenization}
|
||||
|
||||
### Common Issues {#sec-tokenization-issues}
|
||||
|
||||
::: {.callout-warning}
|
||||
Tokenization mismatches between training and inference are a common source of problems.
|
||||
:::
|
||||
|
||||
To debug:
|
||||
|
||||
1. Check training tokenization:
|
||||
```{.bash}
|
||||
axolotl preprocess your_config.yml --debug
|
||||
```
|
||||
|
||||
2. Verify inference tokenization by decoding tokens before model input
|
||||
|
||||
3. Compare token IDs between training and inference
|
||||
|
||||
### Special Tokens {#sec-special-tokens}
|
||||
|
||||
Configure special tokens in your YAML:
|
||||
|
||||
```{.yaml}
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
tokens:
|
||||
- "<|im_start|>"
|
||||
- "<|im_end|>"
|
||||
```
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
### Common Problems {#sec-common-problems}
|
||||
|
||||
::: {.panel-tabset}
|
||||
|
||||
## Memory Issues
|
||||
|
||||
- Use 8-bit loading
|
||||
- Reduce batch sizes
|
||||
- Try CPU offloading
|
||||
|
||||
## Token Issues
|
||||
|
||||
- Verify special tokens
|
||||
- Check tokenizer settings
|
||||
- Compare training and inference preprocessing
|
||||
|
||||
## Performance Issues
|
||||
|
||||
- Verify model loading
|
||||
- Check prompt formatting
|
||||
- Ensure temperature/sampling settings
|
||||
|
||||
:::
|
||||
|
||||
For more details, see our [debugging guide](debugging.qmd).
|
||||
119
docs/installation.qmd
Normal file
119
docs/installation.qmd
Normal file
@@ -0,0 +1,119 @@
|
||||
---
|
||||
title: "Installation Guide"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
number-sections: true
|
||||
code-tools: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
This guide covers all the ways you can install and set up Axolotl for your environment.
|
||||
|
||||
## Requirements {#sec-requirements}
|
||||
|
||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.10
|
||||
- PyTorch ≥2.4.1
|
||||
|
||||
## Installation Methods {#sec-installation-methods}
|
||||
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
|
||||
```{.bash}
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
|
||||
installed) in order not to clobber it, and so that we set the correct version of
|
||||
dependencies that are specific to the PyTorch version or other installed
|
||||
co-dependencies.
|
||||
|
||||
### Edge/Development Build {#sec-edge-build}
|
||||
|
||||
For the latest features between releases:
|
||||
|
||||
```{.bash}
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip3 install packaging ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
### Docker {#sec-docker}
|
||||
|
||||
```{.bash}
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
```
|
||||
|
||||
For development with Docker:
|
||||
|
||||
```{.bash}
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
### Advanced Docker Configuration
|
||||
```{.bash}
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
||||
--name axolotl --ipc=host \
|
||||
--ulimit memlock=-1 --ulimit stack=67108864 \
|
||||
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
|
||||
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
|
||||
axolotlai/axolotl:main-latest
|
||||
```
|
||||
:::
|
||||
|
||||
## Cloud Environments {#sec-cloud}
|
||||
|
||||
### Cloud GPU Providers {#sec-cloud-gpu}
|
||||
|
||||
For providers supporting Docker:
|
||||
|
||||
- Use `axolotlai/axolotl-cloud:main-latest`
|
||||
- Available on:
|
||||
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
### Google Colab {#sec-colab}
|
||||
|
||||
Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb).
|
||||
|
||||
## Platform-Specific Instructions {#sec-platform-specific}
|
||||
|
||||
### macOS {#sec-macos}
|
||||
|
||||
```{.bash}
|
||||
pip3 install --no-build-isolation -e '.'
|
||||
```
|
||||
|
||||
See @sec-troubleshooting for Mac-specific issues.
|
||||
|
||||
### Windows {#sec-windows}
|
||||
|
||||
::: {.callout-important}
|
||||
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||
:::
|
||||
|
||||
## Environment Managers {#sec-env-managers}
|
||||
|
||||
### Conda/Pip venv {#sec-conda}
|
||||
|
||||
1. Install Python ≥3.10
|
||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||
3. Install Axolotl:
|
||||
```{.bash}
|
||||
pip3 install packaging
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
4. (Optional) Login to Hugging Face:
|
||||
```{.bash}
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
If you encounter installation issues, see our [FAQ](faq.qmd) and [Debugging Guide](debugging.qmd).
|
||||
29
docs/lr_groups.qmd
Normal file
29
docs/lr_groups.qmd
Normal file
@@ -0,0 +1,29 @@
|
||||
---
|
||||
title: Learning Rate Groups
|
||||
description: "Setting different learning rates by module name"
|
||||
---
|
||||
|
||||
## Background
|
||||
|
||||
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
|
||||
modules in a model.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
lr_groups:
|
||||
- name: o_proj
|
||||
modules:
|
||||
- self_attn.o_proj.weight
|
||||
lr: 1e-6
|
||||
- name: q_proj
|
||||
modules:
|
||||
- model.layers.2.self_attn.q_proj.weight
|
||||
lr: 1e-5
|
||||
|
||||
learning_rate: 2e-5
|
||||
```
|
||||
|
||||
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
||||
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
||||
self attention `q_proj` module.
|
||||
118
docs/multi-gpu.qmd
Normal file
118
docs/multi-gpu.qmd
Normal file
@@ -0,0 +1,118 @@
|
||||
---
|
||||
title: "Multi-GPU Training Guide"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
number-sections: true
|
||||
code-tools: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
This guide covers advanced training configurations for multi-GPU setups using Axolotl.
|
||||
|
||||
## Overview {#sec-overview}
|
||||
|
||||
Axolotl supports several methods for multi-GPU training:
|
||||
|
||||
- DeepSpeed (recommended)
|
||||
- FSDP (Fully Sharded Data Parallel)
|
||||
- FSDP + QLoRA
|
||||
|
||||
## DeepSpeed {#sec-deepspeed}
|
||||
|
||||
DeepSpeed is the recommended approach for multi-GPU training due to its stability and performance. It provides various optimization levels through ZeRO stages.
|
||||
|
||||
### Configuration {#sec-deepspeed-config}
|
||||
|
||||
Add to your YAML config:
|
||||
|
||||
```{.yaml}
|
||||
deepspeed: deepspeed_configs/zero1.json
|
||||
```
|
||||
|
||||
### Usage {#sec-deepspeed-usage}
|
||||
|
||||
```{.bash}
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
|
||||
```
|
||||
|
||||
### ZeRO Stages {#sec-zero-stages}
|
||||
|
||||
We provide default configurations for:
|
||||
|
||||
- ZeRO Stage 1 (`zero1.json`)
|
||||
- ZeRO Stage 2 (`zero2.json`)
|
||||
- ZeRO Stage 3 (`zero3.json`)
|
||||
|
||||
Choose based on your memory requirements and performance needs.
|
||||
|
||||
## FSDP {#sec-fsdp}
|
||||
|
||||
### Basic FSDP Configuration {#sec-fsdp-config}
|
||||
|
||||
```{.yaml}
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_offload_params: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||
|
||||
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
||||
|
||||
## Performance Optimization {#sec-performance}
|
||||
|
||||
### Liger Kernel Integration {#sec-liger}
|
||||
|
||||
::: {.callout-note}
|
||||
Liger Kernel provides efficient Triton kernels for LLM training, offering:
|
||||
|
||||
- 20% increase in multi-GPU training throughput
|
||||
- 60% reduction in memory usage
|
||||
- Compatibility with both FSDP and DeepSpeed
|
||||
:::
|
||||
|
||||
Configuration:
|
||||
|
||||
```{.yaml}
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
### NCCL Issues {#sec-nccl}
|
||||
|
||||
For NCCL-related problems, see our [NCCL troubleshooting guide](nccl.qmd).
|
||||
|
||||
### Common Problems {#sec-common-problems}
|
||||
|
||||
::: {.panel-tabset}
|
||||
|
||||
## Memory Issues
|
||||
|
||||
- Reduce `micro_batch_size`
|
||||
- Reduce `eval_batch_size`
|
||||
- Adjust `gradient_accumulation_steps`
|
||||
- Consider using a higher ZeRO stage
|
||||
|
||||
## Training Instability
|
||||
|
||||
- Start with DeepSpeed ZeRO-2
|
||||
- Monitor loss values
|
||||
- Check learning rates
|
||||
|
||||
:::
|
||||
|
||||
For more detailed troubleshooting, see our [debugging guide](debugging.qmd).
|
||||
93
docs/ray-integration.qmd
Normal file
93
docs/ray-integration.qmd
Normal file
@@ -0,0 +1,93 @@
|
||||
---
|
||||
title: Ray Train integration
|
||||
description: How to use Axolotl with Ray Train
|
||||
---
|
||||
|
||||
Axolotl supports using Ray as an alternative to `accelerate` for orchestrating training. This is especially useful for multi-node training since you only have to setup code and dependencies in a single node and launch training as if you were using a single node.
|
||||
|
||||
With the `--use-ray` CLI flag, Axolotl will use Ray Train's [`TorchTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray.train.torch.TorchTrainer) to run training.
|
||||
|
||||
## Ray cluster setup
|
||||
|
||||
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs here: https://docs.ray.io/en/latest/cluster/getting-started.html
|
||||
|
||||
Every Ray cluster has one _head_ node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this [doc](https://docs.ray.io/en/latest/cluster/key-concepts.html#cluster-key-concepts).
|
||||
|
||||
## Sanity check
|
||||
|
||||
To run a sanity check on whether your ray cluster is setup properly, execute the following on the head node:
|
||||
|
||||
```bash
|
||||
ray status
|
||||
```
|
||||
|
||||
The output should have a summary of your Ray cluster - list of all the nodes in your cluster, the number of CPUs and GPUs in your cluster, etc. For example, if you have a cluster with 1 CPU-only head node and 2 4xL40S worker nodes, the output can look like this:
|
||||
|
||||
|
||||
```
|
||||
Node status
|
||||
---------------------------------------------------------------
|
||||
Active:
|
||||
1 head
|
||||
Idle:
|
||||
2 4xL40S:48CPU-384GB
|
||||
Pending:
|
||||
(no pending nodes)
|
||||
Recent failures:
|
||||
(no failures)
|
||||
|
||||
Resources
|
||||
---------------------------------------------------------------
|
||||
Usage:
|
||||
0.0/96.0 CPU
|
||||
0.0/8.0 GPU
|
||||
0B/800.00GiB memory
|
||||
0B/229.57GiB object_store_memory
|
||||
|
||||
Demands:
|
||||
(no resource demands)
|
||||
```
|
||||
|
||||
You should also be able to see the same on the [Ray dashboard](https://docs.ray.io/en/latest/ray-observability/getting-started.html).
|
||||
|
||||
|
||||
## Configuring training with Ray Train
|
||||
|
||||
You can find an example configuration at `configs/llama-3/lora-1b-ray.yaml`.
|
||||
|
||||
The key parameters to note here are:
|
||||
|
||||
```yaml
|
||||
...
|
||||
use_ray: true
|
||||
ray_num_workers: 4
|
||||
# optional
|
||||
resources_per_worker:
|
||||
GPU: 1
|
||||
...
|
||||
```
|
||||
|
||||
- `use_ray`: This is the flag that enables the Ray Train integration. You can either use the corresponding `--use-ray` flag in the CLI or set `use_ray` in the config file.
|
||||
- `ray_num_workers`: This is the number of workers/GPUs to use for training.
|
||||
- `resources_per_worker`: This is the Ray [resource request](https://docs.ray.io/en/latest/ray-core/scheduling/resources.html) for each worker. This can be used to request a specific GPU type or a custom resource for each worker. For example, if your ray cluster has GPUs of different types, and you only want to use NVIDIA L40S GPUs, you can do
|
||||
|
||||
```yaml
|
||||
resources_per_worker:
|
||||
accelerator_type:L40S: 0.001
|
||||
```
|
||||
|
||||
## Launching training
|
||||
|
||||
You can simply run the following command on the head node:
|
||||
|
||||
```bash
|
||||
axolotl train examples/llama-3/lora-1b-ray.yml --use-ray
|
||||
```
|
||||
|
||||
This will launch training on the head node and workers will be scheduled automatically by Ray Train to run on the appropriate head or worker nodes.
|
||||
|
||||
You can also monitor training progress on the Ray dashboard.
|
||||
|
||||
Coming back to the example on a Ray cluster with 1 head node and 2 4xL40S worker nodes, let's say you want to make use of all 8 GPUs. You would be able to just set `ray_num_workers: 8` and run the previous command. The Cluster tab will show the following:
|
||||
|
||||

|
||||
47
docs/reward_modelling.qmd
Normal file
47
docs/reward_modelling.qmd
Normal file
@@ -0,0 +1,47 @@
|
||||
---
|
||||
title: "Reward Modelling"
|
||||
description: "Reward models are used to guide models towards behaviors which is preferred by humans, by training over large datasets annotated with human preferences. "
|
||||
---
|
||||
|
||||
### Overview
|
||||
|
||||
Reward modelling is a technique used to train models to predict the reward or value of a given input. This is particularly useful in reinforcement learning scenarios where the model needs to evaluate the quality of its actions or predictions.
|
||||
We support the reward modelling techniques supported by `trl`.
|
||||
|
||||
### (Outcome) Reward Models
|
||||
|
||||
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
|
||||
|
||||
```yaml
|
||||
base_model: google/gemma-2-2b
|
||||
model_type: AutoModelForSequenceClassification
|
||||
num_labels: 1
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
reward_model: true
|
||||
chat_template: gemma
|
||||
datasets:
|
||||
- path: argilla/distilabel-intel-orca-dpo-pairs
|
||||
type: bradley_terry.chat_template
|
||||
|
||||
val_set_size: 0.1
|
||||
eval_steps: 100
|
||||
```
|
||||
|
||||
### Process Reward Models (PRM)
|
||||
|
||||
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-3B
|
||||
model_type: AutoModelForTokenClassification
|
||||
num_labels: 2
|
||||
|
||||
process_reward_model: true
|
||||
datasets:
|
||||
- path: trl-lib/math_shepherd
|
||||
type: stepwise_supervised
|
||||
split: train
|
||||
|
||||
val_set_size: 0.1
|
||||
eval_steps: 100
|
||||
```
|
||||
@@ -29,7 +29,7 @@ datasets:
|
||||
type: chatml.intel
|
||||
- path: argilla/ultrafeedback-binarized-preferences
|
||||
split: train
|
||||
type: chatml.argilla
|
||||
type: chatml
|
||||
```
|
||||
|
||||
#### IPO
|
||||
|
||||
@@ -46,7 +46,7 @@ output_dir: ./outputs/btlm-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
adam_beta2: 0.95
|
||||
adam_eps: 0.000000001
|
||||
max_grad_norm: 1.0
|
||||
|
||||
28
examples/cloud/modal.yaml
Normal file
28
examples/cloud/modal.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
project_name:
|
||||
volumes:
|
||||
- name: axolotl-data
|
||||
mount: /workspace/data
|
||||
- name: axolotl-artifacts
|
||||
mount: /workspace/artifacts
|
||||
|
||||
# environment variables from local to set as secrets
|
||||
secrets:
|
||||
- HF_TOKEN
|
||||
- WANDB_API_KEY
|
||||
|
||||
# Which branch of axolotl to use remotely
|
||||
branch:
|
||||
|
||||
# additional custom commands when building the image
|
||||
dockerfile_commands:
|
||||
|
||||
gpu: h100
|
||||
gpu_count: 1
|
||||
|
||||
# Train specific configurations
|
||||
memory: 128
|
||||
timeout: 86400
|
||||
|
||||
# Preprocess specific configurations
|
||||
memory_preprocess: 32
|
||||
timeout_preprocess: 14400
|
||||
@@ -27,7 +27,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ peft_use_rslora: true
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 8
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
base_model: google/gemma-2-2b
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForSequenceClassification
|
||||
num_labels: 1
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -34,7 +34,7 @@ lora_target_linear: false
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.00001
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ output_dir: ./outputs/model-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
adam_beta2: 0.95
|
||||
adam_eps: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
|
||||
@@ -39,7 +39,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.00001
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
|
||||
79
examples/llama-3/lora-1b-ray.yml
Normal file
79
examples/llama-3/lora-1b-ray.yml
Normal file
@@ -0,0 +1,79 @@
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed: deepspeed_configs/zero3.json
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
use_ray: true
|
||||
ray_num_workers: 4
|
||||
@@ -30,7 +30,7 @@ lora_target_linear: true
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.00001
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.00001
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
adam_beta2: 0.95
|
||||
adam_epsilon: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
|
||||
@@ -38,7 +38,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
adam_beta2: 0.95
|
||||
adam_epsilon: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
|
||||
@@ -38,7 +38,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
adam_beta2: 0.95
|
||||
adam_epsilon: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
|
||||
@@ -39,7 +39,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 12
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
adam_beta2: 0.95
|
||||
adam_epsilon: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
|
||||
@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
adam_beta2: 0.95
|
||||
adam_epsilon: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
|
||||
72
examples/qwen2/prm.yaml
Normal file
72
examples/qwen2/prm.yaml
Normal file
@@ -0,0 +1,72 @@
|
||||
base_model: Qwen/Qwen2.5-3B
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForTokenClassification
|
||||
num_labels: 2
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
process_reward_model: true
|
||||
chat_template:
|
||||
datasets:
|
||||
- path: trl-lib/math_shepherd
|
||||
type: stepwise_supervised
|
||||
step_separator: "\n"
|
||||
max_completion_length:
|
||||
train_on_last_step_only: false
|
||||
|
||||
val_set_size: 0.2
|
||||
output_dir: ./outputs/out
|
||||
remove_unused_columns: false
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 8
|
||||
eval_batch_size: 8
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32:
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
eval_steps: 100
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -37,7 +37,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
|
||||
67
examples/qwen2/reward-model.yaml
Normal file
67
examples/qwen2/reward-model.yaml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: Qwen/Qwen2.5-0.5B
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForSequenceClassification
|
||||
num_labels: 1
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
reward_model: true
|
||||
chat_template: qwen_25
|
||||
datasets:
|
||||
- path: argilla/distilabel-intel-orca-dpo-pairs
|
||||
type: bradley_terry.chat_template
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
remove_unused_columns: false
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -38,7 +38,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.45.0
|
||||
bitsandbytes==0.45.1
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
flash-attn==2.7.0.post2
|
||||
@@ -13,9 +13,9 @@ liger-kernel==0.5.2
|
||||
packaging==23.2
|
||||
|
||||
peft==0.14.0
|
||||
transformers==4.47.1
|
||||
transformers==4.48.1
|
||||
tokenizers>=0.21.0
|
||||
accelerate==1.2.1
|
||||
accelerate==1.3.0
|
||||
datasets==3.2.0
|
||||
deepspeed==0.16.1
|
||||
trl==0.13.0
|
||||
@@ -25,6 +25,7 @@ hf_transfer
|
||||
sentencepiece
|
||||
gradio==3.50.2
|
||||
|
||||
modal==0.70.5
|
||||
pydantic==2.6.3
|
||||
addict
|
||||
fire
|
||||
|
||||
17
scripts/motd
17
scripts/motd
@@ -1,10 +1,15 @@
|
||||
|
||||
dP dP dP
|
||||
88 88 88
|
||||
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
|
||||
88' `88 `8bd8' 88' `88 88 88' `88 88 88
|
||||
88. .88 .d88b. 88. .88 88 88. .88 88 88
|
||||
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
|
||||
#@@ #@@ @@# @@#
|
||||
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
||||
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
||||
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
||||
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
||||
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
||||
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
||||
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
||||
@@@@ @@@@@@@@@@@@@@@@
|
||||
|
||||
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
|
||||
|
||||
|
||||
23
setup.py
23
setup.py
@@ -32,8 +32,6 @@ def parse_requirements():
|
||||
_install_requires.append(line)
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
triton_version = [req for req in _install_requires if "triton" in req][0]
|
||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
if "Darwin" in platform.system():
|
||||
# skip packages not compatible with OSX
|
||||
@@ -87,24 +85,8 @@ def parse_requirements():
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.28.post1")
|
||||
elif (major, minor) >= (2, 3):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(triton_version))
|
||||
_install_requires.append("triton>=2.3.1")
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.26.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
elif (major, minor) >= (2, 2):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.25.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.23.post1")
|
||||
raise ValueError("axolotl requires torch>=2.4")
|
||||
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
@@ -168,5 +150,8 @@ setup(
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -13,6 +13,12 @@ class PreprocessCliArgs:
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
download: Optional[bool] = field(default=True)
|
||||
iterable: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Use IterableDataset for streaming processing of large datasets"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -25,6 +31,8 @@ class TrainerCliArgs:
|
||||
merge_lora: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
main_process_port: Optional[int] = field(default=None)
|
||||
num_processes: Optional[int] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
56
src/axolotl/cli/cloud/__init__.py
Normal file
56
src/axolotl/cli/cloud/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
launch axolotl in supported cloud platforms
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
"""Load and validate cloud configuration."""
|
||||
# Load cloud configuration.
|
||||
with open(cloud_config, encoding="utf-8") as file:
|
||||
cloud_cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
return cloud_cfg
|
||||
|
||||
|
||||
def do_cli_preprocess(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
cloud.preprocess(config_yaml)
|
||||
|
||||
|
||||
def do_cli_train(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
accelerate: bool = True,
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
cloud.train(config_yaml, accelerate=accelerate)
|
||||
|
||||
|
||||
def do_cli_lm_eval(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
cloud.lm_eval(config_yaml)
|
||||
18
src/axolotl/cli/cloud/base.py
Normal file
18
src/axolotl/cli/cloud/base.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
base class for cloud platforms from cli
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Cloud(ABC):
|
||||
"""
|
||||
Abstract base class for cloud platforms.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self, config_yaml: str, accelerate: bool = True) -> str:
|
||||
pass
|
||||
282
src/axolotl/cli/cloud/modal_.py
Normal file
282
src/axolotl/cli/cloud/modal_.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Modal Cloud support from CLI
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
|
||||
import modal
|
||||
|
||||
from axolotl.cli.cloud.base import Cloud
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str, volumes=None):
|
||||
"""Run a command inside a folder, with Modal Volume reloading before and commit on success."""
|
||||
# Ensure volumes contain latest files.
|
||||
if volumes:
|
||||
for _, vol in volumes.items():
|
||||
vol.reload()
|
||||
|
||||
# modal workaround so it doesn't use the automounted axolotl
|
||||
new_env = copy.deepcopy(os.environ)
|
||||
if "PYTHONPATH" in new_env:
|
||||
del new_env["PYTHONPATH"]
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call( # nosec B603
|
||||
cmd.split(), cwd=run_folder, env=new_env
|
||||
):
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
# Commit writes to volume.
|
||||
if volumes:
|
||||
for _, vol in volumes.items():
|
||||
vol.commit()
|
||||
|
||||
|
||||
class ModalCloud(Cloud):
|
||||
"""
|
||||
Modal Cloud implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, config, app=None):
|
||||
self.config = config
|
||||
if not app:
|
||||
app = modal.App()
|
||||
self.app = app
|
||||
|
||||
self.volumes = {}
|
||||
if config.volumes:
|
||||
for volume_config in config.volumes:
|
||||
_, mount, vol = self.create_volume(volume_config)
|
||||
self.volumes[mount] = (vol, volume_config)
|
||||
|
||||
def get_env(self):
|
||||
res = {
|
||||
"HF_DATASETS_CACHE": "/workspace/data/huggingface-cache/datasets",
|
||||
"HF_HUB_CACHE": "/workspace/data/huggingface-cache/hub",
|
||||
}
|
||||
|
||||
for key in self.config.get("env", []):
|
||||
if isinstance(key, str):
|
||||
if val := os.environ.get(key, ""):
|
||||
res[key] = val
|
||||
elif isinstance(key, dict):
|
||||
(key_, val) = list(key.items())[0]
|
||||
res[key_] = val
|
||||
return res
|
||||
|
||||
def get_image(self):
|
||||
docker_tag = "main-py3.11-cu124-2.5.1"
|
||||
if self.config.docker_tag:
|
||||
docker_tag = self.config.docker_tag
|
||||
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
||||
|
||||
# grab the sha256 hash from docker hub for this image+tag
|
||||
# this ensures that we always get the latest image for this tag, even if it's already cached
|
||||
try:
|
||||
manifest = subprocess.check_output( # nosec B602
|
||||
f"docker manifest inspect {docker_image}",
|
||||
shell=True,
|
||||
).decode("utf-8")
|
||||
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
|
||||
except subprocess.CalledProcessError:
|
||||
sha256_hash = None
|
||||
|
||||
# create the image
|
||||
if sha256_hash:
|
||||
image = modal.Image.from_registry(f"axolotlai/axolotl@{sha256_hash}")
|
||||
else:
|
||||
image = modal.Image.from_registry(docker_image)
|
||||
|
||||
dockerfile_commands = []
|
||||
if self.config.dockerfile_commands:
|
||||
dockerfile_commands.extend(self.config.dockerfile_commands)
|
||||
|
||||
# branch
|
||||
if self.config.branch:
|
||||
dockerfile_commands.extend(
|
||||
[
|
||||
# Random id for cache busting of branch commits
|
||||
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
||||
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
|
||||
]
|
||||
)
|
||||
|
||||
if dockerfile_commands:
|
||||
image = image.dockerfile_commands(dockerfile_commands)
|
||||
|
||||
if env := self.get_env():
|
||||
image = image.env(env)
|
||||
|
||||
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
|
||||
return image
|
||||
|
||||
def get_secrets(self):
|
||||
res = []
|
||||
if self.config.secrets:
|
||||
for key in self.config.get("secrets", []):
|
||||
# pylint: disable=duplicate-code
|
||||
if isinstance(key, str):
|
||||
if val := os.environ.get(key, ""):
|
||||
res.append(modal.Secret.from_dict({key: val}))
|
||||
elif isinstance(key, dict):
|
||||
(key_, val) = list(key.items())[0]
|
||||
res.append(modal.Secret.from_dict({key_: val}))
|
||||
return res
|
||||
|
||||
def create_volume(self, volume_config):
|
||||
name = volume_config.name
|
||||
mount = volume_config.mount
|
||||
return name, mount, modal.Volume.from_name(name, create_if_missing=True)
|
||||
|
||||
def get_ephemeral_disk_size(self):
|
||||
return 1000 * 525 # 1 TiB
|
||||
|
||||
def get_preprocess_timeout(self):
|
||||
if self.config.timeout_preprocess:
|
||||
return int(self.config.timeout_preprocess)
|
||||
return 60 * 60 * 3 # 3 hours
|
||||
|
||||
def get_preprocess_memory(self):
|
||||
memory = 128 # default to 128GiB
|
||||
if self.config.memory:
|
||||
memory = int(self.config.memory)
|
||||
if self.config.memory_preprocess:
|
||||
memory = int(self.config.memory_preprocess)
|
||||
return 1024 * memory
|
||||
|
||||
def get_preprocess_env(self):
|
||||
return self.app.function(
|
||||
image=self.get_image(),
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
cpu=8.0,
|
||||
ephemeral_disk=self.get_ephemeral_disk_size(),
|
||||
memory=self.get_preprocess_memory(),
|
||||
timeout=self.get_preprocess_timeout(),
|
||||
secrets=self.get_secrets(),
|
||||
)
|
||||
|
||||
def preprocess(self, config_yaml: str, *args, **kwargs):
|
||||
modal_fn = self.get_preprocess_env()(_preprocess)
|
||||
with modal.enable_output():
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_train_timeout(self):
|
||||
if self.config.timeout:
|
||||
return int(self.config.timeout)
|
||||
return 60 * 60 * 24 # 24 hours
|
||||
|
||||
def get_train_gpu(self): # pylint: disable=too-many-return-statements
|
||||
count = self.config.gpu_count or 1
|
||||
family = self.config.gpu.lower() or "l40s"
|
||||
|
||||
if family == "l40s":
|
||||
return modal.gpu.L40S(count=count)
|
||||
if family in ["a100", "a100-40gb"]:
|
||||
return modal.gpu.A100(count=count, size="40GB")
|
||||
if family == "a100-80gb":
|
||||
return modal.gpu.A100(count=count, size="80GB")
|
||||
if family in ["a10", "a10g"]:
|
||||
return modal.gpu.A10G(count=count)
|
||||
if family == "h100":
|
||||
return modal.gpu.H100(count=count)
|
||||
if family == "t4":
|
||||
return modal.gpu.T4(count=count)
|
||||
if family == "l4":
|
||||
return modal.gpu.L4(count=count)
|
||||
raise ValueError(f"Unsupported GPU family: {family}")
|
||||
|
||||
def get_train_memory(self):
|
||||
memory = 128 # default to 128GiB
|
||||
if self.config.memory:
|
||||
memory = int(self.config.memory)
|
||||
return 1024 * memory
|
||||
|
||||
def get_train_env(self):
|
||||
return self.app.function(
|
||||
image=self.get_image(),
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
cpu=16.0,
|
||||
gpu=self.get_train_gpu(),
|
||||
memory=self.get_train_memory(),
|
||||
timeout=self.get_train_timeout(),
|
||||
secrets=self.get_secrets(),
|
||||
)
|
||||
|
||||
def train(self, config_yaml: str, accelerate: bool = True):
|
||||
modal_fn = self.get_train_env()(_train)
|
||||
with modal.enable_output():
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
accelerate=accelerate,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
)
|
||||
|
||||
def lm_eval(self, config_yaml: str):
|
||||
modal_fn = self.get_train_env()(_lm_eval)
|
||||
with modal.enable_output():
|
||||
with self.app.run(detach=True):
|
||||
if self.config.get("spawn", False):
|
||||
modal_fn_exec = modal_fn.spawn
|
||||
else:
|
||||
modal_fn_exec = modal_fn.remote
|
||||
modal_fn_exec(
|
||||
config_yaml,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
)
|
||||
|
||||
|
||||
def _preprocess(config_yaml: str, volumes=None):
|
||||
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
run_cmd(
|
||||
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
if accelerate:
|
||||
accelerate_args = "--accelerate"
|
||||
else:
|
||||
accelerate_args = "--no-accelerate"
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
|
||||
def _lm_eval(config_yaml: str, volumes=None):
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
run_cmd(
|
||||
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
135
src/axolotl/cli/convert_linear_attention.py
Normal file
135
src/axolotl/cli/convert_linear_attention.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""CLI to run training on a model."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.lolcats.linear_llama.configuration_linear_llama import (
|
||||
LinearLlamaConfig,
|
||||
)
|
||||
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
||||
LinearLlamaForCausalLM,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model_config
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
"""
|
||||
Convert attention to linear attention and perform attention transfer via distillation.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
|
||||
# ensure quantization and peft are turned off (due to how we need to re-apply peft later)
|
||||
cfg.load_in_8bit = False
|
||||
cfg.load_in_4bit = False
|
||||
cfg.adapter = None
|
||||
|
||||
# load model
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# freeze model
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# convert to linear llama
|
||||
linear_llama_config = LinearLlamaConfig.from_llama(
|
||||
model.config, cfg.attention_config
|
||||
)
|
||||
model = LinearLlamaForCausalLM.from_llama(
|
||||
model, config=linear_llama_config, train_attention=True
|
||||
)
|
||||
|
||||
# set save_path, save tokenizer and model config.
|
||||
save_path = str(os.path.join(cfg.output_dir, "distilled"))
|
||||
tokenizer.save_pretrained(save_path)
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(save_path)
|
||||
|
||||
# Get datasets
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# toggle attention to be trainable
|
||||
model.toggle_attention(train=True)
|
||||
|
||||
# Setup trainer
|
||||
trainer = setup_trainer(
|
||||
cfg=cfg,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
model=(model, None, None),
|
||||
tokenizer=tokenizer,
|
||||
processor=None,
|
||||
total_num_steps=total_num_steps,
|
||||
)
|
||||
|
||||
# train
|
||||
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
|
||||
|
||||
# drop base_attention + remove training attn
|
||||
model.toggle_attention(train=False)
|
||||
model.remove_base_attention()
|
||||
|
||||
# NOTE: If in peft mode, consider whether to auto-merge
|
||||
|
||||
# save model
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
# NOTE: may need to consider other ways of saving due to multi-gpu etc
|
||||
model.save_pretrained(save_path, safe_serialization=safe_serialization)
|
||||
|
||||
# cleanup
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
"""
|
||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# load cfg, force linearize and add plugin to linearize
|
||||
parsed_cfg = load_cfg(
|
||||
config,
|
||||
linearize=True,
|
||||
plugins=["axolotl.integrations.lolcats.LinearizePlugin"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
do_linearize(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
@@ -1,10 +1,17 @@
|
||||
"""Click CLI definitions for various axolotl commands."""
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import logging
|
||||
import random
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
|
||||
import axolotl
|
||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
@@ -15,10 +22,81 @@ from axolotl.cli.utils import (
|
||||
fetch_from_github,
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
|
||||
def generate_sweep_configs(base_config, sweeps_config):
|
||||
"""
|
||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||
|
||||
Args:
|
||||
base_config (dict): The original configuration dictionary
|
||||
sweeps_config (dict): Dictionary where keys are parameters and values are either:
|
||||
- lists of values to sweep independently
|
||||
- or for paired values, a list of dicts under the '_' key
|
||||
|
||||
Returns:
|
||||
list: List of all possible configuration dictionaries
|
||||
|
||||
Example:
|
||||
sweeps_config = {
|
||||
'learning_rate': [0.1, 0.01],
|
||||
'_': [
|
||||
{'load_in_8bit': True, 'adapter': 'lora'},
|
||||
{'load_in_4bit': True, 'adapter': 'qlora'}
|
||||
]
|
||||
}
|
||||
"""
|
||||
# Separate paired values from regular sweeps
|
||||
paired_values = sweeps_config.get("_", [])
|
||||
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
|
||||
|
||||
# Process regular sweeps
|
||||
param_names = list(regular_sweeps.keys())
|
||||
param_values = list(regular_sweeps.values())
|
||||
|
||||
# Generate combinations for regular sweeps
|
||||
regular_combinations = list(product(*param_values)) if param_values else [()]
|
||||
|
||||
# Combine regular sweeps with paired values
|
||||
all_combinations = []
|
||||
for reg_combo in regular_combinations:
|
||||
if paired_values:
|
||||
for paired_set in paired_values:
|
||||
new_config = {}
|
||||
# new_config = deepcopy(base_config)
|
||||
# Combine regular parameters with paired parameters
|
||||
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
|
||||
for param_name, param_value in full_combo.items():
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
else:
|
||||
# If no paired values, just use regular combinations
|
||||
# new_config = deepcopy(base_config)
|
||||
new_config = {}
|
||||
for param_name, param_value in zip(param_names, reg_combo):
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
|
||||
# randomize the order of trials
|
||||
random.seed(42)
|
||||
random.shuffle(all_combinations)
|
||||
|
||||
# Generate a new config for each combination
|
||||
result_configs = []
|
||||
for combination in all_combinations:
|
||||
new_config = deepcopy(base_config)
|
||||
for param_name, param_value in combination.items():
|
||||
new_config[param_name] = param_value
|
||||
result_configs.append(new_config)
|
||||
|
||||
return result_configs
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
@@ -27,21 +105,28 @@ def cli():
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(PreprocessCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def preprocess(config: str, **kwargs) -> None:
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
"""
|
||||
Preprocess datasets before training.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
cloud: Path to a cloud accelerator configuration file.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
from axolotl.cli.preprocess import do_cli
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_preprocess
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
do_cli_preprocess(cloud_config=cloud, config=config)
|
||||
else:
|
||||
from axolotl.cli.preprocess import do_cli
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -51,32 +136,99 @@ def preprocess(config: str, **kwargs) -> None:
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
)
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--sweep",
|
||||
type=click.Path(exists=True, path_type=str),
|
||||
help="YAML config for sweeping hyperparameters",
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def train(config: str, accelerate: bool, **kwargs) -> None:
|
||||
def train(
|
||||
config: str,
|
||||
accelerate: bool,
|
||||
cloud: Optional[str] = None,
|
||||
sweep: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Train or fine-tune a model.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
cloud: Path to a cloud accelerator configuration file
|
||||
sweep: Path to YAML config for sweeping hyperparameters.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
if sweep:
|
||||
# load the sweep configuration yaml file
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
|
||||
# generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
|
||||
def iter_configs():
|
||||
for perm in permutations:
|
||||
# open temp directory for temporary configurations
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with open(
|
||||
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
|
||||
) as fout:
|
||||
yaml.dump(perm, fout)
|
||||
yield str(Path(temp_dir) / "config.yaml")
|
||||
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
def iter_configs():
|
||||
yield config
|
||||
|
||||
for cfg_file in iter_configs():
|
||||
# handle errors from subprocess so we can continue rest of sweeps
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num-processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
raise exc
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -195,7 +347,6 @@ def merge_lora(config: str, **kwargs) -> None:
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
@@ -222,6 +373,9 @@ def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
fetch_from_github(f"{directory}/", dest)
|
||||
|
||||
|
||||
cli.add_command(lm_eval)
|
||||
|
||||
|
||||
def main():
|
||||
cli()
|
||||
|
||||
|
||||
@@ -75,7 +75,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
def do_cli(
|
||||
config: Union[Path, str] = Path("examples/"),
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Parses `axolotl` config, CLI args, and calls `do_preprocess`.
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from accelerate import Accelerator
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
@@ -15,6 +16,7 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -63,7 +65,47 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
do_train(parsed_cfg, parsed_cli_args)
|
||||
if parsed_cfg.use_ray:
|
||||
from ray.train import RunConfig, ScalingConfig
|
||||
from ray.train.torch import TorchTrainer
|
||||
|
||||
train_loop_config = {"cfg": parsed_cfg.to_dict(), "cli_args": parsed_cli_args}
|
||||
trainer = TorchTrainer(
|
||||
ray_train_func,
|
||||
train_loop_config=train_loop_config,
|
||||
scaling_config=ScalingConfig(
|
||||
num_workers=parsed_cfg.ray_num_workers,
|
||||
resources_per_worker=parsed_cfg.resources_per_worker.to_dict(),
|
||||
use_gpu=True,
|
||||
),
|
||||
run_config=RunConfig(
|
||||
name=parsed_cfg.ray_run_name,
|
||||
storage_path=Path(parsed_cfg.output_dir).absolute().as_posix(),
|
||||
),
|
||||
)
|
||||
return trainer.fit()
|
||||
return do_train(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
def ray_train_func(kwargs: dict):
|
||||
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
|
||||
# also renormalize the config now that TorchTrainer has spawned distributed workers
|
||||
cfg = DictDefault(kwargs["cfg"])
|
||||
normalize_config(cfg)
|
||||
|
||||
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
|
||||
resolve_dtype(cfg)
|
||||
|
||||
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
|
||||
if cfg.deepspeed:
|
||||
cfg.deepspeed = cfg.deepspeed.to_dict()
|
||||
|
||||
# initialize accelerator before model instantiation
|
||||
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
|
||||
|
||||
kwargs["cfg"] = cfg
|
||||
|
||||
do_train(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -63,11 +63,17 @@ def load_datasets(
|
||||
"""
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||
preprocess_iterable = (
|
||||
hasattr(cli_args, "iterable")
|
||||
and cli_args.iterable is not None
|
||||
and cli_args.iterable
|
||||
)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||
cfg,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
988
src/axolotl/core/trainers/base.py
Normal file
988
src/axolotl/core/trainers/base.py
Normal file
@@ -0,0 +1,988 @@
|
||||
"""
|
||||
module for customized trainers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import (
|
||||
CPOTrainer,
|
||||
DPOTrainer,
|
||||
KTOTrainer,
|
||||
ORPOTrainer,
|
||||
PRMTrainer,
|
||||
RewardTrainer,
|
||||
)
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
if isinstance(tag_names, str):
|
||||
tag_names = [tag_names]
|
||||
|
||||
if kwargs is not None:
|
||||
if "tags" not in kwargs:
|
||||
kwargs["tags"] = tag_names
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||
kwargs["tags"].extend(tag_names)
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||
tag_names.append(kwargs["tags"])
|
||||
kwargs["tags"] = tag_names
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||
if isinstance(dataset_tags, str):
|
||||
dataset_tags = [dataset_tags]
|
||||
|
||||
if (dataset_tags is not None) and (kwargs is not None):
|
||||
if "dataset_tags" not in kwargs:
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||
kwargs["dataset_tags"].extend(dataset_tags)
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||
dataset_tags.append(kwargs["dataset_tags"])
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
class SchedulerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
use_cosine_quadratic = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
)
|
||||
|
||||
use_cosine_min_lr = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.cosine_min_lr_ratio is not None
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
extra_lr_kwargs = {}
|
||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["pct_start"] = pct_start
|
||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
**extra_lr_kwargs,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
elif use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
self._signature_columns = None # workaround for pylint
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
model = torch.compile(
|
||||
model,
|
||||
backend=self.args.torch_compile_backend,
|
||||
mode=self.args.torch_compile_mode,
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
}
|
||||
lr_groups_lookup = {}
|
||||
lr_groups_learning_rates = {}
|
||||
if self.args.lr_groups:
|
||||
for lr_group in self.args.lr_groups:
|
||||
group_name = lr_group["name"]
|
||||
group_modules = lr_group["modules"]
|
||||
for module in group_modules:
|
||||
lr_groups_lookup[module] = group_name
|
||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||
params[f"to_weight_decay_{group_name}"] = {}
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if name.endswith("modules_to_save.default.weight") or any(
|
||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||
):
|
||||
params["embeddings"][name] = param
|
||||
elif name in decay_parameters:
|
||||
lr_group_modules = [
|
||||
group_modules
|
||||
for group_modules in lr_groups_lookup
|
||||
if group_modules in name
|
||||
]
|
||||
if lr_groups_lookup and any(lr_group_modules):
|
||||
lr_group_module = lr_group_modules[0]
|
||||
group_name = lr_groups_lookup[lr_group_module]
|
||||
params[f"to_weight_decay_{group_name}"][name] = param
|
||||
else:
|
||||
params["to_weight_decay"][name] = param
|
||||
else:
|
||||
params["no_weight_decay"][name] = param
|
||||
optimizer_grouped_parameters = []
|
||||
if params["to_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["to_weight_decay"].values()),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr,
|
||||
}
|
||||
)
|
||||
if params["no_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["no_weight_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||
if params[f"to_weight_decay_{group_name}"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(
|
||||
params[f"to_weight_decay_{group_name}"].values()
|
||||
),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": group_lr,
|
||||
}
|
||||
)
|
||||
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def create_optimizer(self):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.lr_groups is None
|
||||
and self.args.alternate_optimizer
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||
opt_model, optimizer_kwargs
|
||||
)
|
||||
|
||||
if self.args.loraplus_lr_ratio is not None:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(
|
||||
self.args, "loraplus_lr_embedding", 1e-6
|
||||
)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
elif (
|
||||
self.args.embedding_lr_scale is not None
|
||||
or self.args.embedding_lr is not None
|
||||
or self.args.lr_groups is not None
|
||||
):
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW(
|
||||
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
||||
)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "ao_adamw_4bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "ao_adamw_8bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "adopt_adamw":
|
||||
from axolotl.utils.optimizers.adopt import ADOPT
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
ADOPT(
|
||||
optimizer_grouped_parameters,
|
||||
decouple=True,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
train_batch_size = (
|
||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
|
||||
if self.args.curriculum_sampling:
|
||||
sampler = SequentialSampler(self.train_dataset)
|
||||
else:
|
||||
sampler = RandomSampler(self.train_dataset)
|
||||
|
||||
return MultipackBatchSampler(
|
||||
sampler,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
)
|
||||
if self.args.curriculum_sampling:
|
||||
return SequentialSampler(self.train_dataset)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_eval_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
lengths=get_dataset_lengths(self.eval_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = self.train_dataset
|
||||
if "length" in train_dataset.features.keys():
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
if isinstance(sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
)
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.train_data_collator
|
||||
)
|
||||
return dataloader
|
||||
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
)
|
||||
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
|
||||
def get_bench_dataloader(
|
||||
self,
|
||||
bench_dataset: Dataset,
|
||||
) -> DataLoader:
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": self.bench_data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
return super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||
concatenated_batch = {}
|
||||
|
||||
max_length = max(
|
||||
inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
|
||||
)
|
||||
# Concatenate positive and negative inputs
|
||||
concatenated_batch["input_ids"] = pad_to_length(
|
||||
inputs["input_ids"], max_length, pad_token
|
||||
)
|
||||
concatenated_batch["rejected_input_ids"] = pad_to_length(
|
||||
inputs["rejected_input_ids"], max_length, pad_token
|
||||
)
|
||||
concatenated_batch["labels"] = pad_to_length(
|
||||
inputs["labels"], max_length, label_pad_token
|
||||
)
|
||||
concatenated_batch["rejected_labels"] = pad_to_length(
|
||||
inputs["rejected_labels"], max_length, label_pad_token
|
||||
)
|
||||
concatenated_batch["attention_mask"] = pad_to_length(
|
||||
inputs["attention_mask"], max_length, 0
|
||||
)
|
||||
concatenated_batch["rejected_attention_mask"] = pad_to_length(
|
||||
inputs["rejected_attention_mask"], max_length, 0
|
||||
)
|
||||
concatenated_batch["prompt_attention_mask"] = pad_to_length(
|
||||
inputs["prompt_attention_mask"], max_length, 0
|
||||
).to(device=device)
|
||||
|
||||
input_ids = torch.cat(
|
||||
[concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]],
|
||||
dim=0,
|
||||
).to(device=device)
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
concatenated_batch["attention_mask"],
|
||||
concatenated_batch["rejected_attention_mask"],
|
||||
],
|
||||
dim=0,
|
||||
).to(device=device)
|
||||
labels = torch.cat(
|
||||
[concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0
|
||||
).to(device=device)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
"prompt_attention_mask": concatenated_batch["prompt_attention_mask"],
|
||||
}
|
||||
|
||||
def orpo_compute_custom_loss(self, logits, labels):
|
||||
logits = logits.contiguous()
|
||||
loss = 0.0
|
||||
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def orpo_compute_logps(
|
||||
self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
|
||||
):
|
||||
# Get the shape of chosen_attention_mask[:, :-1]
|
||||
chosen_shape = chosen_attention_mask[:, :-1].shape
|
||||
|
||||
# Calculate the padding size
|
||||
pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)
|
||||
|
||||
# Pad prompt_attention_mask with zeros to match the desired shape
|
||||
prompt_attention_mask_padded = torch.nn.functional.pad(
|
||||
prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
|
||||
)
|
||||
|
||||
# Perform the subtraction operation
|
||||
mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded
|
||||
|
||||
per_token_logps = torch.gather(
|
||||
logits[:, :-1, :].log_softmax(-1),
|
||||
dim=2,
|
||||
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
||||
).squeeze(2)
|
||||
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
|
||||
|
||||
def orpo_compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
|
||||
inputs,
|
||||
label_pad_token=-100,
|
||||
pad_token=self.tokenizer.pad_token_id,
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
|
||||
# Perform a single forward pass
|
||||
outputs = model(
|
||||
**{
|
||||
"input_ids": concat_inputs["input_ids"],
|
||||
"attention_mask": concat_inputs["attention_mask"],
|
||||
"labels": concat_inputs["labels"],
|
||||
},
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# Split the outputs for positive and negative examples
|
||||
outputs_pos, outputs_neg = outputs.logits.chunk(2)
|
||||
|
||||
# Calculate NLL loss
|
||||
pos_loss = self.orpo_compute_custom_loss(
|
||||
logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0]
|
||||
)
|
||||
|
||||
# Calculate Log Probability
|
||||
pos_prob = self.orpo_compute_logps(
|
||||
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
|
||||
chosen_inputs=concat_inputs["input_ids"].chunk(2)[0],
|
||||
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0],
|
||||
logits=outputs_pos,
|
||||
)
|
||||
neg_prob = self.orpo_compute_logps(
|
||||
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
|
||||
chosen_inputs=concat_inputs["input_ids"].chunk(2)[1],
|
||||
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1],
|
||||
logits=outputs_neg,
|
||||
)
|
||||
|
||||
# Calculate log odds
|
||||
log_odds = (pos_prob - neg_prob) - (
|
||||
torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
|
||||
)
|
||||
sig_ratio = torch.nn.functional.sigmoid(log_odds)
|
||||
ratio = torch.log(sig_ratio)
|
||||
|
||||
# Calculate the Final Loss
|
||||
loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
|
||||
dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
metrics = {}
|
||||
metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
|
||||
metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
|
||||
metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
|
||||
metrics["log_odds"] = torch.mean(log_odds).cpu().item()
|
||||
self.store_metrics(metrics, train_eval="train")
|
||||
|
||||
return (loss, outputs_pos) if return_outputs else loss
|
||||
|
||||
@wraps(Trainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||
def create_accelerator_and_postprocess(self):
|
||||
res = super().create_accelerator_and_postprocess()
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
"limit_all_gathers" in self.args.fsdp_config
|
||||
and self.args.fsdp_config["limit_all_gathers"]
|
||||
):
|
||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||
|
||||
return res
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
Args:
|
||||
logs (`Dict[str, float]`):
|
||||
The values to log.
|
||||
start_time (`Optional[float]`):
|
||||
The start of training.
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
return super().log(logs, start_time)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
) -> None:
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def _save_checkpoint(self, model, trial, **kwargs):
|
||||
# make sure the checkpoint dir exists, since trainer is flakey
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Mamba specific trainer to handle loss calculation
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "mamba"]
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
|
||||
labels = input_ids.to(lm_logits.device)
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
|
||||
return lm_loss
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
if loraplus_lr_ratio:
|
||||
print("Using lora+")
|
||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
@wraps(DPOTrainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "prm"]
|
||||
264
src/axolotl/core/training_args.py
Normal file
264
src/axolotl/core/training_args.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
extra axolotl specific training args
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
sample_packing_bin_size: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
sample_packing_group_size: int = field(
|
||||
default=100000,
|
||||
metadata={
|
||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
do_causal_lm_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
cosine_min_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||
)
|
||||
cosine_constant_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
lr_groups: Optional[list[dict]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
lisa_n_layers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
alternate_optimizer: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate optimizer to the HF trainer"
|
||||
},
|
||||
)
|
||||
alternate_lr_scheduler_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||
},
|
||||
)
|
||||
chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat template converting chat messages to text"},
|
||||
)
|
||||
|
||||
kd_ce_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_alpha: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The alpha scaling parameter for KD loss"},
|
||||
)
|
||||
|
||||
kd_temperature: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "the temperature parameter for KL divergence loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_zscore_base_temp: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "the base temperature parameter for KL divergence with z-score when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_top_k_before_softmax: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
Training arguments for Causal trainer
|
||||
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||
so it can't be used as a mixin.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||
"""
|
||||
ORPO config for ORPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
||||
"""
|
||||
KTO config for KTO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||
"""
|
||||
CPO config for CPO training
|
||||
"""
|
||||
|
||||
simpo_gamma: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "simpo gamma parameter"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
|
||||
"""
|
||||
Reward config for Reward training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig):
|
||||
"""
|
||||
PRM config for PRM training
|
||||
"""
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
@@ -51,7 +51,18 @@ class TokenizedPromptDataset(Dataset):
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
map_kwargs["batch_size"] = 100
|
||||
map_kwargs["batch_size"] = 1_000
|
||||
|
||||
if (
|
||||
hasattr(self.prompt_tokenizer, "filter_rows")
|
||||
and self.prompt_tokenizer.filter_rows
|
||||
):
|
||||
dataset = dataset.filter(
|
||||
self.prompt_tokenizer.filter_rows,
|
||||
num_proc=num_proc,
|
||||
desc="Strategy Filtering Rows",
|
||||
)
|
||||
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
@@ -62,6 +73,24 @@ class TokenizedPromptDataset(Dataset):
|
||||
)
|
||||
|
||||
|
||||
def wrap_dataset_for_tokenized_prompt(
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: Union[Dataset, IterableDataset],
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(dataset, IterableDataset):
|
||||
map_kwargs = {}
|
||||
if prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
features = dataset.features.keys()
|
||||
return dataset.map(
|
||||
prompt_tokenizer.tokenize_prompt,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
)
|
||||
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
||||
|
||||
|
||||
# TODO this isn't the best since it can't interleave datasets
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
|
||||
@@ -111,6 +111,17 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
|
||||
"""
|
||||
Returns a custom class for the trainer.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The global axolotl configuration.
|
||||
|
||||
Returns:
|
||||
class: The class for the trainer.
|
||||
"""
|
||||
|
||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates and returns an optimizer for training.
|
||||
@@ -212,7 +223,17 @@ def load_plugin(plugin_name: str) -> BasePlugin:
|
||||
module_name, class_name = plugin_name.rsplit(".", 1)
|
||||
|
||||
# import the module
|
||||
module = importlib.import_module(module_name)
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ModuleNotFoundError as orig_exc:
|
||||
try:
|
||||
if not module_name.startswith("axolotl.integrations."):
|
||||
module = importlib.import_module("axolotl.integrations." + module_name)
|
||||
else:
|
||||
raise orig_exc
|
||||
except ModuleNotFoundError as exc:
|
||||
raise orig_exc from exc
|
||||
|
||||
# instantiate the class
|
||||
plugin_class = getattr(module, class_name)
|
||||
# create an instance of the class
|
||||
@@ -272,8 +293,10 @@ class PluginManager:
|
||||
ImportError: If the plugin module cannot be imported.
|
||||
"""
|
||||
try:
|
||||
logging.info(f"Attempting to load plugin: {plugin_name}")
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins[plugin_name] = plugin
|
||||
logging.info(f"Plugin loaded successfully: {plugin_name}")
|
||||
except ImportError:
|
||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||
|
||||
@@ -346,6 +369,22 @@ class PluginManager:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_lora_load(cfg, model)
|
||||
|
||||
def get_trainer_cls(self, cfg):
|
||||
"""
|
||||
Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
|
||||
Returns:
|
||||
object: The trainer class, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
trainer_cls = plugin.get_trainer_cls(cfg)
|
||||
if trainer_cls is not None:
|
||||
return trainer_cls
|
||||
return None
|
||||
|
||||
def create_optimizer(self, cfg, trainer):
|
||||
"""
|
||||
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
|
||||
|
||||
36
src/axolotl/integrations/kd/__init__.py
Normal file
36
src/axolotl/integrations/kd/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Plugin init to add KD support to Axolotl.
|
||||
"""
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
|
||||
class KDPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for KD support in Axolotl.
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.kd.KDArgs"
|
||||
|
||||
def get_trainer_cls(self, cfg):
|
||||
if cfg.kd_trainer:
|
||||
from .trainer import AxolotlKDTrainer
|
||||
|
||||
return AxolotlKDTrainer
|
||||
return None
|
||||
37
src/axolotl/integrations/kd/args.py
Normal file
37
src/axolotl/integrations/kd/args.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Plugin args for KD support.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class KDArgs(BaseModel):
|
||||
"""
|
||||
Input args for knowledge distillation.
|
||||
"""
|
||||
|
||||
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
||||
kd_ce_alpha: Optional[
|
||||
float
|
||||
] = None # loss coefficient for cross-entropy loss during KD
|
||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||
kd_top_k_before_softmax: Optional[
|
||||
bool
|
||||
] = None # whether to sample top k before softmax during KD
|
||||
201
src/axolotl/integrations/kd/chat_template.py
Normal file
201
src/axolotl/integrations/kd/chat_template.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Chat template prompt strategy loader with KD support
|
||||
"""
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
"""
|
||||
Handle fields for logprob KD
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=None,
|
||||
train_on_eos=None,
|
||||
logprobs_field="logprobs",
|
||||
gen_temperature=1.0,
|
||||
kd_temperature=1.0,
|
||||
):
|
||||
self.logprobs_field = logprobs_field
|
||||
self.gen_temperature = gen_temperature
|
||||
self.kd_temperature = kd_temperature
|
||||
|
||||
super().__init__(
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=roles_to_train,
|
||||
train_on_eos=train_on_eos,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_batched(self) -> bool:
|
||||
# batching doesn't work well for logprob data
|
||||
return False
|
||||
|
||||
def transform_logprobs(self, sample):
|
||||
"""
|
||||
Transform logprobs to target format for KD training
|
||||
"""
|
||||
|
||||
logprobs = sample.pop(self.logprobs_field)
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(sample["input_ids"])
|
||||
input_padding_len = input_seq_len - target_seq_len
|
||||
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||
top_k_vals = [
|
||||
len(logprobs[i])
|
||||
for i in range(len(logprobs))
|
||||
if logprobs[i] is not None and len(logprobs[i])
|
||||
]
|
||||
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
|
||||
top_k = min(max_top_k, min_top_k)
|
||||
if top_k == 0:
|
||||
raise ValueError("No non-zero top-k logprobs found.")
|
||||
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
|
||||
if input_padding_len < 0:
|
||||
# logprobs is longer than target_seq_len,
|
||||
# so we need to slice from the left/beginning of logprobs
|
||||
logprobs = logprobs[:-input_seq_len]
|
||||
input_padding_len = 0
|
||||
# target_seq_len = input_seq_len
|
||||
|
||||
# truncate the second dimension of the logprobs to top_k
|
||||
logprobs = [row[:top_k] for row in logprobs]
|
||||
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
|
||||
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
||||
# otherwise, we need to shift in the trainer
|
||||
shift = 0
|
||||
for _ in range(shift, input_padding_len):
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
for position in range(input_padding_len, input_seq_len):
|
||||
if sample["labels"][position] == -100:
|
||||
target_mask.append([0] * top_k)
|
||||
else:
|
||||
target_mask.append([1] * top_k)
|
||||
|
||||
for _, token_pos_logprobs in enumerate(logprobs):
|
||||
# Initialize collections for logprobs and token_ids
|
||||
position_logprobs = []
|
||||
position_token_ids = []
|
||||
|
||||
# Process each token probability entry
|
||||
for entry in token_pos_logprobs:
|
||||
# Extract logprob value
|
||||
logprob = entry["logprob"]
|
||||
|
||||
# Parse token_id from the "token_id:###" format
|
||||
token_id = int(entry["token"].split(":")[1])
|
||||
|
||||
# Append to our collections
|
||||
position_logprobs.append(logprob)
|
||||
position_token_ids.append(token_id)
|
||||
|
||||
# Convert to a tensor for easier manipulation
|
||||
position_logprobs_tensor = torch.tensor(
|
||||
position_logprobs, dtype=torch.float
|
||||
)
|
||||
|
||||
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||
else:
|
||||
teacher_probs_t2 = teacher_probs_t1
|
||||
# Re-normalize
|
||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
# Convert back to log
|
||||
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||
|
||||
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(position_token_ids)
|
||||
|
||||
if shift == 1:
|
||||
# since we started at index 1 for causal, we need one more padding token
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
sample["target_mask"] = target_mask
|
||||
|
||||
return sample
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class KDStrategyLoader(StrategyLoader):
|
||||
"""
|
||||
Load ChatTemplateStrategy with KD support using StrategyLoader.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self):
|
||||
return ChatTemplateStrategyWithKD
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
strategy_params = super()._get_strategy_params(cfg, ds_cfg)
|
||||
if logprobs_field := ds_cfg.get("logprobs_field"):
|
||||
strategy_params["logprobs_field"] = logprobs_field
|
||||
if gen_temperature := ds_cfg.get("temperature"):
|
||||
strategy_params["gen_temperature"] = gen_temperature
|
||||
if kd_temperature := cfg.get("kd_temperature"):
|
||||
strategy_params["kd_temperature"] = kd_temperature
|
||||
|
||||
return strategy_params
|
||||
|
||||
|
||||
load = KDStrategyLoader()
|
||||
255
src/axolotl/integrations/kd/collator.py
Normal file
255
src/axolotl/integrations/kd/collator.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
DataCollator for axolotl to handle KD fields without using -inf for padding,
|
||||
and with a teacher_mask to identify padded positions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.utils.collators.batching import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Data collator for KD, including handling KD-specific fields.
|
||||
|
||||
This version avoids using -inf and instead uses a large negative value for padding
|
||||
target_logprobs. It also creates a teacher_mask to indicate which entries are valid.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
model: Optional[Any] = None
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
|
||||
padding_side = self.tokenizer.padding_side
|
||||
|
||||
# Pad labels and position_ids first
|
||||
for feature_name, pad_token_id in [
|
||||
("labels", self.label_pad_token_id),
|
||||
("position_ids", self.position_pad_token_id),
|
||||
]:
|
||||
if feature_name in features[0]:
|
||||
feat = [f[feature_name] for f in features]
|
||||
max_len = max(len(x) for x in feat)
|
||||
if self.pad_to_multiple_of is not None:
|
||||
max_len = (
|
||||
(max_len + self.pad_to_multiple_of - 1)
|
||||
// self.pad_to_multiple_of
|
||||
) * self.pad_to_multiple_of
|
||||
|
||||
for f in features: # pylint: disable=invalid-name
|
||||
remainder = [pad_token_id] * (max_len - len(f[feature_name]))
|
||||
if isinstance(f[feature_name], list):
|
||||
f[feature_name] = (
|
||||
f[feature_name] + remainder
|
||||
if padding_side == "right"
|
||||
else remainder + f[feature_name]
|
||||
)
|
||||
else:
|
||||
# If they are numpy arrays
|
||||
if padding_side == "right":
|
||||
f[feature_name] = np.concatenate(
|
||||
[f[feature_name], remainder]
|
||||
).astype(np.int64)
|
||||
else:
|
||||
f[feature_name] = np.concatenate(
|
||||
[remainder, f[feature_name]]
|
||||
).astype(np.int64)
|
||||
|
||||
# Handle target_logprobs and target_token_ids manually
|
||||
target_logprobs_list = []
|
||||
target_token_ids_list = []
|
||||
target_mask_list = []
|
||||
has_teacher_data = ("target_logprobs" in features[0]) and (
|
||||
"target_token_ids" in features[0]
|
||||
)
|
||||
|
||||
if has_teacher_data:
|
||||
# Extract and remove from features
|
||||
for f in features: # pylint: disable=invalid-name
|
||||
target_logprobs_list.append(f.pop("target_logprobs"))
|
||||
target_token_ids_list.append(f.pop("target_token_ids"))
|
||||
target_mask_list.append(f.pop("target_mask"))
|
||||
|
||||
# Determine max lengths
|
||||
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
|
||||
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
|
||||
|
||||
padded_target_logprobs = []
|
||||
padded_target_token_ids = []
|
||||
padded_teacher_mask_list = []
|
||||
|
||||
for t_logprobs, t_ids, t_mask in zip(
|
||||
target_logprobs_list, target_token_ids_list, target_mask_list
|
||||
):
|
||||
t_logprobs_padded = []
|
||||
t_ids_padded = []
|
||||
t_mask_padded = []
|
||||
|
||||
for lp, ids, mask in zip( # pylint: disable=invalid-name
|
||||
t_logprobs, t_ids, t_mask
|
||||
):
|
||||
lp_len = len(lp)
|
||||
if lp_len < max_k:
|
||||
# Use -1e9 for padding logprobs and 0 for token_ids
|
||||
pad_len = max_k - lp_len
|
||||
lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name
|
||||
ids = ids + [0] * pad_len
|
||||
mask = mask + [0] * pad_len
|
||||
else:
|
||||
lp = lp[:max_k] # pylint: disable=invalid-name
|
||||
ids = ids[:max_k]
|
||||
mask = mask[:max_k]
|
||||
|
||||
t_logprobs_padded.append(lp)
|
||||
t_ids_padded.append(ids)
|
||||
t_mask_padded.append(mask)
|
||||
|
||||
seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded)
|
||||
if seq_len_diff > 0:
|
||||
# Pad sequences fully if needed
|
||||
t_logprobs_padded.extend(
|
||||
[[-1e9] * max_k for _ in range(seq_len_diff)]
|
||||
)
|
||||
t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
|
||||
t_mask_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
|
||||
|
||||
padded_target_logprobs.append(t_logprobs_padded)
|
||||
padded_target_token_ids.append(t_ids_padded)
|
||||
padded_teacher_mask_list.append(t_mask_padded)
|
||||
|
||||
# Convert to tensors
|
||||
padded_target_logprobs = torch.tensor(
|
||||
padded_target_logprobs, dtype=torch.float
|
||||
)
|
||||
padded_target_token_ids = torch.tensor(
|
||||
padded_target_token_ids, dtype=torch.long
|
||||
)
|
||||
padded_teacher_mask_list = torch.tensor(
|
||||
padded_teacher_mask_list, dtype=torch.int
|
||||
)
|
||||
|
||||
# Pad using tokenizer for regular fields
|
||||
features = self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
# Add back teacher data if present
|
||||
if has_teacher_data:
|
||||
features["target_logprobs"] = padded_target_logprobs
|
||||
features["target_token_ids"] = padded_target_token_ids
|
||||
features["target_mask"] = padded_teacher_mask_list
|
||||
|
||||
# Prepare decoder_input_ids if the model supports it
|
||||
if (
|
||||
"labels" in features
|
||||
and self.model is not None
|
||||
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
||||
):
|
||||
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
|
||||
labels=features["labels"]
|
||||
)
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
|
||||
"""
|
||||
Collator for multipack (batch of sub-batches) specifically for KD.
|
||||
Adapts DataCollatorForKD so it can pack multiple sequences in a single batch item.
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
"""
|
||||
Expects that `features` could be either:
|
||||
- a single list of dicts, OR
|
||||
- a list of lists of dicts (the "sub-batches" to be packed).
|
||||
"""
|
||||
# 1) If we are *not* dealing with multiple sequences per batch element,
|
||||
# just pass straight to parent.
|
||||
if not isinstance(features[0], list):
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
# 2) Otherwise, we *are* dealing with multiple sequences in each batch item.
|
||||
# We want to produce a single "merged" feature dict for each sub-batch.
|
||||
out_features = [{} for _ in features]
|
||||
|
||||
for i, sub_features in enumerate(features):
|
||||
# sub_features is a list of dicts, each dict = one sequence’s features
|
||||
# We'll merge them into out_features[i].
|
||||
#
|
||||
# NOTE: You can customize how you combine fields as needed (e.g. summation
|
||||
# or offset for attention_mask). Below is a straightforward concatenation/extension.
|
||||
|
||||
for field_name in sub_features[0].keys():
|
||||
# Some fields you might want to skip or treat specially:
|
||||
if field_name == "length":
|
||||
continue
|
||||
|
||||
# If it’s a KD field that’s a list-of-lists (e.g. target_logprobs),
|
||||
# you typically just want to flatten them by extending.
|
||||
if field_name in ["target_logprobs", "target_token_ids", "target_mask"]:
|
||||
combined = []
|
||||
for feat in sub_features:
|
||||
combined.extend(feat[field_name])
|
||||
out_features[i][field_name] = combined
|
||||
|
||||
elif field_name == "attention_mask":
|
||||
# Here we apply the (j+1) factor to differentiate each sub-sample
|
||||
# within this merged batch item.
|
||||
arrays = []
|
||||
for j, feat in enumerate(sub_features):
|
||||
if field_name in feat:
|
||||
arrays.append((j + 1) * np.array(feat[field_name]))
|
||||
out_features[i][field_name] = np.concatenate(arrays)
|
||||
else:
|
||||
# By default, just concatenate them if they are arrays
|
||||
# or extend them if they are lists.
|
||||
# For example, input_ids or labels are often arrays.
|
||||
arrays = []
|
||||
for feat in sub_features:
|
||||
if field_name in feat:
|
||||
arr = np.array(feat[field_name])
|
||||
arrays.append(arr)
|
||||
out_features[i][field_name] = np.concatenate(arrays)
|
||||
|
||||
# 3) Now call the parent collator, which will do:
|
||||
# - padding of labels/position_ids
|
||||
# - KD-specific padding for target_logprobs, target_token_ids, etc.
|
||||
# - final conversion to return_tensors
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
0
src/axolotl/integrations/kd/kernels/__init__.py
Normal file
0
src/axolotl/integrations/kd/kernels/__init__.py
Normal file
58
src/axolotl/integrations/kd/topk_logprob/LICENSE.md
Normal file
58
src/axolotl/integrations/kd/topk_logprob/LICENSE.md
Normal file
@@ -0,0 +1,58 @@
|
||||
### AXOLOTL COMMUNITY LICENSE AGREEMENT
|
||||
|
||||
This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and
|
||||
any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms
|
||||
and conditions set forth in this Agreement.
|
||||
|
||||
1. Definitions
|
||||
1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.
|
||||
1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,
|
||||
which may be licensed separately by their respective authors and/or licensors.
|
||||
1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at
|
||||
https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which
|
||||
permits Plugin Integrations to integrate with the Axolotl service.
|
||||
2. Grant of License
|
||||
2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,
|
||||
publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:
|
||||
- Licensee must comply with all the terms and conditions of this Agreement.
|
||||
- Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial
|
||||
portions of the Software.
|
||||
2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.
|
||||
3. Restrictions
|
||||
3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for
|
||||
free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such
|
||||
third parties to fine-tune artificial intelligence models.
|
||||
3.2 Licensee shall not:
|
||||
- Use the Software for any illegal or unauthorized purpose.
|
||||
- Reverse engineer, decompile, or disassemble the Software.
|
||||
- Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.
|
||||
- Use the Software in a way that could damage, disable, overburden, or impair the functionality of the
|
||||
Software or interfere with any third-party use of the Software.
|
||||
3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.
|
||||
4. Intellectual Property Rights
|
||||
4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee
|
||||
acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to
|
||||
Licensee.
|
||||
5. Disclaimer of Warranty
|
||||
5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
|
||||
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL
|
||||
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF
|
||||
CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
6. Termination
|
||||
6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and
|
||||
conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any
|
||||
copies in its possession.
|
||||
7. Governing Law
|
||||
7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,
|
||||
without regards to conflicts of laws provisions thereof.
|
||||
8. Entire Agreement
|
||||
8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter
|
||||
hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning
|
||||
the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and
|
||||
Licensee’s continued use of the Software after any such updates shall constitute acceptance of updated terms
|
||||
on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any
|
||||
material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be
|
||||
bound by the terms and conditions of this Agreement.
|
||||
|
||||
This Agreement was last updated on August 23, 2024.
|
||||
235
src/axolotl/integrations/kd/topk_logprob/forward_kl.py
Normal file
235
src/axolotl/integrations/kd/topk_logprob/forward_kl.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
|
||||
"""
|
||||
loss for top_k KL divergence
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def zscore_standardize(
|
||||
logits: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
base_temperature: float = 1.0,
|
||||
eps: float = 1e-9,
|
||||
):
|
||||
"""
|
||||
Z-score standardize along the last dimension of `logits`.
|
||||
i.e., for each [B, seq_len] row, across K entries:
|
||||
z = (logits - mean) / std,
|
||||
then scale by 1 / base_temperature if desired.
|
||||
|
||||
mask can be broadcastable or None. If None, we standardize all elements.
|
||||
"""
|
||||
if mask is None:
|
||||
# shape: [B, seq_len, K]
|
||||
# Mean and std over dim=-1
|
||||
mean = logits.mean(dim=-1, keepdim=True)
|
||||
var = logits.var(dim=-1, unbiased=False, keepdim=True)
|
||||
else:
|
||||
# If you have to exclude some tokens, multiply by mask, etc.
|
||||
float_mask = mask.to(logits.dtype)
|
||||
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
|
||||
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
|
||||
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
|
||||
|
||||
std = torch.sqrt(var.clamp_min(eps))
|
||||
z = (logits - mean) / std
|
||||
|
||||
# Scale by 1 / base_temperature
|
||||
z = z / base_temperature
|
||||
return z
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def loss(
|
||||
student_logits: torch.Tensor,
|
||||
target_token_ids: torch.Tensor,
|
||||
target_logprobs: torch.Tensor,
|
||||
target_mask: torch.Tensor,
|
||||
num_items_in_batch: int = -1, # Use -1 to indicate "None"
|
||||
kd_temperature: float = 1.0,
|
||||
top_k_before_softmax: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
A KD loss function that is TorchScript-friendly.
|
||||
|
||||
Arguments:
|
||||
student_logits (torch.Tensor): The logits of the student model.
|
||||
Shape: [B, student_seq_len, vocab_size]
|
||||
target_token_ids (torch.Tensor): The top-k teacher/target token IDs
|
||||
Shape: [B, teacher_seq_len, top_k]
|
||||
target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized.
|
||||
Shape: [B, teacher_seq_len, top_k]
|
||||
target_mask (torch.Tensor): The mask for valid tokens.
|
||||
Shape: [B, teacher_seq_len, top_k]
|
||||
num_items_in_batch (int, optional): The number of items in the batch.
|
||||
kd_temperature (float, optional): The temperature for KD.
|
||||
Default: 1.0
|
||||
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
|
||||
Default: 0
|
||||
"""
|
||||
|
||||
target_logprobs = target_logprobs.float()
|
||||
|
||||
# Determine the teacher sequence length
|
||||
# target_token_ids shape: [B, teacher_seq_len, K]
|
||||
# student_logits shape: [B, student_seq_len, vocab_size]
|
||||
teacher_seq_len = target_token_ids.shape[1]
|
||||
|
||||
if top_k_before_softmax:
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
student_logits_topk = student_logits_topk.float()
|
||||
|
||||
# Apply KD temperature to student’s logits
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_topk = student_logits_topk / kd_temperature
|
||||
|
||||
# Convert student top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
||||
student_logits_topk, dim=-1, keepdim=True
|
||||
) # [B, teacher_seq_len, K]
|
||||
else:
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = (
|
||||
student_logits[:, :teacher_seq_len, :] / kd_temperature
|
||||
) # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# keep in full precision for numerical stability of loss
|
||||
student_logits_for_kd = student_logits_for_kd.float()
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
# Compute logsumexp across full vocabulary
|
||||
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
|
||||
|
||||
# Convert just the top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - student_lse
|
||||
|
||||
# Convert teacher_mask to boolean for indexing
|
||||
# In TorchScript, .bool() is sometimes unsupported, so we do:
|
||||
valid_mask = target_mask.to(torch.bool)
|
||||
|
||||
# Prune tensors to only keep valid tokens
|
||||
student_logprobs_topk = student_logprobs_topk[valid_mask]
|
||||
target_logprobs = target_logprobs[valid_mask]
|
||||
|
||||
# Convert teacher logprobs to probabilities
|
||||
teacher_probs = target_logprobs.exp()
|
||||
|
||||
# Compute forward KL
|
||||
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
|
||||
# Multiply by T^2 (classical KD scaling)
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items (if provided) or by valid tokens
|
||||
if num_items_in_batch > 0:
|
||||
kd_loss = kd_loss / float(num_items_in_batch)
|
||||
else:
|
||||
# Fall back to average over valid tokens
|
||||
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
|
||||
|
||||
return kd_loss
|
||||
|
||||
|
||||
def topk_kd_loss_with_zscore(
|
||||
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
||||
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
||||
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
|
||||
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
|
||||
kd_temperature: float = 1.0, # classic KD temperature
|
||||
zscore_base_temp: float = 1.0, # from the paper
|
||||
num_items_in_batch: int = -1,
|
||||
):
|
||||
"""
|
||||
A variant of top_k KL divergence with Z-score scaling
|
||||
from "Logit Standardization in Knowledge Distillation".
|
||||
"""
|
||||
|
||||
target_logprobs = target_logprobs.float()
|
||||
|
||||
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
|
||||
# 1) Gather the student's top-k logits to match teacher
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, seq_len, vocab]
|
||||
student_topk_logits = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, seq_len, K]
|
||||
|
||||
student_topk_logits = student_topk_logits.float()
|
||||
|
||||
# 2) If you want to keep the "classical" T scaling, apply it first
|
||||
if kd_temperature != 1.0:
|
||||
student_topk_logits = student_topk_logits / kd_temperature
|
||||
|
||||
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
|
||||
# (They differ by +some_constant from real logits, but in z-score
|
||||
# that constant is subtracted out anyway.)
|
||||
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
|
||||
|
||||
# 4) Z-score teacher and student
|
||||
# If target_mask is 2D, expand to 3D for the K dimension
|
||||
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
|
||||
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
|
||||
|
||||
teacher_z = zscore_standardize(
|
||||
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
|
||||
)
|
||||
student_z = zscore_standardize(
|
||||
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
|
||||
)
|
||||
|
||||
# 5) Convert to log-probs for KL
|
||||
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
|
||||
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
|
||||
|
||||
# 6) Restrict to valid tokens if needed
|
||||
valid_mask = target_mask.bool() # shape [B, seq_len, K]
|
||||
teacher_probs_z = teacher_logprobs_z.exp()
|
||||
teacher_probs_z = teacher_probs_z[valid_mask]
|
||||
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
|
||||
student_logprobs_z = student_logprobs_z[valid_mask]
|
||||
|
||||
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
|
||||
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
|
||||
# 8) If using classical KD scaling by T^2
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
|
||||
# kd_loss = kd_loss * (zscore_base_temp**2)
|
||||
|
||||
# 9) Normalize
|
||||
if num_items_in_batch is not None and num_items_in_batch > 0:
|
||||
kd_loss = kd_loss / float(num_items_in_batch)
|
||||
else:
|
||||
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
|
||||
|
||||
return kd_loss
|
||||
113
src/axolotl/integrations/kd/trainer.py
Normal file
113
src/axolotl/integrations/kd/trainer.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
KD trainer
|
||||
"""
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Custom trainer subclass for Knowledge Distillation (KD)
|
||||
"""
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
columns_to_add = []
|
||||
if self._signature_columns:
|
||||
if "target_logprobs" not in self._signature_columns:
|
||||
columns_to_add.append("target_logprobs")
|
||||
if "target_token_ids" not in self._signature_columns:
|
||||
columns_to_add.append("target_token_ids")
|
||||
if "target_mask" not in self._signature_columns:
|
||||
columns_to_add.append("target_mask")
|
||||
if columns_to_add:
|
||||
self._signature_columns += columns_to_add
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
|
||||
target_logprobs = inputs.pop("target_logprobs")
|
||||
target_token_ids = inputs.pop("target_token_ids")
|
||||
target_mask = inputs.pop("target_mask")
|
||||
|
||||
seq_len = target_token_ids.shape[1]
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
|
||||
# FIXME: account for tokenizer.padding_side
|
||||
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
|
||||
|
||||
shift_logits = student_logits.contiguous()
|
||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||
|
||||
if self.args.kd_zscore_base_temp:
|
||||
loss_kd = topk_kd_loss_with_zscore(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
zscore_base_temp=self.args.kd_zscore_base_temp,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
else:
|
||||
loss_kd = topk_kd_loss(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
||||
)
|
||||
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
kd_alpha = self.args.kd_alpha
|
||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||
else:
|
||||
loss = loss_kd
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
|
||||
self.args.past_index
|
||||
]
|
||||
|
||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||
loss *= self.accelerator.num_processes
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
@@ -2,9 +2,9 @@
|
||||
Module for the Plugin for LM Eval Harness
|
||||
"""
|
||||
import subprocess # nosec
|
||||
from datetime import datetime
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
|
||||
|
||||
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
@@ -18,25 +18,20 @@ class LMEvalPlugin(BasePlugin):
|
||||
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
tasks = ",".join(cfg.lm_eval_tasks)
|
||||
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
|
||||
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
|
||||
output_path = cfg.output_dir
|
||||
output_path += "" if cfg.output_dir.endswith("/") else "/"
|
||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
subprocess.run( # nosec
|
||||
[
|
||||
"lm_eval",
|
||||
"--model",
|
||||
"hf",
|
||||
"--model_args",
|
||||
f"pretrained={cfg.output_dir}{fa2}{dtype}",
|
||||
"--tasks",
|
||||
tasks,
|
||||
"--batch_size",
|
||||
str(cfg.lm_eval_batch_size),
|
||||
"--output_path",
|
||||
output_path,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
if cfg.lm_eval_post_train:
|
||||
# pylint: disable=duplicate-code
|
||||
for lm_eval_args in build_lm_eval_command(
|
||||
cfg.lm_eval_tasks,
|
||||
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||
flash_attention=cfg.flash_attention,
|
||||
output_dir=cfg.output_dir,
|
||||
batch_size=cfg.lm_eval_batch_size,
|
||||
wandb_project=cfg.wandb_project,
|
||||
wandb_entity=cfg.wandb_entity,
|
||||
wandb_name=cfg.wandb_name,
|
||||
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||
):
|
||||
subprocess.run( # nosec
|
||||
lm_eval_args,
|
||||
check=True,
|
||||
)
|
||||
|
||||
@@ -13,3 +13,5 @@ class LMEvalArgs(BaseModel):
|
||||
|
||||
lm_eval_tasks: List[str] = []
|
||||
lm_eval_batch_size: Optional[int] = 8
|
||||
lm_eval_post_train: Optional[bool] = True
|
||||
lm_eval_model: Optional[str] = None
|
||||
|
||||
119
src/axolotl/integrations/lm_eval/cli.py
Normal file
119
src/axolotl/integrations/lm_eval/cli.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
axolotl CLI for running lm_eval tasks
|
||||
"""
|
||||
import subprocess # nosec
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def build_lm_eval_command(
|
||||
tasks: list[str],
|
||||
bfloat16=True,
|
||||
flash_attention=False,
|
||||
output_dir="./",
|
||||
batch_size=8,
|
||||
wandb_project=None,
|
||||
wandb_entity=None,
|
||||
wandb_name=None,
|
||||
model=None,
|
||||
revision=None,
|
||||
apply_chat_template=None,
|
||||
fewshot_as_multiturn=None,
|
||||
):
|
||||
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
num_fewshot = "-1"
|
||||
task_parts = task.split(":")
|
||||
task_name = task_parts[0]
|
||||
if len(task_parts) == 2:
|
||||
task_name, num_fewshot = task_parts
|
||||
tasks_by_num_fewshot[str(num_fewshot)].append(task_name)
|
||||
|
||||
for num_fewshot, tasks_list in tasks_by_num_fewshot.items():
|
||||
tasks_str = ",".join(tasks_list)
|
||||
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
|
||||
pretrained = "pretrained="
|
||||
pretrained += model if model else output_dir
|
||||
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
|
||||
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
|
||||
revision = f",revision={revision}" if revision else ""
|
||||
output_path = output_dir
|
||||
output_path += "" if output_dir.endswith("/") else "/"
|
||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
lm_eval_args = [
|
||||
"lm_eval",
|
||||
"--model",
|
||||
"hf",
|
||||
"--model_args",
|
||||
f"{pretrained}{fa2}{dtype}{revision}",
|
||||
"--tasks",
|
||||
tasks_str,
|
||||
"--batch_size",
|
||||
str(batch_size),
|
||||
"--output_path",
|
||||
output_path,
|
||||
]
|
||||
wandb_args = []
|
||||
if wandb_project:
|
||||
wandb_args.append(f"project={wandb_project}")
|
||||
if wandb_entity:
|
||||
wandb_args.append(f"entity={wandb_entity}")
|
||||
if wandb_name:
|
||||
wandb_args.append(f"name={wandb_name}")
|
||||
if wandb_args:
|
||||
lm_eval_args.append("--wandb_args")
|
||||
lm_eval_args.append(",".join(wandb_args))
|
||||
if apply_chat_template:
|
||||
lm_eval_args.append("--apply_chat_template")
|
||||
if num_fewshot_val:
|
||||
lm_eval_args.append("--num_fewshot")
|
||||
lm_eval_args.append(str(num_fewshot_val))
|
||||
if apply_chat_template and fewshot_as_multiturn:
|
||||
lm_eval_args.append("--fewshot_as_multiturn")
|
||||
|
||||
yield lm_eval_args
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
def lm_eval(config: str, cloud: Optional[str] = None):
|
||||
"""
|
||||
use lm eval to evaluate a trained language model
|
||||
"""
|
||||
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_lm_eval
|
||||
|
||||
do_cli_lm_eval(cloud_config=cloud, config=config)
|
||||
else:
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
for lm_eval_args in build_lm_eval_command(
|
||||
cfg.lm_eval_tasks,
|
||||
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||
flash_attention=cfg.flash_attention,
|
||||
output_dir=cfg.output_dir,
|
||||
batch_size=cfg.lm_eval_batch_size,
|
||||
wandb_project=cfg.wandb_project,
|
||||
wandb_entity=cfg.wandb_entity,
|
||||
wandb_name=cfg.wandb_name,
|
||||
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||
revision=cfg.revision,
|
||||
apply_chat_template=cfg.apply_chat_template,
|
||||
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
|
||||
):
|
||||
subprocess.run( # nosec
|
||||
lm_eval_args,
|
||||
check=True,
|
||||
)
|
||||
201
src/axolotl/integrations/lolcats/LICENSE
Normal file
201
src/axolotl/integrations/lolcats/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
44
src/axolotl/integrations/lolcats/README.md
Normal file
44
src/axolotl/integrations/lolcats/README.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Low-rank Linear Conversion via Attention Transfer (LoLCATs)
|
||||
|
||||
https://github.com/HazyResearch/lolcats/
|
||||
|
||||
### Usage
|
||||
|
||||
Install `causal_dot_product` CUDA kernel (check the README in the `csrc` directory):
|
||||
|
||||
```bash
|
||||
cd src/axolotl/integrations/lolcats/linear_llama/csrc
|
||||
|
||||
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
|
||||
# nano setup.py
|
||||
|
||||
# Build the CUDA kernel
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
Step 1:
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.lolcats.LinearizePlugin
|
||||
|
||||
linearize: true
|
||||
```
|
||||
|
||||
Run axolotl: `python -m axolotl.cli.convert_linear_attention config.yaml` TODO: change path CLI
|
||||
|
||||
Step 2: Remove the config `linearize: true` and finetune with lora with below possible targets.
|
||||
|
||||
```yaml
|
||||
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
|
||||
# with optional config below but this requires patching axolotl
|
||||
# to allow this config to work with lora
|
||||
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
|
||||
```
|
||||
|
||||
`axolotl train config.yaml --base-model={output_dir}/distilled --trust-remote-code --learning-rate=0.0001 # --wandb-project="..."`
|
||||
|
||||
Step 3: Run inference on the finetuned model
|
||||
|
||||
`axolotl inference config.yaml --lora-model-dir="{output_dir}" --trust-remote-code # --prompter="AlpacaPrompter"`
|
||||
43
src/axolotl/integrations/lolcats/__init__.py
Normal file
43
src/axolotl/integrations/lolcats/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Module for the Plugin for LoLCATs linear attention integration with Axolotl.
|
||||
|
||||
Low-rank Linear Conversion via Attention Transfer
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import (
|
||||
DistillAttentionXentMSETrainer,
|
||||
)
|
||||
|
||||
from .args import LinearAttentionArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.lolcats")
|
||||
|
||||
|
||||
class LinearizePlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for lolcats integration with Axolotl.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Register the Linear Llama model with transformers
|
||||
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
||||
register_linear_llama,
|
||||
)
|
||||
|
||||
register_linear_llama()
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.lolcats.LinearAttentionArgs"
|
||||
|
||||
def get_trainer_cls(self, cfg):
|
||||
# defualt to XentMSE
|
||||
# TODO: add check to allow MSE_linear
|
||||
if cfg.linearize:
|
||||
return DistillAttentionXentMSETrainer
|
||||
|
||||
return None
|
||||
47
src/axolotl/integrations/lolcats/args.py
Normal file
47
src/axolotl/integrations/lolcats/args.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
Module for handling linear attention input arguments.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FeatureMapKwargs(BaseModel):
|
||||
"""Args for feature map"""
|
||||
|
||||
eps: float
|
||||
mlp: Optional[None] = None
|
||||
fullspace: bool
|
||||
|
||||
|
||||
class LearnedKernelKwargs(BaseModel):
|
||||
"""Args for learned kernel"""
|
||||
|
||||
feature_dim: int
|
||||
skip_connection: bool
|
||||
bias: bool
|
||||
zero_init: bool
|
||||
|
||||
|
||||
class AttentionConfig(BaseModel):
|
||||
"""Args for attention config"""
|
||||
|
||||
attention_type: str
|
||||
feature_map: str
|
||||
feature_map_kwargs: FeatureMapKwargs
|
||||
layer_idx: Optional[None] = None
|
||||
learned_kernel: str
|
||||
learned_kernel_kwargs: LearnedKernelKwargs
|
||||
tie_qk_kernels: bool
|
||||
train_qk: bool
|
||||
|
||||
|
||||
class LinearAttentionArgs(BaseModel):
|
||||
"""
|
||||
Input args for linear attention
|
||||
"""
|
||||
|
||||
attention_config: AttentionConfig
|
||||
|
||||
linearize: Optional[bool] = False
|
||||
@@ -0,0 +1,90 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Linear LLaMA model configuration"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import LlamaConfig
|
||||
|
||||
|
||||
class LinearLlamaConfig(LlamaConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`LinearLlamaModel`].
|
||||
It is a modified LlamaConfig that includes additional parameters for linear attention.
|
||||
|
||||
Args:
|
||||
attention_config (`dict`):
|
||||
Dictionary containing the configuration for linear attention mechanism.
|
||||
Expected contents:
|
||||
`attention_type` (str):
|
||||
The type of attention to convert to.
|
||||
`feature_map` (`str`):
|
||||
The type of feature map to use for linear attention.
|
||||
`feature_map_kwargs` (`dict`):
|
||||
Additional arguments for the feature map.
|
||||
`learned_kernel` (`str`, *optional*):
|
||||
Type of learned kernel to use, if any.
|
||||
`learned_kernel_kwargs` (`dict`, *optional*):
|
||||
Additional arguments for the learned kernel.
|
||||
`tie_qk_kernels` (`bool`, *optional*, defaults to False):
|
||||
Whether to tie query and key kernels.
|
||||
`rotary_config` (`dict`, *optional*):
|
||||
Configuration for rotary embeddings.
|
||||
`train_attention` (`bool`, *optional*, defaults to False):
|
||||
Whether to train attention to match softmax attention.
|
||||
`remove_base_attn` (`bool`, *optional*, defaults to True):
|
||||
Whether to remove base attention after initialization.
|
||||
`mask_value` (`int`, *optional*, defaults to 0):
|
||||
Value to use for masking.
|
||||
`eps` (`float`, *optional*, defaults to 1e-12):
|
||||
Epsilon value for numerical stability.
|
||||
`fp32_attention` (`bool`, *optional*, defaults to False):
|
||||
Whether to use fp32 precision for attention computation.
|
||||
`track_state_grads` (`bool`, *optional*, defaults to False):
|
||||
Whether to track gradients of attention states.
|
||||
|
||||
**kwargs:
|
||||
Additional arguments inherited from LlamaConfig.
|
||||
"""
|
||||
|
||||
model_type = "linear_llama"
|
||||
|
||||
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set auto_map
|
||||
self.auto_map = {
|
||||
"AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
|
||||
"AutoModel": "modeling_linear_llama.LinearLlamaModel",
|
||||
"AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
|
||||
}
|
||||
|
||||
# Set default attention config if none provided
|
||||
self.attention_config = attention_config or {"attention_type": "softmax"}
|
||||
|
||||
@classmethod
|
||||
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
|
||||
"""
|
||||
Instantiate a LinearLlamaConfig from a LlamaConfig and additional attention config.
|
||||
|
||||
Args:
|
||||
llama_config (:class:`~transformers.LlamaConfig`):
|
||||
The LlamaConfig to inherit from.
|
||||
|
||||
attention_config (`dict`):
|
||||
Dictionary containing the configuration for linear attention mechanism.
|
||||
"""
|
||||
|
||||
return cls(attention_config=attention_config, **llama_config.to_dict())
|
||||
30
src/axolotl/integrations/lolcats/linear_llama/csrc/README.md
Normal file
30
src/axolotl/integrations/lolcats/linear_llama/csrc/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Causal linear attention CUDA kernel
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
cd src/axolotl/integrations/lolcats/linear_llama/csrc
|
||||
|
||||
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
|
||||
# nano setup.py
|
||||
|
||||
# Build the CUDA kernel
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
Reference: https://github.com/idiap/fast-transformers/
|
||||
|
||||
```bib
|
||||
@inproceedings{katharopoulos_et_al_2020,
|
||||
author = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
|
||||
title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
|
||||
booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
|
||||
year = {2020}
|
||||
}
|
||||
|
||||
@article{vyas_et_al_2020,
|
||||
author={Vyas, A. and Katharopoulos, A. and Fleuret, F.},
|
||||
title={Fast Transformers with Clustered Attention},
|
||||
booktitle = {Proceedings of the International Conference on Neural Information Processing Systems (NeurIPS)},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,6 @@
|
||||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
from .causal_attention import causal_dot_product
|
||||
@@ -0,0 +1,225 @@
|
||||
//
|
||||
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
// Apoorv Vyas <avyas@idiap.ch>
|
||||
//
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
/**
|
||||
* Compute a*b^T and save it into out.
|
||||
*
|
||||
* a \in R^A
|
||||
* b \in R^B
|
||||
*/
|
||||
inline void vvt_dot(float *a, float *b, float *out, int A, int B) {
|
||||
for (int i=0; i<A; i++) {
|
||||
float * bi = b;
|
||||
for (int j=0; j<B; j++) {
|
||||
*out += (*a) * (*bi);
|
||||
out++;
|
||||
bi++;
|
||||
}
|
||||
a++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Implement a vector matrix product v*m and save it into out.
|
||||
*
|
||||
* v \in R^A
|
||||
* m \in R^{AxB}
|
||||
*/
|
||||
inline void vm_dot(float *v, float *m, float *out, int A, int B) {
|
||||
// TODO: Consider removing the zeroing part and assuming out already
|
||||
// contains 0s
|
||||
for (int i=0; i<B; i++) {
|
||||
out[i] = 0;
|
||||
}
|
||||
|
||||
for (int i=0; i<A; i++) {
|
||||
float *oi = out;
|
||||
for (int j=0; j<B; j++) {
|
||||
*oi += (*v) * (*m);
|
||||
oi++;
|
||||
m++;
|
||||
}
|
||||
v++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Implement a vector transposed-matrix product and save it into out.
|
||||
*
|
||||
* v \in R^B
|
||||
* m \in R^{AxB}
|
||||
*/
|
||||
inline void vmt_dot(float *v, float *m, float *out, int A, int B) {
|
||||
for (int i=0; i<A; i++) {
|
||||
float *vi = v;
|
||||
float s = 0;
|
||||
for (int j=0; j<B; j++) {
|
||||
s += (*vi) * (*m);
|
||||
vi++;
|
||||
m++;
|
||||
}
|
||||
// TODO: Should we be aggregating? See the comment on vm_dot.
|
||||
*out = s;
|
||||
out++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the causally masked dot products of queries, keys and values.
|
||||
*
|
||||
* Basically compute V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} for all j. The
|
||||
* computation is done efficiently by changing the order of the dot products.
|
||||
*/
|
||||
void causal_dot_product(
|
||||
const torch::Tensor queries,
|
||||
const torch::Tensor keys,
|
||||
const torch::Tensor values,
|
||||
torch::Tensor product
|
||||
) {
|
||||
// Extract some shapes
|
||||
int N = queries.size(0);
|
||||
int H = queries.size(1);
|
||||
int L = queries.size(2);
|
||||
int E = queries.size(3);
|
||||
int M = values.size(3);
|
||||
|
||||
// Create accessors for all the arguments
|
||||
auto qa = queries.accessor<float, 4>();
|
||||
auto ka = keys.accessor<float, 4>();
|
||||
auto va = values.accessor<float, 4>();
|
||||
auto pa = product.accessor<float, 4>();
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int n=0; n<N; n++) {
|
||||
for (int h=0; h<H; h++) {
|
||||
auto kv = torch::zeros({E, M}, queries.options());
|
||||
float *kvp = kv.data_ptr<float>();
|
||||
for (int l=0; l<L; l++) {
|
||||
vvt_dot(
|
||||
&ka[n][h][l][0],
|
||||
&va[n][h][l][0],
|
||||
kvp,
|
||||
E,
|
||||
M
|
||||
);
|
||||
vm_dot(
|
||||
&qa[n][h][l][0],
|
||||
kvp,
|
||||
&pa[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the gradients of queries, keys and values given the gradient of the
|
||||
* causal_dot_product output.
|
||||
*
|
||||
* Make sure that everything is computed in O(N D^2) complexity.
|
||||
*/
|
||||
void causal_dot_backward(
|
||||
const torch::Tensor queries,
|
||||
const torch::Tensor keys,
|
||||
const torch::Tensor values,
|
||||
const torch::Tensor grad_out,
|
||||
torch::Tensor grad_queries,
|
||||
torch::Tensor grad_keys,
|
||||
torch::Tensor grad_values
|
||||
) {
|
||||
// Extract some shapes
|
||||
int N = queries.size(0);
|
||||
int H = queries.size(1);
|
||||
int L = queries.size(2);
|
||||
int E = queries.size(3);
|
||||
int M = values.size(3);
|
||||
|
||||
// Create accessors for all the arguments
|
||||
auto qa = queries.accessor<float, 4>();
|
||||
auto ka = keys.accessor<float, 4>();
|
||||
auto va = values.accessor<float, 4>();
|
||||
auto ga = grad_out.accessor<float, 4>();
|
||||
auto gqa = grad_queries.accessor<float, 4>();
|
||||
auto gka = grad_keys.accessor<float, 4>();
|
||||
auto gva = grad_values.accessor<float, 4>();
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int n=0; n<N; n++) {
|
||||
for (int h=0; h<H; h++) {
|
||||
auto kv = torch::zeros({E, M}, queries.options());
|
||||
float *kvp = kv.data_ptr<float>();
|
||||
|
||||
// Compute the gradient wrt the queries
|
||||
for (int l=0; l<L; l++) {
|
||||
vvt_dot(
|
||||
&ka[n][h][l][0],
|
||||
&va[n][h][l][0],
|
||||
kvp,
|
||||
E,
|
||||
M
|
||||
);
|
||||
vmt_dot(
|
||||
&ga[n][h][l][0],
|
||||
kvp,
|
||||
&gqa[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
}
|
||||
|
||||
// Compute the gradient wrt the keys and values
|
||||
kv.zero_();
|
||||
for (int l=L-1; l>=0; l--) {
|
||||
vvt_dot(
|
||||
&qa[n][h][l][0],
|
||||
&ga[n][h][l][0],
|
||||
kvp,
|
||||
E,
|
||||
M
|
||||
);
|
||||
vmt_dot(
|
||||
&va[n][h][l][0],
|
||||
kvp,
|
||||
&gka[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
vm_dot(
|
||||
&ka[n][h][l][0],
|
||||
kvp,
|
||||
&gva[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"causal_dot_product",
|
||||
&causal_dot_product,
|
||||
"Compute the weighted sum of values but attending only to previous "
|
||||
"values."
|
||||
);
|
||||
m.def(
|
||||
"causal_dot_backward",
|
||||
&causal_dot_backward,
|
||||
"Compute the gradient of queries, keys and values given the gradient "
|
||||
"of causal_dot_product."
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda
|
||||
from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
causal_dot_product_cuda = causal_dot_backward_cuda = None
|
||||
|
||||
|
||||
class CausalDotProduct(torch.autograd.Function):
|
||||
"""Compute the weighted sum of values but attending only to previous
|
||||
values."""
|
||||
|
||||
dot = {
|
||||
# "cpu": causal_dot_product_cpu,
|
||||
"cuda": causal_dot_product_cuda
|
||||
}
|
||||
dot_backward = {
|
||||
# "cpu": causal_dot_backward_cpu,
|
||||
"cuda": causal_dot_backward_cuda
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, Q, K, V):
|
||||
# Save the inputs for the gradient computation
|
||||
ctx.save_for_backward(Q, K, V)
|
||||
|
||||
# Create the output tensor
|
||||
device = Q.device
|
||||
N, H, L, _ = Q.shape
|
||||
_, _, _, M = V.shape
|
||||
product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device)
|
||||
|
||||
# Actually perform the dot product
|
||||
CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
||||
# breakpoint()
|
||||
# CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
||||
|
||||
return product
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
# Extract the saved tensors
|
||||
Q, K, V = ctx.saved_tensors
|
||||
|
||||
# Allocate memory for the gradients
|
||||
grad_Q = torch.zeros_like(Q)
|
||||
grad_K = torch.zeros_like(K)
|
||||
grad_V = torch.zeros_like(V)
|
||||
|
||||
# Actually compute the gradients
|
||||
CausalDotProduct.dot_backward[Q.device.type](
|
||||
Q.data, K.data, V.data, grad_out, grad_Q, grad_K, grad_V
|
||||
)
|
||||
|
||||
return grad_Q, grad_K, grad_V
|
||||
|
||||
|
||||
# Alias the autograd functions to python style snake case naming
|
||||
causal_dot_product = CausalDotProduct.apply
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
65
src/axolotl/integrations/lolcats/linear_llama/csrc/setup.py
Normal file
65
src/axolotl/integrations/lolcats/linear_llama/csrc/setup.py
Normal file
@@ -0,0 +1,65 @@
|
||||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
|
||||
import subprocess # nosec
|
||||
|
||||
import torch
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
|
||||
|
||||
|
||||
def get_last_arch_torch():
|
||||
arch = torch.cuda.get_arch_list()[-1]
|
||||
print(f"Found arch: {arch} from existing torch installation")
|
||||
return arch
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True # nosec
|
||||
)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
|
||||
arch = get_last_arch_torch()
|
||||
sm_num = arch[-2:]
|
||||
cc_flag = ["--generate-code=arch=compute_90,code=compute_90"] # for H100
|
||||
# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100
|
||||
# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090
|
||||
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090
|
||||
# cc_flag = ['--generate-code=arch=compute_75,code=compute_75']
|
||||
|
||||
setup(
|
||||
name="causal_attention_cuda_cpp",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
"causal_attention_cuda",
|
||||
[
|
||||
# 'causal_attention.cpp',
|
||||
"causal_attention_cuda.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3"],
|
||||
"nvcc": append_nvcc_threads(
|
||||
["-O3", "-lineinfo", "--use_fast_math", "-std=c++17"] + cc_flag
|
||||
),
|
||||
},
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
@@ -0,0 +1,856 @@
|
||||
"""
|
||||
Linear attention classes
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||
try:
|
||||
from csrc import causal_dot_product as fast_causal_dot_product
|
||||
except ImportError:
|
||||
fast_causal_dot_product = None
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
# -------------------
|
||||
# Attention functions
|
||||
# -------------------
|
||||
|
||||
|
||||
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""
|
||||
Causal linear attention dot product
|
||||
- If available, use CUDA kernel from fast-transformers
|
||||
"""
|
||||
if fast_causal_dot_product is None:
|
||||
kv = torch.einsum("bhlf,bhld->bhlfd", k, v)
|
||||
return torch.einsum("bhlf,bhlfd->bhld", q, kv.cumsum(dim=2))
|
||||
return fast_causal_dot_product(q, k, v)
|
||||
|
||||
|
||||
def linear_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
fp32_attention: bool = False,
|
||||
eps: float = 1e-12,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Compute linear attention with CUDA kernel implementation from fast-transformers
|
||||
- https://github.com/idiap/fast-transformers
|
||||
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
|
||||
v is shape (b, h, l, head_dim)
|
||||
"""
|
||||
dtype = q.dtype
|
||||
# Causal mask already applied
|
||||
y = causal_dot_product(
|
||||
q.contiguous().to(dtype=torch.float32),
|
||||
k.contiguous().to(dtype=torch.float32),
|
||||
v.contiguous().to(dtype=torch.float32),
|
||||
)
|
||||
if fp32_attention:
|
||||
y = (
|
||||
y
|
||||
/ (
|
||||
torch.einsum("bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)) + eps
|
||||
)[..., None]
|
||||
).to(dtype=dtype)
|
||||
else:
|
||||
y = y.to(dtype=dtype)
|
||||
k = k.float().cumsum(dim=2).to(dtype=dtype)
|
||||
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
|
||||
return y, None, None
|
||||
|
||||
|
||||
def softmax_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: Optional[torch.Tensor] = None,
|
||||
causal: bool = True,
|
||||
fp32_attention: bool = True,
|
||||
):
|
||||
"""
|
||||
Standard softmax attention; only compute outputs if v is not None
|
||||
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
|
||||
"""
|
||||
y = None
|
||||
a = torch.einsum("bhmd,bhnd->bhmn", q, k) * (k.shape[-1] ** -0.5)
|
||||
if causal: # Apply causal mask
|
||||
m, n = a.shape[-2:]
|
||||
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
|
||||
n - m + 1
|
||||
)
|
||||
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
|
||||
if fp32_attention:
|
||||
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
else:
|
||||
a = torch.softmax(a, dim=-1)
|
||||
if v is not None:
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
|
||||
return y, a, None
|
||||
|
||||
|
||||
def quadratic_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: Optional[torch.Tensor] = None,
|
||||
causal: bool = True,
|
||||
fp32_attention: bool = False,
|
||||
eps: float = 1e-12,
|
||||
):
|
||||
"""
|
||||
Compute attention with feature maps by instantiating L x L matrix of attention weights
|
||||
-> Use for attention distillation
|
||||
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
|
||||
"""
|
||||
y = None
|
||||
dtype = q.dtype
|
||||
if fp32_attention:
|
||||
q, k = q.float(), k.float()
|
||||
a = torch.einsum("bhmd,bhnd->bhmn", q, k) # note we don't scale, tho we could
|
||||
if causal: # Apply causal mask
|
||||
m, n = a.shape[-2:]
|
||||
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
|
||||
n - m + 1
|
||||
)
|
||||
a = a.masked_fill(causal_mask, 0)
|
||||
# Normalize to compute attention
|
||||
a = a / (a.sum(dim=-1, keepdim=True) + eps)
|
||||
a = a.to(dtype=dtype) if fp32_attention else a
|
||||
if torch.isnan(a).sum() > 0:
|
||||
breakpoint()
|
||||
if v is not None:
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
|
||||
return y, a, None
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Attention layer class
|
||||
# ---------------------
|
||||
|
||||
|
||||
class LolcatsLinearAttention(nn.Module):
|
||||
"""
|
||||
LoLCATs attention implementation initialized from a
|
||||
`LlamaAttention` or `MistralAttention` object (base_attn)
|
||||
|
||||
Most of the arguments are directly tied to argparse args
|
||||
- For now we don't support padding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_attn: nn.Module, # like LlamaAttention
|
||||
feature_map: str,
|
||||
feature_map_kwargs: dict,
|
||||
layer_idx: Optional[int] = None,
|
||||
max_layer_idx: Optional[int] = None,
|
||||
learned_kernel: Optional[str] = None,
|
||||
learned_kernel_kwargs: Optional[dict] = None,
|
||||
tie_qk_kernels: Optional[bool] = False,
|
||||
rotary_config: Optional[dict] = None,
|
||||
train_attention: Optional[bool] = False,
|
||||
remove_base_attn: bool = True,
|
||||
attention_type: Optional[str] = "lolcats_llama",
|
||||
mask_value: int = 0,
|
||||
eps: float = 1e-12,
|
||||
fp32_attention: bool = False,
|
||||
track_state_grads: bool = False,
|
||||
rank: Optional[int] = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.base_config = getattr(base_attn, "config", None)
|
||||
if self.base_config is not None:
|
||||
self.base_config = self.base_config.to_dict()
|
||||
self.attention_type = attention_type
|
||||
self.mask_value = mask_value
|
||||
self.eps = eps
|
||||
self.layer_idx = layer_idx if layer_idx is not None else base_attn.layer_idx
|
||||
self.max_layer_idx = max_layer_idx
|
||||
self.tie_qk_kernels = tie_qk_kernels
|
||||
self.train_attention = train_attention
|
||||
self.base_inference = False
|
||||
self.fp32_attention = fp32_attention
|
||||
self.track_state_grads = track_state_grads
|
||||
if rank == 0: # multi-gpu
|
||||
if fp32_attention and layer_idx == 0:
|
||||
print(f"-> fp32_attention is {fp32_attention}")
|
||||
if layer_idx == 0 and feature_map_kwargs is not None:
|
||||
for k, v in feature_map_kwargs.items():
|
||||
print(f"-> {k}: {v}")
|
||||
if layer_idx == 0 and learned_kernel_kwargs is not None:
|
||||
for k, v in learned_kernel_kwargs.items():
|
||||
print(f"-> {k}: {v}")
|
||||
|
||||
self.remove_base_attn = remove_base_attn
|
||||
|
||||
self.init_weights_(base_attn, remove_base_attn)
|
||||
self.init_feature_map_(
|
||||
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
|
||||
)
|
||||
|
||||
def init_feature_map_(
|
||||
self,
|
||||
feature_map: str,
|
||||
feature_map_kwargs: dict,
|
||||
learned_kernel: Optional[str] = None,
|
||||
learned_kernel_kwargs: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize MLP-based feature map
|
||||
"""
|
||||
self.fmap_gqa = False # Turn True if specified below
|
||||
if learned_kernel is not None and learned_kernel_kwargs is not None:
|
||||
# Ensure dict
|
||||
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
|
||||
learned_kernel_kwargs["num_heads"] = self.num_heads
|
||||
learned_kernel_kwargs["head_dim"] = self.head_dim
|
||||
learned_kernel_kwargs["dtype"] = self.q_proj.weight.dtype
|
||||
learned_kernel_kwargs["device"] = self.q_proj.weight.device
|
||||
# Create MLP
|
||||
mlp_learned_kernel = init_learned_kernel(
|
||||
learned_kernel, **learned_kernel_kwargs
|
||||
)
|
||||
# Add "activation"; see src.models.feature_map.py
|
||||
self.feature_map_q = init_feature_map(
|
||||
name=feature_map, mlp=mlp_learned_kernel, **feature_map_kwargs
|
||||
)
|
||||
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
|
||||
self.feature_map_k = self.feature_map_q
|
||||
else:
|
||||
self.feature_map_k = copy.deepcopy(self.feature_map_q)
|
||||
|
||||
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
|
||||
"""
|
||||
Initialize module layers, weights, positional dependencies, etc.
|
||||
from original softmax attention layer (base_attn)
|
||||
"""
|
||||
# Make other attributes accessible
|
||||
self.attention_dropout = 0 # We don't use dropout
|
||||
self.hidden_size = base_attn.config.hidden_size
|
||||
self.num_heads = base_attn.config.num_attention_heads
|
||||
self.head_dim = base_attn.head_dim
|
||||
self.num_key_value_heads = base_attn.config.num_key_value_heads
|
||||
self.num_key_value_groups = base_attn.num_key_value_groups
|
||||
|
||||
self.q_shape = [self.num_heads, self.head_dim]
|
||||
self.k_shape = [self.num_key_value_heads, self.head_dim]
|
||||
self.v_shape = [self.num_key_value_heads, self.head_dim]
|
||||
|
||||
# Copy original model projection layers
|
||||
self.q_proj = base_attn.q_proj
|
||||
self.k_proj = base_attn.k_proj
|
||||
self.v_proj = base_attn.v_proj
|
||||
self.o_proj = base_attn.o_proj
|
||||
try: # If wanting to use FA2 for ground-truth inference
|
||||
self._flash_attn_uses_top_left_mask = (
|
||||
base_attn._flash_attn_uses_top_left_mask
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if self.remove_base_attn or remove_base_attn:
|
||||
del base_attn # We don't need to keep these around
|
||||
else:
|
||||
self.base_attn = base_attn # For some training runs helpful to just call
|
||||
|
||||
def process_qkv(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_value: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Compute queries, keys, and values
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
kv_seq_len = k.shape[-2]
|
||||
|
||||
# Shape is (batch_size, seq_len, num_heads, head_dim)
|
||||
q = q.view(b, l, *self.q_shape).transpose(1, 2)
|
||||
k = k.view(b, l, *self.k_shape).transpose(1, 2)
|
||||
v = v.view(b, l, *self.v_shape).transpose(1, 2)
|
||||
|
||||
if (
|
||||
past_key_value is not None
|
||||
): # and k.shape[2] > q.shape[2]: # e.g., when generating
|
||||
past_key_value.window_size = getattr(
|
||||
self, "decode_window_size", None
|
||||
) # self.decode_window_size
|
||||
if isinstance(
|
||||
past_key_value, Cache
|
||||
): # In Transformers v4.36+ this is a DynamicCache object
|
||||
kv_seq_len += past_key_value.get_usable_length(
|
||||
kv_seq_len, self.layer_idx
|
||||
)
|
||||
else:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
k = repeat_kv(k, self.num_key_value_groups)
|
||||
v = repeat_kv(v, self.num_key_value_groups)
|
||||
return q, k, v, kv_seq_len
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_value: Optional[Any] = None, # "legacy" cache approach
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
|
||||
- Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_embeddings, past_key_value
|
||||
)
|
||||
|
||||
if self.base_inference:
|
||||
with torch.no_grad():
|
||||
# 1. Compute "ground-truth" attention output and weights
|
||||
y_true, _, _ = softmax_attention(q, k, v, causal=True)
|
||||
y_true = (
|
||||
y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
)
|
||||
y_true = self.o_proj(y_true)
|
||||
attn_weights = (None, None)
|
||||
|
||||
elif self.train_attention: # Distilling / learning attentions
|
||||
# Note for now we assume no padding when distilling; attention masks only enforce causality
|
||||
assert (
|
||||
output_attentions is True
|
||||
), f"When training feature maps, output_attentions should be True but is {output_attentions}"
|
||||
with torch.no_grad():
|
||||
# 1. Compute "ground-truth" attention output and weights
|
||||
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
|
||||
y_true = (
|
||||
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
)
|
||||
y_true = self.o_proj(y_true)
|
||||
|
||||
# 2. Compute "predicted" attention (just weights)
|
||||
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
|
||||
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
|
||||
attn_weights = ( # type: ignore
|
||||
(attn_pred, attn_true),
|
||||
(y_pred, _y_true),
|
||||
) # Save both attention weights so we can supervise.
|
||||
|
||||
else: # Finetuning
|
||||
q, k = self.feature_map_q(q), self.feature_map_k(k)
|
||||
# Apply prefill mask
|
||||
if attention_mask is not None and q.shape[2] > 1:
|
||||
if len(attention_mask.shape) == 4:
|
||||
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][
|
||||
..., None
|
||||
] # b, 1, k_len, 1
|
||||
else:
|
||||
lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
|
||||
k = k.masked_fill(~lin_attn_mask, 0)
|
||||
|
||||
if past_key_value is not None: # Initialize states
|
||||
if len(past_key_value.kv_states) == self.layer_idx:
|
||||
b, h, _, f = k.shape
|
||||
past_key_value.kv_states.append(
|
||||
torch.zeros(
|
||||
b, h, f, self.head_dim, dtype=q.dtype, device=q.device
|
||||
)
|
||||
)
|
||||
past_key_value.k_states.append(
|
||||
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
|
||||
)
|
||||
# Generating
|
||||
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
|
||||
assert use_cache is True
|
||||
kv_state, k_state = past_key_value.update(
|
||||
k, v, self.layer_idx, accumulate_in_fp32=self.fp32_attention
|
||||
)
|
||||
if self.fp32_attention:
|
||||
q = q.float()
|
||||
y_true = (
|
||||
torch.einsum("bhlf,bhfd->bhld", q, kv_state.float())
|
||||
/ torch.einsum("bhlf,bhlf->bhl", q, k_state.float())[
|
||||
..., None
|
||||
]
|
||||
).to(dtype=k.dtype)
|
||||
else:
|
||||
y_true = (
|
||||
torch.einsum("bhlf,bhfd->bhld", q, kv_state)
|
||||
/ torch.einsum("bhlf,bhlf->bhl", q, k_state)[..., None]
|
||||
)
|
||||
else:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
y_true, _, _ = linear_attention(
|
||||
q, k, v, self.fp32_attention, self.eps
|
||||
) # Ordinarily the states are ignored
|
||||
past_key_value.update(
|
||||
k.detach(),
|
||||
v.detach(),
|
||||
self.layer_idx,
|
||||
accumulate_in_fp32=self.fp32_attention,
|
||||
)
|
||||
# doing some unnecessary recomputation here
|
||||
else:
|
||||
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
|
||||
|
||||
# Concatenate heads and apply output projection
|
||||
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
y_true = self.o_proj(y_true)
|
||||
attn_weights = None
|
||||
|
||||
return y_true, attn_weights
|
||||
|
||||
|
||||
class LinearAttentionState(Cache):
|
||||
"""
|
||||
Handle the KV and K states for linear attention
|
||||
- Adopts HF Transformers `past_key_values` convention
|
||||
- Inherits from `Cache` class
|
||||
- Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.kv_states: List[torch.Tensor] = []
|
||||
self.k_states: List[torch.Tensor] = []
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""
|
||||
Returns the sequence length of the cached states. A layer index can be optionally passed.
|
||||
"""
|
||||
if layer_idx is None:
|
||||
raise ValueError("Layer index must not be None")
|
||||
|
||||
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
|
||||
self._seen_tokens_by_layer.append(0)
|
||||
return self._seen_tokens_by_layer[layer_idx]
|
||||
|
||||
def get_max_length(self) -> Optional[int]:
|
||||
"""
|
||||
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_usable_length(
|
||||
self, new_seq_length: int, layer_idx: Optional[int] = 0
|
||||
) -> int:
|
||||
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
||||
# Cache without size limit -> all cache is usable
|
||||
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
||||
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
||||
max_length = self.get_max_length()
|
||||
previous_seq_length = self.get_seq_length(layer_idx)
|
||||
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
||||
return max_length - new_seq_length
|
||||
return previous_seq_length
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_kwargs: Optional[Any] = None,
|
||||
accumulate_in_fp32: bool = True,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if layer_idx is None:
|
||||
raise ValueError("Layer index must not be None")
|
||||
|
||||
with torch.no_grad():
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
dtype = key_states.dtype
|
||||
if accumulate_in_fp32:
|
||||
key_states, value_states = key_states.float(), value_states.float()
|
||||
|
||||
kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd", key_states, value_states
|
||||
).detach()
|
||||
k_state = key_states.sum(
|
||||
dim=-2, keepdim=True
|
||||
).detach() # b, h, 1, f; note the 1
|
||||
# Update the cache
|
||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||
print(
|
||||
"if len(self.k_states) <= layer_idx: # Initializing kv and k states"
|
||||
)
|
||||
self.kv_states.append(kv_state.to(dtype))
|
||||
self.k_states.append(k_state.to(dtype))
|
||||
else:
|
||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||
dtype
|
||||
)
|
||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||
dtype
|
||||
)
|
||||
self.kv_states[layer_idx] = kv_state
|
||||
self.k_states[layer_idx] = k_state
|
||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||
|
||||
def to_legacy_cache(self):
|
||||
"""Hack, but just return self"""
|
||||
return self
|
||||
|
||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||
"""
|
||||
Reorders the cache for beam search, given the selected beam indices.
|
||||
-> Copied from transformers/src/transformers/cache_utils.py
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Reordering cache not implemented for LinearAttentionState"
|
||||
)
|
||||
|
||||
|
||||
# -------------------
|
||||
# feature map functions
|
||||
# -------------------
|
||||
|
||||
|
||||
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
|
||||
"""
|
||||
Initialize feature map final activation for linear attention
|
||||
"""
|
||||
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
|
||||
|
||||
|
||||
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
|
||||
"""
|
||||
Initialize feature map final activation for linear attention
|
||||
"""
|
||||
if name == "softmax_dim" and fullspace:
|
||||
return SoftmaxDim(**kwargs)
|
||||
elif name == "softmax_dim" and not fullspace:
|
||||
return SoftmaxDimHalfspace(**kwargs)
|
||||
elif name == "exp_dim" and fullspace:
|
||||
return Exp(**kwargs)
|
||||
elif name == "exp_dim" and not fullspace:
|
||||
return ExpHalfspace(**kwargs)
|
||||
elif name == "pos_elu":
|
||||
return PosELU(**kwargs)
|
||||
elif name == "relu":
|
||||
return ReLU(**kwargs)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def init_learned_kernel(name: str, **kwargs):
|
||||
"""
|
||||
Initialize feature map MLP for linear attention
|
||||
"""
|
||||
if name == "untied_head_einsum":
|
||||
return FeatureMapMLP(**kwargs)
|
||||
elif name == "untied_head_adapter":
|
||||
return FeatureMapAdapter(**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FeatureMap(nn.Module):
|
||||
"""
|
||||
Final 'activation' of feature map. Can probably be combined with
|
||||
`FeatureMapMLP` below
|
||||
|
||||
Full feature map is like f(xW + b)
|
||||
-> This is the `f` part
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_name: str,
|
||||
head_dim_idx: int = -1,
|
||||
eps: float = 1e-12,
|
||||
mlp: Optional[nn.Module] = None,
|
||||
fullspace: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim_idx = head_dim_idx
|
||||
self.eps = eps
|
||||
self.mlp = mlp if mlp is not None else nn.Identity()
|
||||
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
|
||||
|
||||
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
|
||||
"""
|
||||
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
|
||||
"""
|
||||
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
|
||||
|
||||
def q_map(self, *args, **kwargs):
|
||||
"""
|
||||
Use for inference in case q and k feature maps differ
|
||||
"""
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def k_map(self, *args, **kwargs):
|
||||
"""
|
||||
Use for inference in case q and k feature maps differ
|
||||
"""
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Feature map activations
|
||||
# -----------------------
|
||||
class FeatureMapAct(nn.Module):
|
||||
"""
|
||||
Base class for feature map activations
|
||||
"""
|
||||
|
||||
def __init__(self, eps: float = 1e-12):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
"""
|
||||
x.shape is (batch_size, n_heads, seq_len, head_dim)
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
class PosELU(FeatureMapAct):
|
||||
"""
|
||||
1 + ELU activation as in https://arxiv.org/abs/2006.16236
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
return (1 + F.elu(x)).clamp(min=self.eps)
|
||||
|
||||
|
||||
class ReLU(FeatureMapAct):
|
||||
"""
|
||||
ReLU activation as in https://arxiv.org/abs/2103.13076
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
return F.relu(x).clamp(min=self.eps)
|
||||
|
||||
|
||||
class SoftmaxDim(FeatureMapAct):
|
||||
"""
|
||||
Softmax activation as in https://arxiv.org/abs/2402.04347
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
return torch.cat(
|
||||
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
|
||||
).clamp(min=self.eps)
|
||||
|
||||
|
||||
class SoftmaxDimHalfspace(FeatureMapAct):
|
||||
"""
|
||||
Softmax activation as in https://arxiv.org/abs/2402.04347
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
return torch.softmax(x, dim=-1).clamp(min=self.eps)
|
||||
|
||||
|
||||
class Exp(FeatureMapAct):
|
||||
"""
|
||||
Exp activation as in https://arxiv.org/abs/2402.04347
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
x_max = torch.amax(x, dim=-1, keepdim=True)
|
||||
x_min = torch.amin(x, dim=-1, keepdim=True)
|
||||
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
|
||||
min=self.eps
|
||||
)
|
||||
|
||||
|
||||
class ExpHalfspace(FeatureMapAct):
|
||||
"""
|
||||
Exp activation as in https://arxiv.org/abs/2402.04347
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
x_max = torch.amax(x, dim=-1, keepdim=True)
|
||||
return torch.exp(x - x_max).clamp(min=self.eps)
|
||||
|
||||
|
||||
# ----------------
|
||||
# Feature map MLPs
|
||||
# ----------------
|
||||
|
||||
|
||||
class FeatureMapMLP(nn.Module):
|
||||
"""
|
||||
Learnable MLP in feature map.
|
||||
|
||||
Full feature map is like f(xW + b)
|
||||
-> This is the `W` and (optional) `b` part
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_dim: int, # input dim
|
||||
feature_dim: int, # output dim
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
skip_connection: bool = False,
|
||||
bias: bool = False,
|
||||
zero_init: bool = False,
|
||||
normal_init: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.feature_dim = feature_dim
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.skip_connection = skip_connection
|
||||
self.bias = bias
|
||||
self.zero_init = zero_init
|
||||
self.normal_init = normal_init
|
||||
self.init_weights_()
|
||||
|
||||
if self.zero_init: # Zero-out weights or set as identity post-initialization
|
||||
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
|
||||
|
||||
if self.normal_init:
|
||||
with torch.no_grad():
|
||||
nn.init.normal_(self.layer)
|
||||
|
||||
if self.skip_connection:
|
||||
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
|
||||
assert self.head_dim == self.feature_dim, assertion_fail
|
||||
|
||||
def init_weights_(self):
|
||||
"""
|
||||
Initialize (W)eights and (b)iases
|
||||
"""
|
||||
self.layer = nn.Parameter(
|
||||
torch.zeros(
|
||||
(self.num_heads, self.head_dim, self.feature_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
nn.init.kaiming_uniform_(self.layer)
|
||||
|
||||
if self.bias:
|
||||
self.bias = nn.Parameter(
|
||||
torch.zeros(
|
||||
(1, self.num_heads, 1, 1), # self.feature_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
nn.init.kaiming_uniform_(self.bias)
|
||||
else:
|
||||
self.bias = 0.0 # hack
|
||||
|
||||
def zero_init_with_skip_(self):
|
||||
"""
|
||||
Initialize weights to zero matrix if skip connection
|
||||
"""
|
||||
with torch.no_grad():
|
||||
nn.init.zeros_(self.layer)
|
||||
|
||||
def zero_init_(self):
|
||||
"""
|
||||
Initialize weights to identity matrix if no skip connection
|
||||
"""
|
||||
with torch.no_grad():
|
||||
for i in range(self.layer.shape[0]):
|
||||
try:
|
||||
nn.init.eye_(self.layer[i])
|
||||
except RuntimeError:
|
||||
with torch.no_grad():
|
||||
dtype = self.layer[i].dtype
|
||||
weight = torch.eye(
|
||||
*self.layer[i].shape,
|
||||
requires_grad=self.layer[i].requires_grad,
|
||||
device=self.layer[i].device,
|
||||
)
|
||||
self.layer[i] = weight.to(dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
||||
"""
|
||||
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
|
||||
return x + _x if self.skip_connection else _x
|
||||
|
||||
|
||||
class FeatureMapAdapter(FeatureMapMLP):
|
||||
"""
|
||||
Learnable Feature map with bottleneck adapter
|
||||
as in https://arxiv.org/abs/1902.00751
|
||||
|
||||
We don't use but could be fun to try
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int, *args, **kwargs):
|
||||
kwargs["skip_connection"] = True
|
||||
kwargs["bias"] = True
|
||||
kwargs["zero_init"] = True
|
||||
self.hidden_dim = hidden_dim
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def init_weights_(self):
|
||||
"""
|
||||
Initialize (W)eights and (b)iases
|
||||
"""
|
||||
kwargs = {"dtype": self.dtype, "device": self.device}
|
||||
self.layer0 = nn.Parameter(
|
||||
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
|
||||
)
|
||||
self.layer1 = nn.Parameter(
|
||||
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
|
||||
)
|
||||
nn.init.kaiming_uniform_(self.layer0)
|
||||
nn.init.kaiming_uniform_(self.layer1)
|
||||
|
||||
self.bias0 = nn.Parameter(
|
||||
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
|
||||
)
|
||||
self.bias1 = nn.Parameter(
|
||||
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
|
||||
)
|
||||
nn.init.kaiming_uniform_(self.bias0)
|
||||
nn.init.kaiming_uniform_(self.bias1)
|
||||
|
||||
def zero_init_with_skip_(self):
|
||||
with torch.no_grad():
|
||||
nn.init.zeros_(self.layer0)
|
||||
nn.init.zeros_(self.layer1)
|
||||
nn.init.zeros_(self.bias0)
|
||||
nn.init.zeros_(self.bias1)
|
||||
|
||||
def zero_init_(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
||||
-> Down-project, apply nonlinearity, up-project; add skip connection
|
||||
"""
|
||||
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
|
||||
_x = F.relu(_x)
|
||||
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
|
||||
return x + _x if self.skip_connection else _x
|
||||
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
Subquadratic attention combining sliding window and linear attentions
|
||||
- Using "standard" sliding windows
|
||||
- Didactically computes outputs with n^2 attention weights for now
|
||||
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from .linear_attention import (
|
||||
LinearAttentionState,
|
||||
LolcatsLinearAttention,
|
||||
softmax_attention,
|
||||
)
|
||||
|
||||
|
||||
# ----------------------
|
||||
# Sliding window helpers
|
||||
# ----------------------
|
||||
def get_masks(
|
||||
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return masks for softmax and linear attention terms
|
||||
-> 1 is include, 0 is ignore
|
||||
"""
|
||||
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||
k_len - q_len
|
||||
)
|
||||
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||
k_len - q_len - window_size
|
||||
)
|
||||
window_mask = causal_mask - linear_mask
|
||||
# Return softmax mask (window), linear attention mask
|
||||
# -> shapes broadcast over (b, h, q_len, k_len)
|
||||
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
||||
|
||||
|
||||
def hybrid_attention_quadratic(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_factor: torch.Tensor,
|
||||
linear_factor: torch.Tensor,
|
||||
window_size: int,
|
||||
kv_state: Optional[torch.Tensor] = None,
|
||||
k_state: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-12,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Hybrid attention combining sliding window and linear attentions
|
||||
"""
|
||||
|
||||
mask_window, mask_linear = get_masks(
|
||||
window_size, q.shape[-2], k.shape[-2], q.device
|
||||
)
|
||||
|
||||
# 1. Sliding window (softmax attention)
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 2. Under window (linear attention)
|
||||
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 3. Combine
|
||||
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||
# Allow outputs to also depend on prior kv_state and k_state
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||
if (
|
||||
kv_state is not None and k_state is not None
|
||||
): # Combine with prior kv_state and k_state
|
||||
y += linear_factor * torch.einsum(
|
||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||
)
|
||||
sum_ln += (
|
||||
linear_factor
|
||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||
)
|
||||
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||
return y, a # attention weights only for the last chunk
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Attention layer class
|
||||
# ---------------------
|
||||
class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 64,
|
||||
decode_window_size: Optional[int] = None,
|
||||
affine_attention_factors: bool = False,
|
||||
init_window_factor: float = 0,
|
||||
train_window_factor: bool = True,
|
||||
state_grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.window_size = window_size
|
||||
self.decode_window_size = (
|
||||
decode_window_size if decode_window_size is not None else window_size
|
||||
)
|
||||
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
super().__init__(**kwargs)
|
||||
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_sw'
|
||||
# Determine how we compute attentions
|
||||
self.quadratic_attention = hybrid_attention_quadratic
|
||||
self.attention_type = kwargs[
|
||||
"attention_type"
|
||||
] # 'hedgehog_long_llama_window_sw'
|
||||
# Learnable factor for combining attentions
|
||||
self.affine_attention_factors = affine_attention_factors
|
||||
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||
if train_window_factor:
|
||||
self.window_factors = nn.Parameter(
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"window_factors",
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||
self.base_inference = False
|
||||
self.state_grad_enabled = state_grad_enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||
k
|
||||
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||
|
||||
if self.train_attention:
|
||||
# 1. Compute "ground-truth" attention output and weights
|
||||
with torch.no_grad():
|
||||
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
||||
y_true = (
|
||||
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
)
|
||||
y_true = self.o_proj(y_true)
|
||||
|
||||
# 2. Compute "predicted" attention outputs
|
||||
# compute attn weights under sliding window
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||
y_pred, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
||||
else:
|
||||
attn_weights = None
|
||||
# attention_mask = None # For now this is always True
|
||||
if past_key_value is None: # Regular training
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = a_pred
|
||||
else:
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
if (
|
||||
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||
): # Generating
|
||||
assert use_cache is True
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||
)
|
||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
|
||||
# Softmax attention terms
|
||||
a_sm = torch.einsum(
|
||||
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||
) * (k.shape[-1] ** -0.5)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Combine with linear attention terms
|
||||
y_true = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum(
|
||||
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
|
||||
)[..., None]
|
||||
)
|
||||
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Stateful training
|
||||
try:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
except IndexError:
|
||||
kv_state, k_state = None, None
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, _ = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
kv_state=kv_state,
|
||||
k_state=k_state,
|
||||
)
|
||||
# Save and update KV cache and states
|
||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||
# fmap_key_states=f_k.detach(),
|
||||
# accumulate_in_fp32=True)
|
||||
past_key_value.update(
|
||||
k,
|
||||
v,
|
||||
self.layer_idx,
|
||||
fmap_key_states=f_k,
|
||||
accumulate_in_fp32=True,
|
||||
)
|
||||
# Concatenate heads and apply output projection
|
||||
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
y_true = self.o_proj(y_true)
|
||||
return y_true, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
||||
"""
|
||||
Class for `past_key_values`
|
||||
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int = 64) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.kv_states: List[torch.Tensor] = []
|
||||
self.k_states: List[torch.Tensor] = []
|
||||
|
||||
# Account for sliding windows
|
||||
self.decode_kv_states: List[torch.Tensor] = []
|
||||
self.decode_k_states: List[torch.Tensor] = []
|
||||
self.k_cache: List[torch.Tensor] = []
|
||||
self.v_cache: List[torch.Tensor] = []
|
||||
self.window_size = window_size
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_kwargs: Optional[Any] = None,
|
||||
accumulate_in_fp32: bool = False,
|
||||
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||
grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Update KV, K states; and KV cache during training
|
||||
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||
up to sliding window terms
|
||||
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||
up to end of sequence
|
||||
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||
"""
|
||||
if fmap_key_states is None:
|
||||
raise ValueError("fmap_key_states must not be None")
|
||||
|
||||
if layer_idx is None:
|
||||
raise ValueError("Layer index must not be None")
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
dtype = key_states.dtype
|
||||
if accumulate_in_fp32:
|
||||
# key_states = key_states.float()
|
||||
fmap_key_states = fmap_key_states.float()
|
||||
value_states = value_states.float()
|
||||
|
||||
# Decoding KV state (KV terms up to last window_size)
|
||||
decode_kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, : -self.window_size],
|
||||
value_states[:, :, : -self.window_size],
|
||||
)
|
||||
# KV state
|
||||
kv_state = decode_kv_state + torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, -self.window_size :],
|
||||
value_states[:, :, -self.window_size :],
|
||||
)
|
||||
# shape is b, h, 1, f; note the 1
|
||||
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||
self.kv_states.append(kv_state.to(dtype))
|
||||
self.k_states.append(k_state.to(dtype))
|
||||
|
||||
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||
|
||||
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||
self.v_cache.append(
|
||||
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||
)
|
||||
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||
else:
|
||||
# Update kv and k states recurrently
|
||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||
dtype
|
||||
)
|
||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||
dtype
|
||||
)
|
||||
self.kv_states[layer_idx] = kv_state
|
||||
self.k_states[layer_idx] = k_state
|
||||
|
||||
decode_kv_state = (
|
||||
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||
+ decode_kv_state
|
||||
).to(dtype)
|
||||
decode_k_state = (
|
||||
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||
).to(dtype)
|
||||
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||
self.decode_k_states[layer_idx] = decode_k_state
|
||||
|
||||
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||
|
||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||
|
||||
def update_for_decoding(
|
||||
self,
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
layer_idx: int,
|
||||
feature_map_k: Callable,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Update the decoding KV and K states, and KV cache, during decodeing
|
||||
"""
|
||||
with torch.no_grad():
|
||||
k_cache = self.k_cache[layer_idx]
|
||||
v_cache = self.v_cache[layer_idx]
|
||||
|
||||
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||
else:
|
||||
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
||||
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
||||
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
||||
# else:
|
||||
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
v_state = v_cache[:, :, :1, :]
|
||||
kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||
).to(
|
||||
dtype
|
||||
) # b, h, f, d
|
||||
self.decode_kv_states[layer_idx] += kv_state
|
||||
self.decode_k_states[layer_idx] += k_state
|
||||
|
||||
self.k_cache[layer_idx] = torch.cat(
|
||||
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||
)
|
||||
self.v_cache[layer_idx] = torch.cat(
|
||||
[v_cache[:, :, 1:, :], values], dim=-2
|
||||
)
|
||||
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += keys.shape[-2]
|
||||
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||
return (
|
||||
self.k_cache[layer_idx],
|
||||
self.v_cache[layer_idx],
|
||||
self.decode_kv_states[layer_idx],
|
||||
self.decode_k_states[layer_idx],
|
||||
)
|
||||
@@ -0,0 +1,685 @@
|
||||
"""
|
||||
Subquadratic attention combining sliding window and linear attentions
|
||||
- Using "standard" sliding windows
|
||||
- Didactically computes outputs with n^2 attention weights for now
|
||||
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
except ModuleNotFoundError:
|
||||
_flash_attention_forward = None # Transformers v4.36
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||
from .linear_attention import (
|
||||
LinearAttentionState,
|
||||
LolcatsLinearAttention,
|
||||
causal_dot_product,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ----------------------
|
||||
# Sliding window helpers
|
||||
# ----------------------
|
||||
def get_masks(
|
||||
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return masks for softmax and linear attention terms
|
||||
-> 1 is include, 0 is ignore
|
||||
"""
|
||||
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||
max(k_len - q_len, 0)
|
||||
)
|
||||
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||
max(k_len - q_len, 0) - window_size
|
||||
)
|
||||
window_mask = causal_mask - linear_mask
|
||||
# Return softmax mask (window), linear attention mask
|
||||
# -> shapes broadcast over (b, h, q_len, k_len)
|
||||
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
||||
|
||||
|
||||
def hybrid_attention_quadratic(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_factor: torch.Tensor,
|
||||
linear_factor: torch.Tensor,
|
||||
window_size: int,
|
||||
kv_state: Optional[torch.Tensor] = None,
|
||||
k_state: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-12,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Hybrid attention combining sliding window and linear attentions
|
||||
"""
|
||||
|
||||
mask_window, mask_linear = get_masks(
|
||||
window_size, q.shape[-2], k.shape[-2], q.device
|
||||
)
|
||||
|
||||
# 1. Sliding window (softmax attention)
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 2. Under window (linear attention)
|
||||
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 3. Combine
|
||||
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||
# Allow outputs to also depend on prior kv_state and k_state
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||
if (
|
||||
kv_state is not None and k_state is not None
|
||||
): # Combine with prior kv_state and k_state
|
||||
y += linear_factor * torch.einsum(
|
||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||
)
|
||||
sum_ln += (
|
||||
linear_factor
|
||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||
)
|
||||
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||
return y, a # attention weights only for the last chunk
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Hybrid window attention linear
|
||||
# ------------------------------
|
||||
def under_window_linear_attention(
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_size: int,
|
||||
linear_factor: torch.Tensor,
|
||||
eps: float = 1e-12,
|
||||
):
|
||||
"""Compute hybrid window attention dot product with linear complexity in q_len"""
|
||||
dtype = f_q.dtype
|
||||
w = window_size
|
||||
f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
||||
v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
||||
qkv = linear_factor * causal_dot_product(
|
||||
f_q.contiguous().to(dtype=torch.float32),
|
||||
f_k.contiguous().to(dtype=torch.float32),
|
||||
v.contiguous().to(dtype=torch.float32),
|
||||
).to(dtype=dtype)
|
||||
sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype)
|
||||
sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None]
|
||||
sum_qk[sum_qk == 0] += eps
|
||||
return qkv, sum_qk
|
||||
|
||||
|
||||
def sliding_window_softmax_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_size: int,
|
||||
window_factor: torch.Tensor,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Compute sliding window softmax attention without materializing
|
||||
O(seq_len^2) attention weights
|
||||
"""
|
||||
d = q.shape[-1]
|
||||
# Compute windows for keys
|
||||
window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
||||
v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
||||
|
||||
# Compute windowed_softmax(qk); causal in its construction
|
||||
a_sm = torch.einsum("bhld,bhldw->bhlw", q, k) * (d**-0.5)
|
||||
a_sm[a_sm == 0] = -torch.finfo(
|
||||
q.dtype
|
||||
).max # heuristic for zeroing out padding above
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
return torch.einsum("bhlw,bhldw->bhld", a_sm, v), sum_sm
|
||||
# return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v)
|
||||
|
||||
|
||||
def hybrid_attention_linear(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_factor: Optional[torch.Tensor] = None,
|
||||
linear_factor: Optional[torch.Tensor] = None,
|
||||
window_size: int = 64,
|
||||
kv_state: Optional[torch.Tensor] = None,
|
||||
k_state: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-12,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Alternative hybrid attention combining sliding window and linear attentions
|
||||
-> Uses O(n) memory if n is sequence length by padding and unfolding windows
|
||||
"""
|
||||
# window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
if window_factor is None:
|
||||
raise ValueError("window_factor must be provided")
|
||||
|
||||
if linear_factor is None:
|
||||
raise ValueError("linear_factor must be provided")
|
||||
|
||||
# 1. Sliding window (softmax attention)
|
||||
with torch.no_grad():
|
||||
qkv_sm, sum_qk_sm = sliding_window_softmax_attention(
|
||||
q, k, v, window_size, window_factor, mask_value
|
||||
)
|
||||
|
||||
# 2. Under window (linear attention)
|
||||
qkv_ln, sum_qk_ln = under_window_linear_attention(
|
||||
f_q, f_k, v, window_size, linear_factor, eps
|
||||
)
|
||||
|
||||
# 3. Combine
|
||||
y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln)
|
||||
return y, None
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Attention layer class
|
||||
# ---------------------
|
||||
class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 64,
|
||||
decode_window_size: Optional[int] = None,
|
||||
affine_attention_factors: bool = False,
|
||||
init_window_factor: float = 0,
|
||||
train_window_factor: bool = True,
|
||||
state_grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.window_size = window_size
|
||||
self.decode_window_size = (
|
||||
decode_window_size if decode_window_size is not None else window_size
|
||||
)
|
||||
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
super().__init__(**kwargs)
|
||||
# Determine how we compute attentions
|
||||
self.linear_attention = hybrid_attention_linear
|
||||
self.attention_type = "lolcats_llama_window_sw"
|
||||
# Learnable factor for combining attentions
|
||||
self.affine_attention_factors = affine_attention_factors
|
||||
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||
if train_window_factor:
|
||||
self.window_factors = nn.Parameter(
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"window_factors",
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||
self.base_inference = False
|
||||
self.state_grad_enabled = state_grad_enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
|
||||
if self.train_attention and self.base_inference:
|
||||
with torch.no_grad():
|
||||
_y_true = flash_attention_2(
|
||||
self, # self.base_attn,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
)[0]
|
||||
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
||||
y_true = _y_true.reshape(b, l, -1).contiguous()
|
||||
y_true = self.o_proj(y_true)
|
||||
# layer_io = (hidden_states, _y_true) # hack
|
||||
layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
||||
return y_true, layer_io, None
|
||||
|
||||
else:
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||
k
|
||||
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||
|
||||
attn_weights = None
|
||||
# attention_mask = None # For now this is always True
|
||||
if past_key_value is None: # Regular training
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, a_pred = self.linear_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = a_pred
|
||||
else:
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
if (
|
||||
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||
): # Generating
|
||||
assert use_cache is True
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||
)
|
||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
|
||||
# Softmax attention terms
|
||||
a_sm = torch.einsum(
|
||||
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||
) * (k.shape[-1] ** -0.5)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Combine with linear attention terms
|
||||
y_true = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum(
|
||||
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
|
||||
)[..., None]
|
||||
)
|
||||
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Stateful training
|
||||
try:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
except IndexError:
|
||||
kv_state, k_state = None, None
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, _ = self.linear_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
kv_state=kv_state,
|
||||
k_state=k_state,
|
||||
)
|
||||
# Save and update KV cache and states
|
||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||
# fmap_key_states=f_k.detach(),
|
||||
# accumulate_in_fp32=True)
|
||||
past_key_value.update(
|
||||
k,
|
||||
v,
|
||||
self.layer_idx,
|
||||
fmap_key_states=f_k,
|
||||
accumulate_in_fp32=True,
|
||||
)
|
||||
# Concatenate heads and apply output projection
|
||||
_y_true = y_true.transpose(1, 2).contiguous()
|
||||
y_true = self.o_proj(_y_true.view(b, l, self.hidden_size))
|
||||
|
||||
if self.train_attention:
|
||||
attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d)
|
||||
return y_true, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
||||
"""
|
||||
Class for `past_key_values`
|
||||
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int = 64) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.kv_states: List[torch.Tensor] = []
|
||||
self.k_states: List[torch.Tensor] = []
|
||||
|
||||
# Account for sliding windows
|
||||
self.decode_kv_states: List[torch.Tensor] = []
|
||||
self.decode_k_states: List[torch.Tensor] = []
|
||||
self.k_cache: List[torch.Tensor] = []
|
||||
self.v_cache: List[torch.Tensor] = []
|
||||
self.window_size = window_size
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_kwargs: Optional[Any] = None,
|
||||
accumulate_in_fp32: bool = False,
|
||||
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||
grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Update KV, K states; and KV cache during training
|
||||
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||
up to sliding window terms
|
||||
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||
up to end of sequence
|
||||
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||
"""
|
||||
if fmap_key_states is None:
|
||||
raise ValueError("fmap_key_states must not be None")
|
||||
|
||||
if layer_idx is None:
|
||||
raise ValueError("Layer index must not be None")
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
dtype = key_states.dtype
|
||||
if accumulate_in_fp32:
|
||||
# key_states = key_states.float()
|
||||
fmap_key_states = fmap_key_states.float()
|
||||
value_states = value_states.float()
|
||||
|
||||
# Decoding KV state (KV terms up to last window_size)
|
||||
decode_kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, : -self.window_size],
|
||||
value_states[:, :, : -self.window_size],
|
||||
)
|
||||
# KV state
|
||||
kv_state = decode_kv_state + torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, -self.window_size :],
|
||||
value_states[:, :, -self.window_size :],
|
||||
)
|
||||
# shape is b, h, 1, f; note the 1
|
||||
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||
self.kv_states.append(kv_state.to(dtype))
|
||||
self.k_states.append(k_state.to(dtype))
|
||||
|
||||
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||
|
||||
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||
self.v_cache.append(
|
||||
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||
)
|
||||
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||
else:
|
||||
# Update kv and k states recurrently
|
||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||
dtype
|
||||
)
|
||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||
dtype
|
||||
)
|
||||
self.kv_states[layer_idx] = kv_state
|
||||
self.k_states[layer_idx] = k_state
|
||||
|
||||
decode_kv_state = (
|
||||
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||
+ decode_kv_state
|
||||
).to(dtype)
|
||||
decode_k_state = (
|
||||
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||
).to(dtype)
|
||||
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||
self.decode_k_states[layer_idx] = decode_k_state
|
||||
|
||||
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||
|
||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||
|
||||
def update_for_decoding(
|
||||
self,
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
layer_idx: int,
|
||||
feature_map_k: Callable,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Update the decoding KV and K states, and KV cache, during decodeing
|
||||
"""
|
||||
with torch.no_grad():
|
||||
k_cache = self.k_cache[layer_idx]
|
||||
v_cache = self.v_cache[layer_idx]
|
||||
|
||||
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||
else:
|
||||
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
||||
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
||||
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
||||
# else:
|
||||
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
v_state = v_cache[:, :, :1, :]
|
||||
kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||
).to(
|
||||
dtype
|
||||
) # b, h, f, d
|
||||
self.decode_kv_states[layer_idx] += kv_state
|
||||
self.decode_k_states[layer_idx] += k_state
|
||||
|
||||
self.k_cache[layer_idx] = torch.cat(
|
||||
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||
)
|
||||
self.v_cache[layer_idx] = torch.cat(
|
||||
[v_cache[:, :, 1:, :], values], dim=-2
|
||||
)
|
||||
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += keys.shape[-2]
|
||||
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||
return (
|
||||
self.k_cache[layer_idx],
|
||||
self.v_cache[layer_idx],
|
||||
self.decode_kv_states[layer_idx],
|
||||
self.decode_k_states[layer_idx],
|
||||
)
|
||||
|
||||
|
||||
# -----------------
|
||||
# Flash Attention 2
|
||||
# -----------------
|
||||
|
||||
|
||||
def flash_attention_2(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
"""
|
||||
Wrapper for LlamaFlashAttention2
|
||||
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
||||
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
||||
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
||||
"""
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
try: # As in Transformers v4.36
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
except Exception: # As in Transformers v4.39
|
||||
cos, sin = self.rotary_emb(key_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
LOG.debug(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
if getattr(self, "_flash_attention_forward", False):
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=0, # dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
return attn_output, past_key_value
|
||||
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
LoLCATs attention combining sliding window and linear attentions
|
||||
- Using standard sliding window arrangement
|
||||
- Training over long sequences with fixed memory with recurrent view
|
||||
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
from .linear_window_attention_sw import hybrid_attention_quadratic
|
||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||
|
||||
|
||||
class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(self, remove_base_attn=True, **kwargs):
|
||||
# keep self.base_attn for Flash Attention inference
|
||||
super().__init__(remove_base_attn=True, **kwargs)
|
||||
self.quadratic_attention = hybrid_attention_quadratic
|
||||
@@ -0,0 +1,466 @@
|
||||
"""
|
||||
Subquadratic attention combining sliding window and linear attentions
|
||||
- Using the TK "terracing" arrangement
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from .linear_attention import (
|
||||
LinearAttentionState,
|
||||
LolcatsLinearAttention,
|
||||
softmax_attention,
|
||||
)
|
||||
|
||||
|
||||
# ----------------------
|
||||
# Sliding window helpers
|
||||
# ----------------------
|
||||
def get_masks(
|
||||
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return masks for softmax and linear attention terms
|
||||
-> 1 is include, 0 is ignore
|
||||
"""
|
||||
win_len = window_size
|
||||
m = math.ceil(max(q_len, k_len) / window_size)
|
||||
# Creates an n x n mask where n = window_size^2
|
||||
mask = torch.block_diag(
|
||||
*[
|
||||
torch.ones(
|
||||
(win_len, win_len),
|
||||
)
|
||||
]
|
||||
* m
|
||||
)
|
||||
mask += torch.roll(mask, -win_len, -1) # this adds the terracing
|
||||
if mask.shape[0] > q_len:
|
||||
mask = mask[-q_len:]
|
||||
if mask.shape[1] > k_len:
|
||||
mask = mask[:, -k_len:]
|
||||
# Return softmax mask (window), linear attention mask
|
||||
mask = mask[None, None, ...] # b, h, q_len, k_len
|
||||
return (
|
||||
torch.tril(mask).to(device=device, dtype=torch.int),
|
||||
torch.tril(1 - mask).to(device=device, dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def hybrid_attention_quadratic(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_factor: torch.Tensor,
|
||||
linear_factor: torch.Tensor,
|
||||
window_size: int,
|
||||
kv_state: Optional[torch.Tensor] = None,
|
||||
k_state: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-12,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Hybrid attention combining sliding window and linear attentions
|
||||
"""
|
||||
|
||||
mask_window, mask_linear = get_masks(
|
||||
window_size, q.shape[-2], k.shape[-2], q.device
|
||||
)
|
||||
|
||||
# 1. Sliding window (softmax attention)
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 2. Under window (linear attention)
|
||||
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 3. Combine
|
||||
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||
# Allow outputs to also depend on prior kv_state and k_state
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||
if (
|
||||
kv_state is not None and k_state is not None
|
||||
): # Combine with prior kv_state and k_state
|
||||
y += linear_factor * torch.einsum(
|
||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||
)
|
||||
sum_ln += (
|
||||
linear_factor
|
||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||
)
|
||||
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||
return y, a # attention weights only for the last chunk
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Attention layer class
|
||||
# ---------------------
|
||||
class LolcatsTKWindowAttention(LolcatsLinearAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 64,
|
||||
decode_window_size: Optional[int] = None,
|
||||
affine_attention_factors: bool = False,
|
||||
init_window_factor: float = 0,
|
||||
train_window_factor: bool = True,
|
||||
state_grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.window_size = window_size
|
||||
self.decode_window_size = (
|
||||
decode_window_size if decode_window_size is not None else window_size
|
||||
)
|
||||
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
super().__init__(**kwargs)
|
||||
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_tk'
|
||||
# Determine how we compute attentions
|
||||
self.quadratic_attention = hybrid_attention_quadratic
|
||||
self.attention_type = kwargs[
|
||||
"attention_type"
|
||||
] # 'hedgehog_long_llama_window_tk'
|
||||
# Learnable factor for combining attentions
|
||||
self.affine_attention_factors = affine_attention_factors
|
||||
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||
if train_window_factor:
|
||||
self.window_factors = nn.Parameter(
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"window_factors",
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||
self.base_inference = False
|
||||
self.state_grad_enabled = state_grad_enabled
|
||||
self.window_factor = self.window_factors # legacy naming support
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||
k
|
||||
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||
|
||||
if self.train_attention:
|
||||
# 1. Compute "ground-truth" attention output and weights
|
||||
with torch.no_grad():
|
||||
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
||||
y_true = (
|
||||
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
)
|
||||
y_true = self.o_proj(y_true)
|
||||
|
||||
# 2. Compute "predicted" attention outputs
|
||||
# compute attn weights under sliding window
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||
y_pred, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
||||
else:
|
||||
attn_weights = None
|
||||
# attention_mask = None # For now this is always True
|
||||
if past_key_value is None: # Regular training
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = a_pred
|
||||
else:
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
if (
|
||||
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||
): # Generating
|
||||
assert use_cache is True
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||
)
|
||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
|
||||
# Softmax attention terms
|
||||
a_sm = torch.einsum(
|
||||
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||
) * (k.shape[-1] ** -0.5)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Combine with linear attention terms
|
||||
y_true = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum(
|
||||
"bhld,bhnd->bhl", f_q.float(), f_k_state.float()
|
||||
)[..., None]
|
||||
)
|
||||
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Stateful training
|
||||
try:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
except IndexError:
|
||||
kv_state, k_state = None, None
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, _ = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
kv_state=kv_state,
|
||||
k_state=k_state,
|
||||
)
|
||||
# Save and update KV cache and states
|
||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||
# fmap_key_states=f_k.detach(),
|
||||
# accumulate_in_fp32=True)
|
||||
past_key_value.update(
|
||||
k,
|
||||
v,
|
||||
self.layer_idx,
|
||||
fmap_key_states=f_k,
|
||||
accumulate_in_fp32=True,
|
||||
)
|
||||
# Concatenate heads and apply output projection
|
||||
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
y_true = self.o_proj(y_true)
|
||||
return y_true, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LinearAttentionTKWindowCache(LinearAttentionState):
|
||||
"""
|
||||
Class for `past_key_values`
|
||||
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int = 64) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.kv_states: List[torch.Tensor] = []
|
||||
self.k_states: List[torch.Tensor] = []
|
||||
|
||||
# Account for sliding windows
|
||||
self.decode_kv_states: List[torch.Tensor] = []
|
||||
self.decode_k_states: List[torch.Tensor] = []
|
||||
self.k_cache: List[torch.Tensor] = []
|
||||
self.v_cache: List[torch.Tensor] = []
|
||||
self.window_size = window_size
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_kwargs: Optional[Any] = None,
|
||||
accumulate_in_fp32: bool = False,
|
||||
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||
grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Update KV, K states; and KV cache during training
|
||||
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||
up to sliding window terms
|
||||
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||
up to end of sequence
|
||||
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||
"""
|
||||
if fmap_key_states is None:
|
||||
raise ValueError("fmap_key_states should not be None")
|
||||
|
||||
if layer_idx is None:
|
||||
raise ValueError("layer_idx should not be None")
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
dtype = key_states.dtype
|
||||
if accumulate_in_fp32:
|
||||
# key_states = key_states.float()
|
||||
fmap_key_states = fmap_key_states.float()
|
||||
value_states = value_states.float()
|
||||
|
||||
# Decoding KV state (KV terms up to last window_size)
|
||||
decode_kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, : -self.window_size],
|
||||
value_states[:, :, : -self.window_size],
|
||||
)
|
||||
# KV state
|
||||
kv_state = decode_kv_state + torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, -self.window_size :],
|
||||
value_states[:, :, -self.window_size :],
|
||||
)
|
||||
# shape is b, h, 1, f; note the 1
|
||||
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||
self.kv_states.append(kv_state.to(dtype))
|
||||
self.k_states.append(k_state.to(dtype))
|
||||
|
||||
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||
|
||||
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||
self.v_cache.append(
|
||||
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||
)
|
||||
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||
else:
|
||||
# Update kv and k states recurrently
|
||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||
dtype
|
||||
)
|
||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||
dtype
|
||||
)
|
||||
self.kv_states[layer_idx] = kv_state
|
||||
self.k_states[layer_idx] = k_state
|
||||
|
||||
decode_kv_state = (
|
||||
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||
+ decode_kv_state
|
||||
).to(dtype)
|
||||
decode_k_state = (
|
||||
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||
).to(dtype)
|
||||
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||
self.decode_k_states[layer_idx] = decode_k_state
|
||||
|
||||
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||
|
||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||
|
||||
def update_for_decoding(
|
||||
self,
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
layer_idx: int,
|
||||
feature_map_k: Callable,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Update the decoding KV and K states, and KV cache, during decodeing
|
||||
"""
|
||||
with torch.no_grad():
|
||||
k_cache = self.k_cache[layer_idx]
|
||||
v_cache = self.v_cache[layer_idx]
|
||||
|
||||
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||
else:
|
||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
v_state = v_cache[:, :, :1, :]
|
||||
kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||
).to(
|
||||
dtype
|
||||
) # b, h, f, d
|
||||
self.decode_kv_states[layer_idx] += kv_state
|
||||
self.decode_k_states[layer_idx] += k_state
|
||||
|
||||
self.k_cache[layer_idx] = torch.cat(
|
||||
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||
)
|
||||
self.v_cache[layer_idx] = torch.cat(
|
||||
[v_cache[:, :, 1:, :], values], dim=-2
|
||||
)
|
||||
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += keys.shape[-2]
|
||||
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||
return (
|
||||
self.k_cache[layer_idx],
|
||||
self.v_cache[layer_idx],
|
||||
self.decode_kv_states[layer_idx],
|
||||
self.decode_k_states[layer_idx],
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user