Compare commits
138 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
272bced137 | ||
|
|
c371d6b546 | ||
|
|
d6273188f0 | ||
|
|
72797b04a5 | ||
|
|
de47bb5eb0 | ||
|
|
c04df54b4b | ||
|
|
e3716db386 | ||
|
|
97943d8fc4 | ||
|
|
9d3f80cd40 | ||
|
|
bfae79a634 | ||
|
|
5a85ee16eb | ||
|
|
3678a6c41d | ||
|
|
f8ae59b0a8 | ||
|
|
4f4d638b84 | ||
|
|
ba043a361e | ||
|
|
41353d2ea0 | ||
|
|
f6ecf14dd4 | ||
|
|
dec66d7c53 | ||
|
|
76357dc5da | ||
|
|
70b46ca4f4 | ||
|
|
85dd4d525b | ||
|
|
384b817dc0 | ||
|
|
db9094df0f | ||
|
|
6ef46f8dca | ||
|
|
628b754824 | ||
|
|
37820f6540 | ||
|
|
7d4185ffcb | ||
|
|
93ebec1ac5 | ||
|
|
2e61dc3180 | ||
|
|
1ffa3866f2 | ||
|
|
62ba1609b6 | ||
|
|
7bbaac98f7 | ||
|
|
161bcb6517 | ||
|
|
d25c34caa6 | ||
|
|
13e938149d | ||
|
|
85de004dd4 | ||
|
|
80ec7af358 | ||
|
|
f28e75513b | ||
|
|
5ada140ff0 | ||
|
|
712fd27b3f | ||
|
|
ef24342538 | ||
|
|
5ea3aa31f0 | ||
|
|
f1f60cb5b2 | ||
|
|
450e04d3c4 | ||
|
|
b0cf397ecb | ||
|
|
5f79b8242f | ||
|
|
f1de29dd1e | ||
|
|
7fabc4d95e | ||
|
|
9a5eb3990c | ||
|
|
86487c2e96 | ||
|
|
35f9b0f149 | ||
|
|
68b227a7d8 | ||
|
|
03c6318ba3 | ||
|
|
40a6362c92 | ||
|
|
d339beb9d9 | ||
|
|
fde091cb12 | ||
|
|
06ae39200b | ||
|
|
a581e9f8f6 | ||
|
|
992e742cdc | ||
|
|
a1da39cd48 | ||
|
|
58ec8b1113 | ||
|
|
476a205cea | ||
|
|
3e3229e2d9 | ||
|
|
1d21aa6b0a | ||
|
|
71b7ea3c05 | ||
|
|
a48dbf6561 | ||
|
|
6a4562ac08 | ||
|
|
1115c501b8 | ||
|
|
7ee3c4cacb | ||
|
|
fb12895a17 | ||
|
|
9fc29e082b | ||
|
|
575a082aae | ||
|
|
ddf815022a | ||
|
|
9bf854e59c | ||
|
|
797f3dd1de | ||
|
|
0de1457189 | ||
|
|
3cc67d2cdd | ||
|
|
1bc11868eb | ||
|
|
b3a61e8ce2 | ||
|
|
8a8d1c4023 | ||
|
|
332984db18 | ||
|
|
48630f5b34 | ||
|
|
b33c1d55a2 | ||
|
|
0c2a630326 | ||
|
|
db8a8afcba | ||
|
|
14706504e3 | ||
|
|
501b4d1379 | ||
|
|
306fe19c54 | ||
|
|
614cff4107 | ||
|
|
1a6309c8a6 | ||
|
|
105d0b350b | ||
|
|
f544ab2bed | ||
|
|
641e6f7e51 | ||
|
|
6dc68a653f | ||
|
|
7de6a5639c | ||
|
|
c74f045ba7 | ||
|
|
0402d19759 | ||
|
|
b2430ce670 | ||
|
|
4c834bf25d | ||
|
|
8056ecd30e | ||
|
|
738a057674 | ||
|
|
cdc71f73c8 | ||
|
|
6459ac7357 | ||
|
|
964d858da0 | ||
|
|
10388a8daf | ||
|
|
9f7e8a971d | ||
|
|
637ed095a0 | ||
|
|
827ec3d274 | ||
|
|
8b79ff0e94 | ||
|
|
0800885e2f | ||
|
|
d3193beac3 | ||
|
|
2e71ff03a6 | ||
|
|
facc49f32b | ||
|
|
e50ab072e2 | ||
|
|
05bd6f1122 | ||
|
|
20aa4b57d2 | ||
|
|
11d1d607db | ||
|
|
6c81c61bc4 | ||
|
|
9b43e7ea15 | ||
|
|
2d8def68dc | ||
|
|
44c9d0151a | ||
|
|
ca84cca2c0 | ||
|
|
32eeeb5b64 | ||
|
|
afedc470bd | ||
|
|
9923b72649 | ||
|
|
21cf09b608 | ||
|
|
15d3a654bf | ||
|
|
a21935f07a | ||
|
|
8966a6f566 | ||
|
|
e4d1585c4e | ||
|
|
70157ccb8f | ||
|
|
3a99495b05 | ||
|
|
440c3ab527 | ||
|
|
992d57f20a | ||
|
|
91a016f410 | ||
|
|
a045db0214 | ||
|
|
e1b214c62b | ||
|
|
3553172e3c |
7
.github/workflows/base.yml
vendored
7
.github/workflows/base.yml
vendored
@@ -28,7 +28,12 @@ jobs:
|
|||||||
- cuda: "118"
|
- cuda: "118"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.0
|
pytorch: 2.1.1
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
- cuda: "121"
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.10"
|
||||||
|
pytorch: 2.1.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
56
.github/workflows/main.yml
vendored
56
.github/workflows/main.yml
vendored
@@ -23,39 +23,60 @@ jobs:
|
|||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.1
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.0
|
pytorch: 2.1.1
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.10"
|
||||||
|
pytorch: 2.1.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Docker metadata
|
- name: Docker metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v3
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl
|
images: winglian/axolotl
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
||||||
uses: docker/setup-buildx-action@v2
|
- name: Build and export to Docker
|
||||||
- name: Build
|
uses: docker/build-push-action@v5
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
|
load: true
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
file: ./docker/Dockerfile
|
file: ./docker/Dockerfile
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
tags: |
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
- name: Unit Tests
|
||||||
|
run: |
|
||||||
|
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||||
|
- name: Push to Docker Hub
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
run: |
|
||||||
|
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
|
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
|
if [ -n "$latest_tag" ]; then
|
||||||
|
docker push "$latest_tag"
|
||||||
|
fi
|
||||||
|
|
||||||
build-axolotl-runpod:
|
build-axolotl-runpod:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||||
@@ -77,26 +98,31 @@ jobs:
|
|||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.0
|
pytorch: 2.1.1
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.10"
|
||||||
|
pytorch: 2.1.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Docker metadata
|
- name: Docker metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v3
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-runpod
|
images: winglian/axolotl-runpod
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v2
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
|
|||||||
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@@ -71,8 +71,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
|
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
||||||
pip3 uninstall -y transformers accelerate
|
pip3 uninstall -y transformers accelerate
|
||||||
pip3 install -U -e .[flash-attn]
|
pip3 install -U -e .[flash-attn,mamba-ssm]
|
||||||
pip3 install -r requirements-tests.txt
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run e2e tests
|
- name: Run e2e tests
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ ignore_missing_imports = True
|
|||||||
[mypy-axolotl.monkeypatch.*]
|
[mypy-axolotl.monkeypatch.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
[mypy-axolotl.models.mixtral.*]
|
||||||
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-axolotl.models.phi.*]
|
[mypy-axolotl.models.phi.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|||||||
295
README.md
295
README.md
@@ -25,17 +25,20 @@ Features:
|
|||||||
- [Installation](#installation)
|
- [Installation](#installation)
|
||||||
- [Docker](#docker)
|
- [Docker](#docker)
|
||||||
- [Conda/Pip venv](#condapip-venv)
|
- [Conda/Pip venv](#condapip-venv)
|
||||||
|
- [Runpod](#runpod)
|
||||||
- [LambdaLabs](#lambdalabs)
|
- [LambdaLabs](#lambdalabs)
|
||||||
- [Windows](#windows)
|
- [Windows](#windows)
|
||||||
|
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||||
- [Dataset](#dataset)
|
- [Dataset](#dataset)
|
||||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||||
- [Config](#config)
|
- [Config](#config)
|
||||||
- [Train](#train)
|
- [Train](#train)
|
||||||
- [Training w/ Deepspeed](#training-with-deepspeed)
|
|
||||||
- [Inference](#inference)
|
- [Inference](#inference)
|
||||||
- [Merge LORA to Base](#merge-lora-to-base)
|
- [Merge LORA to Base](#merge-lora-to-base)
|
||||||
|
- [Special Tokens](#special-tokens)
|
||||||
- [Common Errors](#common-errors-)
|
- [Common Errors](#common-errors-)
|
||||||
|
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
||||||
- [Need Help?](#need-help-)
|
- [Need Help?](#need-help-)
|
||||||
- [Badge](#badge-)
|
- [Badge](#badge-)
|
||||||
- [Community Showcase](#community-showcase)
|
- [Community Showcase](#community-showcase)
|
||||||
@@ -64,17 +67,21 @@ Features:
|
|||||||
|
|
||||||
## Axolotl supports
|
## Axolotl supports
|
||||||
|
|
||||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||||
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||||
|
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||||
|
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
|
||||||
|
|
||||||
## Quickstart ⚡
|
## Quickstart ⚡
|
||||||
@@ -83,20 +90,29 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
|||||||
|
|
||||||
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
||||||
|
|
||||||
|
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
|
||||||
|
|
||||||
|
### For developers
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging
|
pip3 install packaging
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install -e '.[flash-attn,deepspeed]'
|
||||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
```bash
|
||||||
# finetune lora
|
# finetune lora
|
||||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||||
|
|
||||||
# inference
|
# inference
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--lora_model_dir="./lora-out"
|
--lora_model_dir="./lora-out"
|
||||||
|
|
||||||
|
# gradio
|
||||||
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
|
--lora_model_dir="./lora-out" --gradio
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
@@ -107,7 +123,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
```bash
|
```bash
|
||||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||||
```
|
```
|
||||||
- `winglian/axolotl-runpod:main-latest`: for runpod or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
|
||||||
|
|
||||||
Or run on the current files for development:
|
Or run on the current files for development:
|
||||||
|
|
||||||
@@ -115,6 +130,27 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
docker compose up -d
|
docker compose up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Docker advanced</summary>
|
||||||
|
|
||||||
|
A more powerful Docker command to run would be this:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||||
|
```
|
||||||
|
|
||||||
|
It additionally:
|
||||||
|
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
||||||
|
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
||||||
|
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
||||||
|
* The `--privileged` flag gives all capabilities to the container.
|
||||||
|
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
|
||||||
|
|
||||||
|
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
#### Conda/Pip venv
|
#### Conda/Pip venv
|
||||||
1. Install python >=**3.9**
|
1. Install python >=**3.9**
|
||||||
|
|
||||||
@@ -131,6 +167,10 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
```
|
```
|
||||||
Get the token at huggingface.co/settings/tokens
|
Get the token at huggingface.co/settings/tokens
|
||||||
|
|
||||||
|
#### Runpod
|
||||||
|
|
||||||
|
Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||||
|
|
||||||
#### LambdaLabs
|
#### LambdaLabs
|
||||||
<details>
|
<details>
|
||||||
|
|
||||||
@@ -178,6 +218,28 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
#### Windows
|
#### Windows
|
||||||
Please use WSL or Docker!
|
Please use WSL or Docker!
|
||||||
|
|
||||||
|
|
||||||
|
#### Launching on public clouds via SkyPilot
|
||||||
|
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
|
||||||
|
```bash
|
||||||
|
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
|
||||||
|
sky check
|
||||||
|
```
|
||||||
|
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
|
||||||
|
```
|
||||||
|
git clone https://github.com/skypilot-org/skypilot.git
|
||||||
|
cd skypilot/llm/axolotl
|
||||||
|
```
|
||||||
|
Use one command to launch:
|
||||||
|
```bash
|
||||||
|
# On-demand
|
||||||
|
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
||||||
|
|
||||||
|
# Managed spot (auto-recovery on preemption)
|
||||||
|
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Dataset
|
### Dataset
|
||||||
|
|
||||||
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
|
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
|
||||||
@@ -187,10 +249,17 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"instruction": "...", "input": "...", "output": "..."}
|
{"instruction": "...", "input": "...", "output": "..."}
|
||||||
```
|
```
|
||||||
- `sharegpt`: conversations where `from` is `human`/`gpt`
|
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
|
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
|
||||||
|
```yml
|
||||||
|
datasets:
|
||||||
|
- path: <your-path>
|
||||||
|
type: sharegpt
|
||||||
|
conversation: llama-2
|
||||||
|
```
|
||||||
- `completion`: raw corpus
|
- `completion`: raw corpus
|
||||||
```json
|
```json
|
||||||
{"text": "..."}
|
{"text": "..."}
|
||||||
@@ -297,25 +366,24 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
|
|
||||||
#### How to add custom prompts
|
#### How to add custom prompts
|
||||||
|
|
||||||
Using yaml. Example:
|
For a dataset that is preprocessed for instruction purposes:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"instruction": "...", "output": "..."}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can use this example in your YAML config:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
datasets:
|
datasets:
|
||||||
- path: repo
|
- path: repo
|
||||||
type:
|
type:
|
||||||
system_prompt: ""
|
system_prompt: ""
|
||||||
no_input_format: |-
|
field_system: system
|
||||||
User: {instruction}<|end_of_turn|>
|
format: "[INST] {instruction} [/INST]"
|
||||||
Assistant:
|
no_input_format: "[INST] {instruction} [/INST]"
|
||||||
format: |-
|
|
||||||
User: {instruction}
|
|
||||||
{input}<|end_of_turn|>
|
|
||||||
Assistant:
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Using file:
|
|
||||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
|
||||||
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
|
||||||
|
|
||||||
#### How to use your custom pretokenized dataset
|
#### How to use your custom pretokenized dataset
|
||||||
|
|
||||||
- Do not pass a `type:`
|
- Do not pass a `type:`
|
||||||
@@ -357,6 +425,13 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- typescript
|
- typescript
|
||||||
type: ... # unimplemented custom format
|
type: ... # unimplemented custom format
|
||||||
|
|
||||||
|
# fastchat conversation
|
||||||
|
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
|
datasets:
|
||||||
|
- path: ...
|
||||||
|
type: sharegpt
|
||||||
|
conversation: chatml
|
||||||
|
|
||||||
# local
|
# local
|
||||||
datasets:
|
datasets:
|
||||||
- path: data.jsonl # or json
|
- path: data.jsonl # or json
|
||||||
@@ -368,6 +443,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- path: knowrohit07/know_sql
|
- path: knowrohit07/know_sql
|
||||||
type: context_qa.load_v2
|
type: context_qa.load_v2
|
||||||
train_on_split: validation
|
train_on_split: validation
|
||||||
|
|
||||||
|
# loading from s3 or gcs
|
||||||
|
# s3 creds will be loaded from the system default and gcs only supports public access
|
||||||
|
dataset:
|
||||||
|
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
||||||
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
- loading
|
- loading
|
||||||
@@ -395,7 +476,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|
||||||
<summary>All yaml options</summary>
|
<summary>All yaml options (click me)</summary>
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||||
@@ -430,6 +511,23 @@ is_falcon_derived_model:
|
|||||||
is_llama_derived_model:
|
is_llama_derived_model:
|
||||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||||
is_mistral_derived_model:
|
is_mistral_derived_model:
|
||||||
|
is_qwen_derived_model:
|
||||||
|
|
||||||
|
# optional overrides to the base model configuration
|
||||||
|
model_config:
|
||||||
|
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||||
|
rope_scaling:
|
||||||
|
type: # linear | dynamic
|
||||||
|
factor: # float
|
||||||
|
|
||||||
|
# optional overrides to the bnb 4bit quantization configuration
|
||||||
|
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
||||||
|
bnb_config_kwargs:
|
||||||
|
# These are default values
|
||||||
|
llm_int8_has_fp16_weight: false
|
||||||
|
bnb_4bit_quant_type: nf4
|
||||||
|
bnb_4bit_use_double_quant: true
|
||||||
|
|
||||||
|
|
||||||
# Whether you are training a 4-bit GPTQ quantized model
|
# Whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq: true
|
||||||
@@ -454,7 +552,7 @@ float16: true
|
|||||||
|
|
||||||
# A list of one or more datasets to finetune the model with
|
# A list of one or more datasets to finetune the model with
|
||||||
datasets:
|
datasets:
|
||||||
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
|
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||||
- path: vicgalle/alpaca-gpt4
|
- path: vicgalle/alpaca-gpt4
|
||||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||||
@@ -462,7 +560,12 @@ datasets:
|
|||||||
data_files: # Optional[str] path to source data files
|
data_files: # Optional[str] path to source data files
|
||||||
shards: # Optional[int] number of shards to split data into
|
shards: # Optional[int] number of shards to split data into
|
||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
|
train_on_split: train # Optional[str] name of dataset split to load from
|
||||||
|
|
||||||
|
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
|
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
|
field_human: # Optional[str]. Human key to use for conversation.
|
||||||
|
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||||
|
|
||||||
# Custom user prompt
|
# Custom user prompt
|
||||||
- path: repo
|
- path: repo
|
||||||
@@ -486,6 +589,9 @@ datasets:
|
|||||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||||
field:
|
field:
|
||||||
|
|
||||||
|
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||||
|
# Currently supports chatml and inst (mistral/mixtral)
|
||||||
|
chat_template: chatml
|
||||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
@@ -528,6 +634,12 @@ eval_sample_packing:
|
|||||||
sample_packing_eff_est:
|
sample_packing_eff_est:
|
||||||
total_num_tokens:
|
total_num_tokens:
|
||||||
|
|
||||||
|
# Passed through to transformers when loading the model when launched without accelerate
|
||||||
|
# Use `sequential` when training w/ model parallelism to limit memory
|
||||||
|
device_map:
|
||||||
|
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
||||||
|
max_memory:
|
||||||
|
|
||||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||||
adapter: lora
|
adapter: lora
|
||||||
# If you already have a lora model trained that you want to load, put that here.
|
# If you already have a lora model trained that you want to load, put that here.
|
||||||
@@ -571,11 +683,13 @@ relora_warmup_steps: # Number of per-restart warmup steps
|
|||||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||||
|
|
||||||
# wandb configuration if you're using it
|
# wandb configuration if you're using it
|
||||||
|
# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
||||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||||
wandb_project: # Your wandb project name
|
wandb_project: # Your wandb project name
|
||||||
wandb_entity: # A wandb Team name if using a Team
|
wandb_entity: # A wandb Team name if using a Team
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id: # Set the name of your wandb run
|
wandb_name: # Set the name of your wandb run
|
||||||
|
wandb_run_id: # Set the ID of your wandb run
|
||||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||||
|
|
||||||
# Where to save the full-finetuned model to
|
# Where to save the full-finetuned model to
|
||||||
@@ -592,14 +706,17 @@ gradient_accumulation_steps: 1
|
|||||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
eval_batch_size:
|
eval_batch_size:
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
warmup_steps: 100
|
warmup_steps: 100 # cannot use with warmup_ratio
|
||||||
|
warmup_ratio: 0.05 # cannot use with warmup_steps
|
||||||
learning_rate: 0.00003
|
learning_rate: 0.00003
|
||||||
lr_quadratic_warmup:
|
lr_quadratic_warmup:
|
||||||
logging_steps:
|
logging_steps:
|
||||||
|
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||||
|
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
||||||
save_strategy: # Set to `no` to skip checkpoint saves
|
save_strategy: # Set to `no` to skip checkpoint saves
|
||||||
save_steps: # Leave empty to save at each epoch
|
save_steps: # Leave empty to save at each epoch
|
||||||
eval_steps: # Leave empty to eval at each epoch
|
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||||
save_total_limit: # Checkpoints saved at a time
|
save_total_limit: # Checkpoints saved at a time
|
||||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||||
# if both are set, num_epochs will not be guaranteed.
|
# if both are set, num_epochs will not be guaranteed.
|
||||||
@@ -609,6 +726,9 @@ max_steps:
|
|||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
|
||||||
|
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||||
|
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||||
|
|
||||||
# Save model as safetensors (require safetensors package)
|
# Save model as safetensors (require safetensors package)
|
||||||
save_safetensors:
|
save_safetensors:
|
||||||
|
|
||||||
@@ -675,7 +795,7 @@ max_grad_norm:
|
|||||||
# Augmentation techniques
|
# Augmentation techniques
|
||||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||||
# currently only supported on Llama and Mistral
|
# currently only supported on Llama and Mistral
|
||||||
noisy_embedding_alpha:
|
neftune_noise_alpha:
|
||||||
|
|
||||||
# Whether to bettertransformers
|
# Whether to bettertransformers
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
@@ -685,18 +805,11 @@ xformers_attention:
|
|||||||
flash_attention:
|
flash_attention:
|
||||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||||
|
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||||
|
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||||
# Whether to use scaled-dot-product attention
|
# Whether to use scaled-dot-product attention
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
sdp_attention:
|
sdp_attention:
|
||||||
# Landmark attention (only llama)
|
|
||||||
landmark_attention:
|
|
||||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
|
||||||
# LLaMA only
|
|
||||||
xpos_rope:
|
|
||||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
|
||||||
rope_scaling:
|
|
||||||
type: # linear | dynamic
|
|
||||||
factor: # float
|
|
||||||
|
|
||||||
# Resume from a specific checkpoint dir
|
# Resume from a specific checkpoint dir
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
@@ -814,14 +927,41 @@ Run
|
|||||||
accelerate launch -m axolotl.cli.train your_config.yml
|
accelerate launch -m axolotl.cli.train your_config.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Multi-GPU
|
#### Preprocess dataset
|
||||||
|
|
||||||
|
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||||
|
This is recommended for large datasets.
|
||||||
|
|
||||||
|
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
||||||
|
- Use `--debug` to see preprocessed examples.
|
||||||
|
|
||||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES="" accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
|
python -m axolotl.cli.preprocess your_config.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
##### Config
|
#### Multi-GPU
|
||||||
|
|
||||||
|
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
|
||||||
|
is the recommended multi-GPU option currently because FSDP may experience
|
||||||
|
[loss instability](https://github.com/huggingface/transformers/issues/26498).
|
||||||
|
|
||||||
|
##### DeepSpeed
|
||||||
|
|
||||||
|
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||||
|
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||||
|
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||||
|
|
||||||
|
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
deepspeed: deepspeed/zero1.json
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell
|
||||||
|
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||||
|
```
|
||||||
|
|
||||||
|
##### FSDP
|
||||||
|
|
||||||
- llama FSDP
|
- llama FSDP
|
||||||
```yaml
|
```yaml
|
||||||
@@ -836,37 +976,40 @@ fsdp_config:
|
|||||||
|
|
||||||
##### Weights & Biases Logging
|
##### Weights & Biases Logging
|
||||||
|
|
||||||
|
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
||||||
|
|
||||||
- wandb options
|
- wandb options
|
||||||
```yaml
|
```yaml
|
||||||
wandb_mode:
|
wandb_mode:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
```
|
```
|
||||||
|
|
||||||
### Training with Deepspeed
|
##### Special Tokens
|
||||||
|
|
||||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
||||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
|
||||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
|
||||||
|
|
||||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
```yml
|
||||||
|
special_tokens:
|
||||||
```shell
|
bos_token: "<s>"
|
||||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
|
tokens: # these are delimiters
|
||||||
|
- "<|im_start|>"
|
||||||
|
- "<|im_end|>"
|
||||||
```
|
```
|
||||||
|
|
||||||
or
|
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
|
||||||
|
|
||||||
```yaml
|
### Inference Playground
|
||||||
deepspeed: deepspeed/zero1.json
|
|
||||||
```
|
|
||||||
|
|
||||||
### Inference
|
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.
|
||||||
|
The config file is the same config file used for training.
|
||||||
|
|
||||||
Pass the appropriate flag to the train command:
|
Pass the appropriate flag to the inference command, depending upon what kind of model was trained:
|
||||||
|
|
||||||
- Pretrained LORA:
|
- Pretrained LORA:
|
||||||
```bash
|
```bash
|
||||||
@@ -881,6 +1024,10 @@ Pass the appropriate flag to the train command:
|
|||||||
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
||||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||||
```
|
```
|
||||||
|
-- With gradio hosting
|
||||||
|
```bash
|
||||||
|
python -m axolotl.cli.inference examples/your_config.yml --gradio
|
||||||
|
```
|
||||||
|
|
||||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
||||||
|
|
||||||
@@ -891,7 +1038,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
|
|||||||
Add below flag to train command above
|
Add below flag to train command above
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model"
|
||||||
```
|
```
|
||||||
|
|
||||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||||
@@ -902,6 +1049,8 @@ CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
|
|||||||
|
|
||||||
## Common Errors 🧰
|
## Common Errors 🧰
|
||||||
|
|
||||||
|
See also the [FAQ's](./docs/faq.md).
|
||||||
|
|
||||||
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
|
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
|
||||||
|
|
||||||
Please reduce any below
|
Please reduce any below
|
||||||
@@ -910,6 +1059,10 @@ Please reduce any below
|
|||||||
- `gradient_accumulation_steps`
|
- `gradient_accumulation_steps`
|
||||||
- `sequence_len`
|
- `sequence_len`
|
||||||
|
|
||||||
|
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
|
||||||
|
|
||||||
|
Using adamw_bnb_8bit might also save you some memory.
|
||||||
|
|
||||||
> `failed (exitcode: -9)`
|
> `failed (exitcode: -9)`
|
||||||
|
|
||||||
Usually means your system has run out of system memory.
|
Usually means your system has run out of system memory.
|
||||||
@@ -932,6 +1085,20 @@ It's safe to ignore it.
|
|||||||
|
|
||||||
See the [NCCL](docs/nccl.md) guide.
|
See the [NCCL](docs/nccl.md) guide.
|
||||||
|
|
||||||
|
|
||||||
|
### Tokenization Mismatch b/w Inference & Training
|
||||||
|
|
||||||
|
For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks.
|
||||||
|
|
||||||
|
If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following:
|
||||||
|
|
||||||
|
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
|
||||||
|
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
|
||||||
|
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
|
||||||
|
4. As an additional troubleshooting step, you can look look at the token ids between 1 and 2 to make sure they are identical.
|
||||||
|
|
||||||
|
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
|
||||||
|
|
||||||
## Need help? 🙋♂️
|
## Need help? 🙋♂️
|
||||||
|
|
||||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||||
|
|||||||
@@ -24,16 +24,6 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"scheduler": {
|
|
||||||
"type": "WarmupDecayLR",
|
|
||||||
"params": {
|
|
||||||
"warmup_min_lr": "auto",
|
|
||||||
"warmup_max_lr": "auto",
|
|
||||||
"warmup_num_steps": "auto",
|
|
||||||
"warmup_type": "linear",
|
|
||||||
"total_num_steps": "auto"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -28,16 +28,6 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"scheduler": {
|
|
||||||
"type": "WarmupDecayLR",
|
|
||||||
"params": {
|
|
||||||
"warmup_min_lr": "auto",
|
|
||||||
"warmup_max_lr": "auto",
|
|
||||||
"warmup_num_steps": "auto",
|
|
||||||
"warmup_type": "linear",
|
|
||||||
"total_num_steps": "auto"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -1,14 +1,6 @@
|
|||||||
{
|
{
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 3,
|
"stage": 3,
|
||||||
"offload_optimizer": {
|
|
||||||
"device": "cpu",
|
|
||||||
"pin_memory": true
|
|
||||||
},
|
|
||||||
"offload_param": {
|
|
||||||
"device": "cpu",
|
|
||||||
"pin_memory": true
|
|
||||||
},
|
|
||||||
"overlap_comm": true,
|
"overlap_comm": true,
|
||||||
"contiguous_gradients": true,
|
"contiguous_gradients": true,
|
||||||
"sub_group_size": 0,
|
"sub_group_size": 0,
|
||||||
@@ -40,15 +32,6 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"scheduler": {
|
|
||||||
"type": "WarmupLR",
|
|
||||||
"params": {
|
|
||||||
"warmup_min_lr": "auto",
|
|
||||||
"warmup_max_lr": "auto",
|
|
||||||
"warmup_num_steps": "auto",
|
|
||||||
"warmup_type": "linear"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
39
deepspeed/zero3_bf16.json
Normal file
39
deepspeed/zero3_bf16.json
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
{
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 0,
|
||||||
|
"reduce_bucket_size": "auto",
|
||||||
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
|
"stage3_param_persistence_threshold": "auto",
|
||||||
|
"stage3_max_live_parameters": 0,
|
||||||
|
"stage3_max_reuse_distance": 0,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"auto_cast": false,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 32,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"betas": "auto",
|
||||||
|
"eps": "auto",
|
||||||
|
"weight_decay": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
||||||
47
deepspeed/zero3_cpu.json
Normal file
47
deepspeed/zero3_cpu.json
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
{
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"offload_param": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"overlap_comm": true,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 0,
|
||||||
|
"reduce_bucket_size": "auto",
|
||||||
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
|
"stage3_param_persistence_threshold": "auto",
|
||||||
|
"stage3_max_live_parameters": 0,
|
||||||
|
"stage3_max_reuse_distance": 0,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"auto_cast": false,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 32,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"betas": "auto",
|
||||||
|
"eps": "auto",
|
||||||
|
"weight_decay": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ ARG PYTORCH_VERSION="2.0.1"
|
|||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y vim curl
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
@@ -19,13 +19,15 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
|||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[flash-attn]; \
|
pip install -e .[deepspeed,flash-attn]; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# So we can test the Docker image
|
||||||
|
RUN pip install pytest
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
git config --get remote.origin.fetch
|
git config --get remote.origin.fetch
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
|
|||||||
ARG PYTHON_VERSION="3.9"
|
ARG PYTHON_VERSION="3.9"
|
||||||
ARG PYTORCH_VERSION="2.0.1"
|
ARG PYTORCH_VERSION="2.0.1"
|
||||||
ARG CUDA="118"
|
ARG CUDA="118"
|
||||||
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
|
||||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
||||||
@@ -27,47 +29,9 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||||
|
|
||||||
FROM base-builder AS deepspeed-builder
|
RUN git lfs install --skip-repo && \
|
||||||
|
pip3 install awscli && \
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
|
|
||||||
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
|
||||||
cd DeepSpeed && \
|
|
||||||
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel
|
|
||||||
|
|
||||||
FROM base-builder AS bnb-builder
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
ARG CUDA="118"
|
|
||||||
ENV CUDA=$CUDA
|
|
||||||
ARG MAX_JOBS="-1"
|
|
||||||
ENV MAX_JOBS=$MAX_JOBS
|
|
||||||
|
|
||||||
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
|
|
||||||
cd bitsandbytes && \
|
|
||||||
CUDA_VERSION=$CUDA make cuda11x && \
|
|
||||||
python setup.py bdist_wheel
|
|
||||||
|
|
||||||
FROM base-builder
|
|
||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
|
||||||
|
|
||||||
RUN mkdir -p /workspace/builds
|
|
||||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
|
||||||
|
|
||||||
RUN mkdir -p /workspace/wheels/bitsandbytes
|
|
||||||
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
|
||||||
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
|
|
||||||
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
|
|
||||||
|
|
||||||
RUN pip3 install wheels/deepspeed-*.whl
|
|
||||||
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
|
||||||
RUN git lfs install --skip-repo
|
|
||||||
RUN pip3 install awscli && \
|
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ FROM winglian/axolotl:$BASE_TAG
|
|||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
|
|
||||||
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
||||||
|
|
||||||
|
|||||||
18
docs/faq.md
Normal file
18
docs/faq.md
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# Axolotl FAQ's
|
||||||
|
|
||||||
|
|
||||||
|
> The trainer stopped and hasn't progressed in several minutes.
|
||||||
|
|
||||||
|
Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md)
|
||||||
|
|
||||||
|
> Exitcode -9
|
||||||
|
|
||||||
|
This usually happens when you run out of system RAM.
|
||||||
|
|
||||||
|
> Exitcode -7 while using deepspeed
|
||||||
|
|
||||||
|
Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||||
|
|
||||||
|
> AttributeError: 'DummyOptim' object has no attribute 'step'
|
||||||
|
|
||||||
|
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: cerebras/btlm-3b-8k-base
|
base_model: cerebras/btlm-3b-8k-base
|
||||||
base_model_config: cerebras/btlm-3b-8k-base
|
|
||||||
model_type: AutoModelForCausalLM
|
model_type: AutoModelForCausalLM
|
||||||
tokenizer_type: GPT2Tokenizer
|
tokenizer_type: GPT2Tokenizer
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
@@ -15,7 +14,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_prepared_run
|
dataset_prepared_path: last_prepared_run
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
|
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
@@ -36,7 +35,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
output_dir: btlm-out
|
output_dir: btlm-out
|
||||||
@@ -73,8 +72,8 @@ gptq_groupsize:
|
|||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
|
|
||||||
warmup_steps: 32
|
warmup_steps: 32
|
||||||
eval_steps:
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
save_total_limit:
|
save_total_limit:
|
||||||
|
|
||||||
debug:
|
debug:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: cerebras/Cerebras-GPT-1.3B
|
base_model: cerebras/Cerebras-GPT-1.3B
|
||||||
base_model_config: cerebras/Cerebras-GPT-1.3B
|
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
strict: false
|
strict: false
|
||||||
@@ -8,7 +7,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
@@ -25,7 +24,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
@@ -50,8 +49,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: codellama/CodeLlama-13b-hf
|
base_model: codellama/CodeLlama-13b-hf
|
||||||
base_model_config: codellama/CodeLlama-13b-hf
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: CodeLlamaTokenizer
|
tokenizer_type: CodeLlamaTokenizer
|
||||||
is_llama_derived_model: true
|
is_llama_derived_model: true
|
||||||
@@ -12,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
@@ -30,12 +29,12 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -55,8 +54,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: codellama/CodeLlama-13b-hf
|
base_model: codellama/CodeLlama-13b-hf
|
||||||
base_model_config: codellama/CodeLlama-13b-hf
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: CodeLlamaTokenizer
|
tokenizer_type: CodeLlamaTokenizer
|
||||||
is_llama_derived_model: true
|
is_llama_derived_model: true
|
||||||
@@ -12,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -32,12 +31,12 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: paged_adamw_32bit
|
optimizer: paged_adamw_32bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -57,8 +56,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: codellama/CodeLlama-34b-hf
|
base_model: codellama/CodeLlama-34b-hf
|
||||||
base_model_config: codellama/CodeLlama-34b-hf
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: CodeLlamaTokenizer
|
tokenizer_type: CodeLlamaTokenizer
|
||||||
is_llama_derived_model: true
|
is_llama_derived_model: true
|
||||||
@@ -12,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
@@ -30,12 +29,12 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -55,8 +54,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: codellama/CodeLlama-34b-hf
|
base_model: codellama/CodeLlama-34b-hf
|
||||||
base_model_config: codellama/CodeLlama-34b-hf
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: CodeLlamaTokenizer
|
tokenizer_type: CodeLlamaTokenizer
|
||||||
is_llama_derived_model: true
|
is_llama_derived_model: true
|
||||||
@@ -12,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -32,12 +31,12 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: paged_adamw_32bit
|
optimizer: paged_adamw_32bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -57,8 +56,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: codellama/CodeLlama-7b-hf
|
base_model: codellama/CodeLlama-7b-hf
|
||||||
base_model_config: codellama/CodeLlama-7b-hf
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: CodeLlamaTokenizer
|
tokenizer_type: CodeLlamaTokenizer
|
||||||
is_llama_derived_model: true
|
is_llama_derived_model: true
|
||||||
@@ -12,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
@@ -30,12 +29,12 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -55,8 +54,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: codellama/CodeLlama-7b-hf
|
base_model: codellama/CodeLlama-7b-hf
|
||||||
base_model_config: codellama/CodeLlama-7b-hf
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: CodeLlamaTokenizer
|
tokenizer_type: CodeLlamaTokenizer
|
||||||
is_llama_derived_model: true
|
is_llama_derived_model: true
|
||||||
@@ -12,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -32,12 +31,12 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: paged_adamw_32bit
|
optimizer: paged_adamw_32bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -57,8 +56,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: tiiuae/falcon-7b
|
base_model: tiiuae/falcon-7b
|
||||||
base_model_config: tiiuae/falcon-7b
|
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
model_type: AutoModelForCausalLM
|
model_type: AutoModelForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
@@ -13,7 +12,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca:chat
|
type: alpaca:chat
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
@@ -27,7 +26,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
@@ -52,8 +51,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 40
|
warmup_steps: 40
|
||||||
eval_steps: 5
|
evals_per_epoch: 4
|
||||||
save_steps: 43
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
# 1b: tiiuae/falcon-rw-1b
|
# 1b: tiiuae/falcon-rw-1b
|
||||||
# 40b: tiiuae/falcon-40b
|
# 40b: tiiuae/falcon-40b
|
||||||
base_model: tiiuae/falcon-7b
|
base_model: tiiuae/falcon-7b
|
||||||
base_model_config: tiiuae/falcon-7b
|
|
||||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
model_type: AutoModelForCausalLM
|
model_type: AutoModelForCausalLM
|
||||||
@@ -19,7 +18,7 @@ datasets:
|
|||||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||||
type: "alpaca:chat"
|
type: "alpaca:chat"
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
# enable QLoRA
|
# enable QLoRA
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
@@ -41,7 +40,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
@@ -54,7 +53,7 @@ output_dir: ./qlora-out
|
|||||||
# decrease if OOM, increase for max VRAM utilization
|
# decrease if OOM, increase for max VRAM utilization
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
# Optimizer for QLoRA
|
# Optimizer for QLoRA
|
||||||
optimizer: paged_adamw_32bit
|
optimizer: paged_adamw_32bit
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
@@ -81,8 +80,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 5
|
evals_per_epoch: 4
|
||||||
save_steps: 10
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.000001
|
weight_decay: 0.000001
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: tiiuae/falcon-7b
|
base_model: tiiuae/falcon-7b
|
||||||
base_model_config: tiiuae/falcon-7b
|
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
model_type: AutoModelForCausalLM
|
model_type: AutoModelForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
@@ -13,7 +12,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca:chat
|
type: alpaca:chat
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
@@ -27,7 +26,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
@@ -52,8 +51,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 40
|
warmup_steps: 40
|
||||||
eval_steps: 5
|
evals_per_epoch: 4
|
||||||
save_steps: 43
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: EleutherAI/gpt-j-6b
|
base_model: EleutherAI/gpt-j-6b
|
||||||
base_model_config: EleutherAI/gpt-j-6b
|
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
strict: false
|
strict: false
|
||||||
@@ -8,7 +7,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
@@ -22,7 +21,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
@@ -47,8 +46,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: huggyllama/llama-7b
|
base_model: huggyllama/llama-7b
|
||||||
base_model_config: huggyllama/llama-7b
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
@@ -20,12 +19,12 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./jeopardy-bot-7b
|
output_dir: ./jeopardy-bot-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
@@ -43,8 +42,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 110
|
evals_per_epoch: 4
|
||||||
save_steps: 660
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -9,12 +9,16 @@ gradient_accumulation_steps: 2
|
|||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
|
accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
|
||||||
|
|
||||||
```
|
```
|
||||||
or
|
or
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
|
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
To launch a full finetuning with 16-bit precision:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
|
||||||
```
|
```
|
||||||
|
|||||||
72
examples/llama-2/fft_optimized.yml
Normal file
72
examples/llama-2/fft_optimized.yml
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
base_model: NousResearch/Llama-2-7b-hf
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_llama_derived_model: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter:
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r:
|
||||||
|
lora_alpha:
|
||||||
|
lora_dropout:
|
||||||
|
lora_target_linear:
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
flash_attn_cross_entropy: false
|
||||||
|
flash_attn_rms_norm: true
|
||||||
|
flash_attn_fuse_qkv: false
|
||||||
|
flash_attn_fuse_mlp: true
|
||||||
|
|
||||||
|
warmup_steps: 100
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
||||||
|
weight_decay: 0.1
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: TheBloke/Llama-2-7B-GPTQ
|
base_model: TheBloke/Llama-2-7B-GPTQ
|
||||||
base_model_config: TheBloke/Llama-2-7B-GPTQ
|
|
||||||
is_llama_derived_model: false
|
is_llama_derived_model: false
|
||||||
gptq: true
|
gptq: true
|
||||||
gptq_disable_exllama: true
|
gptq_disable_exllama: true
|
||||||
@@ -16,7 +15,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
@@ -33,12 +32,12 @@ lora_target_linear:
|
|||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./model-out
|
output_dir: ./model-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: adamw_torch
|
optimizer: adamw_torch
|
||||||
adam_beta2: 0.95
|
adam_beta2: 0.95
|
||||||
adam_eps: 0.00001
|
adam_eps: 0.00001
|
||||||
@@ -63,8 +62,8 @@ flash_attention:
|
|||||||
sdp_attention:
|
sdp_attention:
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
eval_steps:
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: NousResearch/Llama-2-7b-hf
|
base_model: NousResearch/Llama-2-7b-hf
|
||||||
base_model_config: NousResearch/Llama-2-7b-hf
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
is_llama_derived_model: true
|
is_llama_derived_model: true
|
||||||
@@ -12,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
@@ -30,12 +29,12 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -55,10 +54,10 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: NousResearch/Llama-2-7b-hf
|
base_model: NousResearch/Llama-2-7b-hf
|
||||||
base_model_config: NousResearch/Llama-2-7b-hf
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
is_llama_derived_model: true
|
is_llama_derived_model: true
|
||||||
@@ -12,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -32,12 +31,12 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: paged_adamw_32bit
|
optimizer: paged_adamw_32bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -57,9 +56,9 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: NousResearch/Llama-2-7b-hf
|
base_model: NousResearch/Llama-2-7b-hf
|
||||||
base_model_config: NousResearch/Llama-2-7b-hf
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
is_llama_derived_model: true
|
is_llama_derived_model: true
|
||||||
@@ -12,7 +11,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./relora-out
|
output_dir: ./relora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -36,12 +35,12 @@ relora_cpu_offload: false
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 4
|
micro_batch_size: 4
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -61,8 +60,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
save_steps: 50
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: PY007/TinyLlama-1.1B-step-50K-105b
|
base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
|
||||||
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
|
|
||||||
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
@@ -13,7 +12,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
@@ -30,12 +29,12 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
@@ -55,9 +54,9 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
61
examples/mamba/config.yml
Normal file
61
examples/mamba/config.yml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
base_model: state-spaces/mamba-2.8b
|
||||||
|
model_type: MambaLMHeadModel
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
tokenizer_config: EleutherAI/gpt-neox-20b
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 2
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 5e-5
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: true
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
tokens:
|
||||||
|
save_safetensors: False
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: mistralai/Mistral-7B-v0.1
|
base_model: mistralai/Mistral-7B-v0.1
|
||||||
base_model_config: mistralai/Mistral-7B-v0.1
|
|
||||||
model_type: MistralForCausalLM
|
model_type: MistralForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
is_mistral_derived_model: true
|
is_mistral_derived_model: true
|
||||||
@@ -12,22 +11,23 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.000005
|
learning_rate: 0.000005
|
||||||
@@ -47,10 +47,10 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
eval_table_size: 5
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
91
examples/mistral/mixtral.yml
Normal file
91
examples/mistral/mixtral.yml
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||||
|
unfrozen_parameters:
|
||||||
|
# - lm_head.*
|
||||||
|
# - model.embed_tokens.*
|
||||||
|
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
|
||||||
|
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
|
||||||
|
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
||||||
|
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
#lora_target_modules:
|
||||||
|
# - gate
|
||||||
|
# - q_proj
|
||||||
|
# - k_proj
|
||||||
|
# - v_proj
|
||||||
|
# - o_proj
|
||||||
|
# - w1
|
||||||
|
# - w2
|
||||||
|
# - w3
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed: deepspeed/zero2.json
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: mistralai/Mistral-7B-v0.1
|
base_model: mistralai/Mistral-7B-v0.1
|
||||||
base_model_config: mistralai/Mistral-7B-v0.1
|
|
||||||
model_type: MistralForCausalLM
|
model_type: MistralForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
is_mistral_derived_model: true
|
is_mistral_derived_model: true
|
||||||
@@ -12,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.01
|
val_set_size: 0.1
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -39,7 +38,7 @@ lora_target_modules:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -63,11 +62,14 @@ logging_steps: 1
|
|||||||
xformers_attention:
|
xformers_attention:
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
eval_table_size: 5
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: mosaicml/mpt-7b
|
base_model: mosaicml/mpt-7b
|
||||||
base_model_config: mosaicml/mpt-7b
|
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
|
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
@@ -22,12 +21,12 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project: mpt-alpaca-7b
|
wandb_project: mpt-alpaca-7b
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./mpt-alpaca-7b
|
output_dir: ./mpt-alpaca-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
@@ -45,8 +44,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 110
|
evals_per_epoch: 4
|
||||||
save_steps: 660
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0001
|
weight_decay: 0.0001
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: openlm-research/open_llama_3b_v2
|
base_model: openlm-research/open_llama_3b_v2
|
||||||
base_model_config: openlm-research/open_llama_3b_v2
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
@@ -24,7 +23,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./openllama-out
|
output_dir: ./openllama-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -50,8 +49,8 @@ flash_attention: true
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: openlm-research/open_llama_3b_v2
|
base_model: openlm-research/open_llama_3b_v2
|
||||||
base_model_config: openlm-research/open_llama_3b_v2
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
load_in_8bit: true
|
load_in_8bit: true
|
||||||
@@ -30,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -55,8 +54,8 @@ flash_attention: true
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: openlm-research/open_llama_3b_v2
|
base_model: openlm-research/open_llama_3b_v2
|
||||||
base_model_config: openlm-research/open_llama_3b_v2
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
@@ -10,7 +9,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 1024
|
sequence_len: 1024
|
||||||
@@ -24,7 +23,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -49,8 +48,8 @@ flash_attention: true
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
base_model: microsoft/phi-1_5
|
base_model: microsoft/phi-1_5
|
||||||
base_model_config: microsoft/phi-1_5
|
model_type: PhiForCausalLM
|
||||||
model_type: MixFormerSequentialForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
is_llama_derived_model: false
|
is_llama_derived_model: false
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
@@ -32,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -60,8 +59,8 @@ xformers_attention:
|
|||||||
flash_attention:
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: microsoft/phi-1_5
|
base_model: microsoft/phi-1_5
|
||||||
base_model_config: microsoft/phi-1_5
|
|
||||||
model_type: AutoModelForCausalLM
|
model_type: AutoModelForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
is_llama_derived_model: false
|
is_llama_derived_model: false
|
||||||
@@ -32,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -60,8 +59,8 @@ xformers_attention:
|
|||||||
flash_attention:
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: EleutherAI/pythia-12b-deduped
|
base_model: EleutherAI/pythia-12b-deduped
|
||||||
base_model_config: EleutherAI/pythia-12b-deduped
|
|
||||||
base_model_ignore_patterns: pytorch* # prefer safetensors
|
base_model_ignore_patterns: pytorch* # prefer safetensors
|
||||||
model_type: GPTNeoXForCausalLM
|
model_type: GPTNeoXForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
@@ -25,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./pythia-12b
|
output_dir: ./pythia-12b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: EleutherAI/pythia-1.4b-deduped
|
base_model: EleutherAI/pythia-1.4b-deduped
|
||||||
base_model_config: EleutherAI/pythia-1.4b-deduped
|
|
||||||
load_in_8bit: true
|
load_in_8bit: true
|
||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
@@ -19,12 +18,12 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-alpaca-pythia
|
output_dir: ./lora-alpaca-pythia
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 4
|
micro_batch_size: 4
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
learning_rate: 0.00001
|
learning_rate: 0.00001
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
@@ -34,5 +33,5 @@ early_stopping_patience:
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
eval_steps: 20
|
evals_per_epoch: 4
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
|
|||||||
68
examples/qwen/lora.yml
Normal file
68
examples/qwen/lora.yml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
base_model: Qwen/Qwen-7B
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
is_qwen_derived_model: true
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./lora-out
|
||||||
|
|
||||||
|
sequence_len: 2048 # supports up to 8192
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len:
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
68
examples/qwen/qlora.yml
Normal file
68
examples/qwen/qlora.yml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
base_model: Qwen/Qwen-7B
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
is_qwen_derived_model: true
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./lora-out
|
||||||
|
|
||||||
|
sequence_len: 2048 # supports up to 8192
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len:
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||||
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
|
||||||
model_type: GPTNeoXForCausalLM
|
model_type: GPTNeoXForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
trust_remote_code:
|
trust_remote_code:
|
||||||
@@ -23,12 +22,12 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project: redpajama-alpaca-3b
|
wandb_project: redpajama-alpaca-3b
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./redpajama-alpaca-3b
|
output_dir: ./redpajama-alpaca-3b
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
@@ -46,8 +45,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 110
|
evals_per_epoch: 4
|
||||||
save_steps: 660
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0001
|
weight_decay: 0.0001
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: replit/replit-code-v1-3b
|
base_model: replit/replit-code-v1-3b
|
||||||
base_model_config: replit/replit-code-v1-3b
|
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
datasets:
|
datasets:
|
||||||
@@ -22,12 +21,12 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project: lora-replit
|
wandb_project: lora-replit
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-replit
|
output_dir: ./lora-replit
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
optimizer:
|
optimizer:
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
lr_scheduler:
|
lr_scheduler:
|
||||||
@@ -46,8 +45,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 50
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora
|
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora
|
||||||
# on Tim Dettmer's Guanaco dataset.
|
# on Tim Dettmer's Guanaco dataset.
|
||||||
base_model: Salesforce/xgen-7b-8k-base
|
base_model: Salesforce/xgen-7b-8k-base
|
||||||
base_model_config: Salesforce/xgen-7b-8k-base
|
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
model_type: AutoModelForCausalLM
|
model_type: AutoModelForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
@@ -17,7 +16,7 @@ datasets:
|
|||||||
- openassistant_best_replies_train.jsonl
|
- openassistant_best_replies_train.jsonl
|
||||||
type: "completion"
|
type: "completion"
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.05
|
||||||
# enable QLoRA
|
# enable QLoRA
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
@@ -39,7 +38,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
@@ -52,7 +51,7 @@ output_dir: ./qlora-out
|
|||||||
# decrease if OOM, increase for max VRAM utilization
|
# decrease if OOM, increase for max VRAM utilization
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
num_epochs: 3
|
num_epochs: 4
|
||||||
# Optimizer for QLoRA
|
# Optimizer for QLoRA
|
||||||
optimizer: paged_adamw_32bit
|
optimizer: paged_adamw_32bit
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
@@ -79,8 +78,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 50
|
evals_per_epoch: 4
|
||||||
save_steps: 50
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
64
examples/yayi2-30b/fft.yml
Normal file
64
examples/yayi2-30b/fft.yml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
base_model: models/yayi2-30b
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
is_mistral_derived_model: false
|
||||||
|
trust_remote_code: true
|
||||||
|
model_revision: refs/pr/5
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.000005
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed: deepspeed/zero3_cpu.json
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
76
examples/yayi2-30b/qlora.yml
Normal file
76
examples/yayi2-30b/qlora.yml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
base_model: wenge-research/yayi2-30b
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
is_mistral_derived_model: false
|
||||||
|
trust_remote_code: true
|
||||||
|
model_revision: refs/pr/5
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048 # Fits in 40gb VRAM. Can easily do 4096 in A100 80 or a A6000
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_target_modules:
|
||||||
|
|
||||||
|
wandb_project: yayi2
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0005
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: false
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
5
examples/yi-34B-chat/README.md
Normal file
5
examples/yi-34B-chat/README.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Overview
|
||||||
|
|
||||||
|
This is an example of a Yi-34B-Chat configuration. It demonstrates that it is possible to finetune a 34B model on a GPU with 24GB of VRAM.
|
||||||
|
|
||||||
|
Tested on an RTX 4090 with `python -m axolotl.cli.train examples/mistral/qlora.yml`, a single epoch of finetuning on the alpaca dataset using qlora runs in 47 mins, using 97% of available memory.
|
||||||
76
examples/yi-34B-chat/qlora.yml
Normal file
76
examples/yi-34B-chat/qlora.yml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
base_model: 01-ai/Yi-34B-Chat
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_mistral_derived_model: false
|
||||||
|
is_llama_derived_model: true
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
sequence_len: 1024
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
flash_attention: true
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<|startoftext|>"
|
||||||
|
eos_token: "<|endoftext|>"
|
||||||
|
unk_token: "<unk>"
|
||||||
|
|
||||||
|
# Data
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
warmup_steps: 10
|
||||||
|
|
||||||
|
# Iterations
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
val_set_size: 0.1
|
||||||
|
evals_per_epoch: 5
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
eval_sample_packing: false
|
||||||
|
eval_batch_size: 1
|
||||||
|
|
||||||
|
# LoRA
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_target_modules:
|
||||||
|
|
||||||
|
# Sampling
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
# Batching
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
gradient_checkpointing: true
|
||||||
|
|
||||||
|
# wandb
|
||||||
|
wandb_project:
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
# Misc
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
BIN
image/sticker_fixed.png
Normal file
BIN
image/sticker_fixed.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 370 KiB |
@@ -1,23 +1,22 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
|
||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
torch==2.0.1
|
auto-gptq==0.5.1
|
||||||
auto-gptq
|
|
||||||
packaging
|
packaging
|
||||||
peft @ git+https://github.com/huggingface/peft.git
|
peft==0.6.0
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@bd6205919aad4d3a2300a39a98a642f1cc3a5348
|
transformers==4.36.2
|
||||||
|
tokenizers==0.15.0
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
accelerate==0.24.1
|
||||||
deepspeed
|
deepspeed
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
datasets
|
datasets>=2.15.0
|
||||||
flash-attn>=2.3.0
|
flash-attn==2.3.3
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers>=0.0.22
|
xformers==0.0.22
|
||||||
optimum
|
optimum==1.13.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
@@ -30,4 +29,11 @@ scipy
|
|||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat==0.2.29
|
fschat==0.2.34
|
||||||
|
gradio==3.50.2
|
||||||
|
tensorboard
|
||||||
|
|
||||||
|
# remote filesystems
|
||||||
|
s3fs
|
||||||
|
gcsfs
|
||||||
|
# adlfs
|
||||||
|
|||||||
@@ -45,8 +45,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
if parsed_cli_args.prepare_ds_only:
|
|
||||||
return
|
|
||||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
20
setup.py
20
setup.py
@@ -1,5 +1,7 @@
|
|||||||
"""setup.py for axolotl"""
|
"""setup.py for axolotl"""
|
||||||
|
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
@@ -22,12 +24,13 @@ def parse_requirements():
|
|||||||
# Handle standard packages
|
# Handle standard packages
|
||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
# TODO(wing) remove once xformers release supports torch 2.1.0
|
try:
|
||||||
if "torch==2.1.0" in _install_requires:
|
torch_version = version("torch")
|
||||||
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
|
if torch_version.startswith("2.1.1"):
|
||||||
_install_requires.append(
|
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||||
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
|
_install_requires.append("xformers==0.0.23")
|
||||||
)
|
except PackageNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
return _install_requires, _dependency_links
|
return _install_requires, _dependency_links
|
||||||
|
|
||||||
@@ -46,10 +49,13 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn>=2.2.1",
|
"flash-attn==2.3.3",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed",
|
"deepspeed",
|
||||||
],
|
],
|
||||||
|
"mamba-ssm": [
|
||||||
|
"mamba-ssm==1.0.1",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ import os
|
|||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from threading import Thread
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@@ -16,7 +18,7 @@ from accelerate.commands.config import config_args
|
|||||||
from art import text2art
|
from art import text2art
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
@@ -27,6 +29,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
|
from axolotl.utils.trainer import prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -44,7 +47,7 @@ def print_axolotl_text_art(suffix=None):
|
|||||||
ascii_text = " axolotl"
|
ascii_text = " axolotl"
|
||||||
if suffix:
|
if suffix:
|
||||||
ascii_text += f" x {suffix}"
|
ascii_text += f" x {suffix}"
|
||||||
ascii_art = text2art(" axolotl", font=font)
|
ascii_art = text2art(ascii_text, font=font)
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
print(ascii_art)
|
print(ascii_art)
|
||||||
@@ -69,7 +72,7 @@ def do_merge_lora(
|
|||||||
|
|
||||||
LOG.info("running merge of LoRA with base model")
|
LOG.info("running merge of LoRA with base model")
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
model.to(dtype=torch.float16)
|
model.to(dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||||
@@ -100,15 +103,7 @@ def do_inference(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.landmark_attention:
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
|
||||||
|
|
||||||
set_model_mem_id(model, tokenizer)
|
|
||||||
model.set_mem_cache_args(
|
|
||||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model.to(cfg.device)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -153,6 +148,83 @@ def do_inference(
|
|||||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def do_inference_gradio(
|
||||||
|
*,
|
||||||
|
cfg: DictDefault,
|
||||||
|
cli_args: TrainerCliArgs,
|
||||||
|
):
|
||||||
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
|
prompter = cli_args.prompter
|
||||||
|
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||||
|
|
||||||
|
for token, symbol in default_tokens.items():
|
||||||
|
# If the token isn't already specified in the config, add it
|
||||||
|
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||||
|
tokenizer.add_special_tokens({token: symbol})
|
||||||
|
|
||||||
|
prompter_module = None
|
||||||
|
if prompter:
|
||||||
|
prompter_module = getattr(
|
||||||
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
|
def generate(instruction):
|
||||||
|
if not instruction:
|
||||||
|
return
|
||||||
|
if prompter_module:
|
||||||
|
# pylint: disable=stop-iteration-return
|
||||||
|
prompt: str = next(
|
||||||
|
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = instruction.strip()
|
||||||
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
repetition_penalty=1.1,
|
||||||
|
max_new_tokens=1024,
|
||||||
|
temperature=0.9,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=40,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
do_sample=True,
|
||||||
|
use_cache=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
output_scores=False,
|
||||||
|
)
|
||||||
|
streamer = TextIteratorStreamer(tokenizer)
|
||||||
|
generation_kwargs = {
|
||||||
|
"inputs": batch["input_ids"].to(cfg.device),
|
||||||
|
"generation_config": generation_config,
|
||||||
|
"streamer": streamer,
|
||||||
|
}
|
||||||
|
|
||||||
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
all_text = ""
|
||||||
|
|
||||||
|
for new_text in streamer:
|
||||||
|
all_text += new_text
|
||||||
|
yield all_text
|
||||||
|
|
||||||
|
demo = gr.Interface(
|
||||||
|
fn=generate,
|
||||||
|
inputs="textbox",
|
||||||
|
outputs="text",
|
||||||
|
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||||
|
)
|
||||||
|
demo.queue().launch(show_api=False, share=True)
|
||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
def choose_config(path: Path):
|
||||||
yaml_files = list(path.glob("*.yml"))
|
yaml_files = list(path.glob("*.yml"))
|
||||||
|
|
||||||
@@ -209,6 +281,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
prepare_optim_env(cfg)
|
||||||
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
@@ -222,7 +296,9 @@ def load_datasets(
|
|||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||||
|
cfg, tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
if cli_args.debug or cfg.debug:
|
if cli_args.debug or cfg.debug:
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
@@ -238,6 +314,10 @@ def load_datasets(
|
|||||||
text_only=cli_args.debug_text_only,
|
text_only=cli_args.debug_text_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG.info("printing prompters...")
|
||||||
|
for prompter in prompters:
|
||||||
|
LOG.info(prompter)
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
return TrainDatasetMeta(
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
|
|||||||
@@ -6,11 +6,16 @@ from pathlib import Path
|
|||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
|
from axolotl.cli import (
|
||||||
|
do_inference,
|
||||||
|
do_inference_gradio,
|
||||||
|
load_cfg,
|
||||||
|
print_axolotl_text_art,
|
||||||
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
@@ -21,7 +26,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
)
|
)
|
||||||
parsed_cli_args.inference = True
|
parsed_cli_args.inference = True
|
||||||
|
|
||||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
if gradio:
|
||||||
|
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
else:
|
||||||
|
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -18,7 +18,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
parsed_cli_args.merge_lora = True
|
parsed_cli_args.merge_lora = True
|
||||||
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
|
|
||||||
|
parsed_cfg = load_cfg(
|
||||||
|
config,
|
||||||
|
merge_lora=True,
|
||||||
|
load_in_8bit=False,
|
||||||
|
load_in_4bit=False,
|
||||||
|
flash_attention=False,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|||||||
53
src/axolotl/cli/preprocess.py
Normal file
53
src/axolotl/cli/preprocess.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
CLI to run training on a model
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import transformers
|
||||||
|
from colorama import Fore
|
||||||
|
|
||||||
|
from axolotl.cli import (
|
||||||
|
check_accelerate_default_config,
|
||||||
|
check_user_token,
|
||||||
|
load_cfg,
|
||||||
|
load_datasets,
|
||||||
|
print_axolotl_text_art,
|
||||||
|
)
|
||||||
|
from axolotl.common.cli import PreprocessCliArgs
|
||||||
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
print_axolotl_text_art()
|
||||||
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
check_accelerate_default_config()
|
||||||
|
check_user_token()
|
||||||
|
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
||||||
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
return_remaining_strings=True
|
||||||
|
)
|
||||||
|
if not parsed_cfg.dataset_prepared_path:
|
||||||
|
msg = (
|
||||||
|
Fore.RED
|
||||||
|
+ "preprocess CLI called without dataset_prepared_path set, "
|
||||||
|
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
|
LOG.warning(msg)
|
||||||
|
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
|
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
LOG.info(
|
||||||
|
Fore.GREEN
|
||||||
|
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(do_cli)
|
||||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
from colorama import Fore
|
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
@@ -16,7 +15,6 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.train")
|
LOG = logging.getLogger("axolotl.cli.train")
|
||||||
@@ -24,26 +22,15 @@ LOG = logging.getLogger("axolotl.cli.train")
|
|||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
|
|
||||||
msg = (
|
|
||||||
Fore.RED
|
|
||||||
+ "--prepare_ds_only called without dataset_prepared_path set."
|
|
||||||
+ Fore.RESET
|
|
||||||
)
|
|
||||||
LOG.warning(msg)
|
|
||||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
|
||||||
|
|
||||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
if parsed_cli_args.prepare_ds_only:
|
|
||||||
return
|
|
||||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,11 +25,22 @@ class TrainerCliArgs:
|
|||||||
debug_num_examples: int = field(default=5)
|
debug_num_examples: int = field(default=5)
|
||||||
inference: bool = field(default=False)
|
inference: bool = field(default=False)
|
||||||
merge_lora: bool = field(default=False)
|
merge_lora: bool = field(default=False)
|
||||||
prepare_ds_only: bool = field(default=False)
|
|
||||||
prompter: Optional[str] = field(default=None)
|
prompter: Optional[str] = field(default=None)
|
||||||
shard: bool = field(default=False)
|
shard: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PreprocessCliArgs:
|
||||||
|
"""
|
||||||
|
dataclass representing arguments for preprocessing only
|
||||||
|
"""
|
||||||
|
|
||||||
|
debug: bool = field(default=False)
|
||||||
|
debug_text_only: bool = field(default=False)
|
||||||
|
debug_num_examples: int = field(default=1)
|
||||||
|
prompter: Optional[str] = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
|
|||||||
0
src/axolotl/core/__init__.py
Normal file
0
src/axolotl/core/__init__.py
Normal file
821
src/axolotl/core/trainer_builder.py
Normal file
821
src/axolotl/core/trainer_builder.py
Normal file
@@ -0,0 +1,821 @@
|
|||||||
|
"""
|
||||||
|
Builder for the training args and trainer
|
||||||
|
"""
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
from abc import abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import wraps
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from datasets import Dataset
|
||||||
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
|
from transformers.trainer_utils import seed_worker
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
|
from axolotl.utils.callbacks import (
|
||||||
|
EvalFirstStepCallback,
|
||||||
|
GPUStatsCallback,
|
||||||
|
LossWatchDogCallback,
|
||||||
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
|
SaveBetterTransformerModelCallback,
|
||||||
|
bench_eval_callback_factory,
|
||||||
|
log_prediction_callback_factory,
|
||||||
|
)
|
||||||
|
from axolotl.utils.collators import (
|
||||||
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
MambaDataCollator,
|
||||||
|
)
|
||||||
|
from axolotl.utils.samplers import MultipackBatchSampler
|
||||||
|
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlTrainingArguments(TrainingArguments):
|
||||||
|
"""
|
||||||
|
Extend the base TrainingArguments for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "HF model configuration model_type."}
|
||||||
|
)
|
||||||
|
lr_quadratic_warmup: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||||
|
)
|
||||||
|
sample_packing: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use sample packing for efficient training."},
|
||||||
|
)
|
||||||
|
eval_sample_packing: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Use sample packing for efficient evals."},
|
||||||
|
)
|
||||||
|
sample_packing_efficiency: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||||
|
)
|
||||||
|
max_seq_length: int = field(
|
||||||
|
default=2048,
|
||||||
|
metadata={"help": "The maximum sequence length the model can handle"},
|
||||||
|
)
|
||||||
|
sample_packing_seq_len_multiplier: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||||
|
)
|
||||||
|
relora_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how often to reset for ReLoRA"},
|
||||||
|
)
|
||||||
|
relora_warmup_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
|
)
|
||||||
|
bench_split: Optional[str] = field(
|
||||||
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
|
)
|
||||||
|
bench_dataset: Optional[str] = field(
|
||||||
|
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||||
|
metadata={
|
||||||
|
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_bench_eval: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||||
|
)
|
||||||
|
max_bench_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
bench_source_max_len: int = field(
|
||||||
|
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||||
|
)
|
||||||
|
dataloader_prefetch_factor: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlTrainer(Trainer):
|
||||||
|
"""
|
||||||
|
Extend the base Trainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: AxolotlTrainingArguments
|
||||||
|
tag_names = ["axolotl"]
|
||||||
|
|
||||||
|
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
||||||
|
self.num_epochs = num_epochs
|
||||||
|
self.bench_data_collator = bench_data_collator
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||||
|
passed as an argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_training_steps (int): The number of training steps to do.
|
||||||
|
optimizer (torch.optim.Optimizer): The training optimizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||||
|
# fmt: on
|
||||||
|
if (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.lr_quadratic_warmup is True
|
||||||
|
):
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
|
if self.args.sample_packing:
|
||||||
|
return MultipackBatchSampler(
|
||||||
|
RandomSampler(self.train_dataset),
|
||||||
|
self.args.train_batch_size,
|
||||||
|
drop_last=True,
|
||||||
|
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
||||||
|
lengths=(
|
||||||
|
self.train_dataset.data.column("position_ids")
|
||||||
|
.to_pandas()
|
||||||
|
.apply(lambda x: x[-1] + 1)
|
||||||
|
.values
|
||||||
|
),
|
||||||
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
|
)
|
||||||
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
|
def _get_eval_sampler(
|
||||||
|
self, eval_dataset: Dataset
|
||||||
|
) -> Optional[torch.utils.data.Sampler]:
|
||||||
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
|
return MultipackBatchSampler(
|
||||||
|
SequentialSampler(eval_dataset),
|
||||||
|
self.args.per_device_eval_batch_size,
|
||||||
|
drop_last=True,
|
||||||
|
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
||||||
|
lengths=(
|
||||||
|
eval_dataset.data.column("position_ids")
|
||||||
|
.to_pandas()
|
||||||
|
.apply(lambda x: x[-1] + 1)
|
||||||
|
.values
|
||||||
|
),
|
||||||
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
|
)
|
||||||
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
|
if self.args.sample_packing:
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
train_dataset = train_dataset.remove_columns(["length"])
|
||||||
|
data_collator = self.data_collator
|
||||||
|
dataloader_params = {
|
||||||
|
"batch_size": self._train_batch_size,
|
||||||
|
"collate_fn": data_collator,
|
||||||
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
|
}
|
||||||
|
if self.args.dataloader_prefetch_factor:
|
||||||
|
dataloader_params[
|
||||||
|
"prefetch_factor"
|
||||||
|
] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
|
sampler = self._get_train_sampler()
|
||||||
|
if isinstance(sampler, BatchSampler):
|
||||||
|
dataloader_params["batch_sampler"] = sampler
|
||||||
|
del dataloader_params["batch_size"]
|
||||||
|
else:
|
||||||
|
dataloader_params["sampler"] = sampler
|
||||||
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
|
self.accelerator.even_batches = False
|
||||||
|
return self.accelerator.prepare_data_loader(
|
||||||
|
DataLoader(train_dataset, **dataloader_params)
|
||||||
|
)
|
||||||
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||||
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
|
eval_dataset = (
|
||||||
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
|
data_collator = self.data_collator
|
||||||
|
dataloader_params = {
|
||||||
|
"batch_size": self.args.eval_batch_size,
|
||||||
|
"collate_fn": data_collator,
|
||||||
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
|
}
|
||||||
|
if self.args.dataloader_prefetch_factor:
|
||||||
|
dataloader_params[
|
||||||
|
"prefetch_factor"
|
||||||
|
] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
|
if isinstance(eval_sampler, BatchSampler):
|
||||||
|
dataloader_params["batch_sampler"] = eval_sampler
|
||||||
|
del dataloader_params["batch_size"]
|
||||||
|
else:
|
||||||
|
dataloader_params["sampler"] = eval_sampler
|
||||||
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
|
self.accelerator.even_batches = False
|
||||||
|
return self.accelerator.prepare_data_loader(
|
||||||
|
DataLoader(eval_dataset, **dataloader_params)
|
||||||
|
)
|
||||||
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
|
def _get_bench_sampler(
|
||||||
|
self, bench_dataset: Dataset
|
||||||
|
) -> Optional[torch.utils.data.Sampler]:
|
||||||
|
if self.args.world_size <= 1:
|
||||||
|
return SequentialSampler(bench_dataset)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_bench_dataloader(
|
||||||
|
self,
|
||||||
|
bench_dataset: Dataset,
|
||||||
|
) -> DataLoader:
|
||||||
|
dataloader_params = {
|
||||||
|
"batch_size": self.args.eval_batch_size,
|
||||||
|
"collate_fn": self.bench_data_collator,
|
||||||
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
|
}
|
||||||
|
if self.args.dataloader_prefetch_factor:
|
||||||
|
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
|
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||||
|
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||||
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
|
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
|
# use one's weighted cross entropy loss calc
|
||||||
|
# if self.args.sample_packing:
|
||||||
|
# labels = inputs.pop("labels")
|
||||||
|
# outputs = model(**inputs)
|
||||||
|
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||||
|
# return (loss, outputs) if return_outputs else loss
|
||||||
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
|
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
|
||||||
|
if isinstance(tag_names, str):
|
||||||
|
tag_names = [tag_names]
|
||||||
|
|
||||||
|
if kwargs is not None:
|
||||||
|
if "tags" not in kwargs:
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||||
|
kwargs["tags"].extend(tag_names)
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||||
|
tag_names.append(kwargs["tags"])
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
@wraps(Trainer.push_to_hub)
|
||||||
|
def push_to_hub(self, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
|
"""
|
||||||
|
kwargs = self._sanitize_kwargs_for_tagging(
|
||||||
|
tag_names=self.tag_names, kwargs=kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
|
"""
|
||||||
|
Mamba specific trainer to handle loss calculation
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "mamba"]
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
inputs,
|
||||||
|
return_outputs=False, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
input_ids = inputs.pop("input_ids")
|
||||||
|
lm_logits = model(input_ids).logits
|
||||||
|
|
||||||
|
labels = input_ids.to(lm_logits.device)
|
||||||
|
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||||
|
labels = labels[:, 1:].contiguous()
|
||||||
|
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
lm_loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
return lm_loss
|
||||||
|
|
||||||
|
|
||||||
|
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||||
|
"""
|
||||||
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "onecycle"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lr_scheduler = None
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self,
|
||||||
|
num_training_steps: int,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
|
):
|
||||||
|
optimizer = self.optimizer if optimizer is None else optimizer
|
||||||
|
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||||
|
pct_start = num_warmup_steps / num_training_steps
|
||||||
|
|
||||||
|
self.lr_scheduler = OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
pct_start=pct_start,
|
||||||
|
div_factor=6,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRATrainer(AxolotlTrainer):
|
||||||
|
"""
|
||||||
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "relora"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lr_scheduler = None
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self,
|
||||||
|
num_training_steps: int,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
|
):
|
||||||
|
optimizer = self.optimizer if optimizer is None else optimizer
|
||||||
|
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
if self.args.relora_steps:
|
||||||
|
warmup_steps = (
|
||||||
|
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||||
|
)
|
||||||
|
self.lr_scheduler = ReLoRAScheduler(
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
self.args.relora_steps,
|
||||||
|
warmup_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerBuilderBase(abc.ABC):
|
||||||
|
"""
|
||||||
|
Base class for trainer builder
|
||||||
|
"""
|
||||||
|
|
||||||
|
_train_dataset = None
|
||||||
|
_eval_dataset = None
|
||||||
|
|
||||||
|
def __init__(self, cfg, model, tokenizer):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def train_dataset(self):
|
||||||
|
return self._train_dataset
|
||||||
|
|
||||||
|
@train_dataset.setter
|
||||||
|
def train_dataset(self, dataset):
|
||||||
|
self._train_dataset = dataset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eval_dataset(self):
|
||||||
|
return self._eval_dataset
|
||||||
|
|
||||||
|
@eval_dataset.setter
|
||||||
|
def eval_dataset(self, dataset):
|
||||||
|
self._eval_dataset = dataset
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build(self, total_num_steps):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_callbacks(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
|
"""
|
||||||
|
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||||
|
"""
|
||||||
|
Build the HuggingFace training args/trainer for Causal models
|
||||||
|
"""
|
||||||
|
|
||||||
|
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||||
|
# TODO
|
||||||
|
return training_arguments_kwargs
|
||||||
|
|
||||||
|
def hook_post_create_training_args(self, training_arguments):
|
||||||
|
# TODO
|
||||||
|
return training_arguments
|
||||||
|
|
||||||
|
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
|
||||||
|
# TODO
|
||||||
|
return trainer_kwargs, trainer_cls
|
||||||
|
|
||||||
|
def hook_post_create_trainer(self, trainer):
|
||||||
|
# TODO
|
||||||
|
return trainer
|
||||||
|
|
||||||
|
def get_callbacks(self):
|
||||||
|
callbacks = []
|
||||||
|
callbacks.append(GPUStatsCallback(self.cfg))
|
||||||
|
callbacks.append(EvalFirstStepCallback)
|
||||||
|
|
||||||
|
if self.cfg.relora_steps:
|
||||||
|
callbacks.append(ReLoRACallback(self.cfg))
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(self.model, "use_bettertransformer")
|
||||||
|
and self.model.use_bettertransformer is True
|
||||||
|
):
|
||||||
|
callbacks.append(SaveBetterTransformerModelCallback)
|
||||||
|
|
||||||
|
if self.cfg.use_wandb:
|
||||||
|
callbacks.append(
|
||||||
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.cfg.loss_watchdog_threshold is not None:
|
||||||
|
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||||
|
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
|
callbacks = []
|
||||||
|
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||||
|
LogPredictionCallback = log_prediction_callback_factory(
|
||||||
|
trainer, self.tokenizer
|
||||||
|
)
|
||||||
|
callbacks.append(LogPredictionCallback(self.cfg))
|
||||||
|
|
||||||
|
if self.cfg.do_bench_eval:
|
||||||
|
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||||
|
|
||||||
|
if self.cfg.early_stopping_patience:
|
||||||
|
early_stop_cb = EarlyStoppingCallback(
|
||||||
|
self.cfg.early_stopping_patience,
|
||||||
|
)
|
||||||
|
callbacks.append(early_stop_cb)
|
||||||
|
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
def _get_trainer_cls(self):
|
||||||
|
if self.cfg.lr_scheduler == "one_cycle" and (
|
||||||
|
self.cfg.fsdp or self.cfg.adapter == "qlora"
|
||||||
|
):
|
||||||
|
return OneCycleLRSchedulerTrainer
|
||||||
|
if self.cfg.relora_steps:
|
||||||
|
return ReLoRATrainer
|
||||||
|
if self.cfg.model_config_type == "mamba":
|
||||||
|
return AxolotlMambaTrainer
|
||||||
|
return AxolotlTrainer
|
||||||
|
|
||||||
|
def build(self, total_num_steps):
|
||||||
|
warmup_steps = None
|
||||||
|
if self.cfg.warmup_steps is not None:
|
||||||
|
warmup_steps = self.cfg.warmup_steps
|
||||||
|
elif self.cfg.warmup_ratio is not None:
|
||||||
|
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||||
|
else:
|
||||||
|
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||||
|
|
||||||
|
logging_steps = (
|
||||||
|
self.cfg.logging_steps
|
||||||
|
if self.cfg.logging_steps is not None
|
||||||
|
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
training_arguments_kwargs = {}
|
||||||
|
if self.cfg.bf16 == "full":
|
||||||
|
training_arguments_kwargs["bf16_full_eval"] = True
|
||||||
|
else:
|
||||||
|
training_arguments_kwargs["bf16"] = self.cfg.bf16
|
||||||
|
training_arguments_kwargs["fp16"] = (
|
||||||
|
self.cfg.fp16 and not self.cfg.bf16
|
||||||
|
) or False
|
||||||
|
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
||||||
|
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||||
|
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||||
|
|
||||||
|
if self.cfg.seed:
|
||||||
|
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||||
|
|
||||||
|
if self.cfg.gradient_checkpointing:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"gradient_checkpointing"
|
||||||
|
] = self.cfg.gradient_checkpointing
|
||||||
|
if self.cfg.fsdp:
|
||||||
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
|
if self.cfg.fsdp_config:
|
||||||
|
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
||||||
|
|
||||||
|
# deepspeed
|
||||||
|
if self.cfg.deepspeed:
|
||||||
|
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||||
|
|
||||||
|
if self.cfg.lr_quadratic_warmup is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"lr_quadratic_warmup"
|
||||||
|
] = self.cfg.lr_quadratic_warmup
|
||||||
|
|
||||||
|
if self.cfg.adam_beta1:
|
||||||
|
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||||
|
if self.cfg.adam_beta2:
|
||||||
|
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
||||||
|
if self.cfg.adam_epsilon:
|
||||||
|
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
||||||
|
if self.cfg.max_grad_norm:
|
||||||
|
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
||||||
|
|
||||||
|
if self.cfg.hub_model_id:
|
||||||
|
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||||
|
training_arguments_kwargs["push_to_hub"] = True
|
||||||
|
training_arguments_kwargs["hub_private_repo"] = True
|
||||||
|
|
||||||
|
if self.cfg.hub_strategy:
|
||||||
|
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||||
|
|
||||||
|
if self.cfg.save_safetensors is not None:
|
||||||
|
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
|
if self.cfg.sample_packing_eff_est:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"sample_packing_efficiency"
|
||||||
|
] = self.cfg.sample_packing_eff_est
|
||||||
|
|
||||||
|
if self.cfg.dataloader_pin_memory is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"dataloader_pin_memory"
|
||||||
|
] = self.cfg.dataloader_pin_memory
|
||||||
|
if self.cfg.dataloader_num_workers is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"dataloader_num_workers"
|
||||||
|
] = self.cfg.dataloader_num_workers
|
||||||
|
if self.cfg.dataloader_prefetch_factor is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"dataloader_prefetch_factor"
|
||||||
|
] = self.cfg.dataloader_prefetch_factor
|
||||||
|
|
||||||
|
if self.cfg.val_set_size == 0:
|
||||||
|
# no eval set, so don't eval
|
||||||
|
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||||
|
elif self.cfg.eval_steps:
|
||||||
|
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||||
|
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
|
elif self.cfg.evaluation_strategy:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"evaluation_strategy"
|
||||||
|
] = self.cfg.evaluation_strategy
|
||||||
|
else:
|
||||||
|
# we have an eval set, but no steps defined, default to use epoch
|
||||||
|
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||||
|
|
||||||
|
if self.cfg.save_steps:
|
||||||
|
training_arguments_kwargs["save_strategy"] = "steps"
|
||||||
|
training_arguments_kwargs["save_steps"] = self.cfg.save_steps
|
||||||
|
elif self.cfg.save_strategy:
|
||||||
|
training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||||
|
else:
|
||||||
|
# default to saving each epoch if not defined
|
||||||
|
training_arguments_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
|
if self.cfg.do_bench_eval:
|
||||||
|
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||||
|
if self.cfg.bench_dataset:
|
||||||
|
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
|
||||||
|
if self.cfg.metric_for_best_model:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"metric_for_best_model"
|
||||||
|
] = self.cfg.metric_for_best_model
|
||||||
|
if self.cfg.greater_is_better:
|
||||||
|
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
||||||
|
|
||||||
|
if self.cfg.torch_compile:
|
||||||
|
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
|
||||||
|
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
|
||||||
|
elif torch._dynamo: # pylint: disable=protected-access
|
||||||
|
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||||
|
True
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||||
|
if self.cfg.torch_compile_backend:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"torch_compile_backend"
|
||||||
|
] = self.cfg.torch_compile_backend
|
||||||
|
|
||||||
|
# DDP Config
|
||||||
|
if self.cfg.ddp_timeout:
|
||||||
|
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
|
||||||
|
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
||||||
|
if self.cfg.ddp_bucket_cap_mb:
|
||||||
|
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
||||||
|
if self.cfg.ddp_broadcast_buffers is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"ddp_broadcast_buffers"
|
||||||
|
] = self.cfg.ddp_broadcast_buffers
|
||||||
|
|
||||||
|
# these are all the "standard" kwargs that are def used
|
||||||
|
training_arguments_kwargs["max_steps"] = (
|
||||||
|
total_num_steps if self.cfg.max_steps else -1
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"per_device_train_batch_size"
|
||||||
|
] = self.cfg.micro_batch_size
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"per_device_eval_batch_size"
|
||||||
|
] = self.cfg.eval_batch_size
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"gradient_accumulation_steps"
|
||||||
|
] = self.cfg.gradient_accumulation_steps
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"eval_accumulation_steps"
|
||||||
|
] = self.cfg.gradient_accumulation_steps
|
||||||
|
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||||
|
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||||
|
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
||||||
|
training_arguments_kwargs["save_total_limit"] = (
|
||||||
|
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["load_best_model_at_end"] = (
|
||||||
|
(
|
||||||
|
self.cfg.load_best_model_at_end is not False
|
||||||
|
or self.cfg.early_stopping_patience
|
||||||
|
)
|
||||||
|
and self.cfg.val_set_size > 0
|
||||||
|
and self.cfg.save_steps
|
||||||
|
and self.cfg.eval_steps
|
||||||
|
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||||
|
) or False
|
||||||
|
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||||
|
False if self.cfg.ddp else None
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
|
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
||||||
|
training_arguments_kwargs["run_name"] = (
|
||||||
|
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["optim"] = (
|
||||||
|
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||||
|
self.cfg.lr_scheduler
|
||||||
|
if self.cfg.lr_scheduler
|
||||||
|
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||||
|
else "cosine"
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
||||||
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["weight_decay"] = (
|
||||||
|
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["sample_packing"] = (
|
||||||
|
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["eval_sample_packing"] = (
|
||||||
|
self.cfg.sample_packing
|
||||||
|
if self.cfg.eval_sample_packing is not False
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"sample_packing_seq_len_multiplier"
|
||||||
|
] = self.cfg.micro_batch_size
|
||||||
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
|
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||||
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
|
training_arguments_kwargs
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||||
|
|
||||||
|
if self.cfg.neftune_noise_alpha is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"neftune_noise_alpha"
|
||||||
|
] = self.cfg.neftune_noise_alpha
|
||||||
|
|
||||||
|
training_args = (
|
||||||
|
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
|
**training_arguments_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
training_args = self.hook_post_create_training_args(training_args)
|
||||||
|
trainer_kwargs = {}
|
||||||
|
|
||||||
|
if self.cfg.optimizer == "adamw_anyprecision":
|
||||||
|
if Path(self.cfg.torchdistx_path).exists():
|
||||||
|
sys.path.append(self.cfg.torchdistx_path)
|
||||||
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
|
data_collator_kwargs = {
|
||||||
|
"padding": True, # True/"longest" is the default
|
||||||
|
}
|
||||||
|
if self.cfg.pad_to_sequence_len:
|
||||||
|
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
|
||||||
|
self.cfg.sequence_len / 64
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||||
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
|
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||||
|
|
||||||
|
trainer_cls = self._get_trainer_cls()
|
||||||
|
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||||
|
trainer_kwargs, trainer_cls
|
||||||
|
)
|
||||||
|
trainer = trainer_cls(
|
||||||
|
model=self.model,
|
||||||
|
train_dataset=self.train_dataset,
|
||||||
|
eval_dataset=self.eval_dataset,
|
||||||
|
args=training_args,
|
||||||
|
data_collator=self.build_collator(**data_collator_kwargs),
|
||||||
|
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
|
self.tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
**data_collator_kwargs,
|
||||||
|
),
|
||||||
|
callbacks=self.get_callbacks(),
|
||||||
|
num_epochs=self.cfg.num_epochs,
|
||||||
|
**trainer_kwargs,
|
||||||
|
)
|
||||||
|
trainer = self.hook_post_create_trainer(trainer)
|
||||||
|
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||||
|
trainer.add_callback(callback)
|
||||||
|
|
||||||
|
if self.cfg.deepspeed and self.cfg.sample_packing:
|
||||||
|
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||||
|
"train_micro_batch_size_per_gpu"
|
||||||
|
] = self.cfg.micro_batch_size
|
||||||
|
|
||||||
|
return trainer
|
||||||
|
|
||||||
|
def build_collator(self, **kwargs):
|
||||||
|
if self.cfg.model_config_type == "mamba":
|
||||||
|
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||||
|
|
||||||
|
return BatchSamplerDataCollatorForSeq2Seq(
|
||||||
|
self.tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
self,
|
self,
|
||||||
prompt_tokenizer: PromptTokenizingStrategy,
|
prompt_tokenizer: PromptTokenizingStrategy,
|
||||||
dataset: IterableDataset,
|
dataset: IterableDataset,
|
||||||
|
process_count: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.prompt_tokenizer = prompt_tokenizer
|
self.prompt_tokenizer = prompt_tokenizer
|
||||||
|
self.process_count = process_count
|
||||||
super().__init__(self.process(dataset).data, **kwargs)
|
super().__init__(self.process(dataset).data, **kwargs)
|
||||||
|
|
||||||
def process(self, dataset):
|
def process(self, dataset):
|
||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
num_proc = min(64, os.cpu_count())
|
num_proc = (
|
||||||
|
min(64, self.process_count)
|
||||||
|
if self.process_count
|
||||||
|
else min(64, os.cpu_count())
|
||||||
|
)
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
|
|||||||
12
src/axolotl/models/mamba/__init__.py
Normal file
12
src/axolotl/models/mamba/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Modeling module for Mamba models
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def fix_mamba_attn_for_loss():
|
||||||
|
from mamba_ssm.models import mixer_seq_simple
|
||||||
|
|
||||||
|
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
||||||
|
|
||||||
|
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
|
||||||
|
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
|
||||||
42
src/axolotl/models/mamba/configuration_mamba.py
Normal file
42
src/axolotl/models/mamba/configuration_mamba.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
HF Transformers MambaConfig
|
||||||
|
"""
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MambaConfig(PretrainedConfig):
|
||||||
|
"""
|
||||||
|
modeling configuration for state space model/mamba
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "mamba"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=50280,
|
||||||
|
d_model=2560,
|
||||||
|
n_layer=64,
|
||||||
|
rms_norm=True,
|
||||||
|
residual_in_fp32=True,
|
||||||
|
fused_add_norm=True,
|
||||||
|
pad_vocab_size_multiple=8,
|
||||||
|
pad_token_id=50277,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=0,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.rms_norm = rms_norm
|
||||||
|
self.residual_in_fp32 = residual_in_fp32
|
||||||
|
self.fused_add_norm = fused_add_norm
|
||||||
|
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
128
src/axolotl/models/mamba/modeling_mamba.py
Normal file
128
src/axolotl/models/mamba/modeling_mamba.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
import os
|
||||||
|
from collections import namedtuple
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
|
||||||
|
from mamba_ssm.utils.generation import GenerationMixin
|
||||||
|
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
from axolotl.models.mamba.configuration_mamba import MambaConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_layer: int,
|
||||||
|
vocab_size: int,
|
||||||
|
initializer_cfg=None,
|
||||||
|
pad_vocab_size_multiple: int = 1,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
**backbone_kwargs,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
if vocab_size % pad_vocab_size_multiple != 0:
|
||||||
|
vocab_size += pad_vocab_size_multiple - (
|
||||||
|
vocab_size % pad_vocab_size_multiple
|
||||||
|
)
|
||||||
|
self.config = MambaConfig(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
d_model=d_model,
|
||||||
|
n_layer=n_layer,
|
||||||
|
pad_vocab_size_multiple=pad_vocab_size_multiple,
|
||||||
|
)
|
||||||
|
self.backbone = MixerModel(
|
||||||
|
d_model=d_model,
|
||||||
|
n_layer=n_layer,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
initializer_cfg=initializer_cfg,
|
||||||
|
**backbone_kwargs,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.apply(
|
||||||
|
partial(
|
||||||
|
_init_weights,
|
||||||
|
n_layer=n_layer,
|
||||||
|
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.tie_weights()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
self.lm_head.weight = self.backbone.embedding.weight
|
||||||
|
|
||||||
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||||
|
return self.backbone.allocate_inference_cache(
|
||||||
|
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
position_ids=None,
|
||||||
|
inference_params=None,
|
||||||
|
num_last_tokens=0,
|
||||||
|
labels=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
||||||
|
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||||
|
"""
|
||||||
|
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
||||||
|
if num_last_tokens > 0:
|
||||||
|
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||||
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||||
|
return CausalLMOutput(logits=lm_logits)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
logits = lm_logits
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
|
||||||
|
print(loss)
|
||||||
|
return CausalLMOutput(logits=lm_logits, loss=loss)
|
||||||
|
|
||||||
|
else:
|
||||||
|
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||||
|
return CausalLMOutput(logits=lm_logits)
|
||||||
|
|
||||||
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, os.PathLike],
|
||||||
|
state_dict: Optional[dict] = None,
|
||||||
|
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
||||||
|
config = load_config_hf(pretrained_model_name)
|
||||||
|
model = cls(**config, device=device, dtype=dtype, **kwargs)
|
||||||
|
model.load_state_dict(
|
||||||
|
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
|
||||||
|
)
|
||||||
|
return model
|
||||||
@@ -3,4 +3,6 @@ MixFormers model architecture used for phi models
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
|
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
|
||||||
|
from .configuration_phi import PhiConfig # noqa
|
||||||
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
|
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
|
||||||
|
from .modeling_phi import PhiForCausalLM # noqa
|
||||||
|
|||||||
65
src/axolotl/models/phi/configuration_phi.py
Normal file
65
src/axolotl/models/phi/configuration_phi.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class PhiConfig(PretrainedConfig):
|
||||||
|
"""Phi configuration."""
|
||||||
|
|
||||||
|
model_type = "phi"
|
||||||
|
attribute_map = {
|
||||||
|
"max_position_embeddings": "n_positions",
|
||||||
|
"hidden_size": "n_embd",
|
||||||
|
"num_attention_heads": "n_head",
|
||||||
|
"num_hidden_layers": "n_layer",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int = 50304,
|
||||||
|
n_positions: int = 2048,
|
||||||
|
n_embd: int = 1024,
|
||||||
|
n_layer: int = 20,
|
||||||
|
n_inner: Optional[int] = None,
|
||||||
|
n_head: int = 16,
|
||||||
|
n_head_kv: Optional[int] = None,
|
||||||
|
rotary_dim: Optional[int] = 32,
|
||||||
|
activation_function: Optional[str] = "gelu_new",
|
||||||
|
flash_attn: bool = False,
|
||||||
|
flash_rotary: bool = False,
|
||||||
|
fused_dense: bool = False,
|
||||||
|
attn_pdrop: float = 0.0,
|
||||||
|
embd_pdrop: float = 0.0,
|
||||||
|
resid_pdrop: float = 0.0,
|
||||||
|
layer_norm_epsilon: float = 1e-5,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
tie_word_embeddings: bool = False,
|
||||||
|
pad_vocab_size_multiple: int = 64,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
self.vocab_size = int(
|
||||||
|
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||||
|
)
|
||||||
|
self.n_positions = n_positions
|
||||||
|
self.n_embd = n_embd
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_inner = n_inner
|
||||||
|
self.n_head = n_head
|
||||||
|
self.n_head_kv = n_head_kv
|
||||||
|
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
||||||
|
self.activation_function = activation_function
|
||||||
|
self.flash_attn = flash_attn
|
||||||
|
self.flash_rotary = flash_rotary
|
||||||
|
self.fused_dense = fused_dense
|
||||||
|
self.attn_pdrop = attn_pdrop
|
||||||
|
self.embd_pdrop = embd_pdrop
|
||||||
|
self.resid_pdrop = resid_pdrop
|
||||||
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|
||||||
|
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||||
1063
src/axolotl/models/phi/modeling_phi.py
Normal file
1063
src/axolotl/models/phi/modeling_phi.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -82,15 +82,44 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role + ":", ""
|
yield role + ":", ""
|
||||||
return
|
return
|
||||||
if self.sep_style == SeparatorStyle.LLAMA2:
|
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
|
||||||
seps = [self.sep, self.sep2]
|
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
|
if self.messages:
|
||||||
|
# For llama, the system message is incorporated into the first human instruction
|
||||||
|
first_role, first_msg = self.messages[0]
|
||||||
|
if first_role == self.roles[0]:
|
||||||
|
system_prompt += first_msg
|
||||||
|
self.messages.pop(0)
|
||||||
yield "", system_prompt
|
yield "", system_prompt
|
||||||
else:
|
for i, (role, message) in enumerate(self.messages):
|
||||||
yield "", "[INST] "
|
|
||||||
for i, (role, message) in enumerate(self.messages[1:]):
|
|
||||||
if message:
|
if message:
|
||||||
yield role + " ", message + seps[i % 2]
|
if (i % 2 == 0 and not self.system_message) or (
|
||||||
|
i % 2 != 0 and self.system_message
|
||||||
|
):
|
||||||
|
role = "<s> " + role
|
||||||
|
yield role + " ", message
|
||||||
|
else:
|
||||||
|
yield role, ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
|
||||||
|
contains_sys_msg = False
|
||||||
|
if self.system_message:
|
||||||
|
contains_sys_msg = True
|
||||||
|
if self.messages:
|
||||||
|
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
|
||||||
|
first_role, first_msg = self.messages[0]
|
||||||
|
if first_role == self.roles[0]:
|
||||||
|
system_prompt = self.system_template.format(
|
||||||
|
system_message=" " + self.system_message
|
||||||
|
)
|
||||||
|
system_prompt += first_msg
|
||||||
|
self.messages.pop(0)
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message and i == 0 and not contains_sys_msg:
|
||||||
|
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
|
||||||
|
elif message:
|
||||||
|
yield role + " ", message
|
||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -13,12 +13,18 @@ import transformers
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||||
)
|
)
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaMLP,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
|
from xformers.ops import SwiGLU
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
@@ -38,6 +44,28 @@ except ImportError:
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
def replace_llama_mlp_with_swiglu(model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LlamaMLP):
|
||||||
|
mlp = FusedMLP(
|
||||||
|
module.config, module.gate_proj, module.up_proj, module.down_proj
|
||||||
|
)
|
||||||
|
set_module_name(model, name, mlp)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_llama_qkv_with_fused(model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LlamaAttention):
|
||||||
|
qkv = FusedAttention(
|
||||||
|
module.config,
|
||||||
|
module.q_proj,
|
||||||
|
module.k_proj,
|
||||||
|
module.v_proj,
|
||||||
|
module.o_proj,
|
||||||
|
)
|
||||||
|
set_module_name(model, name, qkv)
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_attn_with_flash_attn(
|
def replace_llama_attn_with_flash_attn(
|
||||||
packed: Optional[bool] = False,
|
packed: Optional[bool] = False,
|
||||||
cross_entropy: Optional[bool] = False,
|
cross_entropy: Optional[bool] = False,
|
||||||
@@ -86,6 +114,92 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAttention(LlamaAttention):
|
||||||
|
"""
|
||||||
|
Fused QKV Attention layer for incrementally improved training efficiency
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
q: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
k: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
v: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
o: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.init_device = next(iter(q.state_dict().values())).device
|
||||||
|
|
||||||
|
# define equivalent fused qkv projection
|
||||||
|
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
||||||
|
self.qkv_proj = torch.nn.Linear(
|
||||||
|
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
||||||
|
)
|
||||||
|
self.o_proj = o
|
||||||
|
|
||||||
|
# overwrite initialized weights with pretrained weights
|
||||||
|
self.qkv_proj.weight.data = torch.cat(
|
||||||
|
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _post_training(self, model, name):
|
||||||
|
q_proj, k_proj, v_proj = torch.split(
|
||||||
|
self.qkv_proj.weight.data, self.out_features, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
new_attn = LlamaAttention(self.config)
|
||||||
|
new_attn.q_proj.weight.data = q_proj
|
||||||
|
new_attn.k_proj.weight.data = k_proj
|
||||||
|
new_attn.v_proj.weight.data = v_proj
|
||||||
|
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
||||||
|
|
||||||
|
set_module_name(model, name, new_attn)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMLP(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Fused MLP layer for incrementally improved training efficiency
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
gate_proj: torch.nn.Linear,
|
||||||
|
up_proj: torch.nn.Linear,
|
||||||
|
down_proj: torch.nn.Linear,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.swiglu = SwiGLU(
|
||||||
|
in_features=config.hidden_size,
|
||||||
|
hidden_features=config.intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
_pack_weights=True,
|
||||||
|
)
|
||||||
|
# overwrite initialized weights with pretrained weights
|
||||||
|
self.swiglu.w12.weight.data = torch.cat(
|
||||||
|
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||||
|
)
|
||||||
|
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||||
|
|
||||||
|
def _post_training(self, model, name):
|
||||||
|
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||||
|
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign the split weights back to the original layers
|
||||||
|
new_mlp = LlamaMLP(self.config)
|
||||||
|
new_mlp.gate_proj.weight.data = w1
|
||||||
|
new_mlp.up_proj.weight.data = w2
|
||||||
|
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||||
|
|
||||||
|
set_module_name(model, name, new_mlp)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||||
|
return self.swiglu(x)
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
@@ -147,9 +261,14 @@ def flashattn_forward(
|
|||||||
value_states = torch.cat(value_states, dim=-1)
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
query_states = self.q_proj(hidden_states)
|
if isinstance(self, FusedAttention):
|
||||||
key_states = self.k_proj(hidden_states)
|
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
||||||
value_states = self.v_proj(hidden_states)
|
self.out_features, dim=-1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(
|
query_states = query_states.view(
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
@@ -202,6 +321,8 @@ def flashattn_forward(
|
|||||||
# only on first autoregressive step q,k,v have same seqlen
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
is_causal = key_states.shape == query_states.shape
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
|
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
qkv = torch.stack(
|
qkv = torch.stack(
|
||||||
@@ -211,7 +332,12 @@ def flashattn_forward(
|
|||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
qkv,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=True,
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
elif query_states.shape == key_states.shape:
|
elif query_states.shape == key_states.shape:
|
||||||
@@ -234,7 +360,7 @@ def flashattn_forward(
|
|||||||
qkv_unpad,
|
qkv_unpad,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
0.0,
|
dropout_p=dropout_rate,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
)
|
)
|
||||||
@@ -247,6 +373,7 @@ def flashattn_forward(
|
|||||||
output = flash_attn_kvpacked_func(
|
output = flash_attn_kvpacked_func(
|
||||||
query_states,
|
query_states,
|
||||||
torch.stack([key_states, value_states], 2),
|
torch.stack([key_states, value_states], 2),
|
||||||
|
dropout_p=dropout_rate,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -279,7 +406,7 @@ def flashattn_forward(
|
|||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
0.0,
|
dropout_p=dropout_rate,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ def sdp_attention_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ def xformers_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
"""
|
|
||||||
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
def noised_embed(orig_embed, noise_alpha, model):
|
|
||||||
def new_func(input_ids):
|
|
||||||
# during training, we add noise to the embedding
|
|
||||||
# during generation, we don't add noise to the embedding
|
|
||||||
if model.training:
|
|
||||||
embed_init = orig_embed(input_ids)
|
|
||||||
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
|
||||||
mag_norm = noise_alpha / torch.sqrt(dims)
|
|
||||||
return embed_init + torch.zeros_like(embed_init).uniform_(
|
|
||||||
-mag_norm, mag_norm
|
|
||||||
)
|
|
||||||
return orig_embed(input_ids)
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
|
|
||||||
def post_init(orig_post_init):
|
|
||||||
def new_func(self):
|
|
||||||
orig_post_init(self)
|
|
||||||
self.embed_tokens.forward = noised_embed(
|
|
||||||
self.embed_tokens.forward, noise_alpha, self
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
|
|
||||||
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
|
|
||||||
transformers.models.llama.modeling_llama.LlamaModel.post_init
|
|
||||||
)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -14,6 +14,9 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
|||||||
flash_attn_varlen_qkvpacked_func,
|
flash_attn_varlen_qkvpacked_func,
|
||||||
)
|
)
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
MistralAttention as OriginalMistralAttention,
|
||||||
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
)
|
)
|
||||||
@@ -42,6 +45,44 @@ def replace_mistral_attn_with_flash_attn(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def _make_sliding_window_causal_mask(
|
||||||
|
bsz: int,
|
||||||
|
tgt_len: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
past_key_values_length: int = 0,
|
||||||
|
sliding_window: int = 4096,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Make causal mask used for sliding window attention
|
||||||
|
"""
|
||||||
|
tensor = torch.full(
|
||||||
|
(tgt_len, tgt_len),
|
||||||
|
fill_value=1,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
mask = torch.tril(tensor, diagonal=0)
|
||||||
|
# make the mask banded to account for sliding window
|
||||||
|
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
||||||
|
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
||||||
|
mask = torch.log(mask).to(dtype)
|
||||||
|
|
||||||
|
if past_key_values_length > 0:
|
||||||
|
mask = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(
|
||||||
|
tgt_len, past_key_values_length, dtype=dtype, device=device
|
||||||
|
),
|
||||||
|
mask,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return mask[None, None, :, :].expand(
|
||||||
|
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
@@ -53,11 +94,29 @@ def _prepare_decoder_attention_mask(
|
|||||||
sliding_window,
|
sliding_window,
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
# [bsz, seq_len]
|
# [bsz, seq_len]
|
||||||
|
if attention_mask is None:
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
||||||
|
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
||||||
|
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
||||||
|
sliding_window_mask = _make_sliding_window_causal_mask(
|
||||||
|
bsz=input_shape[0],
|
||||||
|
tgt_len=input_shape[1],
|
||||||
|
dtype=inputs_embeds.dtype,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
)
|
||||||
|
attention_mask = attention_mask + sliding_window_mask
|
||||||
|
else:
|
||||||
|
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
||||||
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
def flashattn_forward(
|
def flashattn_forward(
|
||||||
self,
|
self: OriginalMistralAttention,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@@ -91,10 +150,41 @@ def flashattn_forward(
|
|||||||
query_states, key_states, cos, sin, position_ids
|
query_states, key_states, cos, sin, position_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_sliding_windows = (
|
||||||
|
hasattr(self.config, "sliding_window") is not None
|
||||||
|
and kv_seq_len > self.config.sliding_window
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_sliding_windows:
|
||||||
|
window_size = (self.config.sliding_window, self.config.sliding_window)
|
||||||
|
else:
|
||||||
|
window_size = (-1, -1)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# reuse k, v, self_attention
|
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
if (
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
hasattr(self.config, "sliding_window")
|
||||||
|
and kv_seq_len > self.config.sliding_window
|
||||||
|
):
|
||||||
|
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||||
|
|
||||||
|
past_key = past_key_value[0]
|
||||||
|
past_value = past_key_value[1]
|
||||||
|
|
||||||
|
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
|
||||||
|
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||||
|
f" {past_key.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_value = (past_key, past_value) if use_cache else None
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
@@ -111,6 +201,8 @@ def flashattn_forward(
|
|||||||
# only on first autoregressive step q,k,v have same seqlen
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
is_causal = key_states.shape == query_states.shape
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
|
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
qkv = torch.stack(
|
qkv = torch.stack(
|
||||||
@@ -120,7 +212,13 @@ def flashattn_forward(
|
|||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
qkv,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
elif query_states.shape == key_states.shape:
|
elif query_states.shape == key_states.shape:
|
||||||
@@ -143,9 +241,10 @@ def flashattn_forward(
|
|||||||
qkv_unpad,
|
qkv_unpad,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
0.0,
|
dropout_p=dropout_rate,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
|
window_size=window_size,
|
||||||
)
|
)
|
||||||
output = output_pad_fn(output_unpad)
|
output = output_pad_fn(output_unpad)
|
||||||
else:
|
else:
|
||||||
@@ -156,7 +255,9 @@ def flashattn_forward(
|
|||||||
output = flash_attn_kvpacked_func(
|
output = flash_attn_kvpacked_func(
|
||||||
query_states,
|
query_states,
|
||||||
torch.stack([key_states, value_states], 2),
|
torch.stack([key_states, value_states], 2),
|
||||||
|
dropout_p=dropout_rate,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
|
window_size=window_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
( # pylint: disable=unbalanced-tuple-unpacking
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
@@ -188,9 +289,10 @@ def flashattn_forward(
|
|||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
0.0,
|
dropout_p=dropout_rate,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
|
window_size=window_size,
|
||||||
)
|
)
|
||||||
output = output_pad_fn(output_unpad)
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
"""
|
|
||||||
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers.models.mistral.modeling_mistral
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
def noised_embed(orig_embed, noise_alpha, model):
|
|
||||||
def new_func(input_ids):
|
|
||||||
# during training, we add noise to the embedding
|
|
||||||
# during generation, we don't add noise to the embedding
|
|
||||||
if model.training:
|
|
||||||
embed_init = orig_embed(input_ids)
|
|
||||||
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
|
||||||
mag_norm = noise_alpha / torch.sqrt(dims)
|
|
||||||
return embed_init + torch.zeros_like(embed_init).uniform_(
|
|
||||||
-mag_norm, mag_norm
|
|
||||||
)
|
|
||||||
return orig_embed(input_ids)
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
|
|
||||||
def post_init(orig_post_init):
|
|
||||||
def new_func(self):
|
|
||||||
orig_post_init(self)
|
|
||||||
self.embed_tokens.forward = noised_embed(
|
|
||||||
self.embed_tokens.forward, noise_alpha, self
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
|
|
||||||
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
|
|
||||||
transformers.models.mistral.modeling_mistral.MistralModel.post_init
|
|
||||||
)
|
|
||||||
22
src/axolotl/monkeypatch/mixtral/__init__.py
Normal file
22
src/axolotl/monkeypatch/mixtral/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""
|
||||||
|
Patches to support multipack for mixtral
|
||||||
|
"""
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
|
def replace_mixtral_attn_with_multipack_flash_attn():
|
||||||
|
from .modeling_mixtral import (
|
||||||
|
MixtralMultipackFlashAttention2,
|
||||||
|
mixtral_decoder_layer_forward,
|
||||||
|
mixtral_model_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
|
||||||
|
mixtral_decoder_layer_forward
|
||||||
|
)
|
||||||
|
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
||||||
|
mixtral_model_forward
|
||||||
|
)
|
||||||
|
transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
|
||||||
|
"flash_attention_2"
|
||||||
|
] = MixtralMultipackFlashAttention2
|
||||||
379
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
Normal file
379
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
"""
|
||||||
|
Mixtral modeling for multipack
|
||||||
|
"""
|
||||||
|
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||||
|
from transformers import Cache, DynamicCache
|
||||||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
|
from transformers.modeling_outputs import MoeModelOutputWithPast
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import (
|
||||||
|
MixtralFlashAttention2,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
|
||||||
|
"""
|
||||||
|
Custom multipack implementation w flash attention 2
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._flash_attn_uses_top_left_mask = True
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
|
# special handling using sample packing
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
|
attn_output = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
dropout_p=self.attention_dropout,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def mixtral_decoder_layer_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
output_router_logits: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_router_logits (`bool`, *optional*):
|
||||||
|
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
||||||
|
should not be returned during inference.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
if output_router_logits:
|
||||||
|
outputs += (router_logits,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def mixtral_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, MoeModelOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_router_logits = (
|
||||||
|
output_router_logits
|
||||||
|
if output_router_logits is not None
|
||||||
|
else self.config.output_router_logits
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||||
|
)
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
|
if use_legacy_cache:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
|
|
||||||
|
cu_seqlens = None
|
||||||
|
max_seqlen = None
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
cu_seqlens = cu_seqlens.squeeze()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||||
|
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||||
|
if is_padding_right:
|
||||||
|
raise ValueError(
|
||||||
|
"You are attempting to perform batched generation with padding_side='right'"
|
||||||
|
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
||||||
|
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._use_flash_attention_2:
|
||||||
|
# 2d mask is passed through the layers
|
||||||
|
attention_mask = (
|
||||||
|
attention_mask
|
||||||
|
if (attention_mask is not None and 0 in attention_mask)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 4d mask is passed through the layers
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
LOG.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
all_router_logits = () if output_router_logits else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
output_router_logits,
|
||||||
|
use_cache,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_router_logits=output_router_logits,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_router_logits:
|
||||||
|
all_router_logits += (layer_outputs[-1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = (
|
||||||
|
next_decoder_cache.to_legacy_cache()
|
||||||
|
if use_legacy_cache
|
||||||
|
else next_decoder_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attns,
|
||||||
|
all_router_logits,
|
||||||
|
]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
return MoeModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
router_logits=all_router_logits,
|
||||||
|
)
|
||||||
@@ -101,3 +101,16 @@ def get_cu_seqlens_from_pos_ids(position_ids):
|
|||||||
max_seq_lens.append(max_seq_len)
|
max_seq_lens.append(max_seq_len)
|
||||||
|
|
||||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def set_module_name(model, name, value):
|
||||||
|
if "." in name:
|
||||||
|
parent_name = name.rsplit(".", 1)[0]
|
||||||
|
child_name = name[len(parent_name) + 1 :]
|
||||||
|
parent = model.get_submodule(parent_name)
|
||||||
|
else:
|
||||||
|
parent_name = ""
|
||||||
|
parent = model
|
||||||
|
child_name = name
|
||||||
|
|
||||||
|
setattr(parent, child_name, value)
|
||||||
|
|||||||
@@ -1,94 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
"""
|
|
||||||
Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
|
|
||||||
class XposRotaryEmbedding(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
max_position_embeddings=2048,
|
|
||||||
base=10000,
|
|
||||||
device=None,
|
|
||||||
scale_base=2048,
|
|
||||||
use_xpos=True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.max_seq_len_cached = max_position_embeddings
|
|
||||||
self.scale_base = scale_base
|
|
||||||
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
||||||
t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
|
|
||||||
freqs = torch.einsum("i , j -> i j", t, inv_freq)
|
|
||||||
freqs = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
||||||
self.register_buffer("freqs_cached", freqs, persistent=False)
|
|
||||||
|
|
||||||
if not use_xpos:
|
|
||||||
self.register_buffer("scale", None)
|
|
||||||
self.register_buffer("scale_cached", torch.ones(1))
|
|
||||||
return
|
|
||||||
|
|
||||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
|
||||||
power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
|
|
||||||
scale_cached = scale ** rearrange(power, "n -> n 1")
|
|
||||||
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
|
||||||
|
|
||||||
self.register_buffer("scale", scale, persistent=False)
|
|
||||||
self.register_buffer("scale_cached", scale_cached, persistent=False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
seq_len,
|
|
||||||
):
|
|
||||||
if seq_len > self.max_seq_len_cached:
|
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
|
|
||||||
self.inv_freq
|
|
||||||
)
|
|
||||||
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
|
|
||||||
freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
|
|
||||||
|
|
||||||
self.register_buffer("freqs_cached", freqs)
|
|
||||||
|
|
||||||
if self.scale is None:
|
|
||||||
self.register_buffer(
|
|
||||||
"scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
|
|
||||||
|
|
||||||
power = (t - (seq_len // 2)) / self.scale_base
|
|
||||||
scale = self.scale ** rearrange(power, "n -> n 1")
|
|
||||||
scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
|
|
||||||
self.register_buffer("scale_cached", scale)
|
|
||||||
|
|
||||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
x1, x2 = x.chunk(2, dim=-1)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
|
|
||||||
freqs = freqs[position_ids, :]
|
|
||||||
if scale.shape[-1] != 1:
|
|
||||||
scale = scale[position_ids, :]
|
|
||||||
|
|
||||||
q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
|
|
||||||
k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
|
|
||||||
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_rope_with_xpos_rope():
|
|
||||||
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
|
|
||||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
"""Module for Alpaca prompt strategy classes"""
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
AlpacaPromptTokenizingStrategy,
|
AlpacaPromptTokenizingStrategy,
|
||||||
@@ -9,9 +9,13 @@ from axolotl.prompt_tokenizers import (
|
|||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
prompt_style = PromptStyle.CHAT.value
|
||||||
|
if ds_cfg and "conversation" in ds_cfg:
|
||||||
|
prompt_style = ds_cfg["conversation"]
|
||||||
|
|
||||||
return AlpacaPromptTokenizingStrategy(
|
return AlpacaPromptTokenizingStrategy(
|
||||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
AlpacaPrompter(prompt_style),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
|
|||||||
@@ -81,8 +81,9 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.sequence_len = 4096
|
self.tokenizer.add_special_tokens(
|
||||||
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
{"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")}
|
||||||
|
)
|
||||||
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
|
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ register_conv_template(
|
|||||||
system_message="You are a helpful assistant.",
|
system_message="You are a helpful assistant.",
|
||||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||||
sep_style=SeparatorStyle.CHATML,
|
sep_style=SeparatorStyle.CHATML,
|
||||||
sep="<|im_end|>\n",
|
sep="<|im_end|>",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
)
|
)
|
||||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
return SimpleShareGPTPromptTokenizingStrategy(
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompterV2(
|
ShareGPTPrompterV2(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
role_key_model=field_model,
|
role_key_model=field_model,
|
||||||
@@ -34,6 +34,26 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
if ds_cfg and "strict" in ds_cfg:
|
||||||
|
strategy.strict = ds_cfg["strict"]
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
|
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
conversation = (
|
||||||
|
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||||
|
)
|
||||||
|
strategy = UltrachatShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompterV2(
|
||||||
|
conversation=conversation,
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
if ds_cfg and "strict" in ds_cfg:
|
||||||
|
strategy.strict = ds_cfg["strict"]
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
def load_role(tokenizer, cfg):
|
def load_role(tokenizer, cfg):
|
||||||
@@ -59,8 +79,26 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
basic sharegpt strategy to grab conversations from the sample row
|
basic sharegpt strategy to grab conversations from the sample row
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_strict = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def strict(self):
|
||||||
|
return self._strict
|
||||||
|
|
||||||
|
@strict.setter
|
||||||
|
def strict(self, strict):
|
||||||
|
self._strict = strict
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
return prompt["conversations"]
|
conversations = prompt["conversations"]
|
||||||
|
if self.strict:
|
||||||
|
return conversations
|
||||||
|
# remap roles - allow for assistant turn
|
||||||
|
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
|
||||||
|
turns = [
|
||||||
|
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
|
||||||
|
]
|
||||||
|
return turns
|
||||||
|
|
||||||
|
|
||||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
@@ -88,3 +126,17 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||||
]
|
]
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
|
||||||
|
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
sharegpt strategy that remaps ultrachat data to sharegpt format
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_conversation_thread(self, prompt):
|
||||||
|
conversations = prompt["messages"]
|
||||||
|
role_map = {"user": "human", "assistant": "gpt"}
|
||||||
|
turns = [
|
||||||
|
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
||||||
|
]
|
||||||
|
return turns
|
||||||
|
|||||||
@@ -45,6 +45,8 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
self.prompter = prompter
|
self.prompter = prompter
|
||||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||||
self.train_on_inputs = train_on_inputs
|
self.train_on_inputs = train_on_inputs
|
||||||
|
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
|
||||||
|
# TODO: Document how they are different.
|
||||||
self.sequence_len = sequence_len
|
self.sequence_len = sequence_len
|
||||||
self.max_length = sequence_len
|
self.max_length = sequence_len
|
||||||
|
|
||||||
@@ -59,34 +61,31 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
def _tokenize(
|
def _tokenize(
|
||||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
result: BatchEncoding
|
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||||
if not prompt:
|
if not prompt:
|
||||||
LOG.warning("Empty text requested for tokenization.")
|
LOG.warning("Empty text requested for tokenization.")
|
||||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
return empty
|
||||||
else:
|
|
||||||
result = self.tokenizer(
|
result = self.tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
padding=False,
|
padding=False,
|
||||||
return_tensors=None,
|
return_tensors=None,
|
||||||
)
|
)
|
||||||
if len(result["input_ids"]) == 0:
|
if len(result["input_ids"]) == 0:
|
||||||
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
||||||
|
return empty
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(result["input_ids"]) > 0
|
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
|
||||||
and len(result["input_ids"]) < self.max_length
|
and len(result["input_ids"]) < self.max_length
|
||||||
and add_eos_token
|
and add_eos_token
|
||||||
):
|
):
|
||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||||
result["attention_mask"].append(1)
|
result["attention_mask"].append(1)
|
||||||
|
|
||||||
if (
|
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||||
len(result["input_ids"]) > 0
|
|
||||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
|
||||||
and strip_bos_token
|
|
||||||
):
|
|
||||||
result["input_ids"] = result["input_ids"][1:]
|
result["input_ids"] = result["input_ids"][1:]
|
||||||
result["attention_mask"] = result["attention_mask"][1:]
|
result["attention_mask"] = result["attention_mask"][1:]
|
||||||
|
|
||||||
@@ -122,7 +121,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
if not self.train_on_inputs:
|
if not self.train_on_inputs:
|
||||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||||
# TODO this could be sped up using numpy array slicing
|
# TODO this could be sped up using numpy array slicing
|
||||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
|
||||||
tokenized_res_prompt = self._tokenize(
|
tokenized_res_prompt = self._tokenize(
|
||||||
response, strip_bos_token=True, add_eos_token=True
|
response, strip_bos_token=True, add_eos_token=True
|
||||||
)
|
)
|
||||||
@@ -246,6 +245,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
(
|
(
|
||||||
instruction,
|
instruction,
|
||||||
input, # pylint: disable=redefined-builtin
|
input, # pylint: disable=redefined-builtin
|
||||||
@@ -270,7 +270,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||||
# TODO this could be sped up using numpy array slicing
|
# TODO this could be sped up using numpy array slicing
|
||||||
tokenized_full_prompt["labels"] = [
|
tokenized_full_prompt["labels"] = [
|
||||||
-100
|
IGNORE_INDEX
|
||||||
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
||||||
|
|
||||||
return tokenized_full_prompt
|
return tokenized_full_prompt
|
||||||
@@ -334,6 +334,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
return prompt["conversations"]
|
return prompt["conversations"]
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
|
# Initial values. We will append to these as we go through the conversation.
|
||||||
result, current_len = tokenize_prompt_default()
|
result, current_len = tokenize_prompt_default()
|
||||||
conversation: Conversation = (
|
conversation: Conversation = (
|
||||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||||
@@ -355,62 +356,67 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
for _, part in enumerate(
|
for _, part in enumerate(
|
||||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||||
):
|
):
|
||||||
if isinstance(part, tuple):
|
if not isinstance(part, tuple):
|
||||||
if conversation.roles[0] in part[0]:
|
LOG.warning(f"expected tuple, got {part}")
|
||||||
role = (
|
continue
|
||||||
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
|
|
||||||
if role_remap
|
user, assistant = conversation.roles
|
||||||
else part[0]
|
role, content = part
|
||||||
)
|
|
||||||
turn = role + part[1]
|
# Uses "in" because role contains extra characters
|
||||||
# this is still the user query, we should
|
if user in role:
|
||||||
if not part[1].strip():
|
role = (
|
||||||
LOG.warning(f"user turn has empty text: {prompt}")
|
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||||
res = self._tokenize(
|
if role_remap
|
||||||
turn,
|
else role
|
||||||
add_eos_token=False,
|
)
|
||||||
strip_bos_token=True,
|
turn = role + content
|
||||||
)
|
# this is still the user query, we should
|
||||||
# everything from this is masked out from the labels
|
if not content.strip():
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
LOG.warning(f"user turn has empty text: {prompt}")
|
||||||
elif conversation.roles[1] in part[0]:
|
res = self._tokenize(
|
||||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
turn,
|
||||||
role = (
|
add_eos_token=False,
|
||||||
part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
|
strip_bos_token=True,
|
||||||
if role_remap
|
)
|
||||||
else part[0]
|
# everything from this is masked out from the labels
|
||||||
)
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
turn = role + part[1]
|
elif assistant in role:
|
||||||
# this should be the assistant response, should end with an eos token
|
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||||
if not part[1].strip():
|
role = (
|
||||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||||
res = self._tokenize(
|
if role_remap
|
||||||
turn,
|
else role
|
||||||
add_eos_token=True,
|
)
|
||||||
strip_bos_token=True,
|
turn = role + content
|
||||||
)
|
# this should be the assistant response, should end with an eos token
|
||||||
role_res = self._tokenize(
|
if not content.strip():
|
||||||
role.rstrip(),
|
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||||
add_eos_token=False,
|
res = self._tokenize(
|
||||||
strip_bos_token=True,
|
turn,
|
||||||
)
|
add_eos_token=True,
|
||||||
# not masked out from labels
|
strip_bos_token=True,
|
||||||
labels = copy.deepcopy(res["input_ids"])
|
)
|
||||||
len_role = len(role_res["input_ids"])
|
role_res = self._tokenize(
|
||||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
role.rstrip(),
|
||||||
len_role, len(labels)
|
add_eos_token=False,
|
||||||
)
|
strip_bos_token=True,
|
||||||
elif part[0] == "":
|
)
|
||||||
turn = part[1]
|
# not masked out from labels
|
||||||
# this is only ever the first part, should include the bos token and the user query
|
labels = copy.deepcopy(res["input_ids"])
|
||||||
res = self._tokenize(
|
len_role = len(role_res["input_ids"])
|
||||||
turn, add_eos_token=False, strip_bos_token=False
|
labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
|
||||||
)
|
elif role == "":
|
||||||
# everything from this is masked out from the labels
|
turn = content
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
# this is only ever the first part, should include the bos token and the user query
|
||||||
else:
|
res = self._tokenize(
|
||||||
LOG.warning(f"unhandled role: {part[0]}")
|
turn, add_eos_token=False, strip_bos_token=False
|
||||||
continue
|
)
|
||||||
|
# everything from this is masked out from the labels
|
||||||
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
|
else:
|
||||||
|
LOG.warning(f"unhandled role: {role}")
|
||||||
|
continue
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
result, current_len = parse_tokenized_to_result(
|
result, current_len = parse_tokenized_to_result(
|
||||||
@@ -424,38 +430,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
except (KeyError, AssertionError, IndexError) as err:
|
except (KeyError, AssertionError, IndexError) as err:
|
||||||
raise InvalidDataException(str(err)) from err
|
raise InvalidDataException(str(err)) from err
|
||||||
|
|
||||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
|
||||||
if not prompt.strip():
|
|
||||||
LOG.warning("Empty text requested for tokenization.")
|
|
||||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
|
||||||
else:
|
|
||||||
result = self.tokenizer(
|
|
||||||
prompt,
|
|
||||||
truncation=True,
|
|
||||||
max_length=self.sequence_len,
|
|
||||||
padding=False,
|
|
||||||
return_tensors=None,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
len(result["input_ids"]) > 0
|
|
||||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
|
||||||
and len(result["input_ids"]) < self.sequence_len
|
|
||||||
and add_eos_token
|
|
||||||
):
|
|
||||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
|
||||||
result["attention_mask"].append(1)
|
|
||||||
|
|
||||||
if (
|
|
||||||
len(result["input_ids"]) > 0
|
|
||||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
|
||||||
and strip_bos_token
|
|
||||||
):
|
|
||||||
result["input_ids"] = result["input_ids"][1:]
|
|
||||||
result["attention_mask"] = result["attention_mask"][1:]
|
|
||||||
|
|
||||||
result["labels"] = result["input_ids"].copy()
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,10 +4,12 @@ import logging
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Generator, Optional, Union
|
from typing import Generator, Optional, Union
|
||||||
|
|
||||||
|
from colorama import Fore
|
||||||
from fastchat.conversation import Conversation, get_conv_template
|
from fastchat.conversation import Conversation, get_conv_template
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
|
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
|
||||||
|
|
||||||
|
|
||||||
class PromptStyle(Enum):
|
class PromptStyle(Enum):
|
||||||
@@ -20,13 +22,19 @@ class PromptStyle(Enum):
|
|||||||
CHATML = "chatml"
|
CHATML = "chatml"
|
||||||
|
|
||||||
|
|
||||||
class AlpacaPrompter:
|
class Prompter:
|
||||||
|
"""
|
||||||
|
Base prompter class for all prompters
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AlpacaPrompter(Prompter):
|
||||||
"""
|
"""
|
||||||
Base class for alpaca prompters
|
Base class for alpaca prompters
|
||||||
"""
|
"""
|
||||||
|
|
||||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
|
||||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||||
system_format: str = "{system}"
|
system_format: str = "{system}"
|
||||||
turn_format: str
|
turn_format: str
|
||||||
turn_no_input_format: str
|
turn_no_input_format: str
|
||||||
@@ -55,29 +63,38 @@ class AlpacaPrompter:
|
|||||||
)
|
)
|
||||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||||
|
|
||||||
|
def _build_result(self, instruction, input_text, output):
|
||||||
|
# returns the full prompt from instruction and optional input
|
||||||
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
|
if input_text:
|
||||||
|
res = (
|
||||||
|
self.system_format.format(system=self.system_prompt)
|
||||||
|
if self.system_prompt
|
||||||
|
else ""
|
||||||
|
) + self.turn_format.format(instruction=instruction, input=input_text)
|
||||||
|
else:
|
||||||
|
res = (
|
||||||
|
self.system_format.format(system=self.system_no_input_prompt)
|
||||||
|
if self.system_no_input_prompt
|
||||||
|
else ""
|
||||||
|
) + self.turn_no_input_format.format(instruction=instruction)
|
||||||
|
if output:
|
||||||
|
res = f"{res}{output}"
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
self,
|
self,
|
||||||
instruction: str,
|
instruction: str,
|
||||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||||
output: Union[None, str] = None,
|
output: Union[None, str] = None,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
# returns the full prompt from instruction and optional input
|
yield self._build_result(instruction, input, output)
|
||||||
# if a label (=response, =output) is provided, it's also appended.
|
|
||||||
if input:
|
def __repr__(self) -> str:
|
||||||
res = (
|
return REPR_TEMPLATE.format(
|
||||||
self.system_format.format(system=self.system_prompt)
|
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||||
if self.system_prompt
|
)
|
||||||
else ""
|
|
||||||
) + self.turn_format.format(instruction=instruction, input=input)
|
|
||||||
else:
|
|
||||||
res = (
|
|
||||||
self.system_format.format(system=self.system_no_input_prompt)
|
|
||||||
if self.system_prompt
|
|
||||||
else ""
|
|
||||||
) + self.turn_no_input_format.format(instruction=instruction)
|
|
||||||
if output:
|
|
||||||
res = f"{res}{output}"
|
|
||||||
yield res
|
|
||||||
|
|
||||||
|
|
||||||
class UnpromptedPrompter(AlpacaPrompter):
|
class UnpromptedPrompter(AlpacaPrompter):
|
||||||
@@ -148,7 +165,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ReflectAlpacaPrompter:
|
class ReflectAlpacaPrompter(Prompter):
|
||||||
"""
|
"""
|
||||||
Prompter for ReflectAlpaca
|
Prompter for ReflectAlpaca
|
||||||
"""
|
"""
|
||||||
@@ -191,14 +208,14 @@ class ReflectAlpacaPrompter:
|
|||||||
)
|
)
|
||||||
self.response_split = "ASSISTANT:"
|
self.response_split = "ASSISTANT:"
|
||||||
|
|
||||||
def build_prompt(
|
def _build_result(
|
||||||
self,
|
self,
|
||||||
instruction: str,
|
instruction: str,
|
||||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||||
output: Union[None, str] = None,
|
output: Union[None, str] = None,
|
||||||
reflection: Union[None, str] = None,
|
reflection: Union[None, str] = None,
|
||||||
corrected: Union[None, str] = None,
|
corrected: Union[None, str] = None,
|
||||||
) -> Generator[str, None, None]:
|
):
|
||||||
# returns the full prompt from instruction and optional input
|
# returns the full prompt from instruction and optional input
|
||||||
# if a label (=response, =output) is provided, it's also appended.
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
if input:
|
if input:
|
||||||
@@ -212,7 +229,30 @@ class ReflectAlpacaPrompter:
|
|||||||
corrected=corrected,
|
corrected=corrected,
|
||||||
)
|
)
|
||||||
res = f"{res}{label}"
|
res = f"{res}{label}"
|
||||||
yield res
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
self,
|
||||||
|
instruction: str,
|
||||||
|
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||||
|
output: Union[None, str] = None,
|
||||||
|
reflection: Union[None, str] = None,
|
||||||
|
corrected: Union[None, str] = None,
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
yield self._build_result(
|
||||||
|
instruction,
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
reflection,
|
||||||
|
corrected,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return REPR_TEMPLATE.format(
|
||||||
|
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||||
@@ -220,7 +260,7 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||||
"""
|
"""
|
||||||
A prompter that generates prompts for the ShareGPT
|
A prompter that generates prompts for the ShareGPT
|
||||||
"""
|
"""
|
||||||
@@ -247,7 +287,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
if role_key_model:
|
if role_key_model:
|
||||||
self.role_key_model = role_key_model
|
self.role_key_model = role_key_model
|
||||||
|
|
||||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
def _build_result(self, source):
|
||||||
if len(source) < 2:
|
if len(source) < 2:
|
||||||
# If there isn't a back and forth conversation, ignore it
|
# If there isn't a back and forth conversation, ignore it
|
||||||
# also happens on the data splitting leaving empty conversations
|
# also happens on the data splitting leaving empty conversations
|
||||||
@@ -282,11 +322,20 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
for part in conv.get_turns():
|
return conv.get_turns()
|
||||||
|
|
||||||
|
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||||
|
turns = self._build_result(source)
|
||||||
|
|
||||||
|
for part in turns:
|
||||||
if part[0] and not part[1]:
|
if part[0] and not part[1]:
|
||||||
LOG.warning(f"role with empty message: {part[0]}")
|
LOG.warning(f"role with empty message: {part[0]}")
|
||||||
yield part
|
yield part
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
||||||
|
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||||
"""
|
"""
|
||||||
@@ -304,3 +353,15 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
|||||||
role_key_human=role_key_human,
|
role_key_human=role_key_human,
|
||||||
role_key_model=role_key_model,
|
role_key_model=role_key_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedPrompter(Prompter):
|
||||||
|
"""
|
||||||
|
A dummy class for custom prompters
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "Pre-tokenized or custom dataset types are unsupported for logging"
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@@ -10,12 +9,16 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
|
from accelerate.logging import get_logger
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
|
from pkg_resources import get_distribution # type: ignore
|
||||||
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.freeze import freeze_parameters_except
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
@@ -24,7 +27,7 @@ src_dir = os.path.join(project_root, "src")
|
|||||||
sys.path.insert(0, src_dir)
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = logging.getLogger("axolotl.train")
|
LOG = get_logger("axolotl.train")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -39,13 +42,13 @@ class TrainDatasetMeta:
|
|||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
*,
|
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs,
|
|
||||||
dataset_meta: TrainDatasetMeta,
|
|
||||||
):
|
):
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
LOG.debug(
|
||||||
|
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
train_dataset = dataset_meta.train_dataset
|
train_dataset = dataset_meta.train_dataset
|
||||||
@@ -53,7 +56,10 @@ def train(
|
|||||||
total_num_steps = dataset_meta.total_num_steps
|
total_num_steps = dataset_meta.total_num_steps
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
LOG.info("loading model and (optionally) peft_config...")
|
msg = "loading model"
|
||||||
|
if cfg.adapter:
|
||||||
|
msg += " and peft_config..."
|
||||||
|
LOG.debug(msg)
|
||||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
@@ -73,11 +79,15 @@ def train(
|
|||||||
)
|
)
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||||
|
|
||||||
|
if cfg.unfrozen_parameters:
|
||||||
|
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||||
|
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
model.config.use_cache = False
|
if hasattr(model, "config"):
|
||||||
|
model.config.use_cache = False
|
||||||
|
|
||||||
# go ahead and presave, so we have the adapter config available to inspect
|
# go ahead and presave, so we have the adapter config available to inspect
|
||||||
if peft_config:
|
if peft_config:
|
||||||
@@ -87,7 +97,8 @@ def train(
|
|||||||
if not Path(cfg.output_dir).is_dir():
|
if not Path(cfg.output_dir).is_dir():
|
||||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
if hasattr(model, "config"):
|
||||||
|
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
|
|
||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
@@ -105,10 +116,17 @@ def train(
|
|||||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||||
|
|
||||||
|
if getattr(cfg, "axolotl_config_path"):
|
||||||
|
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
||||||
|
version = get_distribution("axolotl").version
|
||||||
|
if raw_axolotl_cfg.is_file():
|
||||||
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
||||||
|
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
|
|
||||||
|
pretrain_hooks(cfg, trainer)
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||||
@@ -116,9 +134,15 @@ def train(
|
|||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
else:
|
else:
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
post_train_hooks(cfg, trainer)
|
||||||
|
|
||||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||||
|
|
||||||
|
# post training
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if hasattr(module, "_post_training"):
|
||||||
|
module._post_training(model, name) # pylint: disable=protected-access
|
||||||
|
|
||||||
if trainer.is_fsdp_enabled:
|
if trainer.is_fsdp_enabled:
|
||||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
||||||
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
||||||
@@ -134,6 +158,22 @@ def train(
|
|||||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
trainer.save_model(cfg.output_dir)
|
trainer.save_model(cfg.output_dir)
|
||||||
|
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||||
|
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||||
|
trainer.accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
||||||
|
|
||||||
|
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
||||||
|
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
||||||
|
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
||||||
|
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
|
||||||
|
# The model name saved is `pytorch_model.bin`
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
cfg.output_dir,
|
||||||
|
is_main_process=trainer.accelerator.is_main_process,
|
||||||
|
save_function=trainer.accelerator.save,
|
||||||
|
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
||||||
|
)
|
||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
@@ -144,3 +184,21 @@ def train(
|
|||||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def pretrain_hooks(_cfg, _trainer):
|
||||||
|
"""
|
||||||
|
Run hooks right before kicking off the training
|
||||||
|
:param cfg:
|
||||||
|
:param trainer:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def post_train_hooks(_cfg, _trainer):
|
||||||
|
"""
|
||||||
|
Run hooks right after training completes
|
||||||
|
:param cfg:
|
||||||
|
:param trainer:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from shutil import copyfile
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import TYPE_CHECKING, Dict, List
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
@@ -37,7 +39,7 @@ from axolotl.utils.distributed import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.utils.trainer import AxolotlTrainingArguments
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
@@ -124,6 +126,36 @@ class GPUStatsCallback(
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class LossWatchDogCallback(TrainerCallback):
|
||||||
|
"""Callback to track loss and stop training if loss is too high"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.logged = False
|
||||||
|
self.violations = 0
|
||||||
|
self.threshold = cfg.loss_watchdog_threshold
|
||||||
|
self.patience = cfg.loss_watchdog_patience or 3
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self,
|
||||||
|
_args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**_kwargs,
|
||||||
|
):
|
||||||
|
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
|
||||||
|
if state.log_history[-1]["loss"] > self.threshold:
|
||||||
|
self.violations += 1
|
||||||
|
if self.violations >= self.patience:
|
||||||
|
LOG.warning(
|
||||||
|
"Loss is too high, stopping training (loss_watchdog_threshold)"
|
||||||
|
)
|
||||||
|
control.should_training_stop = True
|
||||||
|
else:
|
||||||
|
self.violations = 0
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
def bench_eval_callback_factory(trainer, tokenizer):
|
def bench_eval_callback_factory(trainer, tokenizer):
|
||||||
accuracy = evaluate.load("accuracy")
|
accuracy = evaluate.load("accuracy")
|
||||||
abcd_idx = [
|
abcd_idx = [
|
||||||
@@ -531,10 +563,15 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
):
|
):
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
try:
|
try:
|
||||||
artifact = wandb.Artifact(name="axolotl-config", type="config")
|
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
||||||
artifact.add_file(local_path=self.axolotl_config_path)
|
with NamedTemporaryFile(
|
||||||
wandb.run.log_artifact(artifact)
|
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||||
LOG.info("Axolotl config has been saved to WandB as an artifact.")
|
) as temp_file:
|
||||||
|
copyfile(self.axolotl_config_path, temp_file.name)
|
||||||
|
wandb.save(temp_file.name)
|
||||||
|
LOG.info(
|
||||||
|
"The Axolotl config has been saved to the WandB run under files."
|
||||||
|
)
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
return control
|
return control
|
||||||
|
|||||||
29
src/axolotl/utils/chat_templates.py
Normal file
29
src/axolotl/utils/chat_templates.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
This module provides functionality for selecting chat templates based on user choices.
|
||||||
|
These templates are used for formatting messages in a conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def chat_templates(user_choice: str):
|
||||||
|
"""
|
||||||
|
Finds the correct chat_template for the tokenizer_config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_choice (str): The user's choice of template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The chosen template string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the user_choice is not found in the templates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
templates = {
|
||||||
|
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||||
|
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
}
|
||||||
|
|
||||||
|
if user_choice in templates:
|
||||||
|
return templates[user_choice]
|
||||||
|
|
||||||
|
raise ValueError(f"Template '{user_choice}' not found.")
|
||||||
@@ -2,12 +2,16 @@
|
|||||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Dict, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -119,3 +123,58 @@ class DataCollatorForSeq2Seq:
|
|||||||
features["decoder_input_ids"] = decoder_input_ids
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
|
"""
|
||||||
|
Collator for multipack specific to the using the BatchSampler
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
chunked_data = {}
|
||||||
|
for feature in features[0].keys():
|
||||||
|
if feature == "length":
|
||||||
|
continue
|
||||||
|
if feature == "attention_mask":
|
||||||
|
arrays = [
|
||||||
|
(1) * np.array(item[feature])
|
||||||
|
for item in features
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
|
else:
|
||||||
|
arrays = [
|
||||||
|
np.array(item[feature]) for item in features if feature in item
|
||||||
|
]
|
||||||
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
|
features = [chunked_data]
|
||||||
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MambaDataCollator:
|
||||||
|
"""
|
||||||
|
Collator for State Space Models (Mamba)
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer
|
||||||
|
|
||||||
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||||
|
input_ids, labels = tuple(
|
||||||
|
[torch.LongTensor(instance[key]) for instance in instances]
|
||||||
|
for key in ("input_ids", "labels")
|
||||||
|
)
|
||||||
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
input_ids,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=self.tokenizer.pad_token_id,
|
||||||
|
)
|
||||||
|
labels = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
labels, batch_first=True, padding_value=IGNORE_INDEX
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def choose_device(cfg):
|
|||||||
|
|
||||||
cfg.device = get_device()
|
cfg.device = get_device()
|
||||||
if cfg.world_size == 1:
|
if cfg.world_size == 1:
|
||||||
cfg.device_map = "auto"
|
cfg.device_map = cfg.device_map or "auto"
|
||||||
else:
|
else:
|
||||||
if cfg.device.startswith("cuda"):
|
if cfg.device.startswith("cuda"):
|
||||||
cfg.device_map = {"": torch.cuda.current_device()}
|
cfg.device_map = {"": torch.cuda.current_device()}
|
||||||
@@ -77,8 +77,20 @@ def normalize_config(cfg):
|
|||||||
else:
|
else:
|
||||||
cfg.torch_dtype = torch.float32
|
cfg.torch_dtype = torch.float32
|
||||||
|
|
||||||
|
if cfg.saves_per_epoch:
|
||||||
|
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||||
|
if save_steps < 1.0: # prevent saves on every step
|
||||||
|
cfg.save_steps = save_steps
|
||||||
|
if cfg.evals_per_epoch:
|
||||||
|
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
|
||||||
|
if eval_steps < 1.0: # prevent evals on every step
|
||||||
|
cfg.eval_steps = eval_steps
|
||||||
|
|
||||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
||||||
|
|
||||||
|
if not cfg.base_model_config:
|
||||||
|
cfg.base_model_config = cfg.base_model
|
||||||
|
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
cfg.model_config_type = model_config.model_type
|
cfg.model_config_type = model_config.model_type
|
||||||
|
|
||||||
@@ -119,6 +131,22 @@ def normalize_config(cfg):
|
|||||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg.is_qwen_derived_model = (
|
||||||
|
(
|
||||||
|
hasattr(model_config, "model_type")
|
||||||
|
and model_config.model_type
|
||||||
|
in [
|
||||||
|
"qwen",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
or cfg.is_qwen_derived_model
|
||||||
|
or "qwen" in cfg.base_model.lower()
|
||||||
|
or (cfg.model_type and "qwen" in cfg.model_type.lower())
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(cfg.learning_rate, str):
|
||||||
|
cfg.learning_rate = float(cfg.learning_rate)
|
||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
|
||||||
@@ -159,7 +187,11 @@ def validate_config(cfg):
|
|||||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||||
)
|
)
|
||||||
if cfg.eval_batch_size != cfg.micro_batch_size:
|
if (
|
||||||
|
cfg.eval_batch_size
|
||||||
|
and cfg.micro_batch_size
|
||||||
|
and cfg.eval_batch_size != cfg.micro_batch_size
|
||||||
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||||
)
|
)
|
||||||
@@ -189,9 +221,15 @@ def validate_config(cfg):
|
|||||||
if not cfg.load_in_4bit:
|
if not cfg.load_in_4bit:
|
||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
|
|
||||||
|
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||||
|
raise ValueError("Fused modules are not supported with QLoRA")
|
||||||
|
|
||||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||||
|
|
||||||
|
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
||||||
|
raise ValueError("Fused modules are not supported with LoRA")
|
||||||
|
|
||||||
if cfg.relora_steps:
|
if cfg.relora_steps:
|
||||||
if cfg.adapter not in ("lora", "qlora"):
|
if cfg.adapter not in ("lora", "qlora"):
|
||||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||||
@@ -205,6 +243,9 @@ def validate_config(cfg):
|
|||||||
if cfg.lr_scheduler == "one_cycle":
|
if cfg.lr_scheduler == "one_cycle":
|
||||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
||||||
|
|
||||||
|
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||||
|
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||||
|
|
||||||
if cfg.trust_remote_code:
|
if cfg.trust_remote_code:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||||
@@ -320,6 +361,27 @@ def validate_config(cfg):
|
|||||||
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||||
"sharegpt_simple", "sharegpt"
|
"sharegpt_simple", "sharegpt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.saves_per_epoch and cfg.save_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||||
|
)
|
||||||
|
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
|
||||||
|
raise ValueError(
|
||||||
|
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||||
|
)
|
||||||
|
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
cfg.evals_per_epoch
|
||||||
|
and cfg.evaluation_strategy
|
||||||
|
and cfg.evaluation_strategy != "steps"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
|
)
|
||||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||||
@@ -339,6 +401,67 @@ def validate_config(cfg):
|
|||||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.sample_packing
|
||||||
|
and cfg.eval_table_size
|
||||||
|
and cfg.eval_sample_packing is not False
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
|
||||||
|
raise ValueError(
|
||||||
|
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
||||||
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.rope_scaling:
|
||||||
|
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||||
|
|
||||||
|
if cfg.warmup_steps and cfg.warmup_ratio:
|
||||||
|
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
||||||
|
|
||||||
|
if cfg.is_qwen_derived_model and cfg.gradient_checkpointing:
|
||||||
|
LOG.warning(
|
||||||
|
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||||
|
cfg.wandb_name = cfg.wandb_run_id
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.noisy_embedding_alpha is not None:
|
||||||
|
# Deprecated, use neftune_noise_alpha
|
||||||
|
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||||
|
if cfg.neftune_noise_alpha is None:
|
||||||
|
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
||||||
|
else:
|
||||||
|
# User is providing both; bail and have them sort out their settings
|
||||||
|
raise ValueError(
|
||||||
|
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||||
|
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.adapter
|
||||||
|
and cfg.tokens
|
||||||
|
and (
|
||||||
|
not cfg.lora_modules_to_save
|
||||||
|
or not all(
|
||||||
|
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -34,8 +34,10 @@ from axolotl.prompters import (
|
|||||||
JeopardyPrompter,
|
JeopardyPrompter,
|
||||||
MultipleChoiceConcisePrompter,
|
MultipleChoiceConcisePrompter,
|
||||||
MultipleChoiceExplainPrompter,
|
MultipleChoiceExplainPrompter,
|
||||||
|
Prompter,
|
||||||
ReflectAlpacaPrompter,
|
ReflectAlpacaPrompter,
|
||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
|
UnsupportedPrompter,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
@@ -55,9 +57,10 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
|||||||
|
|
||||||
|
|
||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer):
|
||||||
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
train_dataset, eval_dataset = load_prepare_datasets(
|
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -70,25 +73,33 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
return train_dataset, eval_dataset, cfg.max_steps
|
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
cfg, train_dataset, eval_dataset, tokenizer
|
cfg, train_dataset, eval_dataset, tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||||
|
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||||
|
if total_eval_steps == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.max_steps:
|
if cfg.max_steps:
|
||||||
total_num_steps = min(
|
total_num_steps = min(
|
||||||
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||||
)
|
)
|
||||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||||
else:
|
else:
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||||
return train_dataset, eval_dataset, total_num_steps
|
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||||
|
|
||||||
|
|
||||||
def load_tokenized_prepared_datasets(
|
def load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
) -> DatasetDict:
|
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = tokenizer.__class__.__name__
|
||||||
ds_hash = str(
|
ds_hash = str(
|
||||||
md5(
|
md5(
|
||||||
@@ -96,7 +107,12 @@ def load_tokenized_prepared_datasets(
|
|||||||
str(cfg.sequence_len)
|
str(cfg.sequence_len)
|
||||||
+ "@"
|
+ "@"
|
||||||
+ "|".join(
|
+ "|".join(
|
||||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
sorted(
|
||||||
|
[
|
||||||
|
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
||||||
|
for d in cfg.datasets
|
||||||
|
]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
+ "|"
|
+ "|"
|
||||||
+ tokenizer_name
|
+ tokenizer_name
|
||||||
@@ -109,6 +125,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
else Path(default_dataset_prepared_path) / ds_hash
|
else Path(default_dataset_prepared_path) / ds_hash
|
||||||
)
|
)
|
||||||
dataset = None
|
dataset = None
|
||||||
|
prompters = []
|
||||||
use_auth_token = cfg.hf_use_auth_token
|
use_auth_token = cfg.hf_use_auth_token
|
||||||
try:
|
try:
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
@@ -147,48 +164,99 @@ def load_tokenized_prepared_datasets(
|
|||||||
yield dataset
|
yield dataset
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
for d in for_d_in_datasets(cfg.datasets):
|
for config_dataset in for_d_in_datasets(cfg.datasets):
|
||||||
ds: Union[Dataset, DatasetDict] = None
|
ds: Union[Dataset, DatasetDict] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
try:
|
try:
|
||||||
load_dataset(
|
load_dataset(
|
||||||
d.path,
|
config_dataset.path,
|
||||||
name=d.name,
|
name=config_dataset.name,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except FileNotFoundError:
|
except (FileNotFoundError, ConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
ds_from_cloud = False
|
||||||
|
storage_options = {}
|
||||||
|
remote_file_system = None
|
||||||
|
if config_dataset.path.startswith("s3://"):
|
||||||
|
try:
|
||||||
|
import aiobotocore.session # type: ignore
|
||||||
|
import s3fs # type: ignore
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"s3:// paths require aiobotocore and s3fs to be installed"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# Takes credentials from ~/.aws/credentials for default profile
|
||||||
|
s3_session = aiobotocore.session.AioSession(profile="default")
|
||||||
|
storage_options = {"session": s3_session}
|
||||||
|
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
||||||
|
elif config_dataset.path.startswith(
|
||||||
|
"gs://"
|
||||||
|
) or config_dataset.path.startswith("gcs://"):
|
||||||
|
try:
|
||||||
|
import gcsfs # type: ignore
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"gs:// or gcs:// paths require gcsfs to be installed"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# gcsfs will use default credentials from the environment else anon
|
||||||
|
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
||||||
|
storage_options = {"token": None}
|
||||||
|
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
||||||
|
# TODO: Figure out how to get auth creds passed
|
||||||
|
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
||||||
|
# try:
|
||||||
|
# import adlfs
|
||||||
|
# except ImportError as exc:
|
||||||
|
# raise ImportError(
|
||||||
|
# "adl:// or abfs:// paths require adlfs to be installed"
|
||||||
|
# ) from exc
|
||||||
|
|
||||||
|
# # Gen 1
|
||||||
|
# storage_options = {
|
||||||
|
# "tenant_id": TENANT_ID,
|
||||||
|
# "client_id": CLIENT_ID,
|
||||||
|
# "client_secret": CLIENT_SECRET,
|
||||||
|
# }
|
||||||
|
# # Gen 2
|
||||||
|
# storage_options = {
|
||||||
|
# "account_name": ACCOUNT_NAME,
|
||||||
|
# "account_key": ACCOUNT_KEY,
|
||||||
|
# }
|
||||||
|
|
||||||
|
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
||||||
|
try:
|
||||||
|
if remote_file_system and remote_file_system.exists(
|
||||||
|
config_dataset.path
|
||||||
|
):
|
||||||
|
ds_from_cloud = True
|
||||||
|
except (FileNotFoundError, ConnectionError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# prefer local dataset, even if hub exists
|
# prefer local dataset, even if hub exists
|
||||||
local_path = Path(d.path)
|
local_path = Path(config_dataset.path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
if local_path.is_dir():
|
if local_path.is_dir():
|
||||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
d.path,
|
config_dataset.path,
|
||||||
name=d.name,
|
name=config_dataset.name,
|
||||||
data_files=d.data_files,
|
data_files=config_dataset.data_files,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds_type = "json"
|
ds_type = get_ds_type(config_dataset)
|
||||||
if d.ds_type:
|
|
||||||
ds_type = d.ds_type
|
|
||||||
elif ".parquet" in d.path:
|
|
||||||
ds_type = "parquet"
|
|
||||||
elif ".arrow" in d.path:
|
|
||||||
ds_type = "arrow"
|
|
||||||
elif ".csv" in d.path:
|
|
||||||
ds_type = "csv"
|
|
||||||
elif ".txt" in d.path:
|
|
||||||
ds_type = "text"
|
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
ds_type,
|
ds_type,
|
||||||
name=d.name,
|
name=config_dataset.name,
|
||||||
data_files=d.path,
|
data_files=config_dataset.path,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
@@ -198,25 +266,41 @@ def load_tokenized_prepared_datasets(
|
|||||||
)
|
)
|
||||||
elif ds_from_hub:
|
elif ds_from_hub:
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
d.path,
|
config_dataset.path,
|
||||||
name=d.name,
|
name=config_dataset.name,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
data_files=d.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
else:
|
elif ds_from_cloud and remote_file_system:
|
||||||
if isinstance(d.data_files, str):
|
if remote_file_system.isdir(config_dataset.path):
|
||||||
fp = hf_hub_download(
|
ds = load_from_disk(
|
||||||
repo_id=d.path,
|
config_dataset.path,
|
||||||
repo_type="dataset",
|
storage_options=storage_options,
|
||||||
filename=d.data_files,
|
|
||||||
)
|
)
|
||||||
elif isinstance(d.data_files, list):
|
elif remote_file_system.isfile(config_dataset.path):
|
||||||
|
ds_type = get_ds_type(config_dataset)
|
||||||
|
ds = load_dataset(
|
||||||
|
ds_type,
|
||||||
|
name=config_dataset.name,
|
||||||
|
data_files=config_dataset.path,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
storage_options=storage_options,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if isinstance(config_dataset.data_files, str):
|
||||||
|
fp = hf_hub_download(
|
||||||
|
repo_id=config_dataset.path,
|
||||||
|
repo_type="dataset",
|
||||||
|
filename=config_dataset.data_files,
|
||||||
|
)
|
||||||
|
elif isinstance(config_dataset.data_files, list):
|
||||||
fp = []
|
fp = []
|
||||||
for file in d.data_files:
|
for file in config_dataset.data_files:
|
||||||
fp.append(
|
fp.append(
|
||||||
hf_hub_download(
|
hf_hub_download(
|
||||||
repo_id=d.path,
|
repo_id=config_dataset.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=file,
|
filename=file,
|
||||||
)
|
)
|
||||||
@@ -226,21 +310,27 @@ def load_tokenized_prepared_datasets(
|
|||||||
"data_files must be either a string or list of strings"
|
"data_files must be either a string or list of strings"
|
||||||
)
|
)
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
"json", name=d.name, data_files=fp, streaming=False, split=None
|
"json",
|
||||||
|
name=config_dataset.name,
|
||||||
|
data_files=fp,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
)
|
)
|
||||||
if not ds:
|
if not ds:
|
||||||
raise ValueError("unhandled dataset load")
|
raise ValueError("unhandled dataset load")
|
||||||
# support for using a subset of the data
|
# support for using a subset of the data
|
||||||
if d.shards:
|
if config_dataset.shards:
|
||||||
if "train" in ds:
|
if "train" in ds:
|
||||||
ds = ds.shuffle(seed=seed)["train"].shard(
|
ds = ds.shuffle(seed=seed)["train"].shard(
|
||||||
num_shards=d.shards, index=0
|
num_shards=config_dataset.shards, index=0
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
ds = ds.shuffle(seed=seed).shard(
|
||||||
|
num_shards=config_dataset.shards, index=0
|
||||||
|
)
|
||||||
|
|
||||||
d_base_type = d_prompt_style = None
|
d_base_type = d_prompt_style = None
|
||||||
d_type = d.type
|
d_type = config_dataset.type
|
||||||
if isinstance(d_type, str):
|
if isinstance(d_type, str):
|
||||||
d_type_split = d_type.split(":")
|
d_type_split = d_type.split(":")
|
||||||
d_base_type = d_type_split[0]
|
d_base_type = d_type_split[0]
|
||||||
@@ -249,108 +339,26 @@ def load_tokenized_prepared_datasets(
|
|||||||
ds = ds["train"]
|
ds = ds["train"]
|
||||||
elif (
|
elif (
|
||||||
isinstance(ds, DatasetDict)
|
isinstance(ds, DatasetDict)
|
||||||
and d.train_on_split
|
and config_dataset.train_on_split
|
||||||
and d.train_on_split in ds
|
and config_dataset.train_on_split in ds
|
||||||
):
|
):
|
||||||
ds = ds[d.train_on_split]
|
ds = ds[config_dataset.train_on_split]
|
||||||
elif isinstance(ds, DatasetDict):
|
elif isinstance(ds, DatasetDict):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
|
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
|
||||||
)
|
|
||||||
if (
|
|
||||||
"input_ids" in ds.features
|
|
||||||
and "attention_mask" in ds.features
|
|
||||||
and "labels" in ds.features
|
|
||||||
):
|
|
||||||
# dataset is already tokenized, just drop it straight in
|
|
||||||
datasets.append(ds)
|
|
||||||
elif isinstance(d.type, DictDefault):
|
|
||||||
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif ds_strategy := load(d.type, tokenizer, cfg, d):
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "alpaca":
|
|
||||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
|
||||||
AlpacaPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "explainchoice":
|
|
||||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
|
||||||
MultipleChoiceExplainPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "concisechoice":
|
|
||||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
|
||||||
MultipleChoiceConcisePrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "summarizetldr":
|
|
||||||
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
|
||||||
SummarizeTLDRPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "jeopardy":
|
|
||||||
ds_strategy = JeopardyPromptTokenizingStrategy(
|
|
||||||
JeopardyPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "oasst":
|
|
||||||
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
|
||||||
AlpacaPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "gpteacher":
|
|
||||||
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
|
||||||
GPTeacherPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "reflection":
|
|
||||||
ds_strategy = AlpacaReflectionPTStrategy(
|
|
||||||
ReflectAlpacaPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
else:
|
|
||||||
suffix = ""
|
|
||||||
if ":load_" in d.type:
|
|
||||||
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
|
||||||
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
|
|
||||||
raise ValueError(
|
|
||||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
||||||
|
config_dataset=config_dataset,
|
||||||
|
dataset=ds,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
cfg=cfg,
|
||||||
|
d_base_type=d_base_type,
|
||||||
|
d_prompt_style=d_prompt_style,
|
||||||
|
)
|
||||||
|
datasets.append(dataset_wrapper)
|
||||||
|
prompters.append(dataset_prompter)
|
||||||
|
|
||||||
LOG.info("merging datasets")
|
LOG.info("merging datasets")
|
||||||
dataset = concatenate_datasets(datasets)
|
dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
@@ -368,14 +376,32 @@ def load_tokenized_prepared_datasets(
|
|||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataset
|
return dataset, prompters
|
||||||
|
|
||||||
|
|
||||||
|
def get_ds_type(config_dataset: DictDefault):
|
||||||
|
"""
|
||||||
|
Get the dataset type from the path if it's not specified
|
||||||
|
"""
|
||||||
|
ds_type = "json"
|
||||||
|
if config_dataset.ds_type:
|
||||||
|
ds_type = config_dataset.ds_type
|
||||||
|
elif ".parquet" in config_dataset.path:
|
||||||
|
ds_type = "parquet"
|
||||||
|
elif ".arrow" in config_dataset.path:
|
||||||
|
ds_type = "arrow"
|
||||||
|
elif ".csv" in config_dataset.path:
|
||||||
|
ds_type = "csv"
|
||||||
|
elif ".txt" in config_dataset.path:
|
||||||
|
ds_type = "text"
|
||||||
|
return ds_type
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_datasets(
|
def load_prepare_datasets(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
cfg,
|
cfg,
|
||||||
default_dataset_prepared_path,
|
default_dataset_prepared_path,
|
||||||
) -> Tuple[Dataset, Dataset]:
|
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
||||||
max_packed_sequence_len = (
|
max_packed_sequence_len = (
|
||||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||||
)
|
)
|
||||||
@@ -384,6 +410,7 @@ def load_prepare_datasets(
|
|||||||
) # make sure we don't accidentally set it larger than sequence_len
|
) # make sure we don't accidentally set it larger than sequence_len
|
||||||
|
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = tokenizer.__class__.__name__
|
||||||
|
prompters: List[Prompter] = []
|
||||||
if cfg.max_packed_sequence_len is not None:
|
if cfg.max_packed_sequence_len is not None:
|
||||||
# see if we can go ahead and load the stacked dataset
|
# see if we can go ahead and load the stacked dataset
|
||||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||||
@@ -439,7 +466,7 @@ def load_prepare_datasets(
|
|||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset = load_tokenized_prepared_datasets(
|
dataset, prompters = load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -481,7 +508,7 @@ def load_prepare_datasets(
|
|||||||
private=True,
|
private=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset = load_tokenized_prepared_datasets(
|
dataset, prompters = load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -517,14 +544,13 @@ def load_prepare_datasets(
|
|||||||
train_fingerprint = md5(to_hash_train)
|
train_fingerprint = md5(to_hash_train)
|
||||||
test_fingerprint = md5(to_hash_test)
|
test_fingerprint = md5(to_hash_test)
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
dataset = dataset.train_test_split(
|
||||||
dataset = dataset.train_test_split(
|
test_size=cfg.val_set_size,
|
||||||
test_size=cfg.val_set_size,
|
shuffle=False,
|
||||||
shuffle=False,
|
seed=cfg.seed or 42,
|
||||||
seed=cfg.seed or 42,
|
train_new_fingerprint=train_fingerprint,
|
||||||
train_new_fingerprint=train_fingerprint,
|
test_new_fingerprint=test_fingerprint,
|
||||||
test_new_fingerprint=test_fingerprint,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"]
|
eval_dataset = dataset["test"]
|
||||||
@@ -532,7 +558,144 @@ def load_prepare_datasets(
|
|||||||
train_dataset = dataset
|
train_dataset = dataset
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
|
||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset, prompters
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_wrapper(
|
||||||
|
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
|
||||||
|
):
|
||||||
|
dataset_wrapper = None
|
||||||
|
dataset_prompter = None
|
||||||
|
|
||||||
|
if (
|
||||||
|
"input_ids" in dataset.features
|
||||||
|
and "attention_mask" in dataset.features
|
||||||
|
and "labels" in dataset.features
|
||||||
|
):
|
||||||
|
# dataset is already tokenized, just drop it straight in
|
||||||
|
dataset_prompter = UnsupportedPrompter()
|
||||||
|
dataset_wrapper = dataset
|
||||||
|
elif isinstance(config_dataset.type, DictDefault):
|
||||||
|
ds_strategy = load(
|
||||||
|
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||||
|
)
|
||||||
|
dataset_prompter = UnsupportedPrompter()
|
||||||
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
|
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||||
|
dataset_prompter = UnsupportedPrompter()
|
||||||
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
|
elif d_base_type == "alpaca":
|
||||||
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||||
|
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "explainchoice":
|
||||||
|
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
||||||
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "concisechoice":
|
||||||
|
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
||||||
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "summarizetldr":
|
||||||
|
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
||||||
|
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "jeopardy":
|
||||||
|
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
||||||
|
ds_strategy = JeopardyPromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "oasst":
|
||||||
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||||
|
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "gpteacher":
|
||||||
|
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
||||||
|
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "reflection":
|
||||||
|
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
||||||
|
ds_strategy = AlpacaReflectionPTStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
else:
|
||||||
|
suffix = ""
|
||||||
|
if ":load_" in config_dataset.type:
|
||||||
|
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
|
||||||
|
LOG.error(
|
||||||
|
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset_wrapper, dataset_prompter
|
||||||
|
|
||||||
|
|
||||||
def encode_pretraining(
|
def encode_pretraining(
|
||||||
|
|||||||
@@ -1,302 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
import hashlib
|
|
||||||
import itertools
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from typing import Any, Callable, List, Union
|
|
||||||
|
|
||||||
import numba
|
|
||||||
import numpy as np
|
|
||||||
from torch.utils.data import DistributedSampler, Sampler
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.dataloader")
|
|
||||||
|
|
||||||
|
|
||||||
@numba.njit
|
|
||||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
|
||||||
# First-fit-decreasing bin packing
|
|
||||||
# Check if a[] could fit in n bins with capacity c
|
|
||||||
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
|
||||||
|
|
||||||
a = np.sort(a)[::-1]
|
|
||||||
bins = np.full((n,), c, dtype=a.dtype)
|
|
||||||
for size in a:
|
|
||||||
not_found = True
|
|
||||||
for idx in range(n):
|
|
||||||
if bins[idx] >= size:
|
|
||||||
bins[idx] -= size
|
|
||||||
not_found = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if not_found:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@numba.njit
|
|
||||||
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
|
||||||
# First-fit-decreasing bin packing (with result return)
|
|
||||||
|
|
||||||
indices = np.argsort(a)[::-1]
|
|
||||||
a = a[indices]
|
|
||||||
|
|
||||||
bins: List[Any] = []
|
|
||||||
bins_result: List[Any] = []
|
|
||||||
for a_id, size in enumerate(a):
|
|
||||||
add_new = True
|
|
||||||
for idx in range(len(bins)):
|
|
||||||
if bins[idx] >= size:
|
|
||||||
bins[idx] -= size
|
|
||||||
bins_result[idx].append(indices[a_id] + start_index)
|
|
||||||
add_new = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if add_new:
|
|
||||||
bins.append(c - size)
|
|
||||||
bins_result.append([indices[a_id] + start_index])
|
|
||||||
|
|
||||||
return bins_result, len(a)
|
|
||||||
|
|
||||||
|
|
||||||
@numba.njit
|
|
||||||
def allocate(
|
|
||||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
:param lengths: array of lengths of each sample
|
|
||||||
:param lengths_cumsum: cumulative sum of consecutive lengths
|
|
||||||
:param rank: rank for this process
|
|
||||||
:param c: length of tokens per batch
|
|
||||||
:param n: number of ranks
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# Dynamic batch allocator, similar to Multifit
|
|
||||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
|
||||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
|
||||||
|
|
||||||
s = 0
|
|
||||||
start_index = 0
|
|
||||||
result = []
|
|
||||||
result_totseqs = []
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# binary search [left, right)
|
|
||||||
left = 1
|
|
||||||
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
|
||||||
|
|
||||||
while right - left > 1:
|
|
||||||
mid = (left + right) // 2
|
|
||||||
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
|
||||||
left = mid
|
|
||||||
else:
|
|
||||||
right = mid
|
|
||||||
|
|
||||||
# use length left
|
|
||||||
batch, tot_seqs = ffd_with_result(
|
|
||||||
lengths[start_index : start_index + left], c, start_index
|
|
||||||
)
|
|
||||||
if len(batch) < n:
|
|
||||||
break
|
|
||||||
|
|
||||||
start_index += left
|
|
||||||
s = lengths_cumsum[start_index - 1]
|
|
||||||
|
|
||||||
# add local rank
|
|
||||||
result.append(batch[rank])
|
|
||||||
# add total seqs for all ranks
|
|
||||||
result_totseqs.append(tot_seqs)
|
|
||||||
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
|
||||||
return result, result_totseqs, s, len(result) * c * n
|
|
||||||
|
|
||||||
|
|
||||||
def chunk(iterable, n):
|
|
||||||
"""
|
|
||||||
Chunk data into tuples of length n
|
|
||||||
"""
|
|
||||||
# batched('ABCDEFG', 3) --> ABC DEF G
|
|
||||||
if n < 1:
|
|
||||||
raise ValueError("n must be at least one")
|
|
||||||
it = iter(iterable)
|
|
||||||
while batch := tuple(itertools.islice(it, n)):
|
|
||||||
yield batch
|
|
||||||
|
|
||||||
|
|
||||||
def hash_indices(lst: List[int]) -> str:
|
|
||||||
# Convert the list of integers to a string representation
|
|
||||||
concatenated = ",".join(map(str, lst))
|
|
||||||
|
|
||||||
# Generate the hash
|
|
||||||
sha256 = hashlib.sha256()
|
|
||||||
sha256.update(concatenated.encode())
|
|
||||||
|
|
||||||
return sha256.hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
class MultipackDistributedDataloader:
|
|
||||||
"""Unpadded data loading using Multipack.
|
|
||||||
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
|
||||||
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dataset: Any,
|
|
||||||
collate_fn: Callable,
|
|
||||||
seq_max_length: int = 2048,
|
|
||||||
batch_size: int = 1,
|
|
||||||
sampler: Union[Sampler, DistributedSampler] = None,
|
|
||||||
packing_efficiency_estimate: float = 1.0,
|
|
||||||
sample_packing_seq_len_multiplier: int = 1,
|
|
||||||
device_count: int = 1,
|
|
||||||
):
|
|
||||||
# Dataset
|
|
||||||
self.dataset = dataset
|
|
||||||
self.lengths = (
|
|
||||||
dataset.data.column("position_ids")
|
|
||||||
.to_pandas()
|
|
||||||
.apply(lambda x: x[-1] + 1)
|
|
||||||
.values
|
|
||||||
)
|
|
||||||
assert isinstance(self.lengths, np.ndarray)
|
|
||||||
assert batch_size % sample_packing_seq_len_multiplier == 0
|
|
||||||
assert batch_size >= sample_packing_seq_len_multiplier
|
|
||||||
self.sampler = sampler
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
|
|
||||||
self.seq_max_length = seq_max_length
|
|
||||||
self.batch_max_length = batch_size * seq_max_length
|
|
||||||
self.collate_fn = collate_fn
|
|
||||||
|
|
||||||
self.num_replicas = 1
|
|
||||||
self.rank = 0
|
|
||||||
|
|
||||||
# statistics
|
|
||||||
self.eff_total_used = 0
|
|
||||||
self.eff_total_slots = 0
|
|
||||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
|
||||||
self.device_count = device_count
|
|
||||||
|
|
||||||
def generate_batches(self, set_stats=False):
|
|
||||||
LOG.info("generating packed batches")
|
|
||||||
if self.sampler:
|
|
||||||
indices = [idx for idx in self.sampler]
|
|
||||||
else:
|
|
||||||
indices = range(0, len(self.dataset))
|
|
||||||
|
|
||||||
LOG.info(hash_indices(indices))
|
|
||||||
lengths = self.lengths[indices]
|
|
||||||
lengths_cumsum = np.cumsum(lengths)
|
|
||||||
|
|
||||||
batches, totseqs, total_used, total_slots = allocate(
|
|
||||||
lengths=lengths,
|
|
||||||
lengths_cumsum=lengths_cumsum,
|
|
||||||
rank=self.rank,
|
|
||||||
# c=self.batch_max_length,
|
|
||||||
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
|
||||||
n=self.num_replicas,
|
|
||||||
)
|
|
||||||
|
|
||||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
|
||||||
|
|
||||||
# statistics
|
|
||||||
if set_stats:
|
|
||||||
self.eff_total_used += total_used
|
|
||||||
self.eff_total_slots += total_slots
|
|
||||||
|
|
||||||
return batches, totseqs
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
if hasattr(self.sampler, "set_epoch"):
|
|
||||||
new_epoch = self.sampler.epoch + 1
|
|
||||||
self.sampler.set_epoch(new_epoch)
|
|
||||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
|
||||||
all_batches, _ = self.generate_batches(set_stats=True)
|
|
||||||
features = self.dataset.features.keys()
|
|
||||||
len_remaining = self._len_est()
|
|
||||||
for batches in chunk(
|
|
||||||
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
|
||||||
):
|
|
||||||
chunked_data = []
|
|
||||||
attn_mask_cum_idx = 0
|
|
||||||
for batch in batches:
|
|
||||||
concatenated = {}
|
|
||||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
|
||||||
for feature in features:
|
|
||||||
if feature == "length":
|
|
||||||
continue
|
|
||||||
if feature == "attention_mask":
|
|
||||||
arrays = [
|
|
||||||
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
|
||||||
for idx, item in enumerate(batched_data)
|
|
||||||
if feature in item
|
|
||||||
]
|
|
||||||
attn_mask_cum_idx += len(batched_data)
|
|
||||||
concatenated[feature] = np.concatenate(arrays)
|
|
||||||
else:
|
|
||||||
arrays = [
|
|
||||||
np.array(item[feature])
|
|
||||||
for item in batched_data
|
|
||||||
if feature in item
|
|
||||||
]
|
|
||||||
concatenated[feature] = np.concatenate(arrays)
|
|
||||||
chunked_data.append(concatenated)
|
|
||||||
yield self.collate_fn(chunked_data)
|
|
||||||
len_remaining -= 1
|
|
||||||
if not len_remaining:
|
|
||||||
return
|
|
||||||
# yield a no-op for cases where we don't have any data left to pack
|
|
||||||
for i in range(0, len_remaining):
|
|
||||||
yield self.collate_fn(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"input_ids": [0],
|
|
||||||
"labels": [-100],
|
|
||||||
"attention_mask": [True],
|
|
||||||
"position_ids": [0],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _len_est(self):
|
|
||||||
lengths_sum = np.sum(self.lengths)
|
|
||||||
lengths_sum_per_device = lengths_sum // self.device_count
|
|
||||||
LOG.info(
|
|
||||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
|
||||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
|
||||||
return (
|
|
||||||
math.floor(
|
|
||||||
0.99
|
|
||||||
* lengths_sum_per_device
|
|
||||||
/ self.packing_efficiency_estimate
|
|
||||||
// self.seq_max_length
|
|
||||||
// self.batch_size
|
|
||||||
)
|
|
||||||
- 1
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
|
|
||||||
# the same share of total tokens
|
|
||||||
# if not self.eff_total_used:
|
|
||||||
# batches, _ = self.generate_batches(set_stats=True)
|
|
||||||
# LOG.info(
|
|
||||||
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
|
||||||
# f"actual packing efficiency: {self.efficiency()}"
|
|
||||||
# )
|
|
||||||
return max(1, self._len_est())
|
|
||||||
|
|
||||||
def len_w_stats(self):
|
|
||||||
if not self.eff_total_used:
|
|
||||||
batches, _ = self.generate_batches(set_stats=True)
|
|
||||||
LOG.info(
|
|
||||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
|
||||||
f"actual packing efficiency: {self.efficiency()}"
|
|
||||||
)
|
|
||||||
return max(1, self._len_est())
|
|
||||||
|
|
||||||
def efficiency(self):
|
|
||||||
return self.eff_total_used / self.eff_total_slots
|
|
||||||
@@ -50,6 +50,17 @@ def get_world_size():
|
|||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def zero_only():
|
||||||
|
"""
|
||||||
|
Context manager that only runs the enclosed block on the main rank.
|
||||||
|
"""
|
||||||
|
if is_main_process():
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield None
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def zero_first(is_main):
|
def zero_first(is_main):
|
||||||
"""
|
"""
|
||||||
|
|||||||
38
src/axolotl/utils/freeze.py
Normal file
38
src/axolotl/utils/freeze.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""
|
||||||
|
module to freeze/unfreeze parameters by name
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.utils.freeze")
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_parameters_except(model, regex_patterns):
|
||||||
|
"""
|
||||||
|
Freezes all layers of the given model except for the layers that match given regex patterns.
|
||||||
|
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- model (nn.Module): The PyTorch model to be modified.
|
||||||
|
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None; the model is modified in place.
|
||||||
|
"""
|
||||||
|
# Escape periods and compile the regex patterns
|
||||||
|
compiled_patterns = [
|
||||||
|
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
||||||
|
]
|
||||||
|
|
||||||
|
# First, freeze all parameters in the model
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Unfreeze layers that match the regex patterns
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if any(pattern.match(name) for pattern in compiled_patterns):
|
||||||
|
if is_main_process():
|
||||||
|
LOG.debug(f"unfreezing {name}")
|
||||||
|
param.requires_grad = True
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user