Compare commits
237 Commits
multi-gpu-
...
mixtral_op
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
450e04d3c4 | ||
|
|
b0cf397ecb | ||
|
|
5f79b8242f | ||
|
|
f1de29dd1e | ||
|
|
7fabc4d95e | ||
|
|
9a5eb3990c | ||
|
|
86487c2e96 | ||
|
|
35f9b0f149 | ||
|
|
68b227a7d8 | ||
|
|
03c6318ba3 | ||
|
|
40a6362c92 | ||
|
|
d339beb9d9 | ||
|
|
fde091cb12 | ||
|
|
06ae39200b | ||
|
|
a581e9f8f6 | ||
|
|
992e742cdc | ||
|
|
a1da39cd48 | ||
|
|
58ec8b1113 | ||
|
|
476a205cea | ||
|
|
3e3229e2d9 | ||
|
|
1d21aa6b0a | ||
|
|
71b7ea3c05 | ||
|
|
a48dbf6561 | ||
|
|
6a4562ac08 | ||
|
|
1115c501b8 | ||
|
|
7ee3c4cacb | ||
|
|
fb12895a17 | ||
|
|
9fc29e082b | ||
|
|
575a082aae | ||
|
|
ddf815022a | ||
|
|
9bf854e59c | ||
|
|
797f3dd1de | ||
|
|
0de1457189 | ||
|
|
3cc67d2cdd | ||
|
|
1bc11868eb | ||
|
|
b3a61e8ce2 | ||
|
|
8a8d1c4023 | ||
|
|
332984db18 | ||
|
|
48630f5b34 | ||
|
|
b33c1d55a2 | ||
|
|
0c2a630326 | ||
|
|
db8a8afcba | ||
|
|
14706504e3 | ||
|
|
501b4d1379 | ||
|
|
306fe19c54 | ||
|
|
614cff4107 | ||
|
|
1a6309c8a6 | ||
|
|
105d0b350b | ||
|
|
f544ab2bed | ||
|
|
641e6f7e51 | ||
|
|
6dc68a653f | ||
|
|
7de6a5639c | ||
|
|
c74f045ba7 | ||
|
|
0402d19759 | ||
|
|
b2430ce670 | ||
|
|
4c834bf25d | ||
|
|
8056ecd30e | ||
|
|
738a057674 | ||
|
|
cdc71f73c8 | ||
|
|
6459ac7357 | ||
|
|
964d858da0 | ||
|
|
10388a8daf | ||
|
|
9f7e8a971d | ||
|
|
637ed095a0 | ||
|
|
827ec3d274 | ||
|
|
8b79ff0e94 | ||
|
|
0800885e2f | ||
|
|
d3193beac3 | ||
|
|
2e71ff03a6 | ||
|
|
facc49f32b | ||
|
|
e50ab072e2 | ||
|
|
05bd6f1122 | ||
|
|
20aa4b57d2 | ||
|
|
11d1d607db | ||
|
|
6c81c61bc4 | ||
|
|
9b43e7ea15 | ||
|
|
2d8def68dc | ||
|
|
44c9d0151a | ||
|
|
ca84cca2c0 | ||
|
|
32eeeb5b64 | ||
|
|
afedc470bd | ||
|
|
9923b72649 | ||
|
|
21cf09b608 | ||
|
|
15d3a654bf | ||
|
|
a21935f07a | ||
|
|
8966a6f566 | ||
|
|
e4d1585c4e | ||
|
|
70157ccb8f | ||
|
|
3a99495b05 | ||
|
|
440c3ab527 | ||
|
|
992d57f20a | ||
|
|
91a016f410 | ||
|
|
a045db0214 | ||
|
|
e1b214c62b | ||
|
|
3553172e3c | ||
|
|
7f2027d93f | ||
|
|
8d288a2ad4 | ||
|
|
f30afe4544 | ||
|
|
bfbdba8614 | ||
|
|
3bd9528390 | ||
|
|
2aa1f71464 | ||
|
|
1c412c7e9d | ||
|
|
490923fb78 | ||
|
|
5855dded3d | ||
|
|
ace70b33c6 | ||
|
|
11c48c5e03 | ||
|
|
295b2662e1 | ||
|
|
77c84e02fd | ||
|
|
f91db198f3 | ||
|
|
7f2618b5f4 | ||
|
|
aca0398315 | ||
|
|
29b8f46aed | ||
|
|
83a950bb87 | ||
|
|
de87ea68f6 | ||
|
|
4c8ddf2c6f | ||
|
|
669f1d052c | ||
|
|
d4a88e4eca | ||
|
|
2d60ba3a6e | ||
|
|
eb480dfd68 | ||
|
|
133e676bcc | ||
|
|
69fac9a020 | ||
|
|
e0b7eeabfd | ||
|
|
43856c0a39 | ||
|
|
e62d5901b5 | ||
|
|
697c50d408 | ||
|
|
90e0d673f7 | ||
|
|
2642caedf2 | ||
|
|
f34648c8b9 | ||
|
|
e50a64e85e | ||
|
|
f4868d733c | ||
|
|
a7e56d83c2 | ||
|
|
5b0bc48fbc | ||
|
|
9ec20777ba | ||
|
|
590d6032fd | ||
|
|
409ca0f21c | ||
|
|
8662e8ffe8 | ||
|
|
b2edaaeff6 | ||
|
|
b88f51512a | ||
|
|
eb41f76f92 | ||
|
|
383f88d7a7 | ||
|
|
b6ab8aad62 | ||
|
|
85b0be2ba7 | ||
|
|
8fe0e633d2 | ||
|
|
d1236f2c41 | ||
|
|
895f0a0723 | ||
|
|
e7d3e2dbb6 | ||
|
|
60c7c48c97 | ||
|
|
e8cbf50be6 | ||
|
|
d887ad86c3 | ||
|
|
19a600a8b8 | ||
|
|
5e5296a77c | ||
|
|
f3d939016a | ||
|
|
cfbce020e9 | ||
|
|
4fecbfe5e1 | ||
|
|
67b9888630 | ||
|
|
923eb91304 | ||
|
|
a363604dcf | ||
|
|
501958bb6f | ||
|
|
c25ba7939b | ||
|
|
d5f8589021 | ||
|
|
03e59077a0 | ||
|
|
97d3776ce6 | ||
|
|
2844eb22b6 | ||
|
|
e85d2eb06b | ||
|
|
196ff1181e | ||
|
|
92512c390b | ||
|
|
2fe95cdcc1 | ||
|
|
c1382e79b6 | ||
|
|
5d931cc042 | ||
|
|
ec0958f4f8 | ||
|
|
faecff9798 | ||
|
|
aa656e04bd | ||
|
|
b53e77775b | ||
|
|
674c57692d | ||
|
|
1eebbd09c3 | ||
|
|
62a774140b | ||
|
|
31b9e0c6e8 | ||
|
|
6b9b229356 | ||
|
|
131afdbd89 | ||
|
|
00dce35fb2 | ||
|
|
b15b19eb8d | ||
|
|
ab534d75ba | ||
|
|
21ec195c9f | ||
|
|
62eaee7649 | ||
|
|
be75668400 | ||
|
|
aeec7c4688 | ||
|
|
360788296a | ||
|
|
12a2dbbc2c | ||
|
|
3a2edc85c3 | ||
|
|
f7a22632d7 | ||
|
|
1aa400721e | ||
|
|
8dcd40ac78 | ||
|
|
a5a625f47e | ||
|
|
861cecac2a | ||
|
|
1078d3eae7 | ||
|
|
24146733db | ||
|
|
9218ebecd2 | ||
|
|
228420972e | ||
|
|
c6d870b91d | ||
|
|
115795079d | ||
|
|
3b18c963cc | ||
|
|
3fbde762ab | ||
|
|
f6060a664e | ||
|
|
a4e1bb6606 | ||
|
|
36e53c7442 | ||
|
|
e7aa7b1a1e | ||
|
|
e5bb22a56b | ||
|
|
fdb777bc06 | ||
|
|
bf0804447c | ||
|
|
5b67ea98a6 | ||
|
|
2f586d18db | ||
|
|
9845c5e12d | ||
|
|
772cd870d4 | ||
|
|
6c5fbe6223 | ||
|
|
bcbc9597e9 | ||
|
|
6d57f2f0f0 | ||
|
|
20ed4c1f9e | ||
|
|
c5dedb17ad | ||
|
|
b56503d423 | ||
|
|
a94f9cb99e | ||
|
|
c1921c9acb | ||
|
|
0b4cf5bc8c | ||
|
|
78ee2cdab2 | ||
|
|
34c0a86a11 | ||
|
|
5e2d8a42d9 | ||
|
|
e30f1e3cf7 | ||
|
|
343714972b | ||
|
|
245c5c41e2 | ||
|
|
a546ca2813 | ||
|
|
3355706e22 | ||
|
|
daa4faca12 | ||
|
|
fc8766e502 | ||
|
|
72a6fe1c1f | ||
|
|
5fe30b1497 | ||
|
|
44454ae4c4 | ||
|
|
09f154397e | ||
|
|
995557bdf3 |
7
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
7
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
@@ -53,6 +53,13 @@ body:
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: config
|
||||
attributes:
|
||||
label: Config yaml
|
||||
description: |
|
||||
Please attach the config yaml!
|
||||
|
||||
- type: textarea
|
||||
id: possible-solution
|
||||
attributes:
|
||||
|
||||
5
.github/workflows/base.yml
vendored
5
.github/workflows/base.yml
vendored
@@ -25,6 +25,11 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
22
.github/workflows/main.yml
vendored
22
.github/workflows/main.yml
vendored
@@ -23,12 +23,13 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras: gptq
|
||||
runs-on: self-hosted
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
@@ -51,9 +52,12 @@ jobs:
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
build-axolotl-runpod:
|
||||
needs: build-axolotl
|
||||
@@ -75,10 +79,10 @@ jobs:
|
||||
is_latest: true
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras: gptq
|
||||
runs-on: self-hosted
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
16
.github/workflows/pre-commit.yml
vendored
16
.github/workflows/pre-commit.yml
vendored
@@ -1,16 +0,0 @@
|
||||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.9"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
45
.github/workflows/pypi.yml
vendored
Normal file
45
.github/workflows/pypi.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
name: publish pypi
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
pypi-publish:
|
||||
name: Upload release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/axolotl
|
||||
permissions:
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel
|
||||
pip3 install -e .
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Extract tag name
|
||||
id: tag
|
||||
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
|
||||
|
||||
- name: Update version in setup.py
|
||||
run: >-
|
||||
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
|
||||
|
||||
- name: Build a binary wheel
|
||||
run: >-
|
||||
python setup.py sdist bdist_wheel
|
||||
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
59
.github/workflows/tests.yml
vendored
59
.github/workflows/tests.yml
vendored
@@ -1,10 +1,32 @@
|
||||
name: PyTest
|
||||
name: Tests
|
||||
on:
|
||||
# check on push/merge to main, PRs, and manual triggers
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
pull_request:
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.9"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -24,9 +46,36 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -e .[peft]
|
||||
pip install -r requirements-tests.txt
|
||||
pip3 install -U -e .
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest tests/
|
||||
pytest --ignore=tests/e2e/ tests/
|
||||
|
||||
e2e-test:
|
||||
name: E2E Tests
|
||||
runs-on: [self-hosted, gpu]
|
||||
timeout-minutes: 20
|
||||
needs: [pre-commit, pytest]
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
# cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
||||
pip3 uninstall -y transformers accelerate
|
||||
pip3 install -U -e .[flash-attn,mamba-ssm]
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Run e2e tests
|
||||
run: |
|
||||
pytest tests/e2e/
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -161,3 +161,7 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# WandB
|
||||
# wandb creates a folder to store logs for training runs
|
||||
wandb
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
[settings]
|
||||
profile=black
|
||||
known_third_party=wandb
|
||||
|
||||
@@ -8,6 +8,12 @@ ignore_missing_imports = True
|
||||
[mypy-axolotl.monkeypatch.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-axolotl.models.mixtral.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-axolotl.models.phi.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-flash_attn.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@@ -20,6 +26,9 @@ ignore_missing_imports = True
|
||||
[mypy-peft]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-wandb]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-bitsandbytes]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
||||
598
README.md
598
README.md
@@ -2,6 +2,18 @@
|
||||
|
||||
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
|
||||
|
||||
Features:
|
||||
- Train various Huggingface models such as llama, pythia, falcon, mpt
|
||||
- Supports fullfinetune, lora, qlora, relora, and gptq
|
||||
- Customize configurations using a simple yaml file or CLI overwrite
|
||||
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
|
||||
- Integrated with xformer, flash attention, rope scaling, and multipacking
|
||||
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
||||
- Easily run with Docker locally or on the cloud
|
||||
- Log results and optionally checkpoints to wandb
|
||||
- And more!
|
||||
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
@@ -11,9 +23,12 @@ Axolotl is a tool designed to streamline the fine-tuning of various AI models, o
|
||||
- [Supported Features](#axolotl-supports)
|
||||
- [Quickstart](#quickstart-)
|
||||
- [Installation](#installation)
|
||||
- [Docker Installation](#environment)
|
||||
- [Conda/Pip venv Installation](#condapip-venv)
|
||||
- [LambdaLabs Installation](#lambdalabs)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [Runpod](#runpod)
|
||||
- [LambdaLabs](#lambdalabs)
|
||||
- [Windows](#windows)
|
||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||
- [Dataset](#dataset)
|
||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||
@@ -37,7 +52,7 @@ Axolotl is a tool designed to streamline the fine-tuning of various AI models, o
|
||||
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
|
||||
</p>
|
||||
<p>
|
||||
Go ahead and axolotl questions!!
|
||||
Go ahead and Axolotl questions!!
|
||||
</p>
|
||||
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
||||
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
||||
@@ -50,15 +65,21 @@ Axolotl is a tool designed to streamline the fine-tuning of various AI models, o
|
||||
|
||||
## Axolotl supports
|
||||
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|----------|:----------|:-----|-------|------|-------------------|------------|---------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
@@ -67,31 +88,39 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
||||
|
||||
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
||||
|
||||
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
|
||||
|
||||
### For developers
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install -e .[flash-attn]
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
### Usage
|
||||
```bash
|
||||
# finetune lora
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
|
||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||
|
||||
# inference
|
||||
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
||||
--inference --lora_model_dir="./lora-out"
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out"
|
||||
|
||||
# gradio
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out" --gradio
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
### Environment
|
||||
|
||||
- Docker
|
||||
#### Docker
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
- `winglian/axolotl-runpod:main-py3.10-cu118-2.0.1`: for runpod
|
||||
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.1-gptq`: for gptq
|
||||
|
||||
Or run on the current files for development:
|
||||
|
||||
@@ -99,27 +128,48 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
- Conda/Pip venv
|
||||
<details>
|
||||
|
||||
<summary>Docker advanced</summary>
|
||||
|
||||
A more powerful Docker command to run would be this:
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
|
||||
It additionally:
|
||||
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
||||
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
||||
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
||||
* The `--privileged` flag gives all capabilities to the container.
|
||||
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
|
||||
|
||||
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
||||
|
||||
</details>
|
||||
|
||||
#### Conda/Pip venv
|
||||
1. Install python >=**3.9**
|
||||
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
3. Install python dependencies with ONE of the following:
|
||||
- Recommended, supports QLoRA, NO gptq/int4 support
|
||||
3. Install Axolotl along with python dependencies
|
||||
```bash
|
||||
pip3 install -e .
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
- gptq/int4 support, NO QLoRA
|
||||
4. (Optional) Login to Huggingface to use gated models/datasets.
|
||||
```bash
|
||||
pip3 install -e .[gptq]
|
||||
```
|
||||
- same as above but not recommended
|
||||
```bash
|
||||
pip3 install -e .[gptq_triton]
|
||||
huggingface-cli login
|
||||
```
|
||||
Get the token at huggingface.co/settings/tokens
|
||||
|
||||
- LambdaLabs
|
||||
#### Runpod
|
||||
|
||||
Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
#### LambdaLabs
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
@@ -151,10 +201,10 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install -e . # change depend on needs
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install protobuf==3.20.3
|
||||
pip3 install -U --ignore-installed requests Pillow psutil scipy
|
||||
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
|
||||
```
|
||||
|
||||
5. Set path
|
||||
@@ -163,7 +213,30 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
||||
```
|
||||
</details>
|
||||
|
||||
- Windows: Please use WSL or Docker!
|
||||
#### Windows
|
||||
Please use WSL or Docker!
|
||||
|
||||
|
||||
#### Launching on public clouds via SkyPilot
|
||||
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
|
||||
```bash
|
||||
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
|
||||
sky check
|
||||
```
|
||||
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
|
||||
```
|
||||
git clone https://github.com/skypilot-org/skypilot.git
|
||||
cd skypilot/llm/axolotl
|
||||
```
|
||||
Use one command to launch:
|
||||
```bash
|
||||
# On-demand
|
||||
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
||||
|
||||
# Managed spot (auto-recovery on preemption)
|
||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
||||
```
|
||||
|
||||
|
||||
### Dataset
|
||||
|
||||
@@ -174,7 +247,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
|
||||
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
@@ -239,6 +312,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"article": "...", "question": "...", "answer": "..."}
|
||||
```
|
||||
- `context_qa.load_v2`: in context question answering (alternate)
|
||||
```json
|
||||
{"context": "...", "question": "...", "answer": "..."}
|
||||
```
|
||||
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
|
||||
```json
|
||||
{"article": "...", "unanswerable_question": "..."}
|
||||
@@ -263,11 +340,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"prompt": "...", "generation": "..."}
|
||||
```
|
||||
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
||||
- `sharegpt.load_role`: conversations where `role` is used instead of `from`
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
||||
- `sharegpt.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
@@ -280,29 +357,28 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
|
||||
#### How to add custom prompts
|
||||
|
||||
Using yaml. Example:
|
||||
For a dataset that is preprocessed for instruction purposes:
|
||||
|
||||
```json
|
||||
{"instruction": "...", "output": "..."}
|
||||
```
|
||||
|
||||
You can use this example in your YAML config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: repo
|
||||
type:
|
||||
system_prompt: ""
|
||||
no_input_format: |-
|
||||
User: {instruction}<|end_of_turn|>
|
||||
Assistant:
|
||||
format: |-
|
||||
User: {instruction}
|
||||
{input}<|end_of_turn|>
|
||||
Assistant:
|
||||
field_system: system
|
||||
format: "[INST] {instruction} [/INST]"
|
||||
no_input_format: "[INST] {instruction} [/INST]"
|
||||
```
|
||||
|
||||
Using file:
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
||||
|
||||
#### How to use your custom pretokenized dataset
|
||||
|
||||
- Do not pass a `type:`
|
||||
- Dataset must contain `input_ids`, `attention_mask`, `labels` in columns
|
||||
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
|
||||
|
||||
|
||||
### Config
|
||||
@@ -329,6 +405,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- path: EleutherAI/pile
|
||||
name: enron_emails
|
||||
type: completion # format from earlier
|
||||
field: text # Optional[str] default: text, field to use for completion data
|
||||
|
||||
# huggingface repo with multiple named configurations/subsets
|
||||
datasets:
|
||||
@@ -339,11 +416,30 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- typescript
|
||||
type: ... # unimplemented custom format
|
||||
|
||||
# fastchat conversation
|
||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
datasets:
|
||||
- path: ...
|
||||
type: sharegpt
|
||||
conversation: chatml
|
||||
|
||||
# local
|
||||
datasets:
|
||||
- path: data.jsonl # or json
|
||||
ds_type: json # see other options below
|
||||
type: alpaca
|
||||
|
||||
# dataset with splits, but no train split
|
||||
dataset:
|
||||
- path: knowrohit07/know_sql
|
||||
type: context_qa.load_v2
|
||||
train_on_split: validation
|
||||
|
||||
# loading from s3 or gcs
|
||||
# s3 creds will be loaded from the system default and gcs only supports public access
|
||||
dataset:
|
||||
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
||||
...
|
||||
```
|
||||
|
||||
- loading
|
||||
@@ -371,18 +467,18 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
|
||||
<details>
|
||||
|
||||
<summary>All yaml options</summary>
|
||||
<summary>All yaml options (click me)</summary>
|
||||
|
||||
```yaml
|
||||
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# this can also be a relative path to a model on disk
|
||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# This can also be a relative path to a model on disk
|
||||
base_model: ./llama-7b-hf
|
||||
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
base_model_ignore_patterns:
|
||||
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# you can set that here, or leave this empty to default to base_model
|
||||
# If the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# You can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# you can specify to choose a specific model revision from huggingface hub
|
||||
# You can specify to choose a specific model revision from huggingface hub
|
||||
model_revision:
|
||||
# Optional tokenizer configuration override in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
@@ -397,18 +493,33 @@ trust_remote_code:
|
||||
tokenizer_use_fast:
|
||||
# Whether to use the legacy tokenizer setting, defaults to True
|
||||
tokenizer_legacy:
|
||||
# resize the model embeddings when new tokens are added to multiples of 32
|
||||
# this is reported to improve training speed on some models
|
||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# This is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
|
||||
# whether you are training a 4-bit GPTQ quantized model
|
||||
# Used to identify which the model is based on
|
||||
is_falcon_derived_model:
|
||||
is_llama_derived_model:
|
||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||
is_mistral_derived_model:
|
||||
is_qwen_derived_model:
|
||||
|
||||
# optional overrides to the base model configuration
|
||||
model_config:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
gptq_groupsize: 128 # group size
|
||||
gptq_model_v1: false # v1 or v2
|
||||
|
||||
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
load_in_8bit: true
|
||||
# use bitsandbytes 4 bit
|
||||
# Use bitsandbytes 4 bit
|
||||
load_in_4bit:
|
||||
|
||||
# Use CUDA bf16
|
||||
@@ -422,28 +533,35 @@ tf32: true # require >=ampere
|
||||
bfloat16: true # require >=ampere
|
||||
float16: true
|
||||
|
||||
# a list of one or more datasets to finetune the model with
|
||||
# A list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# hf dataset repo | "json" for local dataset, make sure to fill data_files
|
||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file
|
||||
data_files: # path to source data files
|
||||
shards: # number of shards to split data into
|
||||
name: # name of dataset configuration to load
|
||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||
data_files: # Optional[str] path to source data files
|
||||
shards: # Optional[int] number of shards to split data into
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
|
||||
# custom user prompt
|
||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
field_human: # Optional[str]. Human key to use for conversation.
|
||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||
|
||||
# Custom user prompt
|
||||
- path: repo
|
||||
type:
|
||||
# the below are defaults. only set what's needed.
|
||||
# The below are defaults. only set what's needed.
|
||||
system_prompt: ""
|
||||
system_format: "{system}"
|
||||
field_system: system
|
||||
field_instruction: instruction
|
||||
field_output: input
|
||||
field_input: input
|
||||
field_output: output
|
||||
|
||||
# customizable to be single line or multi-line
|
||||
system_format: "{system}"
|
||||
# Customizable to be single line or multi-line
|
||||
# 'format' can include {input}
|
||||
format: |-
|
||||
User: {instruction} {input}
|
||||
@@ -451,18 +569,24 @@ datasets:
|
||||
# 'no_input_format' cannot include {input}
|
||||
no_input_format: "{instruction} "
|
||||
|
||||
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
field:
|
||||
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# push prepared dataset to hub
|
||||
# Push prepared dataset to hub
|
||||
push_dataset_to_hub: # repo path
|
||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# if not set.
|
||||
dataset_processes: # defaults to os.cpu_count() if not set
|
||||
# push checkpoints to hub
|
||||
hub_model_id: # repo path to push finetuned model
|
||||
# how to push checkpoints to hub
|
||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
||||
hub_strategy:
|
||||
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# required to be true when used in combination with `push_dataset_to_hub`
|
||||
# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# Required to be true when used in combination with `push_dataset_to_hub`
|
||||
hf_use_auth_token: # boolean
|
||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
|
||||
val_set_size: 0.04
|
||||
@@ -471,28 +595,40 @@ dataset_shard_num:
|
||||
# Index of shard to use for whole dataset
|
||||
dataset_shard_idx:
|
||||
|
||||
# the maximum length of an input to train with, this should typically be less than 2048
|
||||
# The maximum length of an input to train with, this should typically be less than 2048
|
||||
# as most models have a token/context limit of 2048
|
||||
sequence_len: 2048
|
||||
# pad inputs so each step uses constant sized buffers
|
||||
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
# Pad inputs so each step uses constant sized buffers
|
||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
pad_to_sequence_len:
|
||||
# max sequence length to concatenate training samples together up to
|
||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# Max sequence length to concatenate training samples together up to
|
||||
# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# FutureWarning: This will soon be DEPRECATED
|
||||
max_packed_sequence_len: 1024
|
||||
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
sample_packing:
|
||||
# you can set these packing optimizations AFTER starting a training at least once.
|
||||
# Set to 'false' if getting errors during eval with sample_packing on.
|
||||
eval_sample_packing:
|
||||
# You can set these packing optimizations AFTER starting a training at least once.
|
||||
# The trainer will provide recommended values for these values.
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
|
||||
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
# Passed through to transformers when loading the model when launched without accelerate
|
||||
# Use `sequential` when training w/ model parallelism to limit memory
|
||||
device_map:
|
||||
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
||||
max_memory:
|
||||
|
||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# if you already have a lora model trained that you want to load, put that here
|
||||
# lora hyperparameters
|
||||
# If you already have a lora model trained that you want to load, put that here.
|
||||
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
||||
lora_model_dir:
|
||||
|
||||
# LoRA hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
@@ -504,76 +640,129 @@ lora_target_modules:
|
||||
# - gate_proj
|
||||
# - down_proj
|
||||
# - up_proj
|
||||
lora_target_linear: # if true, will target all linear layers
|
||||
lora_target_linear: # If true, will target all linear layers
|
||||
|
||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
||||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
||||
lora_modules_to_save:
|
||||
# - embed_tokens
|
||||
# - lm_head
|
||||
|
||||
# Once you complete training, the model will be saved to the following directory.
|
||||
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
||||
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
||||
lora_out_dir:
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# ReLoRA configuration
|
||||
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # number of steps per ReLoRA restart
|
||||
relora_warmup_steps: # number of per-restart warmup steps
|
||||
relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # Number of steps per ReLoRA restart
|
||||
relora_warmup_steps: # Number of per-restart warmup steps
|
||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
|
||||
# wandb configuration if you're using it
|
||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
wandb_project: # your wandb project name
|
||||
wandb_entity: # a wandb Team name if using a Team
|
||||
wandb_project: # Your wandb project name
|
||||
wandb_entity: # A wandb Team name if using a Team
|
||||
wandb_watch:
|
||||
wandb_run_id: # set the name of your wandb run
|
||||
wandb_name: # Set the name of your wandb run
|
||||
wandb_run_id: # Set the ID of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# where to save the finished model to
|
||||
# Where to save the full-finetuned model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
# training hyperparameters
|
||||
# Whether to use torch.compile and which backend to use
|
||||
torch_compile: # bool
|
||||
torch_compile_backend: # Optional[str]
|
||||
|
||||
# Training hyperparameters
|
||||
|
||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||
gradient_accumulation_steps: 1
|
||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||
micro_batch_size: 2
|
||||
eval_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
eval_batch_size:
|
||||
num_epochs: 4
|
||||
warmup_steps: 100 # cannot use with warmup_ratio
|
||||
warmup_ratio: 0.05 # cannot use with warmup_steps
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
save_strategy: # set to `no` to skip checkpoint saves
|
||||
save_steps: # leave empty to save at each epoch
|
||||
eval_steps: # leave empty to eval at each epoch
|
||||
save_total_limit: # checkpoints saved at a time
|
||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
||||
save_strategy: # Set to `no` to skip checkpoint saves
|
||||
save_steps: # Leave empty to save at each epoch
|
||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
max_steps:
|
||||
|
||||
# save model as safetensors (require safetensors package)
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
|
||||
# Save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
# Whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# group similarly sized data to minimize padding
|
||||
# may be slower to start, as it must download and sort the entire dataset
|
||||
# note that training loss may have an oscillating pattern with this enabled
|
||||
# Group similarly sized data to minimize padding.
|
||||
# May be slower to start, as it must download and sort the entire dataset.
|
||||
# Note that training loss may have an oscillating pattern with this enabled.
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# Stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
|
||||
# specify a scheduler and kwargs to use with the optimizer
|
||||
# Specify a scheduler and kwargs to use with the optimizer
|
||||
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
||||
lr_scheduler_kwargs:
|
||||
|
||||
# for one_cycle optim
|
||||
lr_div_factor: # learning rate div factor
|
||||
# For one_cycle optim
|
||||
lr_div_factor: # Learning rate div factor
|
||||
|
||||
# for log_sweep optim
|
||||
# For log_sweep optim
|
||||
log_sweep_min_lr:
|
||||
log_sweep_max_lr:
|
||||
|
||||
# specify optimizer
|
||||
# Specify optimizer
|
||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||
#
|
||||
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||
# in the examples/ for your model and fine-tuning use case.
|
||||
#
|
||||
# Valid values for 'optimizer' include:
|
||||
# - adamw_hf
|
||||
# - adamw_torch
|
||||
# - adamw_torch_fused
|
||||
# - adamw_torch_xla
|
||||
# - adamw_apex_fused
|
||||
# - adafactor
|
||||
# - adamw_anyprecision
|
||||
# - sgd
|
||||
# - adagrad
|
||||
# - adamw_bnb_8bit
|
||||
# - lion_8bit
|
||||
# - lion_32bit
|
||||
# - paged_adamw_32bit
|
||||
# - paged_adamw_8bit
|
||||
# - paged_lion_32bit
|
||||
# - paged_lion_8bit
|
||||
optimizer:
|
||||
# specify weight decay
|
||||
# Specify weight decay
|
||||
weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
@@ -582,47 +771,54 @@ adam_epsilon:
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
# whether to bettertransformers
|
||||
# Augmentation techniques
|
||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||
# currently only supported on Llama and Mistral
|
||||
noisy_embedding_alpha:
|
||||
|
||||
# Whether to bettertransformers
|
||||
flash_optimum:
|
||||
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
flash_attention:
|
||||
# whether to use scaled-dot-product attention
|
||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
# Whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Landmark attention (only llama)
|
||||
landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# llama only
|
||||
# LLaMA only
|
||||
xpos_rope:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# resume from a specific checkpoint dir
|
||||
# Resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# if resume_from_checkpoint isn't set and you simply want it to start where it left off
|
||||
# be careful with this being turned on between different models
|
||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# Be careful with this being turned on between different models.
|
||||
auto_resume_from_checkpoints: false
|
||||
|
||||
# don't mess with this, it's here for accelerate and torchrun
|
||||
# Don't mess with this, it's here for accelerate and torchrun
|
||||
local_rank:
|
||||
|
||||
# add or change special tokens
|
||||
# Add or change special tokens.
|
||||
# If you add tokens here, you don't need to add them to the `tokens` list.
|
||||
special_tokens:
|
||||
# bos_token: "<s>"
|
||||
# eos_token: "</s>"
|
||||
# unk_token: "<unk>"
|
||||
# add extra tokens
|
||||
|
||||
# Add extra tokens.
|
||||
tokens:
|
||||
|
||||
# FSDP
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
# Deepspeed config path
|
||||
# Deepspeed config path. e.g., deepspeed/zero3.json
|
||||
deepspeed:
|
||||
|
||||
# Advanced DDP Arguments
|
||||
@@ -648,21 +844,108 @@ strict:
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary> Understanding of batch size and gradient accumulation steps </summary>
|
||||
<br/>
|
||||
Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning.
|
||||
|
||||
This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why:
|
||||
|
||||
1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.
|
||||
|
||||
2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.
|
||||
|
||||
**Example 1:**
|
||||
Micro batch size: 3
|
||||
Gradient accumulation steps: 2
|
||||
Number of GPUs: 3
|
||||
Total batch size = 3 * 2 * 3 = 18
|
||||
|
||||
```
|
||||
| GPU 1 | GPU 2 | GPU 3 |
|
||||
|----------------|----------------|----------------|
|
||||
| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 |
|
||||
| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 |
|
||||
|----------------|----------------|----------------|
|
||||
| → (accumulate) | → (accumulate) | → (accumulate) |
|
||||
|----------------|----------------|----------------|
|
||||
| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 |
|
||||
| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 |
|
||||
|----------------|----------------|----------------|
|
||||
| → (apply) | → (apply) | → (apply) |
|
||||
|
||||
Accumulated gradient for the weight w1 after the second iteration (considering all GPUs):
|
||||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18
|
||||
|
||||
Weight update for w1:
|
||||
w1_new = w1_old - learning rate x (Total gradient for w1 / 18)
|
||||
```
|
||||
|
||||
**Example 2:**
|
||||
Micro batch size: 2
|
||||
Gradient accumulation steps: 1
|
||||
Number of GPUs: 3
|
||||
Total batch size = 2 * 1 * 3 = 6
|
||||
|
||||
```
|
||||
| GPU 1 | GPU 2 | GPU 3 |
|
||||
|-----------|-----------|-----------|
|
||||
| S1, S2 | S3, S4 | S5, S6 |
|
||||
| e1, e2 | e3, e4 | e5, e6 |
|
||||
|-----------|-----------|-----------|
|
||||
| → (apply) | → (apply) | → (apply) |
|
||||
|
||||
Accumulated gradient for the weight w1 (considering all GPUs):
|
||||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6
|
||||
|
||||
Weight update for w1:
|
||||
w1_new = w1_old - learning rate × (Total gradient for w1 / 6)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Train
|
||||
|
||||
Run
|
||||
```bash
|
||||
accelerate launch scripts/finetune.py your_config.yml
|
||||
accelerate launch -m axolotl.cli.train your_config.yml
|
||||
```
|
||||
|
||||
#### Preprocess dataset
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||
This is recommended for large datasets.
|
||||
|
||||
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
||||
- Use `--debug` to see preprocessed examples.
|
||||
|
||||
```bash
|
||||
python -m axolotl.cli.preprocess your_config.yml
|
||||
```
|
||||
|
||||
#### Multi-GPU
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
|
||||
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
|
||||
is the recommended multi-GPU option currently because FSDP may experience
|
||||
[loss instability](https://github.com/huggingface/transformers/issues/26498).
|
||||
|
||||
##### DeepSpeed
|
||||
|
||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||
|
||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero1.json
|
||||
```
|
||||
|
||||
##### Config
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||
```
|
||||
|
||||
##### FSDP
|
||||
|
||||
- llama FSDP
|
||||
```yaml
|
||||
@@ -675,11 +958,6 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
- llama Deepspeed
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero3.json
|
||||
```
|
||||
|
||||
##### Weights & Biases Logging
|
||||
|
||||
- wandb options
|
||||
@@ -688,7 +966,7 @@ wandb_mode:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
@@ -698,34 +976,44 @@ Pass the appropriate flag to the train command:
|
||||
|
||||
- Pretrained LORA:
|
||||
```bash
|
||||
--inference --lora_model_dir="./lora-output-dir"
|
||||
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
|
||||
```
|
||||
- Full weights finetune:
|
||||
```bash
|
||||
--inference --base_model="./completed-model"
|
||||
python -m axolotl.cli.inference examples/your_config.yml --base_model="./completed-model"
|
||||
```
|
||||
- Full weights finetune w/ a prompt from a text file:
|
||||
```bash
|
||||
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
|
||||
--base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
|
||||
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||
```
|
||||
-- With gradio hosting
|
||||
```bash
|
||||
python -m axolotl.cli.inference examples/your_config.yml --gradio
|
||||
```
|
||||
|
||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
||||
|
||||
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
|
||||
|
||||
### Merge LORA to base
|
||||
|
||||
Add below flag to train command above
|
||||
|
||||
```bash
|
||||
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||
```
|
||||
|
||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
|
||||
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
|
||||
```
|
||||
|
||||
## Common Errors 🧰
|
||||
|
||||
See also the [FAQ's](./docs/faq.md).
|
||||
|
||||
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
|
||||
|
||||
Please reduce any below
|
||||
@@ -734,6 +1022,10 @@ Please reduce any below
|
||||
- `gradient_accumulation_steps`
|
||||
- `sequence_len`
|
||||
|
||||
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
|
||||
|
||||
Using adamw_bnb_8bit might also save you some memory.
|
||||
|
||||
> `failed (exitcode: -9)`
|
||||
|
||||
Usually means your system has run out of system memory.
|
||||
@@ -752,6 +1044,10 @@ Try to turn off xformers.
|
||||
|
||||
It's safe to ignore it.
|
||||
|
||||
> NCCL Timeouts during training
|
||||
|
||||
See the [NCCL](docs/nccl.md) guide.
|
||||
|
||||
## Need help? 🙋♂️
|
||||
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||
|
||||
31
deepspeed/zero1.json
Normal file
31
deepspeed/zero1.json
Normal file
@@ -0,0 +1,31 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 1,
|
||||
"overlap_comm": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -1,46 +1,35 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"contiguous_gradients": true,
|
||||
"overlap_comm": true
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": [
|
||||
0.9,
|
||||
0.999
|
||||
],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
"contiguous_gradients": true,
|
||||
"overlap_comm": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 0,
|
||||
@@ -36,18 +28,11 @@
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": 1e-8,
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
|
||||
@@ -9,6 +9,11 @@ services:
|
||||
- ~/.cache/huggingface/:/root/.cache/huggingface/
|
||||
# set environment variables
|
||||
environment:
|
||||
# Set environment variables
|
||||
- GIT_AUTHOR_NAME=${GIT_AUTHOR_NAME}
|
||||
- GIT_AUTHOR_EMAIL=${GIT_AUTHOR_EMAIL}
|
||||
- GIT_COMMITTER_NAME=${GIT_COMMITTER_NAME}
|
||||
- GIT_COMMITTER_EMAIL=${GIT_COMMITTER_EMAIL}
|
||||
- WANDB_API_KEY=${WANDB_API_KEY}
|
||||
deploy:
|
||||
resources:
|
||||
|
||||
@@ -5,25 +5,29 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
ARG CUDA="118"
|
||||
ENV BNB_CUDA_VERSION=$CUDA
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y vim curl
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main"
|
||||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN cd axolotl && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
else \
|
||||
pip install -e .[flash-attn]; \
|
||||
pip install -e .[deepspeed,flash-attn]; \
|
||||
fi
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN cd axolotl && \
|
||||
git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch
|
||||
|
||||
# helper for huggingface-login cli
|
||||
|
||||
@@ -10,70 +10,28 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ARG PYTHON_VERSION="3.9"
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
ARG CUDA="118"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN wget \
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh
|
||||
|
||||
RUN conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
|
||||
FROM base-builder AS deepspeed-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||
cd DeepSpeed && \
|
||||
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 python3 setup.py bdist_wheel
|
||||
|
||||
FROM base-builder AS bnb-builder
|
||||
|
||||
WORKDIR /workspace
|
||||
ARG CUDA="118"
|
||||
ENV CUDA=$CUDA
|
||||
|
||||
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
|
||||
cd bitsandbytes && \
|
||||
CUDA_VERSION=$CUDA make cuda11x && \
|
||||
python setup.py bdist_wheel
|
||||
|
||||
FROM base-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
# recompile apex
|
||||
RUN python3 -m pip uninstall -y apex
|
||||
RUN git clone https://github.com/NVIDIA/apex
|
||||
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
|
||||
RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
|
||||
RUN mkdir -p /workspace/builds
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
||||
|
||||
RUN mkdir -p /workspace/wheels/bitsandbytes
|
||||
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
|
||||
|
||||
RUN pip3 install wheels/deepspeed-*.whl
|
||||
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
||||
RUN git lfs install --skip-repo
|
||||
RUN pip3 install awscli && \
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
@@ -4,6 +4,7 @@ FROM winglian/axolotl:$BASE_TAG
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||
|
||||
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
||||
|
||||
|
||||
18
docs/faq.md
Normal file
18
docs/faq.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# Axolotl FAQ's
|
||||
|
||||
|
||||
> The trainer stopped and hasn't progressed in several minutes.
|
||||
|
||||
Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md)
|
||||
|
||||
> Exitcode -9
|
||||
|
||||
This usually happens when you run out of system RAM.
|
||||
|
||||
> Exitcode -7 while using deepspeed
|
||||
|
||||
Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||
|
||||
> AttributeError: 'DummyOptim' object has no attribute 'step'
|
||||
|
||||
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
||||
45
docs/multi-node.md
Normal file
45
docs/multi-node.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Multi Node
|
||||
|
||||
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
|
||||
|
||||
~/.cache/huggingface/accelerate/default_config.yaml
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0 # Set to 0 for the main machine, increment by one for other machines
|
||||
main_process_ip: 10.0.0.4 # Set to main machine's IP
|
||||
main_process_port: 5000
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 2 # Change to the number of machines
|
||||
num_processes: 4 # That's the total number of GPUs, (for example: if you have 2 machines with 4 GPU, put 8)
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
Configure your model to use FSDP with for example:
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_offload_params: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
## Machine configuration
|
||||
|
||||
On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility.
|
||||
|
||||
You will also need to have the same configuration file for your model on each machine.
|
||||
|
||||
On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines.
|
||||
|
||||
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
|
||||
51
docs/multipack.md
Normal file
51
docs/multipack.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Multipack
|
||||
|
||||
4k context, bsz =4,
|
||||
each character represents 256 tokens
|
||||
X represents a padding token
|
||||
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A ]
|
||||
B B B B B B ]
|
||||
C C C C C C C ]
|
||||
D D D D ]]
|
||||
|
||||
[[ E E E E E E E E ]
|
||||
[ F F F F ]
|
||||
[ G G G ]
|
||||
[ H H H H ]]
|
||||
|
||||
[[ I I I ]
|
||||
[ J J J ]
|
||||
[ K K K K K]
|
||||
[ L L L ]]
|
||||
```
|
||||
|
||||
after padding to longest input in each step
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A ]
|
||||
B B B B B B X X X X X X ]
|
||||
C C C C C C C X X X X ]
|
||||
D D D D X X X X X X X ]]
|
||||
|
||||
[[ E E E E E E E E ]
|
||||
[ F F F F X X X X ]
|
||||
[ G G G X X X X X ]
|
||||
[ H H H H X X X X ]]
|
||||
|
||||
[[ I I I X X ]
|
||||
[ J J J X X ]
|
||||
[ K K K K K ]
|
||||
[ L L L X X ]]
|
||||
```
|
||||
|
||||
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A B B B B B
|
||||
B C C C C C C C D D D D E E E E
|
||||
E E E E F F F F F G G G H H H H
|
||||
I I I J J J J K K K K K L L L X ]]
|
||||
```
|
||||
46
docs/nccl.md
Normal file
46
docs/nccl.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# NCCL
|
||||
|
||||
NVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several [environment variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html). A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort:
|
||||
|
||||
```text
|
||||
Watchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out.
|
||||
```
|
||||
|
||||
Often, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends [disabling PCI access control services (ACS)](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#pci-access-control-services-acs) as a possible solution if this is available to you.
|
||||
|
||||
Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:
|
||||
|
||||
```shell
|
||||
nvidia-smi nvlink --status
|
||||
```
|
||||
|
||||
To force NCCL to use NVLink, simply set this in the environment:
|
||||
|
||||
```shell
|
||||
export NCCL_P2P_LEVEL=NVL
|
||||
```
|
||||
|
||||
If NVLink is not available in your environment there are other options for ``NCCL_P2P_LEVEL`` in the table below:
|
||||
|
||||
| NCCL_P2P_LEVEL | Description |
|
||||
| -------------- | ----------- |
|
||||
| PIX | P2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication. |
|
||||
| PXB | P2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency. |
|
||||
| PHB | P2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL) |
|
||||
|
||||
To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:
|
||||
|
||||
```shell
|
||||
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
|
||||
```
|
||||
|
||||
It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:
|
||||
|
||||
```shell
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_DEBUG_SUBSYS=ALL
|
||||
export TORCH_DISTRIBUTED_DEBUG=INFO
|
||||
export TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log
|
||||
```
|
||||
|
||||
Finally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ``ddp_timeout`` value in the Axolotl configuration. See [PyTorch init_process_group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for documentation on this value.
|
||||
89
examples/cerebras/btlm-ft.yml
Normal file
89
examples/cerebras/btlm-ft.yml
Normal file
@@ -0,0 +1,89 @@
|
||||
base_model: cerebras/btlm-3b-8k-base
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: GPT2Tokenizer
|
||||
trust_remote_code: true
|
||||
tokenizer_use_fast: true
|
||||
tokenizer_legacy: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
hf_use_auth_token: true
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_prepared_run
|
||||
val_set_size: 0.05
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
sample_packing: false
|
||||
sample_packing_eff_est:
|
||||
sample_packing_seq_len_multiplier:
|
||||
total_num_tokens:
|
||||
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_modules:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
output_dir: btlm-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_eps: 0.000000001
|
||||
max_grad_norm: 1.0
|
||||
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
lr_quadratic_warmup: true
|
||||
learning_rate: 0.000085
|
||||
train_on_inputs: true
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
sdp_attention:
|
||||
flash_optimum:
|
||||
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
|
||||
warmup_steps: 32
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
save_total_limit:
|
||||
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
fsdp:
|
||||
# - full_shard
|
||||
# - auto_wrap
|
||||
fsdp_config:
|
||||
# fsdp_state_dict_type: FULL_STATE_DICT
|
||||
# fsdp_transformer_layer_cls_to_wrap: BTLMBlock
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: cerebras/Cerebras-GPT-1.3B
|
||||
base_model_config: cerebras/Cerebras-GPT-1.3B
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
@@ -7,8 +6,8 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -25,7 +24,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
batch_size: 4
|
||||
@@ -50,8 +49,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
base_model_config: codellama/CodeLlama-13b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,12 +10,13 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 100000
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
@@ -29,12 +29,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,8 +54,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
base_model_config: codellama/CodeLlama-13b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,15 +10,16 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 100000
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
@@ -31,12 +31,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,8 +56,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
base_model_config: codellama/CodeLlama-34b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,12 +10,13 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 100000
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
@@ -29,12 +29,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,8 +54,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
base_model_config: codellama/CodeLlama-34b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,15 +10,16 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 100000
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
@@ -31,12 +31,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,8 +56,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
base_model_config: codellama/CodeLlama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,12 +10,13 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 100000
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
@@ -29,12 +29,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,8 +54,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
base_model_config: codellama/CodeLlama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,15 +10,16 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 100000
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
@@ -31,12 +31,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,8 +56,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_falcon_derived_model: true
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
gptq: false
|
||||
@@ -11,8 +11,8 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./falcon-7b
|
||||
batch_size: 2
|
||||
@@ -51,8 +51,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 40
|
||||
eval_steps: 5
|
||||
save_steps: 43
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# 1b: tiiuae/falcon-rw-1b
|
||||
# 40b: tiiuae/falcon-40b
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_falcon_derived_model: true
|
||||
load_in_8bit: false
|
||||
# enable 4bit for QLoRA
|
||||
load_in_4bit: true
|
||||
@@ -17,8 +17,8 @@ datasets:
|
||||
data_files:
|
||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||
type: "alpaca:chat"
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
@@ -53,7 +53,7 @@ output_dir: ./qlora-out
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
@@ -80,8 +80,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 5
|
||||
save_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.000001
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_falcon_derived_model: true
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
gptq: false
|
||||
@@ -11,8 +11,8 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./falcon-7b
|
||||
batch_size: 2
|
||||
@@ -51,8 +51,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 40
|
||||
eval_steps: 5
|
||||
save_steps: 43
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: EleutherAI/gpt-j-6b
|
||||
base_model_config: EleutherAI/gpt-j-6b
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
@@ -7,8 +6,8 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -22,7 +21,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 2
|
||||
@@ -47,8 +46,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
# LLaMa 7B using LoRA
|
||||
|
||||
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
|
||||
|
||||
```
|
||||
@@ -1,63 +0,0 @@
|
||||
base_model: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
|
||||
base_model_config: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
trust_remote_code:
|
||||
load_in_8bit: true
|
||||
gptq: true
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_fan_in_fan_out: false
|
||||
wandb_project: llama-7b-lora-int4
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./llama-7b-lora-int4
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0000002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
fp16: true
|
||||
bf16: false
|
||||
tf32: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 5
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
gradient_checkpointing: true
|
||||
gptq_groupsize: 128
|
||||
gptq_model_v1: false
|
||||
warmup_steps: 20
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
tokens:
|
||||
pad_token: "<pad>"
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,12 +1,11 @@
|
||||
base_model: huggyllama/llama-7b
|
||||
base_model_config: huggyllama/llama-7b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: openaccess-ai-collective/jeopardy
|
||||
type: jeopardy
|
||||
dataset_prepared_path: last_run_prepared
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
@@ -20,12 +19,12 @@ lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./jeopardy-bot-7b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
@@ -43,8 +42,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -9,12 +9,16 @@ gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
|
||||
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
|
||||
```
|
||||
or
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
|
||||
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
|
||||
```
|
||||
|
||||
To launch a full finetuning with 16-bit precision:
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
|
||||
```
|
||||
|
||||
72
examples/llama-2/fft_optimized.yml
Normal file
72
examples/llama-2/fft_optimized.yml
Normal file
@@ -0,0 +1,72 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
73
examples/llama-2/gptq-lora.yml
Normal file
73
examples/llama-2/gptq-lora.yml
Normal file
@@ -0,0 +1,73 @@
|
||||
base_model: TheBloke/Llama-2-7B-GPTQ
|
||||
is_llama_derived_model: false
|
||||
gptq: true
|
||||
gptq_disable_exllama: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
tokenizer_use_fast: true
|
||||
tokenizer_legacy: true
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
hf_use_auth_token: true
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 4096
|
||||
sample_packing:
|
||||
lora_r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- k_proj
|
||||
- o_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./model-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_eps: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
lr_quadratic_warmup: true
|
||||
learning_rate: 0.000017
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: false
|
||||
fp16: false
|
||||
float16: true
|
||||
tf32: true
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
sdp_attention:
|
||||
flash_optimum:
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,12 +10,13 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
@@ -29,12 +29,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,8 +54,10 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,8 +10,8 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -20,6 +19,7 @@ lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
@@ -31,12 +31,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,8 +56,9 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,8 +10,8 @@ strict: false
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./relora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -20,6 +19,7 @@ lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
@@ -35,12 +35,12 @@ relora_cpu_offload: false
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -60,8 +60,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps: 50
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
68
examples/llama-2/tiny-llama.yml
Normal file
68
examples/llama-2/tiny-llama.yml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
|
||||
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
61
examples/mamba/config.yml
Normal file
61
examples/mamba/config.yml
Normal file
@@ -0,0 +1,61 @@
|
||||
base_model: state-spaces/mamba-2.8b
|
||||
model_type: MambaLMHeadModel
|
||||
tokenizer_type: AutoTokenizer
|
||||
tokenizer_config: EleutherAI/gpt-neox-20b
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
tokens:
|
||||
save_safetensors: False
|
||||
12
examples/mistral/README.md
Normal file
12
examples/mistral/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
**Mistral 7B** is a language model with a total of 7.3 billion parameters, showcasing a notable performance across a variety of benchmarks.
|
||||
|
||||
Fine Tune:
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/mistral/config.yml
|
||||
|
||||
```
|
||||
|
||||
If you run into CUDA OOM, use deepspeed with config zero2.json:
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
|
||||
```
|
||||
61
examples/mistral/config.yml
Normal file
61
examples/mistral/config.yml
Normal file
@@ -0,0 +1,61 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_mistral_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.000005
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
79
examples/mistral/mixtral.yml
Normal file
79
examples/mistral/mixtral.yml
Normal file
@@ -0,0 +1,79 @@
|
||||
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
#lora_target_modules:
|
||||
# - gate
|
||||
# - q_proj
|
||||
# - k_proj
|
||||
# - v_proj
|
||||
# - o_proj
|
||||
# - w1
|
||||
# - w2
|
||||
# - w3
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed: deepspeed/zero2.json
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
81
examples/mistral/qlora.yml
Normal file
81
examples/mistral/qlora.yml
Normal file
@@ -0,0 +1,81 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_mistral_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,12 +1,11 @@
|
||||
base_model: mosaicml/mpt-7b
|
||||
base_model_config: mosaicml/mpt-7b
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
@@ -22,12 +21,12 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: mpt-alpaca-7b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./mpt-alpaca-7b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
@@ -45,8 +44,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
@@ -9,12 +8,12 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 256
|
||||
max_packed_sequence_len:
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
@@ -24,16 +23,16 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./openllama-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.00001
|
||||
learning_rate: 0.000003
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
float16: true
|
||||
@@ -45,13 +44,13 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 50
|
||||
save_steps:
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
@@ -9,12 +8,12 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 256
|
||||
max_packed_sequence_len:
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.0
|
||||
@@ -30,12 +29,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-out
|
||||
batch_size: 16
|
||||
micro_batch_size: 4
|
||||
num_epochs: 3
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
@@ -50,16 +49,16 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 50
|
||||
save_steps:
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: openlm-research/open_llama_3b
|
||||
base_model_config: openlm-research/open_llama_3b
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
@@ -9,12 +8,12 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
lora_r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
@@ -24,36 +23,36 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
batch_size: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 2
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
bf16: false
|
||||
fp16: true
|
||||
tf32: false
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
|
||||
11
examples/phi/README.md
Normal file
11
examples/phi/README.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# Phi
|
||||
|
||||
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed/zero1.json
|
||||
|
||||
# OR
|
||||
|
||||
python -m axolotl.cli.train examples/phi/phi-qlora.yml
|
||||
```
|
||||
74
examples/phi/phi-ft.yml
Normal file
74
examples/phi/phi-ft.yml
Normal file
@@ -0,0 +1,74 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
model_type: PhiForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_llama_derived_model: false
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: garage-bAInd/Open-Platypus
|
||||
type: alpaca
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./phi-sft-out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_epsilon: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.000003
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing:
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
resize_token_embeddings_to_32x: true
|
||||
special_tokens:
|
||||
bos_token: "<|endoftext|>"
|
||||
eos_token: "<|endoftext|>"
|
||||
unk_token: "<|endoftext|>"
|
||||
pad_token: "<|endoftext|>"
|
||||
74
examples/phi/phi-qlora.yml
Normal file
74
examples/phi/phi-qlora.yml
Normal file
@@ -0,0 +1,74 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_llama_derived_model: false
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: garage-bAInd/Open-Platypus
|
||||
type: alpaca
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./phi-sft-out
|
||||
|
||||
sequence_len: 1024
|
||||
sample_packing: false # not CURRENTLY compatible with LoRAs
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
lora_r: 64
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_epsilon: 0.00001
|
||||
max_grad_norm: 1.0
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.000003
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing:
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
resize_token_embeddings_to_32x: true
|
||||
special_tokens:
|
||||
bos_token: "<|endoftext|>"
|
||||
eos_token: "<|endoftext|>"
|
||||
unk_token: "<|endoftext|>"
|
||||
pad_token: "<|endoftext|>"
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: EleutherAI/pythia-12b-deduped
|
||||
base_model_config: EleutherAI/pythia-12b-deduped
|
||||
base_model_ignore_patterns: pytorch* # prefer safetensors
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
@@ -10,7 +9,7 @@ device_map: auto
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
@@ -25,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./pythia-12b
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
base_model: EleutherAI/pythia-1.4b-deduped
|
||||
base_model_config: EleutherAI/pythia-1.4b-deduped
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
@@ -19,20 +18,20 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-alpaca-pythia
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: True
|
||||
tf32: True
|
||||
bf16: true
|
||||
tf32: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
weight_decay: 0.1
|
||||
eval_steps: 20
|
||||
evals_per_epoch: 4
|
||||
logging_steps: 1
|
||||
|
||||
68
examples/qwen/lora.yml
Normal file
68
examples/qwen/lora.yml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: Qwen/Qwen-7B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
is_qwen_derived_model: true
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 2048 # supports up to 8192
|
||||
sample_packing: false
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
68
examples/qwen/qlora.yml
Normal file
68
examples/qwen/qlora.yml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: Qwen/Qwen-7B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
is_qwen_derived_model: true
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 2048 # supports up to 8192
|
||||
sample_packing: false
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code:
|
||||
@@ -7,7 +6,7 @@ load_in_8bit: false
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
@@ -23,12 +22,12 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: redpajama-alpaca-3b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./redpajama-alpaca-3b
|
||||
batch_size: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
@@ -46,8 +45,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
base_model: replit/replit-code-v1-3b
|
||||
base_model_config: replit/replit-code-v1-3b
|
||||
trust_remote_code: true
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
@@ -22,12 +21,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project: lora-replit
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-replit
|
||||
batch_size: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer:
|
||||
torchdistx_path:
|
||||
lr_scheduler:
|
||||
@@ -46,8 +45,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 50
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora
|
||||
# on Tim Dettmer's Guanaco dataset.
|
||||
base_model: Salesforce/xgen-7b-8k-base
|
||||
base_model_config: Salesforce/xgen-7b-8k-base
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
@@ -16,8 +15,8 @@ datasets:
|
||||
data_files:
|
||||
- openassistant_best_replies_train.jsonl
|
||||
type: "completion"
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
@@ -39,7 +38,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
@@ -52,7 +51,7 @@ output_dir: ./qlora-out
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
@@ -79,8 +78,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 50
|
||||
save_steps: 50
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
BIN
image/sticker_fixed.png
Normal file
BIN
image/sticker_fixed.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 370 KiB |
@@ -1,19 +1,22 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
auto-gptq==0.5.1
|
||||
packaging
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git
|
||||
peft==0.6.0
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@e5079b0b2abcef11ecbdae60ba4a6636c57b725d
|
||||
tokenizers==0.15.0
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
||||
accelerate==0.24.1
|
||||
deepspeed
|
||||
addict
|
||||
evaluate
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
datasets
|
||||
flash-attn>=2.0.8
|
||||
datasets>=2.15.0
|
||||
flash-attn==2.3.3
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers
|
||||
optimum
|
||||
xformers==0.0.22
|
||||
optimum==1.13.2
|
||||
hf_transfer
|
||||
colorama
|
||||
numba
|
||||
@@ -26,3 +29,11 @@ scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat==0.2.34
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
|
||||
# remote filesystems
|
||||
s3fs
|
||||
gcsfs
|
||||
# adlfs
|
||||
|
||||
@@ -1,271 +1,38 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import transformers
|
||||
import yaml
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from art import text2art
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
do_inference,
|
||||
do_merge_lora,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.cli.shard import shard
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta, train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.models import load_model_config, load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.scripts")
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
|
||||
def print_axolotl_text_art(suffix=None):
|
||||
font = "nancyj"
|
||||
ascii_text = " axolotl"
|
||||
if suffix:
|
||||
ascii_text += f" x {suffix}"
|
||||
ascii_art = text2art(" axolotl", font=font)
|
||||
|
||||
if is_main_process():
|
||||
print(ascii_art)
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
print("Give me an instruction (Ctrl + D to finish): ")
|
||||
instruction = ""
|
||||
for line in sys.stdin:
|
||||
instruction += line # pylint: disable=consider-using-join
|
||||
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
||||
return instruction
|
||||
|
||||
|
||||
def do_merge_lora(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
model = model.merge_and_unload()
|
||||
model.to(dtype=torch.float16)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info("saving merged model")
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
|
||||
def shard(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
LOG.debug("Re-saving model w/ sharding")
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
def do_inference(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
model = model.to(cfg.device)
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
# support for multiline inputs
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
if prompter_module:
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextStreamer(tokenizer)
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
generation_config=generation_config,
|
||||
streamer=streamer,
|
||||
)
|
||||
print("=" * 40)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
yaml_files = list(path.glob("*.yml"))
|
||||
|
||||
if not yaml_files:
|
||||
raise ValueError(
|
||||
"No YAML config files found in the specified directory. Are you using a .yml extension?"
|
||||
)
|
||||
|
||||
if len(yaml_files) == 1:
|
||||
print(f"Using default YAML file '{yaml_files[0]}'")
|
||||
return yaml_files[0]
|
||||
|
||||
print("Choose a YAML file:")
|
||||
for idx, file in enumerate(yaml_files):
|
||||
print(f"{idx + 1}. {file}")
|
||||
|
||||
chosen_file = None
|
||||
while chosen_file is None:
|
||||
try:
|
||||
choice = int(input("Enter the number of your choice: "))
|
||||
if 1 <= choice <= len(yaml_files):
|
||||
chosen_file = yaml_files[choice - 1]
|
||||
else:
|
||||
print("Invalid choice. Please choose a number from the list.")
|
||||
except ValueError:
|
||||
print("Invalid input. Please enter a number.")
|
||||
|
||||
return chosen_file
|
||||
|
||||
|
||||
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
|
||||
return not any(el in list2 for el in list1)
|
||||
|
||||
|
||||
def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(config)
|
||||
|
||||
# load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||
# then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
|
||||
model_config = load_model_config(cfg)
|
||||
|
||||
# figure out if the model is llama
|
||||
cfg.is_llama_derived_model = (
|
||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
||||
or cfg.is_llama_derived_model
|
||||
or "llama" in cfg.base_model
|
||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||
)
|
||||
validate_config(cfg)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
def load_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
) -> TrainDatasetMeta:
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
[
|
||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
||||
for _ in range(cli_args.debug_num_examples)
|
||||
]
|
||||
),
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
total_num_steps=total_num_steps,
|
||||
)
|
||||
LOG = logging.getLogger("axolotl.scripts.finetune")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
print_axolotl_text_art()
|
||||
LOG.warning(
|
||||
str(
|
||||
PendingDeprecationWarning(
|
||||
"scripts/finetune.py will be replaced with calling axolotl.cli.train"
|
||||
)
|
||||
)
|
||||
)
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
@@ -278,8 +45,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cli_args.prepare_ds_only:
|
||||
return
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
|
||||
61
setup.py
61
setup.py
@@ -2,38 +2,57 @@
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
install_requires = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
# don't include peft yet until we check the int4
|
||||
# need to manually install peft for now...
|
||||
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
|
||||
reqs = [r for r in reqs if "flash-attn" not in r]
|
||||
reqs = [r for r in reqs if r and r[0] != "#"]
|
||||
for r in reqs:
|
||||
install_requires.append(r)
|
||||
|
||||
def parse_requirements():
|
||||
_install_requires = []
|
||||
_dependency_links = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
lines = [r.strip() for r in requirements_file.readlines()]
|
||||
for line in lines:
|
||||
if line.startswith("--extra-index-url"):
|
||||
# Handle custom index URLs
|
||||
_, url = line.split()
|
||||
_dependency_links.append(url)
|
||||
elif (
|
||||
"flash-attn" not in line
|
||||
and "deepspeed" not in line
|
||||
and line
|
||||
and line[0] != "#"
|
||||
):
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
|
||||
# TODO(wing) remove once xformers release supports torch 2.1.0
|
||||
if "torch==2.1.0" in _install_requires:
|
||||
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
|
||||
_install_requires.append(
|
||||
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
|
||||
)
|
||||
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
install_requires, dependency_links = parse_requirements()
|
||||
|
||||
|
||||
setup(
|
||||
name="axolotl",
|
||||
version="0.1",
|
||||
description="You know you're going to axolotl questions",
|
||||
version="0.3.0",
|
||||
description="LLM Trainer",
|
||||
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages(),
|
||||
install_requires=install_requires,
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"gptq": [
|
||||
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
||||
],
|
||||
"gptq_triton": [
|
||||
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
||||
],
|
||||
"flash-attn": [
|
||||
"flash-attn==2.0.8",
|
||||
"flash-attn==2.3.3",
|
||||
],
|
||||
"extras": [
|
||||
"deepspeed": [
|
||||
"deepspeed",
|
||||
],
|
||||
"peft": [
|
||||
"peft @ git+https://github.com/huggingface/peft.git",
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.0.1",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
361
src/axolotl/cli/__init__.py
Normal file
361
src/axolotl/cli/__init__.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from accelerate.commands.config import config_args
|
||||
from art import text2art
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.scripts")
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
|
||||
def print_axolotl_text_art(suffix=None):
|
||||
font = "nancyj"
|
||||
ascii_text = " axolotl"
|
||||
if suffix:
|
||||
ascii_text += f" x {suffix}"
|
||||
ascii_art = text2art(ascii_text, font=font)
|
||||
|
||||
if is_main_process():
|
||||
print(ascii_art)
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
print("Give me an instruction (Ctrl + D to submit): ")
|
||||
instruction = ""
|
||||
for line in sys.stdin:
|
||||
instruction += line # pylint: disable=consider-using-join
|
||||
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
||||
return instruction
|
||||
|
||||
|
||||
def do_merge_lora(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
model = model.merge_and_unload()
|
||||
model.to(dtype=cfg.torch_dtype)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
|
||||
def do_inference(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
model = model.to(cfg.device)
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
# support for multiline inputs
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
if prompter_module:
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextStreamer(tokenizer)
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
generation_config=generation_config,
|
||||
streamer=streamer,
|
||||
)
|
||||
print("=" * 40)
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
def do_inference_gradio(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
model = model.to(cfg.device)
|
||||
|
||||
def generate(instruction):
|
||||
if not instruction:
|
||||
return
|
||||
if prompter_module:
|
||||
# pylint: disable=stop-iteration-return
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = {
|
||||
"inputs": batch["input_ids"].to(cfg.device),
|
||||
"generation_config": generation_config,
|
||||
"streamer": streamer,
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
all_text = ""
|
||||
|
||||
for new_text in streamer:
|
||||
all_text += new_text
|
||||
yield all_text
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs="textbox",
|
||||
outputs="text",
|
||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||
)
|
||||
demo.queue().launch(show_api=False, share=True)
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
yaml_files = list(path.glob("*.yml"))
|
||||
|
||||
if not yaml_files:
|
||||
raise ValueError(
|
||||
"No YAML config files found in the specified directory. Are you using a .yml extension?"
|
||||
)
|
||||
|
||||
if len(yaml_files) == 1:
|
||||
print(f"Using default YAML file '{yaml_files[0]}'")
|
||||
return yaml_files[0]
|
||||
|
||||
print("Choose a YAML file:")
|
||||
for idx, file in enumerate(yaml_files):
|
||||
print(f"{idx + 1}. {file}")
|
||||
|
||||
chosen_file = None
|
||||
while chosen_file is None:
|
||||
try:
|
||||
choice = int(input("Enter the number of your choice: "))
|
||||
if 1 <= choice <= len(yaml_files):
|
||||
chosen_file = yaml_files[choice - 1]
|
||||
else:
|
||||
print("Invalid choice. Please choose a number from the list.")
|
||||
except ValueError:
|
||||
print("Invalid input. Please enter a number.")
|
||||
|
||||
return chosen_file
|
||||
|
||||
|
||||
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
|
||||
return not any(el in list2 for el in list1)
|
||||
|
||||
|
||||
def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(config)
|
||||
|
||||
# load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
cfg.axolotl_config_path = config
|
||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||
# then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
def load_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
) -> TrainDatasetMeta:
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||
cfg, tokenizer
|
||||
)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
[
|
||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
||||
for _ in range(cli_args.debug_num_examples)
|
||||
]
|
||||
),
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
)
|
||||
|
||||
LOG.info("printing prompters...")
|
||||
for prompter in prompters:
|
||||
LOG.info(prompter)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
total_num_steps=total_num_steps,
|
||||
)
|
||||
|
||||
|
||||
def check_accelerate_default_config():
|
||||
if Path(config_args.default_yaml_config_file).exists():
|
||||
LOG.warning(
|
||||
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
||||
)
|
||||
|
||||
|
||||
def check_user_token():
|
||||
# Verify if token is valid
|
||||
api = HfApi()
|
||||
try:
|
||||
user_info = api.whoami()
|
||||
return bool(user_info)
|
||||
except LocalTokenNotFoundError:
|
||||
LOG.warning(
|
||||
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
36
src/axolotl/cli/inference.py
Normal file
36
src/axolotl/cli/inference.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
CLI to run inference on a trained model
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
|
||||
from axolotl.cli import (
|
||||
do_inference,
|
||||
do_inference_gradio,
|
||||
load_cfg,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
parsed_cli_args.inference = True
|
||||
|
||||
if gradio:
|
||||
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
27
src/axolotl/cli/merge_lora.py
Normal file
27
src/axolotl/cli/merge_lora.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
CLI to run merge a trained LoRA into a base model
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
|
||||
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
parsed_cli_args.merge_lora = True
|
||||
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
|
||||
|
||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
53
src/axolotl/cli/preprocess.py
Normal file
53
src/axolotl/cli/preprocess.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from colorama import Fore
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import PreprocessCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
if not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
+ "preprocess CLI called without dataset_prepared_path set, "
|
||||
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
|
||||
+ Fore.RESET
|
||||
)
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
||||
+ Fore.RESET
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
42
src/axolotl/cli/shard.py
Normal file
42
src/axolotl/cli/shard.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
CLI to shard a trained model into 10GiB chunks
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
|
||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.scripts")
|
||||
|
||||
|
||||
def shard(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
LOG.debug("Re-saving model w/ sharding")
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
parsed_cli_args.shard = True
|
||||
|
||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
38
src/axolotl/cli/train.py
Normal file
38
src/axolotl/cli/train.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
@@ -25,11 +25,22 @@ class TrainerCliArgs:
|
||||
debug_num_examples: int = field(default=5)
|
||||
inference: bool = field(default=False)
|
||||
merge_lora: bool = field(default=False)
|
||||
prepare_ds_only: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessCliArgs:
|
||||
"""
|
||||
dataclass representing arguments for preprocessing only
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
|
||||
5
src/axolotl/common/const.py
Normal file
5
src/axolotl/common/const.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Various shared constants
|
||||
"""
|
||||
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
0
src/axolotl/core/__init__.py
Normal file
0
src/axolotl/core/__init__.py
Normal file
798
src/axolotl/core/trainer_builder.py
Normal file
798
src/axolotl/core/trainer_builder.py
Normal file
@@ -0,0 +1,798 @@
|
||||
"""
|
||||
Builder for the training args and trainer
|
||||
"""
|
||||
|
||||
import abc
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_utils import seed_worker
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
MambaDataCollator,
|
||||
)
|
||||
from axolotl.utils.samplers import MultipackBatchSampler
|
||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||
|
||||
try:
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
Extend the base TrainingArguments for axolotl helpers
|
||||
"""
|
||||
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
sample_packing_seq_len_multiplier: int = field(
|
||||
default=1,
|
||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
|
||||
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
||||
self.num_epochs = num_epochs
|
||||
self.bench_data_collator = bench_data_collator
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
):
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing:
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
self.args.train_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
||||
lengths=(
|
||||
self.train_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
self.args.per_device_eval_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
||||
lengths=(
|
||||
eval_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing:
|
||||
train_dataset = self.train_dataset
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
if isinstance(sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
|
||||
def get_bench_dataloader(
|
||||
self,
|
||||
bench_dataset: Dataset,
|
||||
) -> DataLoader:
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": self.bench_data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Mamba specific trainer to handle loss calculation
|
||||
"""
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
|
||||
labels = input_ids.to(lm_logits.device)
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
|
||||
return lm_loss
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
pct_start=pct_start,
|
||||
div_factor=6,
|
||||
)
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
"""
|
||||
|
||||
_train_dataset = None
|
||||
_eval_dataset = None
|
||||
|
||||
def __init__(self, cfg, model, tokenizer):
|
||||
self.cfg = cfg
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self._train_dataset
|
||||
|
||||
@train_dataset.setter
|
||||
def train_dataset(self, dataset):
|
||||
self._train_dataset = dataset
|
||||
|
||||
@property
|
||||
def eval_dataset(self):
|
||||
return self._eval_dataset
|
||||
|
||||
@eval_dataset.setter
|
||||
def eval_dataset(self, dataset):
|
||||
self._eval_dataset = dataset
|
||||
|
||||
@abstractmethod
|
||||
def build(self, total_num_steps):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_callbacks(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
"""
|
||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||
"""
|
||||
|
||||
|
||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Build the HuggingFace training args/trainer for Causal models
|
||||
"""
|
||||
|
||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||
# TODO
|
||||
return training_arguments_kwargs
|
||||
|
||||
def hook_post_create_training_args(self, training_arguments):
|
||||
# TODO
|
||||
return training_arguments
|
||||
|
||||
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
|
||||
# TODO
|
||||
return trainer_kwargs, trainer_cls
|
||||
|
||||
def hook_post_create_trainer(self, trainer):
|
||||
# TODO
|
||||
return trainer
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = []
|
||||
callbacks.append(GPUStatsCallback(self.cfg))
|
||||
callbacks.append(EvalFirstStepCallback)
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
hasattr(self.model, "use_bettertransformer")
|
||||
and self.model.use_bettertransformer is True
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback)
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||
|
||||
if self.cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
self.cfg.early_stopping_patience,
|
||||
)
|
||||
callbacks.append(early_stop_cb)
|
||||
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
if self.cfg.lr_scheduler == "one_cycle" and (
|
||||
self.cfg.fsdp or self.cfg.adapter == "qlora"
|
||||
):
|
||||
return OneCycleLRSchedulerTrainer
|
||||
if self.cfg.relora_steps:
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
warmup_steps = None
|
||||
if self.cfg.warmup_steps is not None:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
elif self.cfg.warmup_ratio is not None:
|
||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||
else:
|
||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||
|
||||
logging_steps = (
|
||||
self.cfg.logging_steps
|
||||
if self.cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
if self.cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_arguments_kwargs["bf16"] = self.cfg.bf16
|
||||
training_arguments_kwargs["fp16"] = (
|
||||
self.cfg.fp16 and not self.cfg.bf16
|
||||
) or False
|
||||
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
if self.cfg.seed:
|
||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_arguments_kwargs[
|
||||
"gradient_checkpointing"
|
||||
] = self.cfg.gradient_checkpointing
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
||||
|
||||
# deepspeed
|
||||
if self.cfg.deepspeed:
|
||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||
|
||||
if self.cfg.lr_quadratic_warmup is not None:
|
||||
training_arguments_kwargs[
|
||||
"lr_quadratic_warmup"
|
||||
] = self.cfg.lr_quadratic_warmup
|
||||
|
||||
if self.cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||
if self.cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
||||
if self.cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
||||
if self.cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
||||
|
||||
if self.cfg.hub_model_id:
|
||||
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||
training_arguments_kwargs["push_to_hub"] = True
|
||||
training_arguments_kwargs["hub_private_repo"] = True
|
||||
|
||||
if self.cfg.hub_strategy:
|
||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||
|
||||
if self.cfg.save_safetensors is not None:
|
||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.cfg.sample_packing_eff_est:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_efficiency"
|
||||
] = self.cfg.sample_packing_eff_est
|
||||
|
||||
if self.cfg.dataloader_pin_memory is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_pin_memory"
|
||||
] = self.cfg.dataloader_pin_memory
|
||||
if self.cfg.dataloader_num_workers is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_num_workers"
|
||||
] = self.cfg.dataloader_num_workers
|
||||
if self.cfg.dataloader_prefetch_factor is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_prefetch_factor"
|
||||
] = self.cfg.dataloader_prefetch_factor
|
||||
|
||||
if self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
elif self.cfg.evaluation_strategy:
|
||||
training_arguments_kwargs[
|
||||
"evaluation_strategy"
|
||||
] = self.cfg.evaluation_strategy
|
||||
else:
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.save_steps:
|
||||
training_arguments_kwargs["save_strategy"] = "steps"
|
||||
training_arguments_kwargs["save_steps"] = self.cfg.save_steps
|
||||
elif self.cfg.save_strategy:
|
||||
training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_arguments_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||
if self.cfg.bench_dataset:
|
||||
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
|
||||
if self.cfg.metric_for_best_model:
|
||||
training_arguments_kwargs[
|
||||
"metric_for_best_model"
|
||||
] = self.cfg.metric_for_best_model
|
||||
if self.cfg.greater_is_better:
|
||||
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
||||
|
||||
if self.cfg.torch_compile:
|
||||
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
|
||||
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
|
||||
elif torch._dynamo: # pylint: disable=protected-access
|
||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||
True
|
||||
)
|
||||
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||
if self.cfg.torch_compile_backend:
|
||||
training_arguments_kwargs[
|
||||
"torch_compile_backend"
|
||||
] = self.cfg.torch_compile_backend
|
||||
|
||||
# DDP Config
|
||||
if self.cfg.ddp_timeout:
|
||||
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
|
||||
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
||||
if self.cfg.ddp_bucket_cap_mb:
|
||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
||||
if self.cfg.ddp_broadcast_buffers is not None:
|
||||
training_arguments_kwargs[
|
||||
"ddp_broadcast_buffers"
|
||||
] = self.cfg.ddp_broadcast_buffers
|
||||
|
||||
# these are all the "standard" kwargs that are def used
|
||||
training_arguments_kwargs["max_steps"] = (
|
||||
total_num_steps if self.cfg.max_steps else -1
|
||||
)
|
||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||
training_arguments_kwargs[
|
||||
"per_device_train_batch_size"
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs[
|
||||
"per_device_eval_batch_size"
|
||||
] = self.cfg.eval_batch_size
|
||||
training_arguments_kwargs[
|
||||
"gradient_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs[
|
||||
"eval_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
||||
training_arguments_kwargs["save_total_limit"] = (
|
||||
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
||||
)
|
||||
training_arguments_kwargs["load_best_model_at_end"] = (
|
||||
(
|
||||
self.cfg.load_best_model_at_end is not False
|
||||
or self.cfg.early_stopping_patience
|
||||
)
|
||||
and self.cfg.val_set_size > 0
|
||||
and self.cfg.save_steps
|
||||
and self.cfg.eval_steps
|
||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||
) or False
|
||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||
False if self.cfg.ddp else None
|
||||
)
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
||||
training_arguments_kwargs["run_name"] = (
|
||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||
)
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler
|
||||
if self.cfg.lr_scheduler
|
||||
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||
else "cosine"
|
||||
)
|
||||
training_arguments_kwargs["weight_decay"] = (
|
||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||
)
|
||||
training_arguments_kwargs["sample_packing"] = (
|
||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = (
|
||||
self.cfg.sample_packing
|
||||
if self.cfg.eval_sample_packing is not False
|
||||
else False
|
||||
)
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_seq_len_multiplier"
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
)
|
||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||
training_args = (
|
||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
)
|
||||
training_args = self.hook_post_create_training_args(training_args)
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.optimizer == "adamw_anyprecision":
|
||||
if Path(self.cfg.torchdistx_path).exists():
|
||||
sys.path.append(self.cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True, # True/"longest" is the default
|
||||
}
|
||||
if self.cfg.pad_to_sequence_len:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
|
||||
self.cfg.sequence_len / 64
|
||||
)
|
||||
else:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
|
||||
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
add_mem_tokens,
|
||||
get_mem_id,
|
||||
set_model_mem_id,
|
||||
)
|
||||
|
||||
set_model_mem_id(self.model, self.tokenizer)
|
||||
|
||||
LOG.info("Adding landmark attention tokens to dataset")
|
||||
|
||||
for dataset in [self.train_dataset, self.eval_dataset]:
|
||||
dataset = dataset.map(
|
||||
partial(
|
||||
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
|
||||
),
|
||||
batched=False,
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
trainer_kwargs, trainer_cls
|
||||
)
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=self.build_collator(**data_collator_kwargs),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
callbacks=self.get_callbacks(),
|
||||
num_epochs=self.cfg.num_epochs,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
trainer = self.hook_post_create_trainer(trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
if self.cfg.deepspeed and self.cfg.sample_packing:
|
||||
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||
"train_micro_batch_size_per_gpu"
|
||||
] = self.cfg.micro_batch_size
|
||||
|
||||
return trainer
|
||||
|
||||
def build_collator(self, **kwargs):
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||
|
||||
return BatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**kwargs,
|
||||
)
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
@@ -22,7 +22,7 @@ class TokenizedPromptDataset(Dataset):
|
||||
"""
|
||||
Dataset that returns tokenized prompts from a stream of text files.
|
||||
Args:
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
"""
|
||||
|
||||
@@ -30,18 +30,29 @@ class TokenizedPromptDataset(Dataset):
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: IterableDataset,
|
||||
process_count: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
self.process_count = process_count
|
||||
super().__init__(self.process(dataset).data, **kwargs)
|
||||
|
||||
def process(self, dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, os.cpu_count())
|
||||
num_proc = (
|
||||
min(64, self.process_count)
|
||||
if self.process_count
|
||||
else min(64, os.cpu_count())
|
||||
)
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
map_kwargs["batch_size"] = 100
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -50,7 +61,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||
Args:
|
||||
tokenizer (Tokenizer): The processor used for proccessing the data.
|
||||
tokenizer (Tokenizer): The processor used for processing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
seq_length (int): Length of token sequences to return.
|
||||
"""
|
||||
|
||||
@@ -23,6 +23,7 @@ class ColorfulFormatter(Formatter):
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
record.rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
log_message = super().format(record)
|
||||
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
|
||||
|
||||
@@ -35,7 +36,7 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
},
|
||||
"colorful": {
|
||||
"()": ColorfulFormatter,
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
|
||||
},
|
||||
},
|
||||
"filters": {},
|
||||
|
||||
0
src/axolotl/models/__init__.py
Normal file
0
src/axolotl/models/__init__.py
Normal file
12
src/axolotl/models/mamba/__init__.py
Normal file
12
src/axolotl/models/mamba/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Modeling module for Mamba models
|
||||
"""
|
||||
|
||||
|
||||
def fix_mamba_attn_for_loss():
|
||||
from mamba_ssm.models import mixer_seq_simple
|
||||
|
||||
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
||||
|
||||
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
|
||||
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
|
||||
42
src/axolotl/models/mamba/configuration_mamba.py
Normal file
42
src/axolotl/models/mamba/configuration_mamba.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
HF Transformers MambaConfig
|
||||
"""
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MambaConfig(PretrainedConfig):
|
||||
"""
|
||||
modeling configuration for state space model/mamba
|
||||
"""
|
||||
|
||||
model_type = "mamba"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50280,
|
||||
d_model=2560,
|
||||
n_layer=64,
|
||||
rms_norm=True,
|
||||
residual_in_fp32=True,
|
||||
fused_add_norm=True,
|
||||
pad_vocab_size_multiple=8,
|
||||
pad_token_id=50277,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.n_layer = n_layer
|
||||
self.rms_norm = rms_norm
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.fused_add_norm = fused_add_norm
|
||||
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
128
src/axolotl/models/mamba/modeling_mamba.py
Normal file
128
src/axolotl/models/mamba/modeling_mamba.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# pylint: skip-file
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
|
||||
from mamba_ssm.utils.generation import GenerationMixin
|
||||
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from axolotl.models.mamba.configuration_mamba import MambaConfig
|
||||
|
||||
|
||||
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_layer: int,
|
||||
vocab_size: int,
|
||||
initializer_cfg=None,
|
||||
pad_vocab_size_multiple: int = 1,
|
||||
device=None,
|
||||
dtype=None,
|
||||
**backbone_kwargs,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
if vocab_size % pad_vocab_size_multiple != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (
|
||||
vocab_size % pad_vocab_size_multiple
|
||||
)
|
||||
self.config = MambaConfig(
|
||||
vocab_size=vocab_size,
|
||||
d_model=d_model,
|
||||
n_layer=n_layer,
|
||||
pad_vocab_size_multiple=pad_vocab_size_multiple,
|
||||
)
|
||||
self.backbone = MixerModel(
|
||||
d_model=d_model,
|
||||
n_layer=n_layer,
|
||||
vocab_size=vocab_size,
|
||||
initializer_cfg=initializer_cfg,
|
||||
**backbone_kwargs,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(
|
||||
partial(
|
||||
_init_weights,
|
||||
n_layer=n_layer,
|
||||
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||
)
|
||||
)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
self.lm_head.weight = self.backbone.embedding.weight
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.backbone.allocate_inference_cache(
|
||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids=None,
|
||||
inference_params=None,
|
||||
num_last_tokens=0,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||
"""
|
||||
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
||||
if num_last_tokens > 0:
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
logits = lm_logits
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
|
||||
print(loss)
|
||||
return CausalLMOutput(logits=lm_logits, loss=loss)
|
||||
|
||||
else:
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
state_dict: Optional[dict] = None,
|
||||
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
|
||||
):
|
||||
if state_dict is None:
|
||||
state_dict = self.state_dict()
|
||||
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
||||
config = load_config_hf(pretrained_model_name)
|
||||
model = cls(**config, device=device, dtype=dtype, **kwargs)
|
||||
model.load_state_dict(
|
||||
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
|
||||
)
|
||||
return model
|
||||
8
src/axolotl/models/phi/__init__.py
Normal file
8
src/axolotl/models/phi/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
MixFormers model architecture used for phi models
|
||||
"""
|
||||
|
||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
|
||||
from .configuration_phi import PhiConfig # noqa
|
||||
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
|
||||
from .modeling_phi import PhiForCausalLM # noqa
|
||||
63
src/axolotl/models/phi/configuration_mixformer_sequential.py
Normal file
63
src/axolotl/models/phi/configuration_mixformer_sequential.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# pylint: skip-file
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MixFormerSequentialConfig(PretrainedConfig):
|
||||
"""MixFormer (sequential for DeepSpeed) configuration."""
|
||||
|
||||
model_type = "mixformer-sequential"
|
||||
|
||||
attribute_map = {
|
||||
"max_position_embeddings": "n_positions",
|
||||
"hidden_size": "n_embd",
|
||||
"num_attention_heads": "n_head",
|
||||
"num_hidden_layers": "n_layer",
|
||||
"input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
|
||||
"blocks": "architecture", # `blocks` key is for backward compatibility
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: Optional[int] = 50304,
|
||||
n_positions: Optional[int] = 2048,
|
||||
n_embd: Optional[int] = 1024,
|
||||
n_layer: Optional[int] = 20,
|
||||
n_inner: Optional[int] = None,
|
||||
n_head: Optional[int] = 16,
|
||||
rotary_dim: Optional[int] = 32,
|
||||
activation_function: Optional[str] = "gelu_new",
|
||||
embd_layer: Optional[str] = "default",
|
||||
architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
|
||||
embd_pdrop: Optional[float] = 0.0,
|
||||
resid_pdrop: Optional[float] = 0.0,
|
||||
layer_norm_epsilon: Optional[float] = 1e-5,
|
||||
initializer_range: Optional[float] = 0.02,
|
||||
tie_word_embeddings: Optional[bool] = False,
|
||||
pad_vocab_size_multiple: Optional[int] = 64,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.vocab_size = int(
|
||||
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
)
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_inner = n_inner
|
||||
self.n_head = n_head
|
||||
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
||||
self.activation_function = activation_function
|
||||
self.embd_layer = embd_layer
|
||||
self.architecture = architecture
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
65
src/axolotl/models/phi/configuration_phi.py
Normal file
65
src/axolotl/models/phi/configuration_phi.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# pylint: skip-file
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
"""Phi configuration."""
|
||||
|
||||
model_type = "phi"
|
||||
attribute_map = {
|
||||
"max_position_embeddings": "n_positions",
|
||||
"hidden_size": "n_embd",
|
||||
"num_attention_heads": "n_head",
|
||||
"num_hidden_layers": "n_layer",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 50304,
|
||||
n_positions: int = 2048,
|
||||
n_embd: int = 1024,
|
||||
n_layer: int = 20,
|
||||
n_inner: Optional[int] = None,
|
||||
n_head: int = 16,
|
||||
n_head_kv: Optional[int] = None,
|
||||
rotary_dim: Optional[int] = 32,
|
||||
activation_function: Optional[str] = "gelu_new",
|
||||
flash_attn: bool = False,
|
||||
flash_rotary: bool = False,
|
||||
fused_dense: bool = False,
|
||||
attn_pdrop: float = 0.0,
|
||||
embd_pdrop: float = 0.0,
|
||||
resid_pdrop: float = 0.0,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
initializer_range: float = 0.02,
|
||||
tie_word_embeddings: bool = False,
|
||||
pad_vocab_size_multiple: int = 64,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.vocab_size = int(
|
||||
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
)
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_inner = n_inner
|
||||
self.n_head = n_head
|
||||
self.n_head_kv = n_head_kv
|
||||
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
||||
self.activation_function = activation_function
|
||||
self.flash_attn = flash_attn
|
||||
self.flash_rotary = flash_rotary
|
||||
self.fused_dense = fused_dense
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
930
src/axolotl/models/phi/modeling_mixformer_sequential.py
Normal file
930
src/axolotl/models/phi/modeling_mixformer_sequential.py
Normal file
@@ -0,0 +1,930 @@
|
||||
# pylint: skip-file
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# BSD 3-Clause License
|
||||
#
|
||||
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_qkvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference.
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||
|
||||
max_sequence_len: int
|
||||
max_batch_size: int
|
||||
sequence_len_offset: int = 0
|
||||
batch_size_offset: int = 0
|
||||
key_value_memory_dict: dict = field(default_factory=dict)
|
||||
fused_ft_kernel: bool = False
|
||||
lengths_per_sample: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
"""Token embedding with dropout."""
|
||||
|
||||
def __init__(self, config: PretrainedConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
|
||||
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
hidden_states = self.wte(input_ids)
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer.
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base: Optional[int] = 10000,
|
||||
scale_base: Optional[float] = None,
|
||||
device: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if scale_base is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
# Generate and save the inverse frequency buffer (non-trainable)
|
||||
self.dim = dim
|
||||
self.base = base
|
||||
self.scale_base = scale_base
|
||||
self.device = device
|
||||
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
scale = (
|
||||
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
||||
/ (1.4 * dim)
|
||||
if scale_base is not None
|
||||
else None
|
||||
)
|
||||
self.register_buffer("scale", scale)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
def _update_cos_sin_cache(
|
||||
self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0
|
||||
) -> None:
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
seqlen = x.shape[1] + seqlen_offset
|
||||
|
||||
# Re-generate the inverse frequency buffer if it's not fp32
|
||||
# (for instance if model.half() was called)
|
||||
if self.inv_freq.dtype != "torch.float32":
|
||||
self.inv_freq = 1.0 / (
|
||||
self.base
|
||||
** (
|
||||
torch.arange(
|
||||
0, self.dim, 2, device=self.device, dtype=torch.float32
|
||||
)
|
||||
/ self.dim
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != x.device
|
||||
or self._cos_cached.dtype != x.dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
|
||||
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(
|
||||
t, self.inv_freq.to(device=t.device, dtype=torch.float32)
|
||||
)
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(x.dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(x.dtype)
|
||||
else:
|
||||
power = (
|
||||
torch.arange(
|
||||
seqlen, dtype=self.scale.dtype, device=self.scale.device
|
||||
)
|
||||
- seqlen // 2
|
||||
) / self.scale_base
|
||||
scale = self.scale.to(device=power.device) ** rearrange(
|
||||
power, "s -> s 1"
|
||||
)
|
||||
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
||||
|
||||
def apply_rotary_emb_qkv(
|
||||
self,
|
||||
qkv: torch.FloatTensor,
|
||||
sin: torch.FloatTensor,
|
||||
cos: torch.FloatTensor,
|
||||
sin_k: Optional[torch.FloatTensor] = None,
|
||||
cos_k: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
_, seqlen, three, _, headdim = qkv.shape
|
||||
assert three == 3
|
||||
|
||||
rotary_seqlen, rotary_dim = cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= headdim
|
||||
assert seqlen <= rotary_seqlen
|
||||
|
||||
cos_k = cos if cos_k is None else cos_k
|
||||
sin_k = sin if sin_k is None else sin_k
|
||||
assert (
|
||||
sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
)
|
||||
|
||||
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
||||
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
||||
|
||||
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
||||
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
||||
|
||||
# Splits the queries and keys in half
|
||||
q1, q2 = q_rot.chunk(2, dim=-1)
|
||||
k1, k2 = k_rot.chunk(2, dim=-1)
|
||||
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
|
||||
sin[:seqlen], "s d -> s 1 d"
|
||||
)
|
||||
|
||||
# Casts to fp32 are necessary to prevent fp16 overflow issues
|
||||
q1, q2, k1, k2, c, s = [
|
||||
t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]
|
||||
]
|
||||
|
||||
# Computes the new keys and queries, recasting to original dtype
|
||||
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
|
||||
|
||||
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
|
||||
|
||||
return torch.cat(
|
||||
[
|
||||
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
|
||||
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
||||
qkv[:, :, 2:3, :, :],
|
||||
],
|
||||
axis=2,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, qkv: torch.Tensor, seqlen_offset: int = 0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Perform the forward pass.
|
||||
|
||||
Args:
|
||||
qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
|
||||
seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
|
||||
|
||||
Returns:
|
||||
New `qkv` and the cached sinusoids.
|
||||
|
||||
"""
|
||||
|
||||
self._update_cos_sin_cache(qkv, seqlen_offset)
|
||||
|
||||
return self.apply_rotary_emb_qkv(
|
||||
qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:]
|
||||
)
|
||||
|
||||
|
||||
def _update_kv_cache(kv, inference_params, layer_idx):
|
||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||
# Pre-allocate memory for key-values for inference.
|
||||
num_heads, head_dim = kv.shape[-2:]
|
||||
if layer_idx not in inference_params.key_value_memory_dict:
|
||||
kv_cache = torch.empty(
|
||||
inference_params.max_batch_size,
|
||||
inference_params.max_sequence_len,
|
||||
2,
|
||||
num_heads,
|
||||
head_dim,
|
||||
dtype=kv.dtype,
|
||||
device=kv.device,
|
||||
)
|
||||
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
||||
else:
|
||||
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
||||
|
||||
# Adjust key and value for inference
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
sequence_start = inference_params.sequence_len_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
assert batch_end <= (
|
||||
kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] # noqa
|
||||
)
|
||||
assert sequence_end <= (
|
||||
kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] # noqa
|
||||
)
|
||||
|
||||
assert kv_cache is not None
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
return kv
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""Multi-Layer Perceptron.
|
||||
|
||||
Reference:
|
||||
Attention Is All You Need.
|
||||
https://arxiv.org/pdf/1706.03762.pdf.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
n_inner: Optional[int] = None,
|
||||
act_fn: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
act_fn = config.activation_function if act_fn is None else act_fn
|
||||
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
|
||||
|
||||
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
||||
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
||||
|
||||
self.fc1 = nn.Linear(config.n_embd, n_inner)
|
||||
self.fc2 = nn.Linear(n_inner, config.n_embd)
|
||||
self.act = ACT2FN[act_fn]
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
old_keys = [
|
||||
prefix + "fc_in.weight",
|
||||
prefix + "fc_out.weight",
|
||||
prefix + "fc_in.bias",
|
||||
prefix + "fc_out.bias",
|
||||
]
|
||||
new_keys = [
|
||||
prefix + "fc1.weight",
|
||||
prefix + "fc2.weight",
|
||||
prefix + "fc1.bias",
|
||||
prefix + "fc2.bias",
|
||||
]
|
||||
|
||||
if all(k in state_dict for k in old_keys) and not all(
|
||||
k in state_dict for k in new_keys
|
||||
):
|
||||
# Older version of `MLP` saved with different key names.
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
return super()._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMLP(nn.Module):
|
||||
"""Fused Multi-Layer Perceptron from `flash-attn`.
|
||||
|
||||
Reference:
|
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
n_inner: Optional[int] = None,
|
||||
act_fn: Optional[str] = None,
|
||||
raise_on_missing: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
act_fn = config.activation_function if act_fn is None else act_fn
|
||||
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
|
||||
|
||||
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
||||
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
||||
|
||||
gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] # noqa
|
||||
activation = "gelu_approx" if act_fn in gelu_activations else "relu" # noqa
|
||||
|
||||
self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return self.mlp(hidden_states)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||
super().__init__()
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(
|
||||
self, qkv, causal=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None
|
||||
):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
|
||||
causal: if passed, will override self.causal
|
||||
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
||||
False means to mask out. (B, S)
|
||||
"""
|
||||
causal = self.causal if causal is None else causal
|
||||
if cu_seqlens is not None:
|
||||
return flash_attn_varlen_qkvpacked_func(
|
||||
qkv.squeeze(0),
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=self.drop.p,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
else:
|
||||
return flash_attn_qkvpacked_func(
|
||||
qkv,
|
||||
dropout_p=self.drop.p,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||
super().__init__()
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, q, kv, causal=None, key_padding_mask=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
q: The tensor containing the query. (B, Sq, H, D)
|
||||
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
|
||||
causal: if passed, will override self.causal
|
||||
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
||||
False means to mask out. (B, Sk)
|
||||
"""
|
||||
causal = self.causal if causal is None else causal
|
||||
return flash_attn_kvpacked_func(
|
||||
q,
|
||||
kv,
|
||||
dropout_p=self.drop.p,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
|
||||
def find_mha_dims(
|
||||
config: PretrainedConfig,
|
||||
n_head: Optional[int] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""Validate and return the number of heads and head dimension for multi-head attention.
|
||||
|
||||
Args:
|
||||
config: Model configuration.
|
||||
n_head: Number of heads.
|
||||
head_dim: Head dimension.
|
||||
|
||||
Returns:
|
||||
Number of heads and head dimension.
|
||||
|
||||
"""
|
||||
|
||||
assert all(
|
||||
hasattr(config, attr) for attr in ["n_embd", "n_head"]
|
||||
), "`config` must have `n_embd` and `n_head` attributes."
|
||||
|
||||
if head_dim is None:
|
||||
assert (
|
||||
config.n_embd % config.n_head == 0
|
||||
), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
|
||||
|
||||
if n_head is None and head_dim is None:
|
||||
head_dim = config.n_embd // config.n_head
|
||||
n_head = config.n_head
|
||||
elif n_head is None or head_dim is None:
|
||||
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
|
||||
|
||||
return n_head, head_dim
|
||||
|
||||
|
||||
class MHA(nn.Module):
|
||||
"""Multi-head attention layer.
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
rotary_dim: Optional[int] = None,
|
||||
n_head: Optional[int] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
bias: Optional[bool] = True,
|
||||
dropout: Optional[float] = 0.0,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: Optional[bool] = True,
|
||||
layer_idx: Optional[int] = None,
|
||||
rotary_emb_scale_base: Optional[float] = None,
|
||||
return_residual: Optional[bool] = False,
|
||||
checkpointing: Optional[bool] = False,
|
||||
device: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
fused_dense: Optional[bool] = True,
|
||||
flash_attn: Optional[bool] = True,
|
||||
cutlass_attn: Optional[bool] = False,
|
||||
flash_rotary: Optional[bool] = True,
|
||||
raise_on_missing: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
n_head, head_dim = find_mha_dims(config, n_head, head_dim)
|
||||
|
||||
self.hidden_size = config.n_embd
|
||||
self.n_head = n_head
|
||||
self.head_dim = head_dim
|
||||
self.op_size = n_head * head_dim
|
||||
|
||||
self.causal = causal
|
||||
self.layer_idx = layer_idx
|
||||
self.rotary_emb_dim = (
|
||||
rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
|
||||
)
|
||||
self.fused_dense = fused_dense
|
||||
self.flash_attn = flash_attn
|
||||
self.cutlass_attn = cutlass_attn
|
||||
self.flash_rotary = flash_rotary
|
||||
self.return_residual = return_residual
|
||||
self.checkpointing = checkpointing
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
rotary_kwargs = {"device": device}
|
||||
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
||||
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
||||
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
|
||||
else:
|
||||
pass
|
||||
|
||||
self.Wqkv = nn.Linear(
|
||||
self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.out_proj = nn.Linear(
|
||||
self.op_size, self.hidden_size, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
self.inner_attn = SelfAttention(
|
||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||
)
|
||||
self.inner_cross_attn = CrossAttention(
|
||||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||
)
|
||||
|
||||
def _update_kv_cache(
|
||||
self, kv: torch.FloatTensor, inference_params: InferenceParams
|
||||
) -> None:
|
||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
||||
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||
|
||||
assert (
|
||||
self.layer_idx is not None
|
||||
), "Generation requires layer_idx in the constructor"
|
||||
|
||||
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.FloatTensor,
|
||||
x_kv: Optional[torch.FloatTensor] = None,
|
||||
key_padding_mask: Optional[torch.BoolTensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
mixer_subset: Optional[torch.LongTensor] = None,
|
||||
past_cache: Optional[InferenceParams] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Perform the forward pass.
|
||||
|
||||
Args:
|
||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
||||
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
||||
is the is the sum of the sequence lengths in the batch.
|
||||
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
|
||||
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
||||
(batch, seqlen). Only applicable when not using FlashAttention.
|
||||
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into x. Only applicable when using
|
||||
FlashAttention.
|
||||
max_seqlen: int. Maximum sequence length in the batch.
|
||||
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
||||
before applying the query projection. Useful for e.g., ViT where we only care
|
||||
about the CLS token in the last layer.
|
||||
past_cache: For generation only.
|
||||
|
||||
Returns:
|
||||
(batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
|
||||
else (total, hidden_dim) where total is the is the sum of the sequence lengths
|
||||
in the batch.
|
||||
|
||||
"""
|
||||
|
||||
if cu_seqlens is not None:
|
||||
assert max_seqlen is not None
|
||||
assert key_padding_mask is None
|
||||
assert self.flash_attn
|
||||
# assert self.rotary_emb_dim == 0
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert cu_seqlens is None
|
||||
assert max_seqlen is None
|
||||
assert not self.flash_attn
|
||||
|
||||
if past_cache is not None:
|
||||
assert key_padding_mask is None
|
||||
assert cu_seqlens is None and max_seqlen is None
|
||||
|
||||
attn_kwargs = {"key_padding_mask": key_padding_mask}
|
||||
|
||||
assert x_kv is None and mixer_subset is None
|
||||
|
||||
qkv = self.Wqkv(x)
|
||||
qkv = rearrange(
|
||||
qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
|
||||
)
|
||||
|
||||
if past_cache is None:
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv)
|
||||
context = self.inner_attn(
|
||||
qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, **attn_kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
|
||||
q = qkv[:, :, 0]
|
||||
kv = self._update_kv_cache(qkv[:, :, 1:], past_cache)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if past_cache.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
|
||||
out = rearrange(context, "... h d -> ... (h d)")
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out if not self.return_residual else (out, x)
|
||||
|
||||
|
||||
class ParallelBlock(nn.Module):
|
||||
"""Parallel block.
|
||||
|
||||
This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
mixer: Optional[Dict[str, Any]] = None,
|
||||
mlp: Optional[Dict[str, Any]] = None,
|
||||
block_idx: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
self.block_idx = block_idx
|
||||
|
||||
self.mixer = MHA(config, layer_idx=block_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
past_cache: Optional[torch.FloatTensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln(hidden_states)
|
||||
|
||||
attn_outputs = self.mixer(
|
||||
hidden_states,
|
||||
past_cache=past_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
if isinstance(attn_outputs, tuple):
|
||||
attn_outputs = attn_outputs[0]
|
||||
|
||||
attn_outputs = self.resid_dropout(attn_outputs)
|
||||
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
||||
|
||||
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CausalLMHead(nn.Module):
|
||||
"""Causal Language Modeling head.
|
||||
|
||||
Reference:
|
||||
Improving Language Understanding by Generative Pre-Training.
|
||||
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
self.linear = nn.Linear(config.n_embd, config.vocab_size)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = self.ln(hidden_states)
|
||||
logits = self.linear(hidden_states).to(torch.float32)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class CausalLMLoss(nn.Module):
|
||||
"""Causal Language Modeling loss.
|
||||
|
||||
Reference:
|
||||
Improving Language Understanding by Generative Pre-Training.
|
||||
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, shift_labels: Optional[bool] = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.shift_labels = shift_labels
|
||||
self.loss_fct = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(
|
||||
self, logits: torch.FloatTensor, labels: torch.LongTensor
|
||||
) -> torch.FloatTensor:
|
||||
if self.shift_labels:
|
||||
logits = logits[..., :-1, :].contiguous()
|
||||
labels = labels[..., 1:].contiguous()
|
||||
|
||||
loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class MixFormerSequentialPreTrainedModel(PreTrainedModel):
|
||||
"""MixFormer (sequential for DeepSpeed) pre-trained model."""
|
||||
|
||||
config_class = MixFormerSequentialConfig
|
||||
base_model_prefix = "transformer"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, *inputs, **kwargs) -> None:
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
if "use_cache" in kwargs and not kwargs["use_cache"]:
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
if past_key_values is None or not (
|
||||
isinstance(past_key_values, InferenceParams)
|
||||
):
|
||||
past_key_values = InferenceParams(
|
||||
max_batch_size=input_ids.shape[0],
|
||||
max_sequence_len=self.config.n_positions,
|
||||
sequence_len_offset=0,
|
||||
batch_size_offset=0,
|
||||
fused_ft_kernel=False,
|
||||
key_value_memory_dict={},
|
||||
)
|
||||
else:
|
||||
# assume past_key_values has cached all but last token in input_ids
|
||||
past_key_values.sequence_len_offset = len(input_ids[0]) - 1
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
|
||||
|
||||
|
||||
class PackedSequential(nn.Sequential):
|
||||
def forward(
|
||||
self,
|
||||
input,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
max_seqlen: Optional[int] = None,
|
||||
):
|
||||
for module in self:
|
||||
sig = inspect.signature(module.forward)
|
||||
if "cu_seqlens" in sig.parameters:
|
||||
input = module(input, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
||||
else:
|
||||
input = module(input)
|
||||
return input
|
||||
|
||||
|
||||
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
||||
"""MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
|
||||
|
||||
_keys_to_ignore_on_load_missing = [""]
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"
|
||||
]
|
||||
_no_split_modules = ["ParallelBlock"]
|
||||
|
||||
def __init__(self, config: MixFormerSequentialConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
modules = [Embedding(config)]
|
||||
block_config = config.architecture
|
||||
|
||||
if not isinstance(block_config, list):
|
||||
block_config = [block_config for _ in range(config.n_layer)]
|
||||
|
||||
if config.n_layer != len(block_config):
|
||||
config.n_layer = len(block_config)
|
||||
|
||||
for block_idx, block in enumerate(block_config):
|
||||
# `block_cls` with `legacy` value is for backward compatibility
|
||||
# `path` key is for backward compatibility
|
||||
block = copy.deepcopy(block) or {"block_cls": "parallel"}
|
||||
block.pop("path", None) or block.pop("block_cls", None)
|
||||
|
||||
block["block_idx"] = block_idx
|
||||
modules.append(ParallelBlock(config, **block))
|
||||
|
||||
modules.append(CausalLMHead(config))
|
||||
|
||||
self.layers = PackedSequential(*modules)
|
||||
self.loss = CausalLMLoss()
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.layers[0].wte
|
||||
|
||||
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
||||
self.layers[0].wte = new_embeddings
|
||||
|
||||
def get_output_embeddings(self) -> nn.Linear:
|
||||
return self.layers[-1].linear
|
||||
|
||||
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
||||
self.layers[-1].linear = new_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
cu_seqlens: Optional[torch.LongTensor] = None
|
||||
max_seqlen: Optional[int] = None
|
||||
if position_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if not past_key_values:
|
||||
lm_logits = self.layers(
|
||||
input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
||||
)
|
||||
else:
|
||||
hidden_layer = self.layers[0](input_ids)
|
||||
for module in self.layers[1:-1]:
|
||||
hidden_layer = module(
|
||||
hidden_layer,
|
||||
past_cache=past_key_values,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
lm_logits = self.layers[-1](hidden_layer)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss(lm_logits, labels)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss, logits=lm_logits, past_key_values=past_key_values
|
||||
)
|
||||
1063
src/axolotl/models/phi/modeling_phi.py
Normal file
1063
src/axolotl/models/phi/modeling_phi.py
Normal file
File diff suppressed because it is too large
Load Diff
66
src/axolotl/monkeypatch/btlm_attn_hijack_flash.py
Normal file
66
src/axolotl/monkeypatch/btlm_attn_hijack_flash.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Flash attention monkey patch for cerebras btlm model
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
|
||||
# this is a wonky hack to get the remotely loaded module
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_btlm to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(
|
||||
".configuration_btlm", ".modeling_btlm"
|
||||
)
|
||||
modeling_btlm = importlib.import_module(module_name)
|
||||
modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access
|
||||
flashattn_attn
|
||||
)
|
||||
|
||||
|
||||
def flashattn_attn(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
value: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
softmax_scale = (
|
||||
1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None
|
||||
)
|
||||
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
|
||||
# Perform Flash attention
|
||||
attn_output = flash_attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p=0.0, # Assuming you have this attribute
|
||||
softmax_scale=softmax_scale, # Set this if you have specific scaling in mind
|
||||
causal=not self.is_cross_attention, # Assuming you have this attribute
|
||||
return_attn_probs=False, # Set this based on your needs
|
||||
)
|
||||
|
||||
# Optional: Apply head mask if it's not None
|
||||
if head_mask is not None:
|
||||
attn_output *= head_mask
|
||||
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
|
||||
return attn_output, None # We don't have explicit attn_weights in Flash attention
|
||||
174
src/axolotl/monkeypatch/fastchat_conversation_turns.py
Normal file
174
src/axolotl/monkeypatch/fastchat_conversation_turns.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
monkeypatch to add a get_turns method
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from fastchat.conversation import SeparatorStyle
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
|
||||
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
ret = ""
|
||||
for role, msg in self.get_turns():
|
||||
ret += role + msg
|
||||
return ret
|
||||
|
||||
|
||||
def get_turns( # pylint: disable=too-many-return-statements
|
||||
self,
|
||||
) -> Generator[Tuple[str, str], None, None]:
|
||||
"""Get the prompt for generation."""
|
||||
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + ": ", message + seps[i % 2]
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ": ", "" # must be end with a space
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
||||
yield "", "" if system_prompt == "" else system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", message + self.sep
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role, message + self.sep
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role, message + seps[i % 2]
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.RWKV:
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + ": ", message.replace("\r\n", "\n").replace(
|
||||
"\n\n", "\n"
|
||||
) + "\n\n"
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2:
|
||||
seps = [self.sep, self.sep2]
|
||||
if self.system_message:
|
||||
yield "", system_prompt
|
||||
else:
|
||||
yield "", "[INST] "
|
||||
for i, (role, message) in enumerate(self.messages[1:]):
|
||||
if message:
|
||||
yield role + " ", message + seps[i % 2]
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||
round_add_n = 1 if self.name == "chatglm2" else 0
|
||||
if system_prompt:
|
||||
yield "", system_prompt + self.sep
|
||||
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if i % 2 == 0:
|
||||
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
|
||||
|
||||
if message:
|
||||
yield f"{role}:", f"{message}{self.sep}"
|
||||
else:
|
||||
yield f"{role}:", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATML:
|
||||
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", message + self.sep + "\n"
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATINTERN:
|
||||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
prefix = "<s>" if i % 2 == 0 else ""
|
||||
if message:
|
||||
yield prefix + role + ":", message + seps[i % 2] + "\n"
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.DOLLY:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
suffix = "\n\n" if i % 2 == 1 else ""
|
||||
yield role + ":\n", message + seps[i % 2] + suffix
|
||||
else:
|
||||
yield role + ":\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.PHOENIX:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", "<s>" + message + "</s>"
|
||||
else:
|
||||
yield role + ": " + "<s>", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ROBIN:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ":\n", message + self.sep
|
||||
else:
|
||||
yield role + ":\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.FALCON_CHAT:
|
||||
if self.system_message:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ":", ""
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
|
||||
def add_get_turns_to_conversation():
|
||||
import fastchat.conversation
|
||||
|
||||
fastchat.conversation.Conversation.get_turns = get_turns
|
||||
fastchat.conversation.Conversation.get_prompt = get_prompt
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -11,12 +13,18 @@ import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaMLP,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
@@ -33,7 +41,36 @@ except ImportError:
|
||||
)
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def replace_llama_mlp_with_swiglu(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaMLP):
|
||||
mlp = FusedMLP(
|
||||
module.config, module.gate_proj, module.up_proj, module.down_proj
|
||||
)
|
||||
set_module_name(model, name, mlp)
|
||||
|
||||
|
||||
def replace_llama_qkv_with_fused(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
qkv = FusedAttention(
|
||||
module.config,
|
||||
module.q_proj,
|
||||
module.k_proj,
|
||||
module.v_proj,
|
||||
module.o_proj,
|
||||
)
|
||||
set_module_name(model, name, qkv)
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
cross_entropy: Optional[bool] = False,
|
||||
rms_norm: Optional[bool] = False,
|
||||
):
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
@@ -44,6 +81,124 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
||||
llama_model_forward
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if cross_entropy:
|
||||
try:
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if rms_norm:
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm
|
||||
|
||||
class LlamaRMSNorm(RMSNorm):
|
||||
"""Patched LLamaRMSNorm"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__(hidden_size, eps=eps)
|
||||
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||
)
|
||||
|
||||
|
||||
class FusedAttention(LlamaAttention):
|
||||
"""
|
||||
Fused QKV Attention layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
q: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
k: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
v: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
o: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.init_device = next(iter(q.state_dict().values())).device
|
||||
|
||||
# define equivalent fused qkv projection
|
||||
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
||||
self.qkv_proj = torch.nn.Linear(
|
||||
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
||||
)
|
||||
self.o_proj = o
|
||||
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.qkv_proj.weight.data = torch.cat(
|
||||
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
||||
)
|
||||
|
||||
def _post_training(self, model, name):
|
||||
q_proj, k_proj, v_proj = torch.split(
|
||||
self.qkv_proj.weight.data, self.out_features, dim=0
|
||||
)
|
||||
|
||||
new_attn = LlamaAttention(self.config)
|
||||
new_attn.q_proj.weight.data = q_proj
|
||||
new_attn.k_proj.weight.data = k_proj
|
||||
new_attn.v_proj.weight.data = v_proj
|
||||
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
||||
|
||||
set_module_name(model, name, new_attn)
|
||||
|
||||
|
||||
class FusedMLP(torch.nn.Module):
|
||||
"""
|
||||
Fused MLP layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
gate_proj: torch.nn.Linear,
|
||||
up_proj: torch.nn.Linear,
|
||||
down_proj: torch.nn.Linear,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.swiglu = SwiGLU(
|
||||
in_features=config.hidden_size,
|
||||
hidden_features=config.intermediate_size,
|
||||
bias=False,
|
||||
_pack_weights=True,
|
||||
)
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.swiglu.w12.weight.data = torch.cat(
|
||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||
)
|
||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||
|
||||
def _post_training(self, model, name):
|
||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||
)
|
||||
|
||||
# Assign the split weights back to the original layers
|
||||
new_mlp = LlamaMLP(self.config)
|
||||
new_mlp.gate_proj.weight.data = w1
|
||||
new_mlp.up_proj.weight.data = w2
|
||||
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||
|
||||
set_module_name(model, name, new_mlp)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||
return self.swiglu(x)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
@@ -66,6 +221,7 @@ def flashattn_forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
@@ -105,9 +261,14 @@ def flashattn_forward(
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if isinstance(self, FusedAttention):
|
||||
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
||||
self.out_features, dim=-1
|
||||
)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
@@ -160,7 +321,9 @@ def flashattn_forward(
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None:
|
||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
@@ -169,7 +332,12 @@ def flashattn_forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
@@ -192,7 +360,7 @@ def flashattn_forward(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
0.0,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
@@ -205,6 +373,7 @@ def flashattn_forward(
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
dropout_p=dropout_rate,
|
||||
causal=is_causal,
|
||||
)
|
||||
else:
|
||||
@@ -228,6 +397,8 @@ def flashattn_forward(
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
)
|
||||
if q_unpad.dtype != kv_unpad.dtype:
|
||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
@@ -235,7 +406,7 @@ def flashattn_forward(
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
0.0,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
@@ -441,6 +612,13 @@ def llama_model_forward(
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
padding_mask = None
|
||||
else:
|
||||
if 0 in attention_mask:
|
||||
padding_mask = attention_mask
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
@@ -475,7 +653,9 @@ def llama_model_forward(
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs)
|
||||
return module(
|
||||
*inputs,
|
||||
)
|
||||
|
||||
return custom_forward
|
||||
|
||||
@@ -484,9 +664,10 @@ def llama_model_forward(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
None,
|
||||
padding_mask,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
@@ -498,6 +679,7 @@ def llama_model_forward(
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
@@ -544,6 +726,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
@@ -576,6 +759,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
@@ -25,6 +25,8 @@ def sdp_attention_forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@@ -29,6 +29,8 @@ def xformers_forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
643
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
Normal file
643
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
Normal file
@@ -0,0 +1,643 @@
|
||||
"""Flash attention monkey patch for mistral model"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention as OriginalMistralAttention,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||
|
||||
|
||||
def replace_mistral_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
):
|
||||
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||
flashattn_forward
|
||||
)
|
||||
if packed:
|
||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||
MistralDecoderLayer
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
||||
mistral_model_forward
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _make_sliding_window_causal_mask(
|
||||
bsz: int,
|
||||
tgt_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
past_key_values_length: int = 0,
|
||||
sliding_window: int = 4096,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for sliding window attention
|
||||
"""
|
||||
tensor = torch.full(
|
||||
(tgt_len, tgt_len),
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
mask = torch.tril(tensor, diagonal=0)
|
||||
# make the mask banded to account for sliding window
|
||||
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
||||
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
||||
mask = torch.log(mask).to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return mask[None, None, :, :].expand(
|
||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window,
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
||||
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
||||
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
||||
sliding_window_mask = _make_sliding_window_causal_mask(
|
||||
bsz=input_shape[0],
|
||||
tgt_len=input_shape[1],
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
attention_mask = attention_mask + sliding_window_mask
|
||||
else:
|
||||
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
def flashattn_forward(
|
||||
self: OriginalMistralAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
use_sliding_windows = (
|
||||
hasattr(self.config, "sliding_window") is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
)
|
||||
|
||||
if use_sliding_windows:
|
||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
||||
else:
|
||||
window_size = (-1, -1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
if (
|
||||
hasattr(self.config, "sliding_window")
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
):
|
||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[0]
|
||||
past_value = past_key_value[1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
past_key_value = (past_key, past_value) if use_cache else None
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if self.training:
|
||||
# during training q,k,v always have same seqlen
|
||||
assert key_states.shape == query_states.shape
|
||||
is_causal = True
|
||||
else:
|
||||
# turn off FA causal mask after first inference autoregressive iteration
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
qkvpacked=True,
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
else:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
if attention_mask is None or attention_mask.all().item():
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
dropout_p=dropout_rate,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
else:
|
||||
( # pylint: disable=unbalanced-tuple-unpacking
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
_,
|
||||
_,
|
||||
output_pad_fn,
|
||||
) = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
kvpacked=True,
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
)
|
||||
if q_unpad.dtype != kv_unpad.dtype:
|
||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
|
||||
attn_output = output
|
||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||
def generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
kvpacked=False,
|
||||
qkvpacked=False,
|
||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, d)
|
||||
k: (batch_size, seqlen_k, nheads_k, d)
|
||||
v: (batch_size, seqlen_k, nheads_k, d)
|
||||
query_padding_mask: (batch_size, seqlen), bool
|
||||
key_padding_mask: (batch_size, seqlen), bool
|
||||
"""
|
||||
assert not (kvpacked and qkvpacked)
|
||||
batch_size, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
|
||||
if query_padding_mask is not None:
|
||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||
q, query_padding_mask
|
||||
)
|
||||
|
||||
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_q,
|
||||
step=seqlen_q,
|
||||
dtype=torch.int32,
|
||||
device=q_unpad.device,
|
||||
)
|
||||
max_seqlen_q = seqlen_q
|
||||
|
||||
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||
else:
|
||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||
cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_k,
|
||||
step=seqlen_k,
|
||||
dtype=torch.int32,
|
||||
device=k_unpad.device,
|
||||
)
|
||||
max_seqlen_k = seqlen_k
|
||||
|
||||
if qkvpacked:
|
||||
assert nheads == nheads_k
|
||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||
|
||||
if kvpacked:
|
||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
return (
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
kv,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
return (
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
|
||||
def mistral_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
transformers.logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
None,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
||||
"""
|
||||
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
22
src/axolotl/monkeypatch/mixtral/__init__.py
Normal file
22
src/axolotl/monkeypatch/mixtral/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Patches to support multipack for mixtral
|
||||
"""
|
||||
import transformers
|
||||
|
||||
|
||||
def replace_mixtral_attn_with_multipack_flash_attn():
|
||||
from .modeling_mixtral import (
|
||||
MixtralMultipackFlashAttention2,
|
||||
mixtral_decoder_layer_forward,
|
||||
mixtral_model_forward,
|
||||
)
|
||||
|
||||
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
|
||||
mixtral_decoder_layer_forward
|
||||
)
|
||||
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
||||
mixtral_model_forward
|
||||
)
|
||||
transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
|
||||
"flash_attention_2"
|
||||
] = MixtralMultipackFlashAttention2
|
||||
379
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
Normal file
379
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
Mixtral modeling for multipack
|
||||
"""
|
||||
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
|
||||
import logging
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
from transformers import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import MoeModelOutputWithPast
|
||||
from transformers.models.mixtral.modeling_mixtral import (
|
||||
MixtralFlashAttention2,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
|
||||
|
||||
|
||||
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
|
||||
"""
|
||||
Custom multipack implementation w flash attention 2
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._flash_attn_uses_top_left_mask = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
attn_output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=self.attention_dropout,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def mixtral_decoder_layer_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_router_logits (`bool`, *optional*):
|
||||
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
||||
should not be returned during inference.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
if output_router_logits:
|
||||
outputs += (router_logits,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def mixtral_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, MoeModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = (
|
||||
attention_mask
|
||||
if (attention_mask is not None and 0 in attention_mask)
|
||||
else None
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
LOG.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_router_logits = () if output_router_logits else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if output_router_logits:
|
||||
all_router_logits += (layer_outputs[-1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache()
|
||||
if use_legacy_cache
|
||||
else next_decoder_cache
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
all_router_logits,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
65
src/axolotl/monkeypatch/neft_embeddings.py
Normal file
65
src/axolotl/monkeypatch/neft_embeddings.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
def patch_neft(alpha, model):
|
||||
embeddings = None
|
||||
if isinstance(model, PreTrainedModel):
|
||||
embeddings = model.get_input_embeddings()
|
||||
if isinstance(model, PeftModel):
|
||||
embeddings = model.base_model.get_input_embeddings()
|
||||
if not embeddings:
|
||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
||||
embeddings.noisy_embedding_alpha = alpha
|
||||
old_forward = embeddings.forward
|
||||
|
||||
# This hack seems to be needed to properly use a custom forward pass
|
||||
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
||||
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
|
||||
embeddings, embeddings.__class__
|
||||
)
|
||||
setattr(embeddings, "forward", bound_method)
|
||||
|
||||
embeddings._old_forward = old_forward # pylint: disable=protected-access
|
||||
return model
|
||||
|
||||
|
||||
def unpatch_neft(model):
|
||||
embeddings = None
|
||||
if isinstance(model, PreTrainedModel):
|
||||
embeddings = model.get_input_embeddings()
|
||||
if isinstance(model, PeftModel):
|
||||
embeddings = model.base_model.get_input_embeddings()
|
||||
if not embeddings:
|
||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
||||
if hasattr(embeddings, "_old_forward"):
|
||||
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
|
||||
del embeddings._old_forward # pylint: disable=protected-access
|
||||
del embeddings.noisy_embedding_alpha
|
||||
|
||||
|
||||
def neft_forward(self, inputs: torch.Tensor):
|
||||
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
|
||||
|
||||
if self.training:
|
||||
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
||||
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
|
||||
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def pretrain_hook(cfg, trainer):
|
||||
if cfg.noisy_embedding_alpha:
|
||||
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
|
||||
|
||||
|
||||
def post_train_hook(cfg, trainer):
|
||||
if cfg.noisy_embedding_alpha:
|
||||
unpatch_neft(trainer.model)
|
||||
415
src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py
Normal file
415
src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# This code is based off the following work:
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
""" PyTorch StableLM Epoch model. """
|
||||
import importlib
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import init_empty_weights
|
||||
from einops import rearrange
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.utils import logging
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
|
||||
# this is a wonky hack to get the remotely loaded module
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_stablelm_epoch to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(
|
||||
".configuration_stablelm_epoch", ".modeling_stablelm_epoch"
|
||||
)
|
||||
modeling_stablelm = importlib.import_module(module_name)
|
||||
modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access
|
||||
flashattn_attn
|
||||
)
|
||||
modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access
|
||||
stablelm_model_forward
|
||||
)
|
||||
modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access
|
||||
decoder_layer_forward
|
||||
)
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
# pylint: disable=invalid-name
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
# pylint: disable=invalid-name
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def flashattn_attn(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False, # pylint: disable=unused-argument
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
query_rot = query_states[..., : self.rotary_ndims]
|
||||
query_pass = query_states[..., self.rotary_ndims :]
|
||||
key_rot = key_states[..., : self.rotary_ndims]
|
||||
key_pass = key_states[..., self.rotary_ndims :]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_rot, key_rot, cos, sin, position_ids
|
||||
)
|
||||
|
||||
# [batch_size, num_heads, seq_len, head_dim]
|
||||
query_states = torch.cat((query_states, query_pass), dim=-1)
|
||||
key_states = torch.cat((key_states, key_pass), dim=-1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Reuse k, v, self_attention
|
||||
key_states = torch.cat((past_key_value[0], key_states), dim=2)
|
||||
value_states = torch.cat((past_key_value[1], value_states), dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# Repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
softmax_scale = None
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True
|
||||
)
|
||||
|
||||
attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# Upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
# Merge heads
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
# Final linear projection
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
def decoder_layer_forward(
|
||||
self,
|
||||
hidden_states: Optional[torch.FloatTensor],
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]
|
||||
]:
|
||||
# pylint: disable=duplicate-code
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def stablelm_model_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# Retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# Embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# Decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
None,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# Add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
@@ -101,3 +101,16 @@ def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
max_seq_lens.append(max_seq_len)
|
||||
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
def set_module_name(model, name, value):
|
||||
if "." in name:
|
||||
parent_name = name.rsplit(".", 1)[0]
|
||||
child_name = name[len(parent_name) + 1 :]
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent_name = ""
|
||||
parent = model
|
||||
child_name = name
|
||||
|
||||
setattr(parent, child_name, value)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module to load prompt strategies."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||
|
||||
@@ -16,6 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg):
|
||||
load_kwargs = {}
|
||||
if strategy == "user_defined":
|
||||
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
||||
else:
|
||||
sig = inspect.signature(func)
|
||||
if "ds_cfg" in sig.parameters:
|
||||
load_kwargs["ds_cfg"] = ds_cfg
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||
"""Module for Alpaca prompt strategy classes"""
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
@@ -9,9 +9,13 @@ from axolotl.prompt_tokenizers import (
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
prompt_style = PromptStyle.CHAT.value
|
||||
if ds_cfg and "conversation" in ds_cfg:
|
||||
prompt_style = ds_cfg["conversation"]
|
||||
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||
AlpacaPrompter(prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user