Compare commits
144 Commits
neft-v2
...
NanoCode01
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ecc3a408c | ||
|
|
9ca358b671 | ||
|
|
553c80f79a | ||
|
|
eb4c99431b | ||
|
|
cbdbf9e6e5 | ||
|
|
bdfefaf054 | ||
|
|
63fb3eb426 | ||
|
|
31d23504a5 | ||
|
|
f243c2186d | ||
|
|
59b2d302c8 | ||
|
|
bcc78d8fa3 | ||
|
|
74532ddc45 | ||
|
|
8ba27f3bde | ||
|
|
a3e8783328 | ||
|
|
b31038aae9 | ||
|
|
c75f916745 | ||
|
|
4d2e842e46 | ||
|
|
3678a6c41d | ||
|
|
f8ae59b0a8 | ||
|
|
4f4d638b84 | ||
|
|
ba043a361e | ||
|
|
41353d2ea0 | ||
|
|
f6ecf14dd4 | ||
|
|
dec66d7c53 | ||
|
|
76357dc5da | ||
|
|
70b46ca4f4 | ||
|
|
85dd4d525b | ||
|
|
384b817dc0 | ||
|
|
db9094df0f | ||
|
|
6ef46f8dca | ||
|
|
628b754824 | ||
|
|
37820f6540 | ||
|
|
7d4185ffcb | ||
|
|
93ebec1ac5 | ||
|
|
2e61dc3180 | ||
|
|
1ffa3866f2 | ||
|
|
62ba1609b6 | ||
|
|
7bbaac98f7 | ||
|
|
161bcb6517 | ||
|
|
d25c34caa6 | ||
|
|
13e938149d | ||
|
|
85de004dd4 | ||
|
|
80ec7af358 | ||
|
|
f28e75513b | ||
|
|
5ada140ff0 | ||
|
|
712fd27b3f | ||
|
|
ef24342538 | ||
|
|
5ea3aa31f0 | ||
|
|
f1f60cb5b2 | ||
|
|
450e04d3c4 | ||
|
|
b0cf397ecb | ||
|
|
5f79b8242f | ||
|
|
f1de29dd1e | ||
|
|
7fabc4d95e | ||
|
|
9a5eb3990c | ||
|
|
86487c2e96 | ||
|
|
35f9b0f149 | ||
|
|
68b227a7d8 | ||
|
|
03c6318ba3 | ||
|
|
40a6362c92 | ||
|
|
d339beb9d9 | ||
|
|
fde091cb12 | ||
|
|
06ae39200b | ||
|
|
a581e9f8f6 | ||
|
|
992e742cdc | ||
|
|
a1da39cd48 | ||
|
|
58ec8b1113 | ||
|
|
476a205cea | ||
|
|
3e3229e2d9 | ||
|
|
1d21aa6b0a | ||
|
|
71b7ea3c05 | ||
|
|
a48dbf6561 | ||
|
|
6a4562ac08 | ||
|
|
1115c501b8 | ||
|
|
7ee3c4cacb | ||
|
|
fb12895a17 | ||
|
|
9fc29e082b | ||
|
|
575a082aae | ||
|
|
ddf815022a | ||
|
|
9bf854e59c | ||
|
|
797f3dd1de | ||
|
|
0de1457189 | ||
|
|
3cc67d2cdd | ||
|
|
1bc11868eb | ||
|
|
b3a61e8ce2 | ||
|
|
8a8d1c4023 | ||
|
|
332984db18 | ||
|
|
48630f5b34 | ||
|
|
b33c1d55a2 | ||
|
|
0c2a630326 | ||
|
|
db8a8afcba | ||
|
|
14706504e3 | ||
|
|
501b4d1379 | ||
|
|
306fe19c54 | ||
|
|
614cff4107 | ||
|
|
1a6309c8a6 | ||
|
|
105d0b350b | ||
|
|
f544ab2bed | ||
|
|
641e6f7e51 | ||
|
|
6dc68a653f | ||
|
|
7de6a5639c | ||
|
|
c74f045ba7 | ||
|
|
0402d19759 | ||
|
|
b2430ce670 | ||
|
|
4c834bf25d | ||
|
|
8056ecd30e | ||
|
|
738a057674 | ||
|
|
cdc71f73c8 | ||
|
|
6459ac7357 | ||
|
|
964d858da0 | ||
|
|
10388a8daf | ||
|
|
9f7e8a971d | ||
|
|
637ed095a0 | ||
|
|
827ec3d274 | ||
|
|
8b79ff0e94 | ||
|
|
0800885e2f | ||
|
|
d3193beac3 | ||
|
|
2e71ff03a6 | ||
|
|
facc49f32b | ||
|
|
e50ab072e2 | ||
|
|
05bd6f1122 | ||
|
|
20aa4b57d2 | ||
|
|
11d1d607db | ||
|
|
6c81c61bc4 | ||
|
|
9b43e7ea15 | ||
|
|
2d8def68dc | ||
|
|
44c9d0151a | ||
|
|
ca84cca2c0 | ||
|
|
32eeeb5b64 | ||
|
|
afedc470bd | ||
|
|
9923b72649 | ||
|
|
21cf09b608 | ||
|
|
15d3a654bf | ||
|
|
a21935f07a | ||
|
|
8966a6f566 | ||
|
|
e4d1585c4e | ||
|
|
70157ccb8f | ||
|
|
3a99495b05 | ||
|
|
440c3ab527 | ||
|
|
992d57f20a | ||
|
|
91a016f410 | ||
|
|
a045db0214 | ||
|
|
e1b214c62b | ||
|
|
3553172e3c |
7
.github/workflows/base.yml
vendored
7
.github/workflows/base.yml
vendored
@@ -28,7 +28,12 @@ jobs:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
pytorch: 2.1.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
- cuda: "121"
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
56
.github/workflows/main.yml
vendored
56
.github/workflows/main.yml
vendored
@@ -23,39 +23,60 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
pytorch: 2.1.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.1
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v3
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: winglian/axolotl
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v4
|
||||
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
||||
- name: Build and export to Docker
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
load: true
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
- name: Unit Tests
|
||||
run: |
|
||||
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
- name: Push to Docker Hub
|
||||
if: github.event_name != 'pull_request'
|
||||
run: |
|
||||
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
if [ -n "$latest_tag" ]; then
|
||||
docker push "$latest_tag"
|
||||
fi
|
||||
|
||||
build-axolotl-runpod:
|
||||
needs: build-axolotl
|
||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||
@@ -77,26 +98,31 @@ jobs:
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
pytorch: 2.1.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.1
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v3
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: winglian/axolotl-runpod
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v4
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
|
||||
46
.github/workflows/tests-docker.yml
vendored
Normal file
46
.github/workflows/tests-docker.yml
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
name: e2e-docker-tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/*.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.1
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Build Docker image
|
||||
run: |
|
||||
# Set up build arguments
|
||||
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
|
||||
CUDA="${{ matrix.cuda }}"
|
||||
PYTORCH_VERSION="${{ matrix.pytorch }}"
|
||||
# Build the Docker image
|
||||
docker build . \
|
||||
--file ./docker/Dockerfile \
|
||||
--build-arg BASE_TAG=$BASE_TAG \
|
||||
--build-arg CUDA=$CUDA \
|
||||
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
|
||||
--tag test-axolotl
|
||||
- name: Unit Tests w docker image
|
||||
run: |
|
||||
docker run --rm test-axolotl pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@@ -71,8 +71,9 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
||||
pip3 uninstall -y transformers accelerate
|
||||
pip3 install -U -e .[flash-attn]
|
||||
pip3 install -U -e .[flash-attn,mamba-ssm]
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Run e2e tests
|
||||
|
||||
@@ -8,6 +8,9 @@ ignore_missing_imports = True
|
||||
[mypy-axolotl.monkeypatch.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-axolotl.models.mixtral.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-axolotl.models.phi.*]
|
||||
ignore_errors = True
|
||||
|
||||
|
||||
317
README.md
317
README.md
@@ -25,17 +25,20 @@ Features:
|
||||
- [Installation](#installation)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [Runpod](#runpod)
|
||||
- [LambdaLabs](#lambdalabs)
|
||||
- [Windows](#windows)
|
||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||
- [Dataset](#dataset)
|
||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||
- [Config](#config)
|
||||
- [Train](#train)
|
||||
- [Training w/ Deepspeed](#training-with-deepspeed)
|
||||
- [Inference](#inference)
|
||||
- [Merge LORA to Base](#merge-lora-to-base)
|
||||
- [Special Tokens](#special-tokens)
|
||||
- [Common Errors](#common-errors-)
|
||||
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
||||
- [Need Help?](#need-help-)
|
||||
- [Badge](#badge-)
|
||||
- [Community Showcase](#community-showcase)
|
||||
@@ -64,17 +67,21 @@ Features:
|
||||
|
||||
## Axolotl supports
|
||||
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
@@ -83,20 +90,29 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
||||
|
||||
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
||||
|
||||
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
|
||||
|
||||
### For developers
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
|
||||
### Usage
|
||||
```bash
|
||||
# finetune lora
|
||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||
|
||||
# inference
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out"
|
||||
|
||||
# gradio
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out" --gradio
|
||||
```
|
||||
|
||||
## Installation
|
||||
@@ -107,7 +123,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
- `winglian/axolotl-runpod:main-latest`: for runpod or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
Or run on the current files for development:
|
||||
|
||||
@@ -115,6 +130,27 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Docker advanced</summary>
|
||||
|
||||
A more powerful Docker command to run would be this:
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
|
||||
It additionally:
|
||||
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
||||
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
||||
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
||||
* The `--privileged` flag gives all capabilities to the container.
|
||||
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
|
||||
|
||||
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
||||
|
||||
</details>
|
||||
|
||||
#### Conda/Pip venv
|
||||
1. Install python >=**3.9**
|
||||
|
||||
@@ -131,6 +167,10 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
```
|
||||
Get the token at huggingface.co/settings/tokens
|
||||
|
||||
#### Runpod
|
||||
|
||||
Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
#### LambdaLabs
|
||||
<details>
|
||||
|
||||
@@ -178,6 +218,28 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
#### Windows
|
||||
Please use WSL or Docker!
|
||||
|
||||
|
||||
#### Launching on public clouds via SkyPilot
|
||||
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
|
||||
```bash
|
||||
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
|
||||
sky check
|
||||
```
|
||||
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
|
||||
```
|
||||
git clone https://github.com/skypilot-org/skypilot.git
|
||||
cd skypilot/llm/axolotl
|
||||
```
|
||||
Use one command to launch:
|
||||
```bash
|
||||
# On-demand
|
||||
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
||||
|
||||
# Managed spot (auto-recovery on preemption)
|
||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
||||
```
|
||||
|
||||
|
||||
### Dataset
|
||||
|
||||
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
|
||||
@@ -187,10 +249,17 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
- `sharegpt`: conversations where `from` is `human`/`gpt`
|
||||
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
|
||||
```yml
|
||||
datasets:
|
||||
- path: <your-path>
|
||||
type: sharegpt
|
||||
conversation: llama-2
|
||||
```
|
||||
- `completion`: raw corpus
|
||||
```json
|
||||
{"text": "..."}
|
||||
@@ -297,25 +366,24 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
|
||||
#### How to add custom prompts
|
||||
|
||||
Using yaml. Example:
|
||||
For a dataset that is preprocessed for instruction purposes:
|
||||
|
||||
```json
|
||||
{"instruction": "...", "output": "..."}
|
||||
```
|
||||
|
||||
You can use this example in your YAML config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: repo
|
||||
type:
|
||||
system_prompt: ""
|
||||
no_input_format: |-
|
||||
User: {instruction}<|end_of_turn|>
|
||||
Assistant:
|
||||
format: |-
|
||||
User: {instruction}
|
||||
{input}<|end_of_turn|>
|
||||
Assistant:
|
||||
field_system: system
|
||||
format: "[INST] {instruction} [/INST]"
|
||||
no_input_format: "[INST] {instruction} [/INST]"
|
||||
```
|
||||
|
||||
Using file:
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
||||
|
||||
#### How to use your custom pretokenized dataset
|
||||
|
||||
- Do not pass a `type:`
|
||||
@@ -357,6 +425,13 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- typescript
|
||||
type: ... # unimplemented custom format
|
||||
|
||||
# fastchat conversation
|
||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
datasets:
|
||||
- path: ...
|
||||
type: sharegpt
|
||||
conversation: chatml
|
||||
|
||||
# local
|
||||
datasets:
|
||||
- path: data.jsonl # or json
|
||||
@@ -368,6 +443,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- path: knowrohit07/know_sql
|
||||
type: context_qa.load_v2
|
||||
train_on_split: validation
|
||||
|
||||
# loading from s3 or gcs
|
||||
# s3 creds will be loaded from the system default and gcs only supports public access
|
||||
dataset:
|
||||
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
||||
...
|
||||
```
|
||||
|
||||
- loading
|
||||
@@ -395,7 +476,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
|
||||
<details>
|
||||
|
||||
<summary>All yaml options</summary>
|
||||
<summary>All yaml options (click me)</summary>
|
||||
|
||||
```yaml
|
||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
@@ -430,6 +511,23 @@ is_falcon_derived_model:
|
||||
is_llama_derived_model:
|
||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||
is_mistral_derived_model:
|
||||
is_qwen_derived_model:
|
||||
|
||||
# optional overrides to the base model configuration
|
||||
model_config:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# optional overrides to the bnb 4bit quantization configuration
|
||||
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
||||
bnb_config_kwargs:
|
||||
# These are default values
|
||||
llm_int8_has_fp16_weight: false
|
||||
bnb_4bit_quant_type: nf4
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
@@ -452,9 +550,14 @@ tf32: true # require >=ampere
|
||||
bfloat16: true # require >=ampere
|
||||
float16: true
|
||||
|
||||
# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
|
||||
gpu_memory_limit: 20GiB
|
||||
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
|
||||
lora_on_cpu: true
|
||||
|
||||
# A list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
|
||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
@@ -462,7 +565,12 @@ datasets:
|
||||
data_files: # Optional[str] path to source data files
|
||||
shards: # Optional[int] number of shards to split data into
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
|
||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
field_human: # Optional[str]. Human key to use for conversation.
|
||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||
|
||||
# Custom user prompt
|
||||
- path: repo
|
||||
@@ -486,6 +594,9 @@ datasets:
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
field:
|
||||
|
||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||
# Currently supports chatml and inst (mistral/mixtral)
|
||||
chat_template: chatml
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
@@ -528,10 +639,17 @@ eval_sample_packing:
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
|
||||
# Passed through to transformers when loading the model when launched without accelerate
|
||||
# Use `sequential` when training w/ model parallelism to limit memory
|
||||
device_map:
|
||||
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
||||
max_memory:
|
||||
|
||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# If you already have a lora model trained that you want to load, put that here.
|
||||
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
||||
# This means after training, if you want to test the model, you should set this to the value of `output_dir`.
|
||||
# Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`.
|
||||
lora_model_dir:
|
||||
|
||||
# LoRA hyperparameters
|
||||
@@ -558,10 +676,6 @@ lora_modules_to_save:
|
||||
# - embed_tokens
|
||||
# - lm_head
|
||||
|
||||
# Once you complete training, the model will be saved to the following directory.
|
||||
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
||||
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
||||
lora_out_dir:
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# ReLoRA configuration
|
||||
@@ -571,11 +685,13 @@ relora_warmup_steps: # Number of per-restart warmup steps
|
||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
|
||||
# wandb configuration if you're using it
|
||||
# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
wandb_project: # Your wandb project name
|
||||
wandb_entity: # A wandb Team name if using a Team
|
||||
wandb_watch:
|
||||
wandb_run_id: # Set the name of your wandb run
|
||||
wandb_name: # Set the name of your wandb run
|
||||
wandb_run_id: # Set the ID of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
@@ -592,14 +708,17 @@ gradient_accumulation_steps: 1
|
||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||
micro_batch_size: 2
|
||||
eval_batch_size:
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
num_epochs: 4
|
||||
warmup_steps: 100 # cannot use with warmup_ratio
|
||||
warmup_ratio: 0.05 # cannot use with warmup_steps
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
||||
save_strategy: # Set to `no` to skip checkpoint saves
|
||||
save_steps: # Leave empty to save at each epoch
|
||||
eval_steps: # Leave empty to eval at each epoch
|
||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
@@ -609,6 +728,9 @@ max_steps:
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
|
||||
# Save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
@@ -621,6 +743,9 @@ group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||
# gradient_checkpointing_kwargs:
|
||||
# use_reentrant: false
|
||||
|
||||
# Stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
@@ -675,7 +800,7 @@ max_grad_norm:
|
||||
# Augmentation techniques
|
||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||
# currently only supported on Llama and Mistral
|
||||
noisy_embedding_alpha:
|
||||
neftune_noise_alpha:
|
||||
|
||||
# Whether to bettertransformers
|
||||
flash_optimum:
|
||||
@@ -685,18 +810,11 @@ xformers_attention:
|
||||
flash_attention:
|
||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
# Whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Landmark attention (only llama)
|
||||
landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# LLaMA only
|
||||
xpos_rope:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# Resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
@@ -814,14 +932,42 @@ Run
|
||||
accelerate launch -m axolotl.cli.train your_config.yml
|
||||
```
|
||||
|
||||
#### Multi-GPU
|
||||
#### Preprocess dataset
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||
This is recommended for large datasets.
|
||||
|
||||
- Set `dataset_prepared_path:` to a local folder for saving and loading pre-tokenized dataset.
|
||||
- (Optional): Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
||||
- (Optional): Use `--debug` to see preprocessed examples.
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
|
||||
python -m axolotl.cli.preprocess your_config.yml
|
||||
```
|
||||
|
||||
##### Config
|
||||
#### Multi-GPU
|
||||
|
||||
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
|
||||
is the recommended multi-GPU option currently because FSDP may experience
|
||||
[loss instability](https://github.com/huggingface/transformers/issues/26498).
|
||||
|
||||
##### DeepSpeed
|
||||
|
||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||
|
||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero1.json
|
||||
```
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||
```
|
||||
|
||||
##### FSDP
|
||||
|
||||
- llama FSDP
|
||||
```yaml
|
||||
@@ -836,37 +982,40 @@ fsdp_config:
|
||||
|
||||
##### Weights & Biases Logging
|
||||
|
||||
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
||||
|
||||
- wandb options
|
||||
```yaml
|
||||
wandb_mode:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
### Training with Deepspeed
|
||||
##### Special Tokens
|
||||
|
||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
||||
|
||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||
```yml
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
tokens: # these are delimiters
|
||||
- "<|im_start|>"
|
||||
- "<|im_end|>"
|
||||
```
|
||||
|
||||
or
|
||||
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero1.json
|
||||
```
|
||||
### Inference Playground
|
||||
|
||||
### Inference
|
||||
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.
|
||||
The config file is the same config file used for training.
|
||||
|
||||
Pass the appropriate flag to the train command:
|
||||
Pass the appropriate flag to the inference command, depending upon what kind of model was trained:
|
||||
|
||||
- Pretrained LORA:
|
||||
```bash
|
||||
@@ -881,6 +1030,10 @@ Pass the appropriate flag to the train command:
|
||||
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||
```
|
||||
-- With gradio hosting
|
||||
```bash
|
||||
python -m axolotl.cli.inference examples/your_config.yml --gradio
|
||||
```
|
||||
|
||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
||||
|
||||
@@ -888,20 +1041,24 @@ Please use `--sample_packing False` if you have it on and receive the error simi
|
||||
|
||||
### Merge LORA to base
|
||||
|
||||
Add below flag to train command above
|
||||
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
|
||||
|
||||
```bash
|
||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
|
||||
```
|
||||
|
||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||
You may need to use the `gpu_memory_limit` and/or `lora_on_cpu` config options to avoid running out of memory. If you still run out of CUDA memory, you can try to merge in system RAM with
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
|
||||
```
|
||||
|
||||
although this will be very slow, and using the config options above are recommended instead.
|
||||
|
||||
## Common Errors 🧰
|
||||
|
||||
See also the [FAQ's](./docs/faq.md).
|
||||
|
||||
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
|
||||
|
||||
Please reduce any below
|
||||
@@ -910,6 +1067,10 @@ Please reduce any below
|
||||
- `gradient_accumulation_steps`
|
||||
- `sequence_len`
|
||||
|
||||
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
|
||||
|
||||
Using adamw_bnb_8bit might also save you some memory.
|
||||
|
||||
> `failed (exitcode: -9)`
|
||||
|
||||
Usually means your system has run out of system memory.
|
||||
@@ -932,6 +1093,20 @@ It's safe to ignore it.
|
||||
|
||||
See the [NCCL](docs/nccl.md) guide.
|
||||
|
||||
|
||||
### Tokenization Mismatch b/w Inference & Training
|
||||
|
||||
For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks.
|
||||
|
||||
If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following:
|
||||
|
||||
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
|
||||
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
|
||||
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
|
||||
4. As an additional troubleshooting step, you can look look at the token ids between 1 and 2 to make sure they are identical.
|
||||
|
||||
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
|
||||
|
||||
## Need help? 🙋♂️
|
||||
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||
|
||||
@@ -24,16 +24,6 @@
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
|
||||
@@ -28,16 +28,6 @@
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 0,
|
||||
@@ -40,15 +32,6 @@
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
|
||||
39
deepspeed/zero3_bf16.json
Normal file
39
deepspeed/zero3_bf16.json
Normal file
@@ -0,0 +1,39 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 0,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 0,
|
||||
"stage3_max_reuse_distance": 0,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -10,7 +10,7 @@ ARG PYTORCH_VERSION="2.0.1"
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y vim curl
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
@@ -19,13 +19,15 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
else \
|
||||
pip install -e .[flash-attn]; \
|
||||
pip install -e .[deepspeed,flash-attn]; \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN pip install pytest
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch
|
||||
|
||||
@@ -10,8 +10,10 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ARG PYTHON_VERSION="3.9"
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
ARG CUDA="118"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
||||
@@ -27,47 +29,9 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
|
||||
FROM base-builder AS deepspeed-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||
cd DeepSpeed && \
|
||||
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel
|
||||
|
||||
FROM base-builder AS bnb-builder
|
||||
|
||||
WORKDIR /workspace
|
||||
ARG CUDA="118"
|
||||
ENV CUDA=$CUDA
|
||||
ARG MAX_JOBS="-1"
|
||||
ENV MAX_JOBS=$MAX_JOBS
|
||||
|
||||
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
|
||||
cd bitsandbytes && \
|
||||
CUDA_VERSION=$CUDA make cuda11x && \
|
||||
python setup.py bdist_wheel
|
||||
|
||||
FROM base-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN mkdir -p /workspace/builds
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
||||
|
||||
RUN mkdir -p /workspace/wheels/bitsandbytes
|
||||
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
|
||||
|
||||
RUN pip3 install wheels/deepspeed-*.whl
|
||||
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
||||
RUN git lfs install --skip-repo
|
||||
RUN pip3 install awscli && \
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
@@ -4,6 +4,7 @@ FROM winglian/axolotl:$BASE_TAG
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||
|
||||
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
||||
|
||||
|
||||
18
docs/faq.md
Normal file
18
docs/faq.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# Axolotl FAQ's
|
||||
|
||||
|
||||
> The trainer stopped and hasn't progressed in several minutes.
|
||||
|
||||
Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md)
|
||||
|
||||
> Exitcode -9
|
||||
|
||||
This usually happens when you run out of system RAM.
|
||||
|
||||
> Exitcode -7 while using deepspeed
|
||||
|
||||
Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||
|
||||
> AttributeError: 'DummyOptim' object has no attribute 'step'
|
||||
|
||||
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
||||
35
docs/rlhf.md
Normal file
35
docs/rlhf.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# RLHF (Beta)
|
||||
|
||||
### Overview
|
||||
|
||||
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
||||
feedback. Various methods include, but not limited to:
|
||||
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
||||
- Direct Preference Optimization (DPO)
|
||||
- Identity Preference Optimization (IPO)
|
||||
|
||||
|
||||
### RLHF using Axolotl
|
||||
|
||||
[!IMPORTANT]
|
||||
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
|
||||
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
|
||||
|
||||
#### DPO
|
||||
```yaml
|
||||
rl: true
|
||||
datasets:
|
||||
- path: Intel/orca_dpo_pairs
|
||||
split: train
|
||||
type: intel_apply_chatml
|
||||
- path: argilla/ultrafeedback-binarized-preferences
|
||||
split: train
|
||||
type: argilla_apply_chatml
|
||||
```
|
||||
|
||||
#### IPO
|
||||
```yaml
|
||||
rl: ipo
|
||||
```
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: cerebras/btlm-3b-8k-base
|
||||
base_model_config: cerebras/btlm-3b-8k-base
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: GPT2Tokenizer
|
||||
trust_remote_code: true
|
||||
@@ -15,7 +14,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_prepared_run
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
@@ -36,7 +35,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
output_dir: btlm-out
|
||||
@@ -73,8 +72,8 @@ gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
|
||||
warmup_steps: 32
|
||||
eval_steps:
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
save_total_limit:
|
||||
|
||||
debug:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: cerebras/Cerebras-GPT-1.3B
|
||||
base_model_config: cerebras/Cerebras-GPT-1.3B
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
@@ -8,7 +7,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -25,7 +24,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
batch_size: 4
|
||||
@@ -50,8 +49,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
base_model_config: codellama/CodeLlama-13b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -12,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -30,12 +29,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -55,8 +54,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
base_model_config: codellama/CodeLlama-13b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -12,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -32,12 +31,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -57,8 +56,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
base_model_config: codellama/CodeLlama-34b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -12,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -30,12 +29,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -55,8 +54,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
base_model_config: codellama/CodeLlama-34b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -12,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -32,12 +31,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -57,8 +56,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
base_model_config: codellama/CodeLlama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -12,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -30,12 +29,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -55,8 +54,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
base_model_config: codellama/CodeLlama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -12,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -32,12 +31,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -57,8 +56,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
@@ -13,7 +12,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -27,7 +26,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./falcon-7b
|
||||
batch_size: 2
|
||||
@@ -52,8 +51,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 40
|
||||
eval_steps: 5
|
||||
save_steps: 43
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# 1b: tiiuae/falcon-rw-1b
|
||||
# 40b: tiiuae/falcon-40b
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
@@ -19,7 +18,7 @@ datasets:
|
||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||
type: "alpaca:chat"
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
@@ -41,7 +40,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
@@ -54,7 +53,7 @@ output_dir: ./qlora-out
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
@@ -81,8 +80,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 5
|
||||
save_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.000001
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
@@ -13,7 +12,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -27,7 +26,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./falcon-7b
|
||||
batch_size: 2
|
||||
@@ -52,8 +51,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 40
|
||||
eval_steps: 5
|
||||
save_steps: 43
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: EleutherAI/gpt-j-6b
|
||||
base_model_config: EleutherAI/gpt-j-6b
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
@@ -8,7 +7,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -22,7 +21,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 2
|
||||
@@ -47,8 +46,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: huggyllama/llama-7b
|
||||
base_model_config: huggyllama/llama-7b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
@@ -20,12 +19,12 @@ lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./jeopardy-bot-7b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
@@ -43,8 +42,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -9,12 +9,16 @@ gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
|
||||
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
|
||||
```
|
||||
or
|
||||
|
||||
```shell
|
||||
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
|
||||
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
|
||||
```
|
||||
|
||||
To launch a full finetuning with 16-bit precision:
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
|
||||
```
|
||||
|
||||
72
examples/llama-2/fft_optimized.yml
Normal file
72
examples/llama-2/fft_optimized.yml
Normal file
@@ -0,0 +1,72 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: TheBloke/Llama-2-7B-GPTQ
|
||||
base_model_config: TheBloke/Llama-2-7B-GPTQ
|
||||
is_llama_derived_model: false
|
||||
gptq: true
|
||||
gptq_disable_exllama: true
|
||||
@@ -16,7 +15,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 4096
|
||||
@@ -33,12 +32,12 @@ lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./model-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_eps: 0.00001
|
||||
@@ -63,8 +62,8 @@ flash_attention:
|
||||
sdp_attention:
|
||||
flash_optimum:
|
||||
warmup_steps: 100
|
||||
eval_steps:
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
base_model_config: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -12,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -30,12 +29,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -55,10 +54,10 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
base_model_config: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -12,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -32,12 +31,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -57,9 +56,9 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
base_model_config: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -12,7 +11,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./relora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -36,12 +35,12 @@ relora_cpu_offload: false
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -61,8 +60,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
save_steps: 50
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
61
examples/mamba/config.yml
Normal file
61
examples/mamba/config.yml
Normal file
@@ -0,0 +1,61 @@
|
||||
base_model: state-spaces/mamba-2.8b
|
||||
model_type: MambaLMHeadModel
|
||||
tokenizer_type: AutoTokenizer
|
||||
tokenizer_config: EleutherAI/gpt-neox-20b
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: true
|
||||
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
tokens:
|
||||
save_safetensors: False
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
base_model_config: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_mistral_derived_model: true
|
||||
@@ -12,22 +11,23 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
eval_sample_packing: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.000005
|
||||
@@ -47,10 +47,10 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_table_size: 5
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
91
examples/mistral/mixtral.yml
Normal file
91
examples/mistral/mixtral.yml
Normal file
@@ -0,0 +1,91 @@
|
||||
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./qlora-out
|
||||
|
||||
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||
unfrozen_parameters:
|
||||
# - lm_head.*
|
||||
# - model.embed_tokens.*
|
||||
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
|
||||
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
|
||||
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
||||
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
||||
|
||||
model_config:
|
||||
output_router_logits: true
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
#lora_target_modules:
|
||||
# - gate
|
||||
# - q_proj
|
||||
# - k_proj
|
||||
# - v_proj
|
||||
# - o_proj
|
||||
# - w1
|
||||
# - w2
|
||||
# - w3
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed: deepspeed/zero2.json
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
base_model_config: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_mistral_derived_model: true
|
||||
@@ -12,7 +11,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.1
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -39,7 +38,7 @@ lora_target_modules:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
@@ -63,11 +62,14 @@ logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_table_size: 5
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: mosaicml/mpt-7b
|
||||
base_model_config: mosaicml/mpt-7b
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
|
||||
load_in_8bit: false
|
||||
@@ -22,12 +21,12 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: mpt-alpaca-7b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./mpt-alpaca-7b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
@@ -45,8 +44,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
base_model_config: openlm-research/open_llama_3b_v2
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
@@ -24,7 +23,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./openllama-out
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -50,8 +49,8 @@ flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
base_model_config: openlm-research/open_llama_3b_v2
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
@@ -30,7 +29,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-out
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -55,8 +54,8 @@ flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
base_model_config: openlm-research/open_llama_3b_v2
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
@@ -10,7 +9,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
@@ -24,7 +23,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -49,8 +48,8 @@ flash_attention: true
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
base_model_config: microsoft/phi-1_5
|
||||
model_type: MixFormerSequentialForCausalLM
|
||||
model_type: PhiForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_llama_derived_model: false
|
||||
trust_remote_code: true
|
||||
@@ -32,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -60,8 +59,8 @@ xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 100
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
base_model_config: microsoft/phi-1_5
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_llama_derived_model: false
|
||||
@@ -32,7 +31,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
@@ -60,8 +59,8 @@ xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 100
|
||||
eval_steps: 0.05
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: EleutherAI/pythia-12b-deduped
|
||||
base_model_config: EleutherAI/pythia-12b-deduped
|
||||
base_model_ignore_patterns: pytorch* # prefer safetensors
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
@@ -25,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./pythia-12b
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: EleutherAI/pythia-1.4b-deduped
|
||||
base_model_config: EleutherAI/pythia-1.4b-deduped
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
@@ -19,12 +18,12 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-alpaca-pythia
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
@@ -34,5 +33,5 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
weight_decay: 0.1
|
||||
eval_steps: 20
|
||||
evals_per_epoch: 4
|
||||
logging_steps: 1
|
||||
|
||||
68
examples/qwen/lora.yml
Normal file
68
examples/qwen/lora.yml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: Qwen/Qwen-7B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
is_qwen_derived_model: true
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 2048 # supports up to 8192
|
||||
sample_packing: false
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
68
examples/qwen/qlora.yml
Normal file
68
examples/qwen/qlora.yml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: Qwen/Qwen-7B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
is_qwen_derived_model: true
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 2048 # supports up to 8192
|
||||
sample_packing: false
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code:
|
||||
@@ -23,12 +22,12 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: redpajama-alpaca-3b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./redpajama-alpaca-3b
|
||||
batch_size: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
@@ -46,8 +45,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 110
|
||||
save_steps: 660
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0001
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: replit/replit-code-v1-3b
|
||||
base_model_config: replit/replit-code-v1-3b
|
||||
trust_remote_code: true
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
@@ -22,12 +21,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project: lora-replit
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-replit
|
||||
batch_size: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer:
|
||||
torchdistx_path:
|
||||
lr_scheduler:
|
||||
@@ -46,8 +45,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
eval_steps: 50
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0
|
||||
|
||||
17
examples/tiny-llama/README.md
Normal file
17
examples/tiny-llama/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# Overview
|
||||
|
||||
This is a simple example of how to finetune TinyLlama1.1B using either lora or qlora:
|
||||
|
||||
LoRa:
|
||||
|
||||
```
|
||||
accelerate launch -m axolotl.cli.train examples/tiny-llama/lora.yml
|
||||
```
|
||||
|
||||
qLoRa:
|
||||
|
||||
```
|
||||
accelerate launch -m axolotl.cli.train examples/tiny-llama/qlora.yml
|
||||
```
|
||||
|
||||
Both take about 10 minutes to complete on a 4090.
|
||||
@@ -1,6 +1,4 @@
|
||||
base_model: PY007/TinyLlama-1.1B-step-50K-105b
|
||||
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
|
||||
|
||||
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -13,11 +11,12 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
@@ -30,12 +29,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -55,15 +54,11 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 20
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
58
examples/tiny-llama/pretrain.yml
Normal file
58
examples/tiny-llama/pretrain.yml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
max_steps: 200
|
||||
pretraining_dataset:
|
||||
path: c4
|
||||
name: en
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./model-out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch:
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
66
examples/tiny-llama/qlora.yml
Normal file
66
examples/tiny-llama/qlora.yml
Normal file
@@ -0,0 +1,66 @@
|
||||
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -1,7 +1,6 @@
|
||||
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora
|
||||
# on Tim Dettmer's Guanaco dataset.
|
||||
base_model: Salesforce/xgen-7b-8k-base
|
||||
base_model_config: Salesforce/xgen-7b-8k-base
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
@@ -17,7 +16,7 @@ datasets:
|
||||
- openassistant_best_replies_train.jsonl
|
||||
type: "completion"
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
val_set_size: 0.05
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
@@ -39,7 +38,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
@@ -52,7 +51,7 @@ output_dir: ./qlora-out
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
num_epochs: 3
|
||||
num_epochs: 4
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
@@ -79,8 +78,8 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 50
|
||||
save_steps: 50
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
|
||||
5
examples/yi-34B-chat/README.md
Normal file
5
examples/yi-34B-chat/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# Overview
|
||||
|
||||
This is an example of a Yi-34B-Chat configuration. It demonstrates that it is possible to finetune a 34B model on a GPU with 24GB of VRAM.
|
||||
|
||||
Tested on an RTX 4090 with `python -m axolotl.cli.train examples/mistral/qlora.yml`, a single epoch of finetuning on the alpaca dataset using qlora runs in 47 mins, using 97% of available memory.
|
||||
76
examples/yi-34B-chat/qlora.yml
Normal file
76
examples/yi-34B-chat/qlora.yml
Normal file
@@ -0,0 +1,76 @@
|
||||
base_model: 01-ai/Yi-34B-Chat
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_mistral_derived_model: false
|
||||
is_llama_derived_model: true
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
sequence_len: 1024
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
flash_attention: true
|
||||
special_tokens:
|
||||
bos_token: "<|startoftext|>"
|
||||
eos_token: "<|endoftext|>"
|
||||
unk_token: "<unk>"
|
||||
|
||||
# Data
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
warmup_steps: 10
|
||||
|
||||
# Iterations
|
||||
num_epochs: 1
|
||||
|
||||
# Evaluation
|
||||
val_set_size: 0.1
|
||||
evals_per_epoch: 5
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
eval_sample_packing: false
|
||||
eval_batch_size: 1
|
||||
|
||||
# LoRA
|
||||
output_dir: ./qlora-out
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
|
||||
# Sampling
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
# Batching
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
gradient_checkpointing: true
|
||||
|
||||
# wandb
|
||||
wandb_project:
|
||||
|
||||
# Optimizer
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
# Misc
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
BIN
image/sticker_fixed.png
Normal file
BIN
image/sticker_fixed.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 370 KiB |
@@ -1,23 +1,22 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
torch==2.0.1
|
||||
auto-gptq
|
||||
auto-gptq==0.5.1
|
||||
packaging
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@bd6205919aad4d3a2300a39a98a642f1cc3a5348
|
||||
peft==0.6.0
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
||||
tokenizers==0.15.0
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
||||
accelerate==0.24.1
|
||||
deepspeed
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
datasets
|
||||
flash-attn>=2.3.0
|
||||
datasets>=2.15.0
|
||||
flash-attn==2.3.3
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers>=0.0.22
|
||||
optimum
|
||||
xformers==0.0.22
|
||||
optimum==1.13.2
|
||||
hf_transfer
|
||||
colorama
|
||||
numba
|
||||
@@ -30,4 +29,13 @@ scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat==0.2.29
|
||||
fschat==0.2.34
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
|
||||
# remote filesystems
|
||||
s3fs
|
||||
gcsfs
|
||||
# adlfs
|
||||
|
||||
trl @ git+https://github.com/huggingface/trl.git@main
|
||||
|
||||
@@ -45,8 +45,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cli_args.prepare_ds_only:
|
||||
return
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
|
||||
20
setup.py
20
setup.py
@@ -1,5 +1,7 @@
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
@@ -22,12 +24,13 @@ def parse_requirements():
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
|
||||
# TODO(wing) remove once xformers release supports torch 2.1.0
|
||||
if "torch==2.1.0" in _install_requires:
|
||||
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
|
||||
_install_requires.append(
|
||||
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
|
||||
)
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
if torch_version.startswith("2.1.1"):
|
||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||
_install_requires.append("xformers==0.0.23")
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
@@ -46,10 +49,13 @@ setup(
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn>=2.2.1",
|
||||
"flash-attn==2.3.3",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.0.1",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -2,21 +2,25 @@
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from accelerate.commands.config import config_args
|
||||
from art import text2art
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.logging_config import configure_logging
|
||||
@@ -27,6 +31,7 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
@@ -44,7 +49,7 @@ def print_axolotl_text_art(suffix=None):
|
||||
ascii_text = " axolotl"
|
||||
if suffix:
|
||||
ascii_text += f" x {suffix}"
|
||||
ascii_art = text2art(" axolotl", font=font)
|
||||
ascii_art = text2art(ascii_text, font=font)
|
||||
|
||||
if is_main_process():
|
||||
print(ascii_art)
|
||||
@@ -68,14 +73,15 @@ def do_merge_lora(
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
model = model.merge_and_unload()
|
||||
model.to(dtype=torch.float16)
|
||||
model = model.merge_and_unload(progressbar=True)
|
||||
model.to(dtype=cfg.torch_dtype)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
progressbar=True,
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
@@ -100,15 +106,7 @@ def do_inference(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
model = model.to(cfg.device)
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
@@ -153,6 +151,83 @@ def do_inference(
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
def do_inference_gradio(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
def generate(instruction):
|
||||
if not instruction:
|
||||
return
|
||||
if prompter_module:
|
||||
# pylint: disable=stop-iteration-return
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = {
|
||||
"inputs": batch["input_ids"].to(cfg.device),
|
||||
"generation_config": generation_config,
|
||||
"streamer": streamer,
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
all_text = ""
|
||||
|
||||
for new_text in streamer:
|
||||
all_text += new_text
|
||||
yield all_text
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs="textbox",
|
||||
outputs="text",
|
||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||
)
|
||||
demo.queue().launch(show_api=False, share=True)
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
yaml_files = list(path.glob("*.yml"))
|
||||
|
||||
@@ -209,6 +284,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
@@ -222,7 +299,9 @@ def load_datasets(
|
||||
) -> TrainDatasetMeta:
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||
cfg, tokenizer
|
||||
)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
@@ -238,6 +317,98 @@ def load_datasets(
|
||||
text_only=cli_args.debug_text_only,
|
||||
)
|
||||
|
||||
LOG.info("printing prompters...")
|
||||
for prompter in prompters:
|
||||
LOG.info(prompter)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
total_num_steps=total_num_steps,
|
||||
)
|
||||
|
||||
|
||||
def load_rl_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
|
||||
) -> TrainDatasetMeta:
|
||||
train_datasets: List[Any] = []
|
||||
for i, ds_cfg in enumerate(cfg.datasets):
|
||||
train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
|
||||
# eval_dataset = load_dataset(
|
||||
# cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
|
||||
# )
|
||||
eval_dataset = None
|
||||
|
||||
def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
def apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
for i, data_set in enumerate(train_datasets):
|
||||
_type = cfg.datasets[i]["type"]
|
||||
ds_type_fn = locals()[_type]
|
||||
train_datasets[i] = data_set.map(ds_type_fn)
|
||||
train_dataset = concatenate_datasets(train_datasets)
|
||||
|
||||
# eval_dataset = eval_dataset.map(intel_apply_chatml)
|
||||
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
|
||||
@@ -6,11 +6,16 @@ from pathlib import Path
|
||||
import fire
|
||||
import transformers
|
||||
|
||||
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
|
||||
from axolotl.cli import (
|
||||
do_inference,
|
||||
do_inference_gradio,
|
||||
load_cfg,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
@@ -21,7 +26,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
)
|
||||
parsed_cli_args.inference = True
|
||||
|
||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if gradio:
|
||||
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -18,7 +18,22 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
return_remaining_strings=True
|
||||
)
|
||||
parsed_cli_args.merge_lora = True
|
||||
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
|
||||
|
||||
parsed_cfg = load_cfg(
|
||||
config,
|
||||
merge_lora=True,
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
flash_attention=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir:
|
||||
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
|
||||
if not Path(parsed_cfg.lora_model_dir).exists():
|
||||
raise ValueError(
|
||||
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
|
||||
)
|
||||
|
||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
|
||||
53
src/axolotl/cli/preprocess.py
Normal file
53
src/axolotl/cli/preprocess.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from colorama import Fore
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import PreprocessCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
if not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
+ "preprocess CLI called without dataset_prepared_path set, "
|
||||
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
|
||||
+ Fore.RESET
|
||||
)
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
||||
+ Fore.RESET
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
@@ -6,17 +6,16 @@ from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from colorama import Fore
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
load_rl_datasets,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.train import train
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
@@ -24,26 +23,18 @@ LOG = logging.getLogger("axolotl.cli.train")
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
+ "--prepare_ds_only called without dataset_prepared_path set."
|
||||
+ Fore.RESET
|
||||
)
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cli_args.prepare_ds_only:
|
||||
return
|
||||
if parsed_cfg.rl:
|
||||
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
|
||||
@@ -25,11 +25,22 @@ class TrainerCliArgs:
|
||||
debug_num_examples: int = field(default=5)
|
||||
inference: bool = field(default=False)
|
||||
merge_lora: bool = field(default=False)
|
||||
prepare_ds_only: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessCliArgs:
|
||||
"""
|
||||
dataclass representing arguments for preprocessing only
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
|
||||
0
src/axolotl/core/__init__.py
Normal file
0
src/axolotl/core/__init__.py
Normal file
943
src/axolotl/core/trainer_builder.py
Normal file
943
src/axolotl/core/trainer_builder.py
Normal file
@@ -0,0 +1,943 @@
|
||||
"""
|
||||
Builder for the training args and trainer
|
||||
"""
|
||||
|
||||
import abc
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
MambaDataCollator,
|
||||
)
|
||||
from axolotl.utils.samplers import MultipackBatchSampler
|
||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||
|
||||
try:
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
Extend the base TrainingArguments for axolotl helpers
|
||||
"""
|
||||
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
sample_packing_seq_len_multiplier: int = field(
|
||||
default=1,
|
||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
tag_names = ["axolotl"]
|
||||
|
||||
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
||||
self.num_epochs = num_epochs
|
||||
self.bench_data_collator = bench_data_collator
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
):
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
self.args.train_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
||||
lengths=(
|
||||
self.train_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
self.args.per_device_eval_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
||||
lengths=(
|
||||
eval_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = self.train_dataset
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
if isinstance(sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
|
||||
def get_bench_dataloader(
|
||||
self,
|
||||
bench_dataset: Dataset,
|
||||
) -> DataLoader:
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": self.bench_data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
|
||||
if isinstance(tag_names, str):
|
||||
tag_names = [tag_names]
|
||||
|
||||
if kwargs is not None:
|
||||
if "tags" not in kwargs:
|
||||
kwargs["tags"] = tag_names
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||
kwargs["tags"].extend(tag_names)
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||
tag_names.append(kwargs["tags"])
|
||||
kwargs["tags"] = tag_names
|
||||
|
||||
return kwargs
|
||||
|
||||
@wraps(Trainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = self._sanitize_kwargs_for_tagging(
|
||||
tag_names=self.tag_names, kwargs=kwargs
|
||||
)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Mamba specific trainer to handle loss calculation
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "mamba"]
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
|
||||
labels = input_ids.to(lm_logits.device)
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
|
||||
return lm_loss
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "onecycle"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
pct_start=pct_start,
|
||||
div_factor=6,
|
||||
)
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
"""
|
||||
|
||||
_train_dataset = None
|
||||
_eval_dataset = None
|
||||
_model_ref = None
|
||||
|
||||
def __init__(self, cfg, model, tokenizer):
|
||||
self.cfg = cfg
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@property
|
||||
def model_ref(self):
|
||||
return self._model_ref
|
||||
|
||||
@model_ref.setter
|
||||
def model_ref(self, model):
|
||||
self._model_ref = model
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self._train_dataset
|
||||
|
||||
@train_dataset.setter
|
||||
def train_dataset(self, dataset):
|
||||
self._train_dataset = dataset
|
||||
|
||||
@property
|
||||
def eval_dataset(self):
|
||||
return self._eval_dataset
|
||||
|
||||
@eval_dataset.setter
|
||||
def eval_dataset(self, dataset):
|
||||
self._eval_dataset = dataset
|
||||
|
||||
@abstractmethod
|
||||
def build(self, total_num_steps):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_callbacks(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
"""
|
||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||
"""
|
||||
|
||||
|
||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Build the HuggingFace training args/trainer for Causal models
|
||||
"""
|
||||
|
||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||
# TODO
|
||||
return training_arguments_kwargs
|
||||
|
||||
def hook_post_create_training_args(self, training_arguments):
|
||||
# TODO
|
||||
return training_arguments
|
||||
|
||||
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
|
||||
# TODO
|
||||
return trainer_kwargs, trainer_cls
|
||||
|
||||
def hook_post_create_trainer(self, trainer):
|
||||
# TODO
|
||||
return trainer
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = []
|
||||
callbacks.append(GPUStatsCallback(self.cfg))
|
||||
callbacks.append(EvalFirstStepCallback)
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
hasattr(self.model, "use_bettertransformer")
|
||||
and self.model.use_bettertransformer is True
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback)
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||
|
||||
if self.cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
self.cfg.early_stopping_patience,
|
||||
)
|
||||
callbacks.append(early_stop_cb)
|
||||
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
if self.cfg.lr_scheduler == "one_cycle" and (
|
||||
self.cfg.fsdp or self.cfg.adapter == "qlora"
|
||||
):
|
||||
return OneCycleLRSchedulerTrainer
|
||||
if self.cfg.relora_steps:
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
warmup_steps = None
|
||||
if self.cfg.warmup_steps is not None:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
elif self.cfg.warmup_ratio is not None:
|
||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||
else:
|
||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||
|
||||
logging_steps = (
|
||||
self.cfg.logging_steps
|
||||
if self.cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
if self.cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_arguments_kwargs["bf16"] = self.cfg.bf16
|
||||
training_arguments_kwargs["fp16"] = (
|
||||
self.cfg.fp16 and not self.cfg.bf16
|
||||
) or False
|
||||
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
if self.cfg.seed:
|
||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_arguments_kwargs[
|
||||
"gradient_checkpointing"
|
||||
] = self.cfg.gradient_checkpointing
|
||||
if self.cfg.gradient_checkpointing_kwargs:
|
||||
training_arguments_kwargs[
|
||||
"gradient_checkpointing_kwargs"
|
||||
] = self.cfg.gradient_checkpointing_kwargs
|
||||
else:
|
||||
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
|
||||
"use_reentrant": False
|
||||
}
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
||||
|
||||
# deepspeed
|
||||
if self.cfg.deepspeed:
|
||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||
|
||||
if self.cfg.lr_quadratic_warmup is not None:
|
||||
training_arguments_kwargs[
|
||||
"lr_quadratic_warmup"
|
||||
] = self.cfg.lr_quadratic_warmup
|
||||
|
||||
if self.cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||
if self.cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
||||
if self.cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
||||
if self.cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
||||
|
||||
if self.cfg.hub_model_id:
|
||||
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||
training_arguments_kwargs["push_to_hub"] = True
|
||||
training_arguments_kwargs["hub_private_repo"] = True
|
||||
training_arguments_kwargs["hub_always_push"] = True
|
||||
|
||||
if self.cfg.hub_strategy:
|
||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||
|
||||
if self.cfg.save_safetensors is not None:
|
||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.cfg.sample_packing_eff_est:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_efficiency"
|
||||
] = self.cfg.sample_packing_eff_est
|
||||
|
||||
if self.cfg.dataloader_pin_memory is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_pin_memory"
|
||||
] = self.cfg.dataloader_pin_memory
|
||||
if self.cfg.dataloader_num_workers is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_num_workers"
|
||||
] = self.cfg.dataloader_num_workers
|
||||
if self.cfg.dataloader_prefetch_factor is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_prefetch_factor"
|
||||
] = self.cfg.dataloader_prefetch_factor
|
||||
|
||||
if self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
elif self.cfg.evaluation_strategy:
|
||||
training_arguments_kwargs[
|
||||
"evaluation_strategy"
|
||||
] = self.cfg.evaluation_strategy
|
||||
else:
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.save_steps:
|
||||
training_arguments_kwargs["save_strategy"] = "steps"
|
||||
training_arguments_kwargs["save_steps"] = self.cfg.save_steps
|
||||
elif self.cfg.save_strategy:
|
||||
training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_arguments_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||
if self.cfg.bench_dataset:
|
||||
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
|
||||
if self.cfg.metric_for_best_model:
|
||||
training_arguments_kwargs[
|
||||
"metric_for_best_model"
|
||||
] = self.cfg.metric_for_best_model
|
||||
if self.cfg.greater_is_better:
|
||||
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
||||
|
||||
if self.cfg.torch_compile:
|
||||
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
|
||||
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
|
||||
elif torch._dynamo: # pylint: disable=protected-access
|
||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||
True
|
||||
)
|
||||
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||
if self.cfg.torch_compile_backend:
|
||||
training_arguments_kwargs[
|
||||
"torch_compile_backend"
|
||||
] = self.cfg.torch_compile_backend
|
||||
|
||||
# DDP Config
|
||||
if self.cfg.ddp_timeout:
|
||||
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
|
||||
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
||||
if self.cfg.ddp_bucket_cap_mb:
|
||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
||||
if self.cfg.ddp_broadcast_buffers is not None:
|
||||
training_arguments_kwargs[
|
||||
"ddp_broadcast_buffers"
|
||||
] = self.cfg.ddp_broadcast_buffers
|
||||
|
||||
# these are all the "standard" kwargs that are def used
|
||||
training_arguments_kwargs["max_steps"] = (
|
||||
total_num_steps if self.cfg.max_steps else -1
|
||||
)
|
||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||
training_arguments_kwargs[
|
||||
"per_device_train_batch_size"
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs[
|
||||
"per_device_eval_batch_size"
|
||||
] = self.cfg.eval_batch_size
|
||||
training_arguments_kwargs[
|
||||
"gradient_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs[
|
||||
"eval_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
||||
training_arguments_kwargs["save_total_limit"] = (
|
||||
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
||||
)
|
||||
training_arguments_kwargs["load_best_model_at_end"] = (
|
||||
(
|
||||
self.cfg.load_best_model_at_end is not False
|
||||
or self.cfg.early_stopping_patience
|
||||
)
|
||||
and self.cfg.val_set_size > 0
|
||||
and self.cfg.save_steps
|
||||
and self.cfg.eval_steps
|
||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||
) or False
|
||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||
False if self.cfg.ddp else None
|
||||
)
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
||||
training_arguments_kwargs["run_name"] = (
|
||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||
)
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler
|
||||
if self.cfg.lr_scheduler
|
||||
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||
else "cosine"
|
||||
)
|
||||
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
training_arguments_kwargs["weight_decay"] = (
|
||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||
)
|
||||
training_arguments_kwargs["sample_packing"] = (
|
||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = (
|
||||
self.cfg.sample_packing
|
||||
if self.cfg.eval_sample_packing is not False
|
||||
else False
|
||||
)
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_seq_len_multiplier"
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
)
|
||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||
|
||||
if self.cfg.neftune_noise_alpha is not None:
|
||||
training_arguments_kwargs[
|
||||
"neftune_noise_alpha"
|
||||
] = self.cfg.neftune_noise_alpha
|
||||
|
||||
training_args = (
|
||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
)
|
||||
training_args = self.hook_post_create_training_args(training_args)
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.optimizer == "adamw_anyprecision":
|
||||
if Path(self.cfg.torchdistx_path).exists():
|
||||
sys.path.append(self.cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True, # True/"longest" is the default
|
||||
}
|
||||
if self.cfg.pad_to_sequence_len:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
|
||||
self.cfg.sequence_len / 64
|
||||
)
|
||||
else:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
trainer_kwargs, trainer_cls
|
||||
)
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
callbacks=self.get_callbacks(),
|
||||
num_epochs=self.cfg.num_epochs,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
trainer = self.hook_post_create_trainer(trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
if self.cfg.deepspeed and self.cfg.sample_packing:
|
||||
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||
"train_micro_batch_size_per_gpu"
|
||||
] = self.cfg.micro_batch_size
|
||||
|
||||
return trainer
|
||||
|
||||
def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
|
||||
if training_args.pretraining:
|
||||
return None
|
||||
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||
|
||||
return BatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Trainer factory class for DPO Trainer
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = []
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
return callbacks
|
||||
|
||||
def build_training_arguments(self, total_num_steps):
|
||||
training_args_kwargs = {}
|
||||
for arg in [
|
||||
"adam_beta1",
|
||||
"adam_beta2",
|
||||
"adam_epsilon",
|
||||
"dataloader_num_workers",
|
||||
"dataloader_pin_memory",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
max_steps=total_num_steps,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||
learning_rate=self.cfg.learning_rate,
|
||||
evaluation_strategy="no",
|
||||
# eval_steps=self.cfg.eval_steps,
|
||||
save_strategy="steps",
|
||||
save_steps=self.cfg.save_steps,
|
||||
output_dir=self.cfg.output_dir,
|
||||
warmup_steps=self.cfg.warmup_steps,
|
||||
bf16=True,
|
||||
gradient_checkpointing=self.cfg.gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||||
logging_first_step=True,
|
||||
logging_steps=1,
|
||||
optim=self.cfg.optimizer,
|
||||
save_total_limit=self.cfg.save_total_limit or 5,
|
||||
**training_args_kwargs,
|
||||
)
|
||||
|
||||
return training_args
|
||||
|
||||
def build(self, total_num_steps):
|
||||
training_args = self.build_training_arguments(total_num_steps)
|
||||
dpo_trainer_kwargs = {}
|
||||
if self.cfg.rl == "ipo":
|
||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
|
||||
dpo_trainer = DPOTrainer(
|
||||
self.model,
|
||||
self.model_ref,
|
||||
args=training_args,
|
||||
beta=self.cfg.dpo_beta or 0.1,
|
||||
train_dataset=self.train_dataset,
|
||||
# eval_dataset=self.eval_dataset,
|
||||
eval_dataset=None,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=self.cfg.sequence_len,
|
||||
max_target_length=None,
|
||||
max_prompt_length=self.cfg.sequence_len,
|
||||
generate_during_eval=True,
|
||||
**dpo_trainer_kwargs,
|
||||
)
|
||||
|
||||
return dpo_trainer
|
||||
|
||||
|
||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
HF Factory class for PPO Trainer
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = []
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
return callbacks
|
||||
|
||||
def build(self, total_num_steps):
|
||||
# build PPOConfig
|
||||
pass
|
||||
0
src/axolotl/core/trainers/__init__.py
Normal file
0
src/axolotl/core/trainers/__init__.py
Normal file
66
src/axolotl/core/trainers/trl.py
Normal file
66
src/axolotl/core/trainers/trl.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
module for TRL PPO training
|
||||
"""
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from trl import PPOTrainer
|
||||
|
||||
|
||||
class TRLPPOTrainer(PPOTrainer):
|
||||
"""
|
||||
wrapper for ppo trainer to handle customizations
|
||||
"""
|
||||
|
||||
def train(
|
||||
self,
|
||||
reward_pipe,
|
||||
resume_from_checkpoint=None, # pylint: disable=unused-argument
|
||||
):
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": self.tokenizer.eos_token_id,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
sent_kwargs = {
|
||||
"return_all_scores": True,
|
||||
"function_to_apply": "none",
|
||||
"batch_size": 16,
|
||||
}
|
||||
|
||||
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
||||
enumerate(self.dataloader)
|
||||
):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
# generate model response
|
||||
response_tensors, ref_response_tensors = self.generate(
|
||||
query_tensors,
|
||||
return_prompt=False,
|
||||
generate_ref_response=True,
|
||||
**generation_kwargs
|
||||
)
|
||||
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
||||
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
||||
|
||||
# Compute sentiment score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = reward_pipe(texts, **sent_kwargs)
|
||||
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
||||
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
|
||||
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
|
||||
ref_rewards = [
|
||||
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
|
||||
]
|
||||
batch["ref_rewards"] = ref_rewards
|
||||
|
||||
# Run PPO step
|
||||
stats = self.step(query_tensors, response_tensors, rewards)
|
||||
self.log_stats(
|
||||
stats,
|
||||
batch,
|
||||
rewards,
|
||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||
)
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: IterableDataset,
|
||||
process_count: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
self.process_count = process_count
|
||||
super().__init__(self.process(dataset).data, **kwargs)
|
||||
|
||||
def process(self, dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, os.cpu_count())
|
||||
num_proc = (
|
||||
min(64, self.process_count)
|
||||
if self.process_count
|
||||
else min(64, os.cpu_count())
|
||||
)
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
|
||||
12
src/axolotl/models/mamba/__init__.py
Normal file
12
src/axolotl/models/mamba/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Modeling module for Mamba models
|
||||
"""
|
||||
|
||||
|
||||
def fix_mamba_attn_for_loss():
|
||||
from mamba_ssm.models import mixer_seq_simple
|
||||
|
||||
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
||||
|
||||
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
|
||||
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
|
||||
42
src/axolotl/models/mamba/configuration_mamba.py
Normal file
42
src/axolotl/models/mamba/configuration_mamba.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
HF Transformers MambaConfig
|
||||
"""
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MambaConfig(PretrainedConfig):
|
||||
"""
|
||||
modeling configuration for state space model/mamba
|
||||
"""
|
||||
|
||||
model_type = "mamba"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50280,
|
||||
d_model=2560,
|
||||
n_layer=64,
|
||||
rms_norm=True,
|
||||
residual_in_fp32=True,
|
||||
fused_add_norm=True,
|
||||
pad_vocab_size_multiple=8,
|
||||
pad_token_id=50277,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.n_layer = n_layer
|
||||
self.rms_norm = rms_norm
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.fused_add_norm = fused_add_norm
|
||||
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
128
src/axolotl/models/mamba/modeling_mamba.py
Normal file
128
src/axolotl/models/mamba/modeling_mamba.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# pylint: skip-file
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
|
||||
from mamba_ssm.utils.generation import GenerationMixin
|
||||
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from axolotl.models.mamba.configuration_mamba import MambaConfig
|
||||
|
||||
|
||||
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_layer: int,
|
||||
vocab_size: int,
|
||||
initializer_cfg=None,
|
||||
pad_vocab_size_multiple: int = 1,
|
||||
device=None,
|
||||
dtype=None,
|
||||
**backbone_kwargs,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
if vocab_size % pad_vocab_size_multiple != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (
|
||||
vocab_size % pad_vocab_size_multiple
|
||||
)
|
||||
self.config = MambaConfig(
|
||||
vocab_size=vocab_size,
|
||||
d_model=d_model,
|
||||
n_layer=n_layer,
|
||||
pad_vocab_size_multiple=pad_vocab_size_multiple,
|
||||
)
|
||||
self.backbone = MixerModel(
|
||||
d_model=d_model,
|
||||
n_layer=n_layer,
|
||||
vocab_size=vocab_size,
|
||||
initializer_cfg=initializer_cfg,
|
||||
**backbone_kwargs,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(
|
||||
partial(
|
||||
_init_weights,
|
||||
n_layer=n_layer,
|
||||
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||
)
|
||||
)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
self.lm_head.weight = self.backbone.embedding.weight
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.backbone.allocate_inference_cache(
|
||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids=None,
|
||||
inference_params=None,
|
||||
num_last_tokens=0,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||
"""
|
||||
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
||||
if num_last_tokens > 0:
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
logits = lm_logits
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
|
||||
print(loss)
|
||||
return CausalLMOutput(logits=lm_logits, loss=loss)
|
||||
|
||||
else:
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
state_dict: Optional[dict] = None,
|
||||
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
|
||||
):
|
||||
if state_dict is None:
|
||||
state_dict = self.state_dict()
|
||||
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
||||
config = load_config_hf(pretrained_model_name)
|
||||
model = cls(**config, device=device, dtype=dtype, **kwargs)
|
||||
model.load_state_dict(
|
||||
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
|
||||
)
|
||||
return model
|
||||
@@ -3,4 +3,6 @@ MixFormers model architecture used for phi models
|
||||
"""
|
||||
|
||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
|
||||
from .configuration_phi import PhiConfig # noqa
|
||||
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
|
||||
from .modeling_phi import PhiForCausalLM # noqa
|
||||
|
||||
65
src/axolotl/models/phi/configuration_phi.py
Normal file
65
src/axolotl/models/phi/configuration_phi.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# pylint: skip-file
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
"""Phi configuration."""
|
||||
|
||||
model_type = "phi"
|
||||
attribute_map = {
|
||||
"max_position_embeddings": "n_positions",
|
||||
"hidden_size": "n_embd",
|
||||
"num_attention_heads": "n_head",
|
||||
"num_hidden_layers": "n_layer",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 50304,
|
||||
n_positions: int = 2048,
|
||||
n_embd: int = 1024,
|
||||
n_layer: int = 20,
|
||||
n_inner: Optional[int] = None,
|
||||
n_head: int = 16,
|
||||
n_head_kv: Optional[int] = None,
|
||||
rotary_dim: Optional[int] = 32,
|
||||
activation_function: Optional[str] = "gelu_new",
|
||||
flash_attn: bool = False,
|
||||
flash_rotary: bool = False,
|
||||
fused_dense: bool = False,
|
||||
attn_pdrop: float = 0.0,
|
||||
embd_pdrop: float = 0.0,
|
||||
resid_pdrop: float = 0.0,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
initializer_range: float = 0.02,
|
||||
tie_word_embeddings: bool = False,
|
||||
pad_vocab_size_multiple: int = 64,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.vocab_size = int(
|
||||
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
)
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_inner = n_inner
|
||||
self.n_head = n_head
|
||||
self.n_head_kv = n_head_kv
|
||||
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
||||
self.activation_function = activation_function
|
||||
self.flash_attn = flash_attn
|
||||
self.flash_rotary = flash_rotary
|
||||
self.fused_dense = fused_dense
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
1063
src/axolotl/models/phi/modeling_phi.py
Normal file
1063
src/axolotl/models/phi/modeling_phi.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -82,15 +82,44 @@ def get_turns( # pylint: disable=too-many-return-statements
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2:
|
||||
seps = [self.sep, self.sep2]
|
||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
|
||||
if self.system_message:
|
||||
if self.messages:
|
||||
# For llama, the system message is incorporated into the first human instruction
|
||||
first_role, first_msg = self.messages[0]
|
||||
if first_role == self.roles[0]:
|
||||
system_prompt += first_msg
|
||||
self.messages.pop(0)
|
||||
yield "", system_prompt
|
||||
else:
|
||||
yield "", "[INST] "
|
||||
for i, (role, message) in enumerate(self.messages[1:]):
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + " ", message + seps[i % 2]
|
||||
if (i % 2 == 0 and not self.system_message) or (
|
||||
i % 2 != 0 and self.system_message
|
||||
):
|
||||
role = "<s> " + role
|
||||
yield role + " ", message
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
|
||||
contains_sys_msg = False
|
||||
if self.system_message:
|
||||
contains_sys_msg = True
|
||||
if self.messages:
|
||||
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
|
||||
first_role, first_msg = self.messages[0]
|
||||
if first_role == self.roles[0]:
|
||||
system_prompt = self.system_template.format(
|
||||
system_message=" " + self.system_message
|
||||
)
|
||||
system_prompt += first_msg
|
||||
self.messages.pop(0)
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message and i == 0 and not contains_sys_msg:
|
||||
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
|
||||
elif message:
|
||||
yield role + " ", message
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
@@ -118,6 +147,15 @@ def get_turns( # pylint: disable=too-many-return-statements
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATGLM3:
|
||||
if self.system_message:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", " " + message
|
||||
else:
|
||||
yield role
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATINTERN:
|
||||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
||||
seps = [self.sep, self.sep2]
|
||||
|
||||
@@ -13,12 +13,18 @@ import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaMLP,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
@@ -38,25 +44,37 @@ except ImportError:
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def replace_llama_mlp_with_swiglu(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaMLP):
|
||||
mlp = FusedMLP(
|
||||
module.config, module.gate_proj, module.up_proj, module.down_proj
|
||||
)
|
||||
set_module_name(model, name, mlp)
|
||||
|
||||
|
||||
def replace_llama_qkv_with_fused(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
qkv = FusedAttention(
|
||||
module.config,
|
||||
module.q_proj,
|
||||
module.k_proj,
|
||||
module.v_proj,
|
||||
module.o_proj,
|
||||
)
|
||||
set_module_name(model, name, qkv)
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
cross_entropy: Optional[bool] = False,
|
||||
rms_norm: Optional[bool] = False,
|
||||
noisy_embeddings_alpha: Optional[int] = False,
|
||||
):
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
||||
if noisy_embeddings_alpha:
|
||||
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = partial(
|
||||
llama_model_get_inputs_embeds, noisy_embeddings_alpha=noisy_embeddings_alpha
|
||||
)
|
||||
else:
|
||||
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = (
|
||||
llama_model_get_inputs_embeds
|
||||
)
|
||||
|
||||
if packed:
|
||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||
@@ -96,6 +114,92 @@ def replace_llama_attn_with_flash_attn(
|
||||
)
|
||||
|
||||
|
||||
class FusedAttention(LlamaAttention):
|
||||
"""
|
||||
Fused QKV Attention layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
q: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
k: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
v: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
o: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.init_device = next(iter(q.state_dict().values())).device
|
||||
|
||||
# define equivalent fused qkv projection
|
||||
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
||||
self.qkv_proj = torch.nn.Linear(
|
||||
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
||||
)
|
||||
self.o_proj = o
|
||||
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.qkv_proj.weight.data = torch.cat(
|
||||
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
||||
)
|
||||
|
||||
def _post_training(self, model, name):
|
||||
q_proj, k_proj, v_proj = torch.split(
|
||||
self.qkv_proj.weight.data, self.out_features, dim=0
|
||||
)
|
||||
|
||||
new_attn = LlamaAttention(self.config)
|
||||
new_attn.q_proj.weight.data = q_proj
|
||||
new_attn.k_proj.weight.data = k_proj
|
||||
new_attn.v_proj.weight.data = v_proj
|
||||
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
||||
|
||||
set_module_name(model, name, new_attn)
|
||||
|
||||
|
||||
class FusedMLP(torch.nn.Module):
|
||||
"""
|
||||
Fused MLP layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
gate_proj: torch.nn.Linear,
|
||||
up_proj: torch.nn.Linear,
|
||||
down_proj: torch.nn.Linear,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.swiglu = SwiGLU(
|
||||
in_features=config.hidden_size,
|
||||
hidden_features=config.intermediate_size,
|
||||
bias=False,
|
||||
_pack_weights=True,
|
||||
)
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.swiglu.w12.weight.data = torch.cat(
|
||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||
)
|
||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||
|
||||
def _post_training(self, model, name):
|
||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||
)
|
||||
|
||||
# Assign the split weights back to the original layers
|
||||
new_mlp = LlamaMLP(self.config)
|
||||
new_mlp.gate_proj.weight.data = w1
|
||||
new_mlp.up_proj.weight.data = w2
|
||||
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||
|
||||
set_module_name(model, name, new_mlp)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||
return self.swiglu(x)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
@@ -157,9 +261,14 @@ def flashattn_forward(
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
if isinstance(self, FusedAttention):
|
||||
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
||||
self.out_features, dim=-1
|
||||
)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
@@ -212,6 +321,8 @@ def flashattn_forward(
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
@@ -221,7 +332,12 @@ def flashattn_forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
@@ -244,7 +360,7 @@ def flashattn_forward(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
0.0,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
@@ -257,6 +373,7 @@ def flashattn_forward(
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
dropout_p=dropout_rate,
|
||||
causal=is_causal,
|
||||
)
|
||||
else:
|
||||
@@ -289,7 +406,7 @@ def flashattn_forward(
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
0.0,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
@@ -421,28 +538,6 @@ def generate_qkv(
|
||||
)
|
||||
|
||||
|
||||
def llama_model_get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
noisy_embeddings_alpha: Optional[int] = None,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if noisy_embeddings_alpha:
|
||||
input_mask = attention_mask.to(inputs_embeds) # B x L
|
||||
input_lengths = torch.sum(input_mask, 1) # B
|
||||
|
||||
noise_ = torch.zeros_like(inputs_embeds).uniform_(-1, 1)
|
||||
delta = noise_ * input_mask.unsqueeze(2)
|
||||
dims = input_lengths * inputs_embeds.size(-1)
|
||||
mag = noisy_embeddings_alpha / torch.sqrt(dims)
|
||||
delta = (delta * mag.view(-1, 1, 1)).detach()
|
||||
inputs_embeds += delta
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def llama_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@@ -509,8 +604,7 @@ def llama_model_forward(
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_inputs_embeds(input_ids, attention_mask)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
|
||||
@@ -25,6 +25,8 @@ def sdp_attention_forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@@ -29,6 +29,8 @@ def xformers_forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
"""
|
||||
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
|
||||
import torch
|
||||
import transformers.models.llama.modeling_llama
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
|
||||
# pylint: disable=duplicate-code
|
||||
def noised_embed(orig_embed, noise_alpha, model):
|
||||
def new_func(input_ids):
|
||||
# during training, we add noise to the embedding
|
||||
# during generation, we don't add noise to the embedding
|
||||
if model.training:
|
||||
embed_init = orig_embed(input_ids)
|
||||
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
||||
mag_norm = noise_alpha / torch.sqrt(dims)
|
||||
return embed_init + torch.zeros_like(embed_init).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
return orig_embed(input_ids)
|
||||
|
||||
return new_func
|
||||
|
||||
def post_init(orig_post_init):
|
||||
def new_func(self):
|
||||
orig_post_init(self)
|
||||
self.embed_tokens.forward = noised_embed(
|
||||
self.embed_tokens.forward, noise_alpha, self
|
||||
)
|
||||
|
||||
return new_func
|
||||
|
||||
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
|
||||
transformers.models.llama.modeling_llama.LlamaModel.post_init
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,6 +14,9 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention as OriginalMistralAttention,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
)
|
||||
@@ -42,6 +45,44 @@ def replace_mistral_attn_with_flash_attn(
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _make_sliding_window_causal_mask(
|
||||
bsz: int,
|
||||
tgt_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
past_key_values_length: int = 0,
|
||||
sliding_window: int = 4096,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for sliding window attention
|
||||
"""
|
||||
tensor = torch.full(
|
||||
(tgt_len, tgt_len),
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
mask = torch.tril(tensor, diagonal=0)
|
||||
# make the mask banded to account for sliding window
|
||||
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
||||
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
||||
mask = torch.log(mask).to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return mask[None, None, :, :].expand(
|
||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
@@ -53,11 +94,29 @@ def _prepare_decoder_attention_mask(
|
||||
sliding_window,
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
||||
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
||||
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
||||
sliding_window_mask = _make_sliding_window_causal_mask(
|
||||
bsz=input_shape[0],
|
||||
tgt_len=input_shape[1],
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
attention_mask = attention_mask + sliding_window_mask
|
||||
else:
|
||||
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
def flashattn_forward(
|
||||
self,
|
||||
self: OriginalMistralAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -91,10 +150,41 @@ def flashattn_forward(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
use_sliding_windows = (
|
||||
hasattr(self.config, "sliding_window") is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
)
|
||||
|
||||
if use_sliding_windows:
|
||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
||||
else:
|
||||
window_size = (-1, -1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
if (
|
||||
hasattr(self.config, "sliding_window")
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
):
|
||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[0]
|
||||
past_value = past_key_value[1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
past_key_value = (past_key, past_value) if use_cache else None
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
@@ -111,6 +201,8 @@ def flashattn_forward(
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
@@ -120,7 +212,13 @@ def flashattn_forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
@@ -143,9 +241,10 @@ def flashattn_forward(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
0.0,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
else:
|
||||
@@ -156,7 +255,9 @@ def flashattn_forward(
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
dropout_p=dropout_rate,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
else:
|
||||
( # pylint: disable=unbalanced-tuple-unpacking
|
||||
@@ -188,9 +289,10 @@ def flashattn_forward(
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
0.0,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
"""
|
||||
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
|
||||
import torch
|
||||
import transformers.models.mistral.modeling_mistral
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
|
||||
# pylint: disable=duplicate-code
|
||||
def noised_embed(orig_embed, noise_alpha, model):
|
||||
def new_func(input_ids):
|
||||
# during training, we add noise to the embedding
|
||||
# during generation, we don't add noise to the embedding
|
||||
if model.training:
|
||||
embed_init = orig_embed(input_ids)
|
||||
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
||||
mag_norm = noise_alpha / torch.sqrt(dims)
|
||||
return embed_init + torch.zeros_like(embed_init).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
return orig_embed(input_ids)
|
||||
|
||||
return new_func
|
||||
|
||||
def post_init(orig_post_init):
|
||||
def new_func(self):
|
||||
orig_post_init(self)
|
||||
self.embed_tokens.forward = noised_embed(
|
||||
self.embed_tokens.forward, noise_alpha, self
|
||||
)
|
||||
|
||||
return new_func
|
||||
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.post_init
|
||||
)
|
||||
22
src/axolotl/monkeypatch/mixtral/__init__.py
Normal file
22
src/axolotl/monkeypatch/mixtral/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Patches to support multipack for mixtral
|
||||
"""
|
||||
import transformers
|
||||
|
||||
|
||||
def replace_mixtral_attn_with_multipack_flash_attn():
|
||||
from .modeling_mixtral import (
|
||||
MixtralMultipackFlashAttention2,
|
||||
mixtral_decoder_layer_forward,
|
||||
mixtral_model_forward,
|
||||
)
|
||||
|
||||
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
|
||||
mixtral_decoder_layer_forward
|
||||
)
|
||||
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
||||
mixtral_model_forward
|
||||
)
|
||||
transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
|
||||
"flash_attention_2"
|
||||
] = MixtralMultipackFlashAttention2
|
||||
383
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
Normal file
383
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""
|
||||
Mixtral modeling for multipack
|
||||
"""
|
||||
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
|
||||
import logging
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
from transformers import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import MoeModelOutputWithPast
|
||||
from transformers.models.mixtral.modeling_mixtral import (
|
||||
MixtralFlashAttention2,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
|
||||
|
||||
|
||||
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
|
||||
"""
|
||||
Custom multipack implementation w flash attention 2
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._flash_attn_uses_top_left_mask = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
attn_output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=self.attention_dropout,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def mixtral_decoder_layer_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_router_logits (`bool`, *optional*):
|
||||
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
||||
should not be returned during inference.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
if output_router_logits:
|
||||
outputs += (router_logits,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def mixtral_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, MoeModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if (
|
||||
attention_mask is not None
|
||||
and self._attn_implementation == "flash_attention_2"
|
||||
and use_cache
|
||||
):
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = (
|
||||
attention_mask
|
||||
if (attention_mask is not None and 0 in attention_mask)
|
||||
else None
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
LOG.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_router_logits = () if output_router_logits else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if output_router_logits:
|
||||
all_router_logits += (layer_outputs[-1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = (
|
||||
next_decoder_cache.to_legacy_cache()
|
||||
if use_legacy_cache
|
||||
else next_decoder_cache
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
all_router_logits,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
@@ -101,3 +101,16 @@ def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
max_seq_lens.append(max_seq_len)
|
||||
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
def set_module_name(model, name, value):
|
||||
if "." in name:
|
||||
parent_name = name.rsplit(".", 1)[0]
|
||||
child_name = name[len(parent_name) + 1 :]
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent_name = ""
|
||||
parent = model
|
||||
child_name = name
|
||||
|
||||
setattr(parent, child_name, value)
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
# pylint: skip-file
|
||||
"""
|
||||
Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
"""
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.models.llama.modeling_llama
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class XposRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scale_base=2048,
|
||||
use_xpos=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.scale_base = scale_base
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
|
||||
freqs = torch.einsum("i , j -> i j", t, inv_freq)
|
||||
freqs = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.register_buffer("freqs_cached", freqs, persistent=False)
|
||||
|
||||
if not use_xpos:
|
||||
self.register_buffer("scale", None)
|
||||
self.register_buffer("scale_cached", torch.ones(1))
|
||||
return
|
||||
|
||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||
power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
|
||||
scale_cached = scale ** rearrange(power, "n -> n 1")
|
||||
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
||||
|
||||
self.register_buffer("scale", scale, persistent=False)
|
||||
self.register_buffer("scale_cached", scale_cached, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
seq_len,
|
||||
):
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
|
||||
self.inv_freq
|
||||
)
|
||||
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
|
||||
freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
|
||||
|
||||
self.register_buffer("freqs_cached", freqs)
|
||||
|
||||
if self.scale is None:
|
||||
self.register_buffer(
|
||||
"scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
|
||||
)
|
||||
|
||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
|
||||
|
||||
power = (t - (seq_len // 2)) / self.scale_base
|
||||
scale = self.scale ** rearrange(power, "n -> n 1")
|
||||
scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
|
||||
self.register_buffer("scale_cached", scale)
|
||||
|
||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
|
||||
freqs = freqs[position_ids, :]
|
||||
if scale.shape[-1] != 1:
|
||||
scale = scale[position_ids, :]
|
||||
|
||||
q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
|
||||
k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def replace_llama_rope_with_xpos_rope():
|
||||
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
|
||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||
"""Module for Alpaca prompt strategy classes"""
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
@@ -9,9 +9,13 @@ from axolotl.prompt_tokenizers import (
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
prompt_style = PromptStyle.CHAT.value
|
||||
if ds_cfg and "conversation" in ds_cfg:
|
||||
prompt_style = ds_cfg["conversation"]
|
||||
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||
AlpacaPrompter(prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
@@ -81,8 +81,9 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sequence_len = 4096
|
||||
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
self.tokenizer.add_special_tokens(
|
||||
{"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")}
|
||||
)
|
||||
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
|
||||
@@ -13,7 +13,7 @@ register_conv_template(
|
||||
system_message="You are a helpful assistant.",
|
||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>\n",
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
)
|
||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
return SimpleShareGPTPromptTokenizingStrategy(
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
@@ -34,6 +34,26 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and "strict" in ds_cfg:
|
||||
strategy.strict = ds_cfg["strict"]
|
||||
return strategy
|
||||
|
||||
|
||||
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
strategy = UltrachatShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and "strict" in ds_cfg:
|
||||
strategy.strict = ds_cfg["strict"]
|
||||
return strategy
|
||||
|
||||
|
||||
def load_role(tokenizer, cfg):
|
||||
@@ -59,8 +79,26 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
_strict = True
|
||||
|
||||
@property
|
||||
def strict(self):
|
||||
return self._strict
|
||||
|
||||
@strict.setter
|
||||
def strict(self, strict):
|
||||
self._strict = strict
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
conversations = prompt["conversations"]
|
||||
if self.strict:
|
||||
return conversations
|
||||
# remap roles - allow for assistant turn
|
||||
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
@@ -88,3 +126,17 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps ultrachat data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["messages"]
|
||||
role_map = {"user": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
@@ -45,6 +45,8 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
self.prompter = prompter
|
||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||
self.train_on_inputs = train_on_inputs
|
||||
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
|
||||
# TODO: Document how they are different.
|
||||
self.sequence_len = sequence_len
|
||||
self.max_length = sequence_len
|
||||
|
||||
@@ -59,34 +61,31 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
def _tokenize(
|
||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||
) -> BatchEncoding:
|
||||
result: BatchEncoding
|
||||
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||
if not prompt:
|
||||
LOG.warning("Empty text requested for tokenization.")
|
||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||
else:
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
return empty
|
||||
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if len(result["input_ids"]) == 0:
|
||||
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
||||
return empty
|
||||
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.max_length
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
and strip_bos_token
|
||||
):
|
||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
@@ -122,7 +121,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||
tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
|
||||
tokenized_res_prompt = self._tokenize(
|
||||
response, strip_bos_token=True, add_eos_token=True
|
||||
)
|
||||
@@ -246,6 +245,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# pylint: disable=duplicate-code
|
||||
(
|
||||
instruction,
|
||||
input, # pylint: disable=redefined-builtin
|
||||
@@ -270,7 +270,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_full_prompt["labels"] = [
|
||||
-100
|
||||
IGNORE_INDEX
|
||||
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
||||
|
||||
return tokenized_full_prompt
|
||||
@@ -334,6 +334,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
return prompt["conversations"]
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# Initial values. We will append to these as we go through the conversation.
|
||||
result, current_len = tokenize_prompt_default()
|
||||
conversation: Conversation = (
|
||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||
@@ -355,62 +356,67 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||
):
|
||||
if isinstance(part, tuple):
|
||||
if conversation.roles[0] in part[0]:
|
||||
role = (
|
||||
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||
if role_remap
|
||||
else part[0]
|
||||
)
|
||||
turn = role + part[1]
|
||||
# this is still the user query, we should
|
||||
if not part[1].strip():
|
||||
LOG.warning(f"user turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif conversation.roles[1] in part[0]:
|
||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||
role = (
|
||||
part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||
if role_remap
|
||||
else part[0]
|
||||
)
|
||||
turn = role + part[1]
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not part[1].strip():
|
||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
role_res = self._tokenize(
|
||||
role.rstrip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
len_role = len(role_res["input_ids"])
|
||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
||||
len_role, len(labels)
|
||||
)
|
||||
elif part[0] == "":
|
||||
turn = part[1]
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
LOG.warning(f"unhandled role: {part[0]}")
|
||||
continue
|
||||
if not isinstance(part, tuple):
|
||||
LOG.warning(f"expected tuple, got {part}")
|
||||
continue
|
||||
|
||||
user, assistant = conversation.roles
|
||||
role, content = part
|
||||
|
||||
# Uses "in" because role contains extra characters
|
||||
if user in role:
|
||||
role = (
|
||||
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||
if role_remap
|
||||
else role
|
||||
)
|
||||
turn = role + content
|
||||
# this is still the user query, we should
|
||||
if not content.strip():
|
||||
LOG.warning(f"user turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif assistant in role:
|
||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||
role = (
|
||||
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||
if role_remap
|
||||
else role
|
||||
)
|
||||
turn = role + content
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not content.strip():
|
||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
role_res = self._tokenize(
|
||||
role.rstrip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
len_role = len(role_res["input_ids"])
|
||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
|
||||
elif role == "":
|
||||
turn = content
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
LOG.warning(f"unhandled role: {role}")
|
||||
continue
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
@@ -424,38 +430,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
if not prompt.strip():
|
||||
LOG.warning("Empty text requested for tokenization.")
|
||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||
else:
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.sequence_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.sequence_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
and strip_bos_token
|
||||
):
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
|
||||
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
||||
"""
|
||||
|
||||
@@ -4,10 +4,12 @@ import logging
|
||||
from enum import Enum
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from colorama import Fore
|
||||
from fastchat.conversation import Conversation, get_conv_template
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
IGNORE_TOKEN_ID = -100
|
||||
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
|
||||
|
||||
|
||||
class PromptStyle(Enum):
|
||||
@@ -20,13 +22,19 @@ class PromptStyle(Enum):
|
||||
CHATML = "chatml"
|
||||
|
||||
|
||||
class AlpacaPrompter:
|
||||
class Prompter:
|
||||
"""
|
||||
Base prompter class for all prompters
|
||||
"""
|
||||
|
||||
|
||||
class AlpacaPrompter(Prompter):
|
||||
"""
|
||||
Base class for alpaca prompters
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||
system_format: str = "{system}"
|
||||
turn_format: str
|
||||
turn_no_input_format: str
|
||||
@@ -55,29 +63,38 @@ class AlpacaPrompter:
|
||||
)
|
||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||
|
||||
def _build_result(self, instruction, input_text, output):
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input_text:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_format.format(instruction=instruction, input=input_text)
|
||||
else:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_no_input_prompt)
|
||||
if self.system_no_input_prompt
|
||||
else ""
|
||||
) + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
|
||||
return res
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_format.format(instruction=instruction, input=input)
|
||||
else:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_no_input_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
yield self._build_result(instruction, input, output)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return REPR_TEMPLATE.format(
|
||||
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||
)
|
||||
|
||||
|
||||
class UnpromptedPrompter(AlpacaPrompter):
|
||||
@@ -148,7 +165,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
|
||||
"""
|
||||
|
||||
|
||||
class ReflectAlpacaPrompter:
|
||||
class ReflectAlpacaPrompter(Prompter):
|
||||
"""
|
||||
Prompter for ReflectAlpaca
|
||||
"""
|
||||
@@ -191,14 +208,14 @@ class ReflectAlpacaPrompter:
|
||||
)
|
||||
self.response_split = "ASSISTANT:"
|
||||
|
||||
def build_prompt(
|
||||
def _build_result(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
reflection: Union[None, str] = None,
|
||||
corrected: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
):
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
@@ -212,7 +229,30 @@ class ReflectAlpacaPrompter:
|
||||
corrected=corrected,
|
||||
)
|
||||
res = f"{res}{label}"
|
||||
yield res
|
||||
|
||||
return res
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
reflection: Union[None, str] = None,
|
||||
corrected: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# pylint: disable=duplicate-code
|
||||
yield self._build_result(
|
||||
instruction,
|
||||
input,
|
||||
output,
|
||||
reflection,
|
||||
corrected,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return REPR_TEMPLATE.format(
|
||||
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||
)
|
||||
|
||||
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
@@ -220,7 +260,7 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
@@ -247,7 +287,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
if role_key_model:
|
||||
self.role_key_model = role_key_model
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
def _build_result(self, source):
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
@@ -282,11 +322,20 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
for part in conv.get_turns():
|
||||
return conv.get_turns()
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
turns = self._build_result(source)
|
||||
|
||||
for part in turns:
|
||||
if part[0] and not part[1]:
|
||||
LOG.warning(f"role with empty message: {part[0]}")
|
||||
yield part
|
||||
|
||||
def __repr__(self) -> str:
|
||||
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
||||
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
||||
|
||||
|
||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
"""
|
||||
@@ -304,3 +353,15 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
role_key_human=role_key_human,
|
||||
role_key_model=role_key_model,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedPrompter(Prompter):
|
||||
"""
|
||||
A dummy class for custom prompters
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "Pre-tokenized or custom dataset types are unsupported for logging"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
@@ -10,12 +9,16 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from pkg_resources import get_distribution # type: ignore
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.freeze import freeze_parameters_except
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
@@ -24,7 +27,7 @@ src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.train")
|
||||
LOG = get_logger("axolotl.train")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -39,13 +42,13 @@ class TrainDatasetMeta:
|
||||
|
||||
|
||||
def train(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
dataset_meta: TrainDatasetMeta,
|
||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||
):
|
||||
# load the tokenizer first
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
@@ -53,8 +56,17 @@ def train(
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# Load the model and tokenizer
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
model_ref = None
|
||||
if cfg.rl:
|
||||
# load the model again for model_ref/baseline
|
||||
model_ref, _ = load_model(
|
||||
cfg, tokenizer, inference=cli_args.inference, reference_model=True
|
||||
)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
@@ -73,11 +85,15 @@ def train(
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps
|
||||
)
|
||||
|
||||
model.config.use_cache = False
|
||||
if hasattr(model, "config"):
|
||||
model.config.use_cache = False
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
@@ -87,7 +103,8 @@ def train(
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if cfg.local_rank == 0:
|
||||
@@ -105,10 +122,17 @@ def train(
|
||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||
|
||||
if getattr(cfg, "axolotl_config_path"):
|
||||
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
||||
version = get_distribution("axolotl").version
|
||||
if raw_axolotl_cfg.is_file():
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
||||
|
||||
LOG.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
pretrain_hooks(cfg, trainer)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
@@ -116,9 +140,15 @@ def train(
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
post_train_hooks(cfg, trainer)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
# post training
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "_post_training"):
|
||||
module._post_training(model, name) # pylint: disable=protected-access
|
||||
|
||||
if trainer.is_fsdp_enabled:
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
||||
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
||||
@@ -134,6 +164,22 @@ def train(
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.fsdp:
|
||||
trainer.save_model(cfg.output_dir)
|
||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
||||
|
||||
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
||||
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
||||
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
||||
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
|
||||
# The model name saved is `pytorch_model.bin`
|
||||
unwrapped_model.save_pretrained(
|
||||
cfg.output_dir,
|
||||
is_main_process=trainer.accelerator.is_main_process,
|
||||
save_function=trainer.accelerator.save,
|
||||
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
||||
)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
@@ -142,5 +188,26 @@ def train(
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
elif cfg.hub_model_id:
|
||||
# defensively push to the hub to ensure the model card is updated
|
||||
trainer.push_to_hub()
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def pretrain_hooks(_cfg, _trainer):
|
||||
"""
|
||||
Run hooks right before kicking off the training
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
|
||||
|
||||
def post_train_hooks(_cfg, _trainer):
|
||||
"""
|
||||
Run hooks right after training completes
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import evaluate
|
||||
@@ -37,7 +39,7 @@ from axolotl.utils.distributed import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.utils.trainer import AxolotlTrainingArguments
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
IGNORE_INDEX = -100
|
||||
@@ -124,6 +126,36 @@ class GPUStatsCallback(
|
||||
return control
|
||||
|
||||
|
||||
class LossWatchDogCallback(TrainerCallback):
|
||||
"""Callback to track loss and stop training if loss is too high"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.logged = False
|
||||
self.violations = 0
|
||||
self.threshold = cfg.loss_watchdog_threshold
|
||||
self.patience = cfg.loss_watchdog_patience or 3
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
_args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**_kwargs,
|
||||
):
|
||||
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
|
||||
if state.log_history[-1]["loss"] > self.threshold:
|
||||
self.violations += 1
|
||||
if self.violations >= self.patience:
|
||||
LOG.warning(
|
||||
"Loss is too high, stopping training (loss_watchdog_threshold)"
|
||||
)
|
||||
control.should_training_stop = True
|
||||
else:
|
||||
self.violations = 0
|
||||
return control
|
||||
|
||||
|
||||
def bench_eval_callback_factory(trainer, tokenizer):
|
||||
accuracy = evaluate.load("accuracy")
|
||||
abcd_idx = [
|
||||
@@ -531,10 +563,15 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
artifact = wandb.Artifact(name="axolotl-config", type="config")
|
||||
artifact.add_file(local_path=self.axolotl_config_path)
|
||||
wandb.run.log_artifact(artifact)
|
||||
LOG.info("Axolotl config has been saved to WandB as an artifact.")
|
||||
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
copyfile(self.axolotl_config_path, temp_file.name)
|
||||
wandb.save(temp_file.name)
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the WandB run under files."
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||
return control
|
||||
|
||||
29
src/axolotl/utils/chat_templates.py
Normal file
29
src/axolotl/utils/chat_templates.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
This module provides functionality for selecting chat templates based on user choices.
|
||||
These templates are used for formatting messages in a conversation.
|
||||
"""
|
||||
|
||||
|
||||
def chat_templates(user_choice: str):
|
||||
"""
|
||||
Finds the correct chat_template for the tokenizer_config.
|
||||
|
||||
Args:
|
||||
user_choice (str): The user's choice of template.
|
||||
|
||||
Returns:
|
||||
str: The chosen template string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the user_choice is not found in the templates.
|
||||
"""
|
||||
|
||||
templates = {
|
||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
}
|
||||
|
||||
if user_choice in templates:
|
||||
return templates[user_choice]
|
||||
|
||||
raise ValueError(f"Template '{user_choice}' not found.")
|
||||
@@ -2,12 +2,16 @@
|
||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Dict, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSeq2Seq:
|
||||
@@ -119,3 +123,79 @@ class DataCollatorForSeq2Seq:
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Collator for multipack specific to the using the BatchSampler
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
chunked_data = {}
|
||||
for feature in features[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(1) * np.array(item[feature])
|
||||
for item in features
|
||||
if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaDataCollator:
|
||||
"""
|
||||
Collator for State Space Models (Mamba)
|
||||
"""
|
||||
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
|
||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||
input_ids, labels = tuple(
|
||||
[torch.LongTensor(instance[key]) for instance in instances]
|
||||
for key in ("input_ids", "labels")
|
||||
)
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
labels, batch_first=True, padding_value=IGNORE_INDEX
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Collator for multipack specific to the using the BatchSampler
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
chunked_data = {}
|
||||
for feature in features.keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [(1) * np.array(item) for item in features[feature]]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [np.array(item) for item in features[feature]]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
@@ -27,7 +27,7 @@ def choose_device(cfg):
|
||||
|
||||
cfg.device = get_device()
|
||||
if cfg.world_size == 1:
|
||||
cfg.device_map = "auto"
|
||||
cfg.device_map = cfg.device_map or "auto"
|
||||
else:
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": torch.cuda.current_device()}
|
||||
@@ -77,8 +77,20 @@ def normalize_config(cfg):
|
||||
else:
|
||||
cfg.torch_dtype = torch.float32
|
||||
|
||||
if cfg.saves_per_epoch:
|
||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||
if save_steps < 1.0: # prevent saves on every step
|
||||
cfg.save_steps = save_steps
|
||||
if cfg.evals_per_epoch:
|
||||
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
|
||||
if eval_steps < 1.0: # prevent evals on every step
|
||||
cfg.eval_steps = eval_steps
|
||||
|
||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
||||
|
||||
if not cfg.base_model_config:
|
||||
cfg.base_model_config = cfg.base_model
|
||||
|
||||
model_config = load_model_config(cfg)
|
||||
cfg.model_config_type = model_config.model_type
|
||||
|
||||
@@ -119,6 +131,22 @@ def normalize_config(cfg):
|
||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||
)
|
||||
|
||||
cfg.is_qwen_derived_model = (
|
||||
(
|
||||
hasattr(model_config, "model_type")
|
||||
and model_config.model_type
|
||||
in [
|
||||
"qwen",
|
||||
]
|
||||
)
|
||||
or cfg.is_qwen_derived_model
|
||||
or "qwen" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "qwen" in cfg.model_type.lower())
|
||||
)
|
||||
|
||||
if isinstance(cfg.learning_rate, str):
|
||||
cfg.learning_rate = float(cfg.learning_rate)
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
|
||||
@@ -159,7 +187,11 @@ def validate_config(cfg):
|
||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||
)
|
||||
if cfg.eval_batch_size != cfg.micro_batch_size:
|
||||
if (
|
||||
cfg.eval_batch_size
|
||||
and cfg.micro_batch_size
|
||||
and cfg.eval_batch_size != cfg.micro_batch_size
|
||||
):
|
||||
LOG.warning(
|
||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||
)
|
||||
@@ -189,9 +221,15 @@ def validate_config(cfg):
|
||||
if not cfg.load_in_4bit:
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with QLoRA")
|
||||
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
||||
raise ValueError("Fused modules are not supported with LoRA")
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
@@ -205,6 +243,9 @@ def validate_config(cfg):
|
||||
if cfg.lr_scheduler == "one_cycle":
|
||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||
|
||||
if cfg.trust_remote_code:
|
||||
LOG.warning(
|
||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||
@@ -320,6 +361,27 @@ def validate_config(cfg):
|
||||
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||
"sharegpt_simple", "sharegpt"
|
||||
)
|
||||
|
||||
if cfg.saves_per_epoch and cfg.save_steps:
|
||||
raise ValueError(
|
||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||
)
|
||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if (
|
||||
cfg.evals_per_epoch
|
||||
and cfg.evaluation_strategy
|
||||
and cfg.evaluation_strategy != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
@@ -339,6 +401,67 @@ def validate_config(cfg):
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.sample_packing
|
||||
and cfg.eval_table_size
|
||||
and cfg.eval_sample_packing is not False
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
||||
)
|
||||
|
||||
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
|
||||
raise ValueError(
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
|
||||
if cfg.rope_scaling:
|
||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||
|
||||
if cfg.warmup_steps and cfg.warmup_ratio:
|
||||
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
||||
|
||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||
cfg.wandb_name = cfg.wandb_run_id
|
||||
|
||||
LOG.warning(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
)
|
||||
|
||||
if cfg.noisy_embedding_alpha is not None:
|
||||
# Deprecated, use neftune_noise_alpha
|
||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||
if cfg.neftune_noise_alpha is None:
|
||||
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
||||
else:
|
||||
# User is providing both; bail and have them sort out their settings
|
||||
raise ValueError(
|
||||
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
||||
)
|
||||
|
||||
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||
|
||||
if (
|
||||
cfg.adapter
|
||||
and cfg.tokens
|
||||
and (
|
||||
not cfg.lora_modules_to_save
|
||||
or not all(
|
||||
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
|
||||
)
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
|
||||
)
|
||||
|
||||
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
||||
raise ValueError(
|
||||
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
||||
)
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user