Compare commits
86 Commits
pytest-ski
...
mora
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d7ec10e337 | ||
|
|
05b0bd08d2 | ||
|
|
d4f6c65e4c | ||
|
|
a944f7b32b | ||
|
|
9d4225a058 | ||
|
|
f7332ac449 | ||
|
|
16d46b74e4 | ||
|
|
a6b37bdeb4 | ||
|
|
b7520801a3 | ||
|
|
fe650dd326 | ||
|
|
49b967b62f | ||
|
|
65db903714 | ||
|
|
6a5a725f10 | ||
|
|
f5febc729a | ||
|
|
230e0ac363 | ||
|
|
cc11c6bce2 | ||
|
|
5f91064040 | ||
|
|
ef223519c9 | ||
|
|
8a20a7b711 | ||
|
|
367b2e879b | ||
|
|
bbfed318bc | ||
|
|
84bb8061ba | ||
|
|
a27d5e1f4e | ||
|
|
6299eb5919 | ||
|
|
7c2bf3091f | ||
|
|
22ae21a6c2 | ||
|
|
ba45531802 | ||
|
|
8a1572a831 | ||
|
|
702a669cad | ||
|
|
891ae8aa13 | ||
|
|
0c49ecc429 | ||
|
|
60113437e4 | ||
|
|
419b2a6a98 | ||
|
|
2501a371c6 | ||
|
|
e6937e884b | ||
|
|
039e2a0370 | ||
|
|
4fde300e5f | ||
|
|
3319780300 | ||
|
|
81da7d2531 | ||
|
|
1e1921b794 | ||
|
|
1634ac82e0 | ||
|
|
02982733ec | ||
|
|
5d97e65f95 | ||
|
|
2147cf6837 | ||
|
|
50421c8b1d | ||
|
|
b32c08f8cc | ||
|
|
fff06af8d0 | ||
|
|
796a085b2f | ||
|
|
cb78a36374 | ||
|
|
8b9c15b17f | ||
|
|
9e1480e9ca | ||
|
|
3367fca732 | ||
|
|
1ac899800b | ||
|
|
70185763f6 | ||
|
|
120b809465 | ||
|
|
29cf15a28c | ||
|
|
dde02fcb94 | ||
|
|
b9bb169602 | ||
|
|
601c08b4c2 | ||
|
|
cc5d31e0d9 | ||
|
|
1aeece6e24 | ||
|
|
5294653a2d | ||
|
|
98c25e15cb | ||
|
|
68601ec6ad | ||
|
|
60f5ce0569 | ||
|
|
7477a53287 | ||
|
|
7d1d22f72f | ||
|
|
0e8f340945 | ||
|
|
59ef25470c | ||
|
|
c10563c444 | ||
|
|
37c037c69d | ||
|
|
15f7910d33 | ||
|
|
d28ba2e405 | ||
|
|
0eadfc8c86 | ||
|
|
bcaa92325d | ||
|
|
7d9bafcb88 | ||
|
|
e07dcb288c | ||
|
|
6319da1f9b | ||
|
|
132eb740f0 | ||
|
|
5ed29393e3 | ||
|
|
da9b1a3196 | ||
|
|
057fa44191 | ||
|
|
8fa0785f74 | ||
|
|
4313b1a6a0 | ||
|
|
7f17eff81a | ||
|
|
ff01c45127 |
7
.github/workflows/base.yml
vendored
7
.github/workflows/base.yml
vendored
@@ -30,7 +30,12 @@ jobs:
|
|||||||
- cuda: "121"
|
- cuda: "121"
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.2
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "121"
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
1
.github/workflows/lint.yml
vendored
1
.github/workflows/lint.yml
vendored
@@ -7,6 +7,7 @@ on:
|
|||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
- "*.md"
|
- "*.md"
|
||||||
|
- "examples/**/*.y[a]?ml"
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|||||||
56
.github/workflows/main.yml
vendored
56
.github/workflows/main.yml
vendored
@@ -28,7 +28,12 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.2
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
@@ -84,7 +89,12 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.2
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
@@ -115,3 +125,45 @@ jobs:
|
|||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|
||||||
|
build-axolotl-cloud-no-tmux:
|
||||||
|
needs: build-axolotl
|
||||||
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
||||||
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
|
axolotl_extras:
|
||||||
|
runs-on: axolotl-gpu-runner
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Docker metadata
|
||||||
|
id: metadata
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
images: winglian/axolotl-cloud-term
|
||||||
|
- name: Login to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v2
|
||||||
|
- name: Build
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
build-args: |
|
||||||
|
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
|
CUDA=${{ matrix.cuda }}
|
||||||
|
file: ./docker/Dockerfile-cloud-no-tmux
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
tags: |
|
||||||
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|||||||
14
.github/workflows/nightlies.yml
vendored
14
.github/workflows/nightlies.yml
vendored
@@ -27,7 +27,12 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.2
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
@@ -84,7 +89,12 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.2
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
7
.github/workflows/tests.yml
vendored
7
.github/workflows/tests.yml
vendored
@@ -82,7 +82,12 @@ jobs:
|
|||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.2.1
|
pytorch: 2.2.2
|
||||||
|
num_gpus: 1
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -133,6 +133,7 @@ venv/
|
|||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
venv.bak/
|
venv.bak/
|
||||||
|
venv3.10/
|
||||||
|
|
||||||
# Spyder project settings
|
# Spyder project settings
|
||||||
.spyderproject
|
.spyderproject
|
||||||
|
|||||||
45
README.md
45
README.md
@@ -34,6 +34,7 @@ Features:
|
|||||||
- [Mac](#mac)
|
- [Mac](#mac)
|
||||||
- [Google Colab](#google-colab)
|
- [Google Colab](#google-colab)
|
||||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||||
|
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
|
||||||
- [Dataset](#dataset)
|
- [Dataset](#dataset)
|
||||||
- [Config](#config)
|
- [Config](#config)
|
||||||
- [Train](#train)
|
- [Train](#train)
|
||||||
@@ -44,6 +45,7 @@ Features:
|
|||||||
- Advanced Topics
|
- Advanced Topics
|
||||||
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
|
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
- [Common Errors](#common-errors-)
|
- [Common Errors](#common-errors-)
|
||||||
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
||||||
- [Debugging Axolotl](#debugging-axolotl)
|
- [Debugging Axolotl](#debugging-axolotl)
|
||||||
@@ -81,6 +83,7 @@ Features:
|
|||||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
@@ -121,11 +124,11 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
|||||||
|
|
||||||
# inference
|
# inference
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--lora_model_dir="./lora-out"
|
--lora_model_dir="./outputs/lora-out"
|
||||||
|
|
||||||
# gradio
|
# gradio
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--lora_model_dir="./lora-out" --gradio
|
--lora_model_dir="./outputs/lora-out" --gradio
|
||||||
|
|
||||||
# remote yaml files - the yaml config can be hosted on a public URL
|
# remote yaml files - the yaml config can be hosted on a public URL
|
||||||
# Note: the yaml config must directly link to the **raw** yaml
|
# Note: the yaml config must directly link to the **raw** yaml
|
||||||
@@ -290,6 +293,42 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
|||||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Launching on public clouds via dstack
|
||||||
|
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
|
||||||
|
|
||||||
|
Write a job description in YAML as below:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# dstack.yaml
|
||||||
|
type: task
|
||||||
|
|
||||||
|
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
|
||||||
|
|
||||||
|
env:
|
||||||
|
- HUGGING_FACE_HUB_TOKEN
|
||||||
|
- WANDB_API_KEY
|
||||||
|
|
||||||
|
commands:
|
||||||
|
- accelerate launch -m axolotl.cli.train config.yaml
|
||||||
|
|
||||||
|
ports:
|
||||||
|
- 6006
|
||||||
|
|
||||||
|
resources:
|
||||||
|
gpu:
|
||||||
|
memory: 24GB..
|
||||||
|
count: 2
|
||||||
|
```
|
||||||
|
|
||||||
|
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install dstack
|
||||||
|
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
|
||||||
|
```
|
||||||
|
|
||||||
|
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
|
||||||
|
|
||||||
### Dataset
|
### Dataset
|
||||||
|
|
||||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||||
@@ -425,7 +464,7 @@ deepspeed: deepspeed_configs/zero1.json
|
|||||||
```
|
```
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json
|
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
|
||||||
```
|
```
|
||||||
|
|
||||||
##### FSDP
|
##### FSDP
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||||
pytest /workspace/axolotl/tests/e2e/patched/
|
pytest /workspace/axolotl/tests/e2e/patched/
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
{
|
{
|
||||||
|
"zero_force_ds_cpu_optimizer": false,
|
||||||
|
"zero_allow_untested_optimizer": true,
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 3,
|
"stage": 3,
|
||||||
"offload_optimizer": {
|
"offload_optimizer": {
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
{
|
{
|
||||||
|
"zero_force_ds_cpu_optimizer": false,
|
||||||
|
"zero_allow_untested_optimizer": true,
|
||||||
"zero_optimization": {
|
"zero_optimization": {
|
||||||
"stage": 3,
|
"stage": 3,
|
||||||
"offload_param": {
|
"offload_param": {
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ ARG PYTORCH_VERSION="2.1.2"
|
|||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
|||||||
27
docker/Dockerfile-cloud-no-tmux
Normal file
27
docker/Dockerfile-cloud-no-tmux
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
ARG BASE_TAG=main
|
||||||
|
FROM winglian/axolotl:$BASE_TAG
|
||||||
|
|
||||||
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
EXPOSE 8888
|
||||||
|
EXPOSE 22
|
||||||
|
|
||||||
|
COPY scripts/cloud-entrypoint-term.sh /root/cloud-entrypoint.sh
|
||||||
|
COPY scripts/motd /etc/motd
|
||||||
|
|
||||||
|
RUN pip install jupyterlab notebook ipywidgets && \
|
||||||
|
jupyter lab clean
|
||||||
|
RUN apt install --yes --no-install-recommends openssh-server tmux sudo && \
|
||||||
|
pip3 install -U --no-cache-dir grpcio ray[default]==2.9.3 && \
|
||||||
|
mkdir -p ~/.ssh && \
|
||||||
|
chmod 700 ~/.ssh && \
|
||||||
|
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||||
|
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||||
|
chmod +x /root/cloud-entrypoint.sh
|
||||||
|
|
||||||
|
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
|
||||||
|
CMD ["sleep", "infinity"]
|
||||||
@@ -186,6 +186,11 @@ eval_sample_packing:
|
|||||||
# The trainer will provide recommended values for these values.
|
# The trainer will provide recommended values for these values.
|
||||||
sample_packing_eff_est:
|
sample_packing_eff_est:
|
||||||
total_num_tokens:
|
total_num_tokens:
|
||||||
|
# Increasing the following values helps with packing, but usually only slightly (<%1.)
|
||||||
|
# The number of samples packed at a time.
|
||||||
|
sample_packing_group_size: 100000
|
||||||
|
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
||||||
|
sample_packing_bin_size: 200
|
||||||
|
|
||||||
# Passed through to transformers when loading the model when launched without accelerate
|
# Passed through to transformers when loading the model when launched without accelerate
|
||||||
# Use `sequential` when training w/ model parallelism to limit memory
|
# Use `sequential` when training w/ model parallelism to limit memory
|
||||||
@@ -227,6 +232,12 @@ lora_modules_to_save:
|
|||||||
|
|
||||||
lora_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
|
|
||||||
|
# LoRA+ hyperparameters
|
||||||
|
# For more details about the following options, see:
|
||||||
|
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
||||||
|
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
|
||||||
|
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
|
||||||
|
|
||||||
peft:
|
peft:
|
||||||
# Configuration options for loftq initialization for LoRA
|
# Configuration options for loftq initialization for LoRA
|
||||||
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
||||||
@@ -268,6 +279,7 @@ torch_compile_backend: # Optional[str]
|
|||||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||||
|
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
eval_batch_size:
|
eval_batch_size:
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
@@ -278,7 +290,7 @@ lr_quadratic_warmup:
|
|||||||
logging_steps:
|
logging_steps:
|
||||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||||
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
||||||
save_strategy: # Set to `no` to skip checkpoint saves
|
save_strategy: # Set to `"no"` to skip checkpoint saves
|
||||||
save_steps: # Leave empty to save at each epoch
|
save_steps: # Leave empty to save at each epoch
|
||||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||||
save_total_limit: # Checkpoints saved at a time
|
save_total_limit: # Checkpoints saved at a time
|
||||||
@@ -412,6 +424,7 @@ special_tokens:
|
|||||||
# bos_token: "<s>"
|
# bos_token: "<s>"
|
||||||
# eos_token: "</s>"
|
# eos_token: "</s>"
|
||||||
# unk_token: "<unk>"
|
# unk_token: "<unk>"
|
||||||
|
# pad_token: "[PAD]"
|
||||||
|
|
||||||
# Add extra tokens.
|
# Add extra tokens.
|
||||||
tokens:
|
tokens:
|
||||||
|
|||||||
35
docs/dataset_preprocessing.qmd
Normal file
35
docs/dataset_preprocessing.qmd
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
---
|
||||||
|
title: Dataset Preprocessing
|
||||||
|
description: How datasets are processed
|
||||||
|
---
|
||||||
|
|
||||||
|
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||||
|
the (dataset format)[../dataset-formats/] and prompt strategies to:
|
||||||
|
- parse the dataset based on the *dataset format*
|
||||||
|
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||||
|
- tokenize the dataset based on the configured model & tokenizer
|
||||||
|
- shuffle and merge multiple datasets together if using more than one
|
||||||
|
|
||||||
|
The processing of the datasets can happen one of two ways:
|
||||||
|
|
||||||
|
1. Before kicking off training by calling `python -m axolotl.cli.preprocess /path/to/your.yaml --debug`
|
||||||
|
2. When training is started
|
||||||
|
|
||||||
|
What are the benefits of pre-processing? When training interactively or for sweeps
|
||||||
|
(e.g. you are restarting the trainer often), processing the datasets can oftentimes be frustratingly
|
||||||
|
slow. Pre-processing will cache the tokenized/formatted datasets according to a hash of dependent
|
||||||
|
training parameters so that it will intelligently pull from its cache when possible.
|
||||||
|
|
||||||
|
The path of the cache is controlled by `dataset_prepared_path:` and is often left blank in example
|
||||||
|
YAMLs as this leads to a more robust solution that prevents unexpectedly reusing cached data.
|
||||||
|
|
||||||
|
If `dataset_prepared_path:` is left empty, when training, the processed dataset will be cached in a
|
||||||
|
default path of `./last_run_prepared/`, but will ignore anything already cached there. By explicitly
|
||||||
|
setting `dataset_prepared_path: ./last_run_prepared`, the trainer will use whatever pre-processed
|
||||||
|
data is in the cache.
|
||||||
|
|
||||||
|
What are the edge cases? Let's say you are writing a custom prompt strategy or using a user-defined
|
||||||
|
prompt template. Because the trainer cannot readily detect these changes, we cannot change the
|
||||||
|
calculated hash value for the pre-processed dataset. If you have `dataset_prepared_path: ...` set
|
||||||
|
and change your prompt templating logic, it may not pick up the changes you made and you will be
|
||||||
|
training over the old prompt.
|
||||||
@@ -49,7 +49,7 @@ remove_unused_columns: false
|
|||||||
chat_template: chatml
|
chat_template: chatml
|
||||||
datasets:
|
datasets:
|
||||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||||
type: orpo.chat_template
|
type: chat_template.argilla
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Using local dataset files
|
#### Using local dataset files
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ wandb_watch:
|
|||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
output_dir: btlm-out
|
output_dir: ./outputs/btlm-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
micro_batch_size: 4
|
micro_batch_size: 4
|
||||||
num_epochs: 2
|
num_epochs: 2
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -1,216 +1,223 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "AKjdG7tbTb-n"
|
"id": "AKjdG7tbTb-n"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"# Example notebook for running Axolotl on google colab"
|
"# Example notebook for running Axolotl on google colab"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "RcbNpOgWRcii"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import torch\n",
|
|
||||||
"# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n",
|
|
||||||
"assert (torch.cuda.is_available()==True)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "h3nLav8oTRA5"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Install Axolotl and dependencies"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "3c3yGAwnOIdi",
|
|
||||||
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!pip install torch==\"2.1.2\"\n",
|
|
||||||
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
|
|
||||||
"!pip install flash-attn==\"2.5.0\"\n",
|
|
||||||
"!pip install deepspeed==\"0.13.1\""
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "BW2MFr7HTjub"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Create an yaml config file"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "9pkF2dSoQEUN"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import yaml\n",
|
|
||||||
"\n",
|
|
||||||
"# Your YAML string\n",
|
|
||||||
"yaml_string = \"\"\"\n",
|
|
||||||
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
|
|
||||||
"model_type: LlamaForCausalLM\n",
|
|
||||||
"tokenizer_type: LlamaTokenizer\n",
|
|
||||||
"is_llama_derived_model: true\n",
|
|
||||||
"\n",
|
|
||||||
"load_in_8bit: false\n",
|
|
||||||
"load_in_4bit: true\n",
|
|
||||||
"strict: false\n",
|
|
||||||
"\n",
|
|
||||||
"datasets:\n",
|
|
||||||
" - path: mhenrichsen/alpaca_2k_test\n",
|
|
||||||
" type: alpaca\n",
|
|
||||||
"dataset_prepared_path:\n",
|
|
||||||
"val_set_size: 0.05\n",
|
|
||||||
"output_dir: ./qlora-out\n",
|
|
||||||
"\n",
|
|
||||||
"adapter: qlora\n",
|
|
||||||
"lora_model_dir:\n",
|
|
||||||
"\n",
|
|
||||||
"sequence_len: 1096\n",
|
|
||||||
"sample_packing: true\n",
|
|
||||||
"pad_to_sequence_len: true\n",
|
|
||||||
"\n",
|
|
||||||
"lora_r: 32\n",
|
|
||||||
"lora_alpha: 16\n",
|
|
||||||
"lora_dropout: 0.05\n",
|
|
||||||
"lora_target_modules:\n",
|
|
||||||
"lora_target_linear: true\n",
|
|
||||||
"lora_fan_in_fan_out:\n",
|
|
||||||
"\n",
|
|
||||||
"wandb_project:\n",
|
|
||||||
"wandb_entity:\n",
|
|
||||||
"wandb_watch:\n",
|
|
||||||
"wandb_name:\n",
|
|
||||||
"wandb_log_model:\n",
|
|
||||||
"\n",
|
|
||||||
"mlflow_experiment_name: colab-example\n",
|
|
||||||
"\n",
|
|
||||||
"gradient_accumulation_steps: 1\n",
|
|
||||||
"micro_batch_size: 1\n",
|
|
||||||
"num_epochs: 4\n",
|
|
||||||
"max_steps: 20\n",
|
|
||||||
"optimizer: paged_adamw_32bit\n",
|
|
||||||
"lr_scheduler: cosine\n",
|
|
||||||
"learning_rate: 0.0002\n",
|
|
||||||
"\n",
|
|
||||||
"train_on_inputs: false\n",
|
|
||||||
"group_by_length: false\n",
|
|
||||||
"bf16: false\n",
|
|
||||||
"fp16: true\n",
|
|
||||||
"tf32: false\n",
|
|
||||||
"\n",
|
|
||||||
"gradient_checkpointing: true\n",
|
|
||||||
"early_stopping_patience:\n",
|
|
||||||
"resume_from_checkpoint:\n",
|
|
||||||
"local_rank:\n",
|
|
||||||
"logging_steps: 1\n",
|
|
||||||
"xformers_attention:\n",
|
|
||||||
"flash_attention: false\n",
|
|
||||||
"\n",
|
|
||||||
"warmup_steps: 10\n",
|
|
||||||
"evals_per_epoch:\n",
|
|
||||||
"saves_per_epoch:\n",
|
|
||||||
"debug:\n",
|
|
||||||
"deepspeed:\n",
|
|
||||||
"weight_decay: 0.0\n",
|
|
||||||
"fsdp:\n",
|
|
||||||
"fsdp_config:\n",
|
|
||||||
"special_tokens:\n",
|
|
||||||
"\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
"# Convert the YAML string to a Python dictionary\n",
|
|
||||||
"yaml_dict = yaml.safe_load(yaml_string)\n",
|
|
||||||
"\n",
|
|
||||||
"# Specify your file path\n",
|
|
||||||
"file_path = 'test_axolotl.yaml'\n",
|
|
||||||
"\n",
|
|
||||||
"# Write the YAML file\n",
|
|
||||||
"with open(file_path, 'w') as file:\n",
|
|
||||||
" yaml.dump(yaml_dict, file)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"id": "bidoj8YLTusD"
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## Launch the training"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "ydTI2Jk2RStU",
|
|
||||||
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Buy using the ! the comand will be executed as a bash command\n",
|
|
||||||
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Play with inference"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Buy using the ! the comand will be executed as a bash command\n",
|
|
||||||
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
|
||||||
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"accelerator": "GPU",
|
|
||||||
"colab": {
|
|
||||||
"gpuType": "T4",
|
|
||||||
"provenance": []
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"name": "python"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
{
|
||||||
"nbformat_minor": 0
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "RcbNpOgWRcii"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n",
|
||||||
|
"assert (torch.cuda.is_available()==True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "h3nLav8oTRA5"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Install Axolotl and dependencies"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "3c3yGAwnOIdi",
|
||||||
|
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install torch==\"2.1.2\"\n",
|
||||||
|
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
|
||||||
|
"!pip install flash-attn==\"2.5.0\"\n",
|
||||||
|
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "BW2MFr7HTjub"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Create an yaml config file"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "9pkF2dSoQEUN"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import yaml\n",
|
||||||
|
"\n",
|
||||||
|
"# Your YAML string\n",
|
||||||
|
"yaml_string = \"\"\"\n",
|
||||||
|
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
|
||||||
|
"model_type: LlamaForCausalLM\n",
|
||||||
|
"tokenizer_type: LlamaTokenizer\n",
|
||||||
|
"\n",
|
||||||
|
"load_in_8bit: false\n",
|
||||||
|
"load_in_4bit: true\n",
|
||||||
|
"strict: false\n",
|
||||||
|
"\n",
|
||||||
|
"datasets:\n",
|
||||||
|
" - path: mhenrichsen/alpaca_2k_test\n",
|
||||||
|
" type: alpaca\n",
|
||||||
|
"dataset_prepared_path:\n",
|
||||||
|
"val_set_size: 0.05\n",
|
||||||
|
"output_dir: ./outputs/qlora-out\n",
|
||||||
|
"\n",
|
||||||
|
"adapter: qlora\n",
|
||||||
|
"lora_model_dir:\n",
|
||||||
|
"\n",
|
||||||
|
"sequence_len: 4096\n",
|
||||||
|
"sample_packing: true\n",
|
||||||
|
"eval_sample_packing: false\n",
|
||||||
|
"pad_to_sequence_len: true\n",
|
||||||
|
"\n",
|
||||||
|
"lora_r: 32\n",
|
||||||
|
"lora_alpha: 16\n",
|
||||||
|
"lora_dropout: 0.05\n",
|
||||||
|
"lora_target_modules:\n",
|
||||||
|
"lora_target_linear: true\n",
|
||||||
|
"lora_fan_in_fan_out:\n",
|
||||||
|
"\n",
|
||||||
|
"wandb_project:\n",
|
||||||
|
"wandb_entity:\n",
|
||||||
|
"wandb_watch:\n",
|
||||||
|
"wandb_name:\n",
|
||||||
|
"wandb_log_model:\n",
|
||||||
|
"\n",
|
||||||
|
"gradient_accumulation_steps: 4\n",
|
||||||
|
"micro_batch_size: 2\n",
|
||||||
|
"num_epochs: 4\n",
|
||||||
|
"optimizer: paged_adamw_32bit\n",
|
||||||
|
"lr_scheduler: cosine\n",
|
||||||
|
"learning_rate: 0.0002\n",
|
||||||
|
"\n",
|
||||||
|
"train_on_inputs: false\n",
|
||||||
|
"group_by_length: false\n",
|
||||||
|
"bf16: auto\n",
|
||||||
|
"fp16:\n",
|
||||||
|
"tf32: false\n",
|
||||||
|
"\n",
|
||||||
|
"gradient_checkpointing: true\n",
|
||||||
|
"early_stopping_patience:\n",
|
||||||
|
"resume_from_checkpoint:\n",
|
||||||
|
"local_rank:\n",
|
||||||
|
"logging_steps: 1\n",
|
||||||
|
"xformers_attention:\n",
|
||||||
|
"flash_attention: true\n",
|
||||||
|
"\n",
|
||||||
|
"warmup_steps: 10\n",
|
||||||
|
"evals_per_epoch: 4\n",
|
||||||
|
"saves_per_epoch: 1\n",
|
||||||
|
"debug:\n",
|
||||||
|
"deepspeed:\n",
|
||||||
|
"weight_decay: 0.0\n",
|
||||||
|
"fsdp:\n",
|
||||||
|
"fsdp_config:\n",
|
||||||
|
"special_tokens:\n",
|
||||||
|
"\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Convert the YAML string to a Python dictionary\n",
|
||||||
|
"yaml_dict = yaml.safe_load(yaml_string)\n",
|
||||||
|
"\n",
|
||||||
|
"# Specify your file path\n",
|
||||||
|
"file_path = 'test_axolotl.yaml'\n",
|
||||||
|
"\n",
|
||||||
|
"# Write the YAML file\n",
|
||||||
|
"with open(file_path, 'w') as file:\n",
|
||||||
|
" yaml.dump(yaml_dict, file)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "bidoj8YLTusD"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Launch the training"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"base_uri": "https://localhost:8080/"
|
||||||
|
},
|
||||||
|
"id": "ydTI2Jk2RStU",
|
||||||
|
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Buy using the ! the comand will be executed as a bash command\n",
|
||||||
|
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Play with inference"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Buy using the ! the comand will be executed as a bash command\n",
|
||||||
|
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
||||||
|
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"colab": {
|
||||||
|
"gpuType": "T4",
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
}
|
}
|
||||||
|
|||||||
81
examples/dbrx/16bit-lora.yaml
Normal file
81
examples/dbrx/16bit-lora.yaml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: LnL-AI/dbrx-base-converted-v2
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
# w1, w2, & v1 will hang the trainer
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj # attn
|
||||||
|
- k_proj # attn
|
||||||
|
- v_proj # attn
|
||||||
|
- out_proj # attn
|
||||||
|
- layer # router
|
||||||
|
# - w1
|
||||||
|
# - w2
|
||||||
|
# - v1
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_activation_checkpointing: true
|
||||||
81
examples/dbrx/8bit-lora.yaml
Normal file
81
examples/dbrx/8bit-lora.yaml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: LnL-AI/dbrx-base-converted-v2
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
# w1, w2, & v1 will hang the trainer
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj # attn
|
||||||
|
- k_proj # attn
|
||||||
|
- v_proj # attn
|
||||||
|
- out_proj # attn
|
||||||
|
- layer # router
|
||||||
|
# - w1
|
||||||
|
# - w2
|
||||||
|
# - v1
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_activation_checkpointing: true
|
||||||
26
examples/dbrx/README.md
Normal file
26
examples/dbrx/README.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# DBRX MoE
|
||||||
|
|
||||||
|
Currently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable.
|
||||||
|
|
||||||
|
We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10)
|
||||||
|
where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation
|
||||||
|
is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers
|
||||||
|
results in the trainer hanging.
|
||||||
|
|
||||||
|
|
||||||
|
### FSDP
|
||||||
|
We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.
|
||||||
|
|
||||||
|
The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.
|
||||||
|
|
||||||
|
- 16-bit LoRA w/ FSDP
|
||||||
|
- ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
|
||||||
|
- ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu
|
||||||
|
- ✅ 8-bit LoRA w/ FSDP
|
||||||
|
- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`
|
||||||
|
- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)
|
||||||
|
|
||||||
|
|
||||||
|
### Deepspeed
|
||||||
|
|
||||||
|
WIP
|
||||||
56
examples/dbrx/fft-ds-zero3.yaml
Normal file
56
examples/dbrx/fft-ds-zero3.yaml
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
base_model: LnL-AI/dbrx-base-converted-v2
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
unfrozen_parameters:
|
||||||
|
- transformer.blocks.[0-7].
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||||
@@ -28,7 +28,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./outputs/falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
# QLoRA paper Table 9
|
# QLoRA paper Table 9
|
||||||
# - 16 for 7b & 13b
|
# - 16 for 7b & 13b
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./outputs/falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 2
|
num_epochs: 2
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./jeopardy-bot-7b
|
output_dir: ./outputs/jeopardy-bot-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ wandb_project:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./model-out
|
output_dir: ./outputs/model-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lisa-out
|
output_dir: ./outputs/lisa-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
@@ -65,12 +65,14 @@ deepspeed:
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
- full_shard
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
fsdp_limit_all_gathers: true
|
fsdp_limit_all_gathers: true
|
||||||
fsdp_sync_module_states: true
|
fsdp_sync_module_states: true
|
||||||
fsdp_offload_params: true
|
fsdp_offload_params: true
|
||||||
fsdp_use_orig_params: false
|
fsdp_use_orig_params: false
|
||||||
fsdp_cpu_ram_efficient_loading: true
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./relora-out
|
output_dir: ./outputs/relora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
13
examples/llama-3/README.md
Normal file
13
examples/llama-3/README.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Llama-3
|
||||||
|
|
||||||
|
https://llama.meta.com/llama3/
|
||||||
|
|
||||||
|
[8B Base Model](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
|
||||||
|
- [Full Fine Tune](./fft-8b.yaml)
|
||||||
|
- Single GPU @ 48GB VRAM
|
||||||
|
- [LoRA](./lora-8b.yml)
|
||||||
|
- Single GPU @ 11GB VRAM
|
||||||
|
|
||||||
|
[70B Base Model](https://huggingface.co/meta-llama/Meta-Llama-3-70B)
|
||||||
|
- [QLORA+FSDP](./qlora-fsdp-70b.yaml)
|
||||||
|
- Dual GPU @ 21GB VRAM
|
||||||
58
examples/llama-3/fft-8b.yaml
Normal file
58
examples/llama-3/fft-8b.yaml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3-8B
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 100
|
||||||
|
evals_per_epoch: 2
|
||||||
|
eval_table_size:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
76
examples/llama-3/instruct-lora-8b.yml
Normal file
76
examples/llama-3/instruct-lora-8b.yml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: llama3
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
chat_template: llama3
|
||||||
|
field_messages: messages
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
70
examples/llama-3/lora-8b.yml
Normal file
70
examples/llama-3/lora-8b.yml
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3-8B
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_modules_to_save:
|
||||||
|
- embed_tokens
|
||||||
|
- lm_head
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
80
examples/llama-3/qlora-fsdp-70b.yaml
Normal file
80
examples/llama-3/qlora-fsdp-70b.yaml
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
base_model: casperhansen/llama-3-70b-fp16
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer # PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/out/qlora-llama3-70b
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.00001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
67
examples/llama-3/qlora.yml
Normal file
67
examples/llama-3/qlora.yml
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3-8B
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: aaditya/alpaca_subset_1
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: paged_adamw_32bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
63
examples/mistral/bigstral-ds-zero3.yaml
Normal file
63
examples/mistral/bigstral-ds-zero3.yaml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
unfrozen_parameters:
|
||||||
|
- ^lm_head.weight$
|
||||||
|
- ^model.embed_tokens.weight$
|
||||||
|
- model.layers.4[4-9]+.block_sparse_moe.gate
|
||||||
|
- model.layers.4[4-9]+.block_sparse_moe.experts
|
||||||
|
- model.layers.5[0-5]+.block_sparse_moe.gate
|
||||||
|
- model.layers.5[0-5]+.block_sparse_moe.experts
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 3
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
save_total_limit: 1
|
||||||
|
save_steps:
|
||||||
|
debug:
|
||||||
|
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
eos_token: "<|im_end|>"
|
||||||
|
tokens:
|
||||||
|
- "<|im_start|>"
|
||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
82
examples/mistral/mistral-qlora-fsdp.yml
Normal file
82
examples/mistral/mistral-qlora-fsdp.yml
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.02
|
||||||
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 1024
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: false
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
special_tokens:
|
||||||
82
examples/mistral/mistral-qlora-orpo.yml
Normal file
82
examples/mistral/mistral-qlora-orpo.yml
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
base_model: mistralai/Mistral-7B-v0.1
|
||||||
|
model_type: MistralForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
rl: orpo
|
||||||
|
orpo_alpha: 0.1
|
||||||
|
remove_unused_columns: false
|
||||||
|
|
||||||
|
chat_template: chatml
|
||||||
|
datasets:
|
||||||
|
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||||
|
type: chat_template.argilla
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/mistral-qlora-orpo-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
81
examples/mistral/mixtral-8x22b-qlora-fsdp.yml
Normal file
81
examples/mistral/mixtral-8x22b-qlora-fsdp.yml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.02
|
||||||
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 1024
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
special_tokens:
|
||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
model_config:
|
model_config:
|
||||||
output_router_logits: true
|
output_router_logits: true
|
||||||
@@ -39,7 +39,7 @@ wandb_log_model:
|
|||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
optimizer: paged_adamw_8bit
|
optimizer: adamw_torch
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ train_on_inputs: false
|
|||||||
group_by_length: false
|
group_by_length: false
|
||||||
bf16: auto
|
bf16: auto
|
||||||
fp16:
|
fp16:
|
||||||
tf32: false
|
tf32: true
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
early_stopping_patience:
|
early_stopping_patience:
|
||||||
@@ -69,6 +69,17 @@ debug:
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
- full_shard
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
fsdp_forward_prefetch: false
|
||||||
|
fsdp_backward_prefetch: BACKWARD_PRE
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||||
unfrozen_parameters:
|
unfrozen_parameters:
|
||||||
|
|||||||
61
examples/mistral/mixtral_22.yml
Normal file
61
examples/mistral/mixtral_22.yml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
base_model: mistral-community/Mixtral-8x22B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
unfrozen_parameters:
|
||||||
|
- ^lm_head.weight$
|
||||||
|
- ^model.embed_tokens.weight$
|
||||||
|
- model.layers.4[4-9]+.block_sparse_moe.gate
|
||||||
|
- model.layers.4[4-9]+.block_sparse_moe.experts
|
||||||
|
- model.layers.5[0-5]+.block_sparse_moe.gate
|
||||||
|
- model.layers.5[0-5]+.block_sparse_moe.experts
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: yahma/alpaca-cleaned
|
||||||
|
type: alpaca
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 8000
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 3
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
save_total_limit: 1
|
||||||
|
save_steps:
|
||||||
|
debug:
|
||||||
|
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
eos_token: "<|im_end|>"
|
||||||
|
tokens:
|
||||||
|
- "<|im_start|>"
|
||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./mpt-alpaca-7b
|
output_dir: ./outputs/mpt-alpaca-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./openllama-out
|
output_dir: ./outputs/openllama-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./phi-sft-out
|
output_dir: ./outputs/phi-sft-out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./phi-sft-out
|
output_dir: ./outputs/phi-sft-out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./phi-sft-out
|
output_dir: ./outputs/phi-sft-out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./pythia-12b
|
output_dir: ./outputs/pythia-12b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 5
|
num_epochs: 5
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-alpaca-pythia
|
output_dir: ./outputs/lora-alpaca-pythia
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 4
|
micro_batch_size: 4
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 2048 # supports up to 8192
|
sequence_len: 2048 # supports up to 8192
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 2048 # supports up to 8192
|
sequence_len: 2048 # supports up to 8192
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 1024 # supports up to 32k
|
sequence_len: 1024 # supports up to 32k
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,13 +10,13 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 1024 # supports up to 32k
|
sequence_len: 1024 # supports up to 32k
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
pad_to_sequence_len: false
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
adapter: lora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./redpajama-alpaca-3b
|
output_dir: ./outputs/redpajama-alpaca-3b
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-replit
|
output_dir: ./outputs/lora-replit
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./out
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.2
|
val_set_size: 0.2
|
||||||
output_dir: ./qlora
|
output_dir: ./outputs/qlora
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0
|
val_set_size: 0
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ pretraining_dataset:
|
|||||||
type: pretrain
|
type: pretrain
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./model-out
|
output_dir: ./outputs/model-out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ datasets:
|
|||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ wandb_entity:
|
|||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_name:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
# QLoRA paper Table 9
|
# QLoRA paper Table 9
|
||||||
# - 16 for 7b & 13b
|
# - 16 for 7b & 13b
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ eval_sample_packing: false
|
|||||||
eval_batch_size: 1
|
eval_batch_size: 1
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
output_dir: ./qlora-out
|
output_dir: ./outputs/qlora-out
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
|
|||||||
@@ -1,22 +1,22 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.10.0
|
peft==0.11.1
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@43d17c18360ac9c3d3491389328e2fe55fe8f9ce
|
transformers==4.41.1
|
||||||
tokenizers==0.15.0
|
tokenizers==0.19.1
|
||||||
bitsandbytes==0.43.0
|
bitsandbytes==0.43.1
|
||||||
accelerate==0.28.0
|
accelerate==0.30.1
|
||||||
deepspeed==0.13.1
|
deepspeed==0.14.2
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
datasets>=2.15.0
|
datasets==2.19.1
|
||||||
flash-attn==2.5.5
|
flash-attn==2.5.8
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers==0.0.22
|
xformers==0.0.26.post1
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
@@ -28,7 +28,7 @@ scipy
|
|||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat==0.2.36
|
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|
||||||
@@ -39,5 +39,6 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
trl==0.8.6
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
|
fastcore
|
||||||
|
|||||||
82
scripts/cloud-entrypoint-term.sh
Executable file
82
scripts/cloud-entrypoint-term.sh
Executable file
@@ -0,0 +1,82 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Export specific ENV variables to /etc/rp_environment
|
||||||
|
echo "Exporting environment variables..."
|
||||||
|
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
||||||
|
conda init
|
||||||
|
# this needs to come after conda init
|
||||||
|
echo 'source /etc/rp_environment' >> ~/.bashrc
|
||||||
|
|
||||||
|
add_keys_to_authorized() {
|
||||||
|
local key_value=$1
|
||||||
|
|
||||||
|
# Create the ~/.ssh directory and set permissions
|
||||||
|
mkdir -p ~/.ssh
|
||||||
|
chmod 700 ~/.ssh
|
||||||
|
|
||||||
|
# Create the authorized_keys file if it doesn't exist
|
||||||
|
touch ~/.ssh/authorized_keys
|
||||||
|
|
||||||
|
# Initialize an empty key variable
|
||||||
|
local key=""
|
||||||
|
|
||||||
|
# Read the key variable word by word
|
||||||
|
for word in $key_value; do
|
||||||
|
# Check if the word looks like the start of a key
|
||||||
|
if [[ $word == ssh-* ]]; then
|
||||||
|
# If there's a key being built, add it to the authorized_keys file
|
||||||
|
if [[ -n $key ]]; then
|
||||||
|
echo $key >> ~/.ssh/authorized_keys
|
||||||
|
fi
|
||||||
|
# Start a new key
|
||||||
|
key=$word
|
||||||
|
else
|
||||||
|
# Append the word to the current key
|
||||||
|
key="$key $word"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Add the last key to the authorized_keys file
|
||||||
|
if [[ -n $key ]]; then
|
||||||
|
echo $key >> ~/.ssh/authorized_keys
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set the correct permissions
|
||||||
|
chmod 600 ~/.ssh/authorized_keys
|
||||||
|
chmod 700 -R ~/.ssh
|
||||||
|
}
|
||||||
|
|
||||||
|
if [[ $PUBLIC_KEY ]]; then
|
||||||
|
# runpod
|
||||||
|
add_keys_to_authorized "$PUBLIC_KEY"
|
||||||
|
# Start the SSH service in the background
|
||||||
|
service ssh start
|
||||||
|
elif [[ $SSH_KEY ]]; then
|
||||||
|
# latitude.sh
|
||||||
|
add_keys_to_authorized "$SSH_KEY"
|
||||||
|
# Start the SSH service in the background
|
||||||
|
service ssh start
|
||||||
|
else
|
||||||
|
echo "No PUBLIC_KEY or SSH_KEY environment variable provided, not starting openSSH daemon"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if JUPYTER_PASSWORD is set and not empty
|
||||||
|
if [ -n "$JUPYTER_PASSWORD" ]; then
|
||||||
|
# Set JUPYTER_TOKEN to the value of JUPYTER_PASSWORD
|
||||||
|
export JUPYTER_TOKEN="$JUPYTER_PASSWORD"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
||||||
|
# Run Jupyter Lab in the background
|
||||||
|
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d "/workspace/data/axolotl-artifacts" ]; then
|
||||||
|
mkdir -p /workspace/data/axolotl-artifacts
|
||||||
|
fi
|
||||||
|
if [ ! -L "/workspace/axolotl/outputs" ]; then
|
||||||
|
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Execute the passed arguments (CMD)
|
||||||
|
exec "$@"
|
||||||
@@ -5,20 +5,53 @@ echo "Exporting environment variables..."
|
|||||||
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
||||||
echo 'source /etc/rp_environment' >> ~/.bashrc
|
echo 'source /etc/rp_environment' >> ~/.bashrc
|
||||||
|
|
||||||
|
add_keys_to_authorized() {
|
||||||
|
local key_value=$1
|
||||||
|
|
||||||
|
# Create the ~/.ssh directory and set permissions
|
||||||
|
mkdir -p ~/.ssh
|
||||||
|
chmod 700 ~/.ssh
|
||||||
|
|
||||||
|
# Create the authorized_keys file if it doesn't exist
|
||||||
|
touch ~/.ssh/authorized_keys
|
||||||
|
|
||||||
|
# Initialize an empty key variable
|
||||||
|
local key=""
|
||||||
|
|
||||||
|
# Read the key variable word by word
|
||||||
|
for word in $key_value; do
|
||||||
|
# Check if the word looks like the start of a key
|
||||||
|
if [[ $word == ssh-* ]]; then
|
||||||
|
# If there's a key being built, add it to the authorized_keys file
|
||||||
|
if [[ -n $key ]]; then
|
||||||
|
echo $key >> ~/.ssh/authorized_keys
|
||||||
|
fi
|
||||||
|
# Start a new key
|
||||||
|
key=$word
|
||||||
|
else
|
||||||
|
# Append the word to the current key
|
||||||
|
key="$key $word"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Add the last key to the authorized_keys file
|
||||||
|
if [[ -n $key ]]; then
|
||||||
|
echo $key >> ~/.ssh/authorized_keys
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set the correct permissions
|
||||||
|
chmod 600 ~/.ssh/authorized_keys
|
||||||
|
chmod 700 -R ~/.ssh
|
||||||
|
}
|
||||||
|
|
||||||
if [[ $PUBLIC_KEY ]]; then
|
if [[ $PUBLIC_KEY ]]; then
|
||||||
# runpod
|
# runpod
|
||||||
mkdir -p ~/.ssh
|
add_keys_to_authorized "$PUBLIC_KEY"
|
||||||
chmod 700 ~/.ssh
|
|
||||||
echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
|
|
||||||
chmod 700 -R ~/.ssh
|
|
||||||
# Start the SSH service in the background
|
# Start the SSH service in the background
|
||||||
service ssh start
|
service ssh start
|
||||||
elif [ -n "$SSH_KEY" ]; then
|
elif [[ $SSH_KEY ]]; then
|
||||||
# latitude.sh
|
# latitude.sh
|
||||||
mkdir -p ~/.ssh
|
add_keys_to_authorized "$SSH_KEY"
|
||||||
chmod 700 ~/.ssh
|
|
||||||
echo $SSH_KEY >> ~/.ssh/authorized_keys
|
|
||||||
chmod 700 -R ~/.ssh
|
|
||||||
# Start the SSH service in the background
|
# Start the SSH service in the background
|
||||||
service ssh start
|
service ssh start
|
||||||
else
|
else
|
||||||
@@ -33,7 +66,14 @@ fi
|
|||||||
|
|
||||||
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
if [ "$JUPYTER_DISABLE" != "1" ]; then
|
||||||
# Run Jupyter Lab in the background
|
# Run Jupyter Lab in the background
|
||||||
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
|
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* &
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d "/workspace/data/axolotl-artifacts" ]; then
|
||||||
|
mkdir -p /workspace/data/axolotl-artifacts
|
||||||
|
fi
|
||||||
|
if [ ! -L "/workspace/axolotl/outputs" ]; then
|
||||||
|
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Execute the passed arguments (CMD)
|
# Execute the passed arguments (CMD)
|
||||||
|
|||||||
25
setup.py
25
setup.py
@@ -30,8 +30,11 @@ def parse_requirements():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if "Darwin" in platform.system():
|
if "Darwin" in platform.system():
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
# don't install xformers on MacOS
|
||||||
|
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
||||||
else:
|
else:
|
||||||
|
# detect the version of torch already installed
|
||||||
|
# and set it so dependencies don't clobber the torch version
|
||||||
torch_version = version("torch")
|
torch_version = version("torch")
|
||||||
_install_requires.append(f"torch=={torch_version}")
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
@@ -45,9 +48,15 @@ def parse_requirements():
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 1):
|
if (major, minor) >= (2, 3):
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
pass
|
||||||
_install_requires.append("xformers>=0.0.23")
|
elif (major, minor) >= (2, 2):
|
||||||
|
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
||||||
|
_install_requires.append("xformers>=0.0.25.post1")
|
||||||
|
else:
|
||||||
|
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
|
||||||
|
_install_requires.append("xformers>=0.0.23.post1")
|
||||||
|
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -59,7 +68,7 @@ install_requires, dependency_links = parse_requirements()
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="axolotl",
|
name="axolotl",
|
||||||
version="0.4.0",
|
version="0.4.1",
|
||||||
description="LLM Trainer",
|
description="LLM Trainer",
|
||||||
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
@@ -68,13 +77,13 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.5.5",
|
"flash-attn==2.5.8",
|
||||||
],
|
],
|
||||||
"fused-dense-lib": [
|
"fused-dense-lib": [
|
||||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
|
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.13.1",
|
"deepspeed==0.14.2",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from huggingface_hub import HfApi
|
|||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
from transformers.utils.import_utils import _is_package_available
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
@@ -62,6 +63,20 @@ def print_axolotl_text_art(suffix=None):
|
|||||||
if is_main_process():
|
if is_main_process():
|
||||||
print(ascii_art)
|
print(ascii_art)
|
||||||
|
|
||||||
|
print_dep_versions()
|
||||||
|
|
||||||
|
|
||||||
|
def print_dep_versions():
|
||||||
|
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
||||||
|
max_len = max(len(pkg) for pkg in packages)
|
||||||
|
if is_main_process():
|
||||||
|
print("*" * 40)
|
||||||
|
print("**** Axolotl Dependency Versions *****")
|
||||||
|
for pkg in packages:
|
||||||
|
version = _is_package_available(pkg, return_version=True)
|
||||||
|
print(f"{pkg: >{max_len}}: {version[1]: <15}")
|
||||||
|
print("*" * 40)
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]):
|
def check_remote_config(config: Union[str, Path]):
|
||||||
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
|
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
|
||||||
@@ -249,8 +264,8 @@ def do_inference_gradio(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
repetition_penalty=1.1,
|
repetition_penalty=1.1,
|
||||||
max_new_tokens=1024,
|
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
|
||||||
temperature=0.9,
|
temperature=cfg.get("gradio_temperature", 0.9),
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=40,
|
top_k=40,
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
@@ -285,7 +300,13 @@ def do_inference_gradio(
|
|||||||
outputs="text",
|
outputs="text",
|
||||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||||
)
|
)
|
||||||
demo.queue().launch(show_api=False, share=True)
|
|
||||||
|
demo.queue().launch(
|
||||||
|
show_api=False,
|
||||||
|
share=cfg.get("gradio_share", True),
|
||||||
|
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||||
|
server_port=cfg.get("gradio_server_port", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
def choose_config(path: Path):
|
||||||
@@ -418,6 +439,23 @@ def load_rl_datasets(
|
|||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cli_args.debug or cfg.debug:
|
||||||
|
LOG.info("check_dataset_labels...")
|
||||||
|
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
check_dataset_labels(
|
||||||
|
train_dataset.select(
|
||||||
|
[
|
||||||
|
random.randrange(0, len(train_dataset) - 1) # nosec
|
||||||
|
for _ in range(cli_args.debug_num_examples)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
num_examples=cli_args.debug_num_examples,
|
||||||
|
text_only=cli_args.debug_text_only,
|
||||||
|
rl_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
return TrainDatasetMeta(
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
load_in_4bit=False,
|
load_in_4bit=False,
|
||||||
flash_attention=False,
|
flash_attention=False,
|
||||||
|
deepspeed=None,
|
||||||
|
fsdp=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,6 +42,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
parsed_cfg.flash_attention = False
|
parsed_cfg.flash_attention = False
|
||||||
parsed_cfg.deepspeed = None
|
parsed_cfg.deepspeed = None
|
||||||
parsed_cfg.fsdp = None
|
parsed_cfg.fsdp = None
|
||||||
|
parsed_cfg.fsdp_config = None
|
||||||
|
|
||||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
|||||||
)
|
)
|
||||||
from axolotl.common.cli import PreprocessCliArgs
|
from axolotl.common.cli import PreprocessCliArgs
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
from axolotl.prompt_strategies.sharegpt import (
|
||||||
|
register_chatml_template,
|
||||||
|
register_llama3_template,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||||
|
|
||||||
@@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
|
if parsed_cfg.chat_template == "chatml":
|
||||||
LOG.info(
|
if parsed_cfg.default_system_message:
|
||||||
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
LOG.info(
|
||||||
)
|
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||||
register_chatml_template(parsed_cfg.default_system_message)
|
)
|
||||||
else:
|
register_chatml_template(parsed_cfg.default_system_message)
|
||||||
register_chatml_template()
|
else:
|
||||||
|
register_chatml_template()
|
||||||
|
elif parsed_cfg.chat_template == "llama3":
|
||||||
|
if parsed_cfg.default_system_message:
|
||||||
|
LOG.info(
|
||||||
|
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||||
|
)
|
||||||
|
register_llama3_template(parsed_cfg.default_system_message)
|
||||||
|
else:
|
||||||
|
register_llama3_template()
|
||||||
|
|
||||||
if not parsed_cfg.dataset_prepared_path:
|
if not parsed_cfg.dataset_prepared_path:
|
||||||
msg = (
|
msg = (
|
||||||
@@ -54,7 +66,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
LOG.warning(msg)
|
LOG.warning(msg)
|
||||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
||||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
else:
|
else:
|
||||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
from axolotl.prompt_strategies.sharegpt import (
|
||||||
|
register_chatml_template,
|
||||||
|
register_llama3_template,
|
||||||
|
)
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.train")
|
LOG = logging.getLogger("axolotl.cli.train")
|
||||||
@@ -47,7 +50,15 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
else:
|
else:
|
||||||
register_chatml_template()
|
register_chatml_template()
|
||||||
|
|
||||||
if cfg.rl and cfg.rl != "orpo":
|
if cfg.chat_template == "llama3" and cfg.default_system_message:
|
||||||
|
LOG.info(
|
||||||
|
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
|
||||||
|
)
|
||||||
|
register_llama3_template(cfg.default_system_message)
|
||||||
|
else:
|
||||||
|
register_llama3_template()
|
||||||
|
|
||||||
|
if cfg.rl: # and cfg.rl != "orpo":
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
284
src/axolotl/core/trainer_builder.py
Normal file → Executable file
284
src/axolotl/core/trainer_builder.py
Normal file → Executable file
@@ -30,18 +30,20 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
|
from axolotl.utils import is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
|
SaveModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
causal_lm_bench_eval_callback_factory,
|
causal_lm_bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
@@ -53,6 +55,7 @@ from axolotl.utils.collators import (
|
|||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.models import ensure_dtype
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
get_cosine_schedule_with_min_lr,
|
get_cosine_schedule_with_min_lr,
|
||||||
@@ -71,10 +74,6 @@ except ImportError:
|
|||||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||||
|
|
||||||
|
|
||||||
def is_mlflow_available():
|
|
||||||
return importlib.util.find_spec("mlflow") is not None
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||||
if isinstance(tag_names, str):
|
if isinstance(tag_names, str):
|
||||||
tag_names = [tag_names]
|
tag_names = [tag_names]
|
||||||
@@ -92,11 +91,12 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(TrainingArguments):
|
class AxolotlTrainingMixins:
|
||||||
"""
|
"""
|
||||||
Extend the base TrainingArguments for axolotl helpers
|
Mixin class for the Axolotl training args.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
default=None, metadata={"help": "HF model configuration model_type."}
|
default=None, metadata={"help": "HF model configuration model_type."}
|
||||||
)
|
)
|
||||||
@@ -126,14 +126,22 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||||
)
|
)
|
||||||
|
sample_packing_bin_size: int = field(
|
||||||
|
default=200,
|
||||||
|
metadata={
|
||||||
|
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sample_packing_group_size: int = field(
|
||||||
|
default=100000,
|
||||||
|
metadata={
|
||||||
|
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
max_seq_length: int = field(
|
max_seq_length: int = field(
|
||||||
default=2048,
|
default=2048,
|
||||||
metadata={"help": "The maximum sequence length the model can handle"},
|
metadata={"help": "The maximum sequence length the model can handle"},
|
||||||
)
|
)
|
||||||
sample_packing_seq_len_multiplier: int = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
|
||||||
)
|
|
||||||
relora_steps: Optional[int] = field(
|
relora_steps: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "how often to reset for ReLoRA"},
|
metadata={"help": "how often to reset for ReLoRA"},
|
||||||
@@ -214,6 +222,34 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "path under the model to access the layers"},
|
metadata={"help": "path under the model to access the layers"},
|
||||||
)
|
)
|
||||||
|
curriculum_sampling: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
|
"""
|
||||||
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||||
|
so it can't be used as a mixin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||||
|
"""
|
||||||
|
ORPO config for ORPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
||||||
|
"""
|
||||||
|
KTO config for KTO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
@@ -343,12 +379,15 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
RandomSampler(self.train_dataset),
|
RandomSampler(self.train_dataset),
|
||||||
batch_size=batch_size,
|
|
||||||
drop_last=True,
|
|
||||||
batch_max_len=batch_max_len,
|
|
||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
batch_max_len=batch_max_len,
|
||||||
|
batch_size=batch_size,
|
||||||
|
group_size=self.args.sample_packing_group_size,
|
||||||
|
bin_size=self.args.sample_packing_bin_size,
|
||||||
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
if self.args.curriculum_sampling:
|
||||||
|
return SequentialSampler(self.train_dataset)
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def _get_eval_sampler(
|
def _get_eval_sampler(
|
||||||
@@ -365,11 +404,12 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
SequentialSampler(eval_dataset),
|
SequentialSampler(eval_dataset),
|
||||||
batch_size=batch_size,
|
lengths=get_dataset_lengths(self.eval_dataset),
|
||||||
drop_last=True,
|
|
||||||
batch_max_len=batch_max_len,
|
batch_max_len=batch_max_len,
|
||||||
lengths=get_dataset_lengths(eval_dataset),
|
batch_size=batch_size,
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
group_size=self.args.sample_packing_group_size,
|
||||||
|
bin_size=self.args.sample_packing_bin_size,
|
||||||
|
drop_last=True,
|
||||||
)
|
)
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
@@ -793,6 +833,40 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.optimizer = None
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
if self.args.loraplus_lr_ratio is None:
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args,
|
||||||
|
opt_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
if loraplus_lr_ratio:
|
||||||
|
print("Using lora+")
|
||||||
|
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
optimizer_kwargs,
|
||||||
|
loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
@wraps(DPOTrainer.push_to_hub)
|
@wraps(DPOTrainer.push_to_hub)
|
||||||
def push_to_hub(self, *args, **kwargs) -> str:
|
def push_to_hub(self, *args, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -813,6 +887,22 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlORPOTrainer(ORPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKTOTrainer(KTOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Base class for trainer builder
|
Base class for trainer builder
|
||||||
@@ -876,6 +966,14 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
callbacks.append(
|
callbacks.append(
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
if self.cfg.use_mlflow and is_mlflow_available():
|
||||||
|
from axolotl.utils.callbacks.mlflow_ import (
|
||||||
|
SaveAxolotlConfigtoMlflowCallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks.append(
|
||||||
|
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||||
|
)
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -921,29 +1019,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
):
|
):
|
||||||
callbacks.append(SaveBetterTransformerModelCallback())
|
callbacks.append(SaveBetterTransformerModelCallback())
|
||||||
|
|
||||||
if self.cfg.use_wandb:
|
|
||||||
callbacks.append(
|
|
||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
|
||||||
)
|
|
||||||
if self.cfg.use_mlflow and is_mlflow_available():
|
|
||||||
from axolotl.utils.callbacks.mlflow_ import (
|
|
||||||
SaveAxolotlConfigtoMlflowCallback,
|
|
||||||
)
|
|
||||||
|
|
||||||
callbacks.append(
|
|
||||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.loss_watchdog_threshold is not None:
|
if self.cfg.loss_watchdog_threshold is not None:
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||||
|
|
||||||
|
callbacks.append(SaveModelCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
callbacks = []
|
callbacks = []
|
||||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||||
LogPredictionCallback = log_prediction_callback_factory(
|
LogPredictionCallback = log_prediction_callback_factory(
|
||||||
trainer, self.tokenizer
|
trainer, self.tokenizer, "wandb"
|
||||||
|
)
|
||||||
|
callbacks.append(LogPredictionCallback(self.cfg))
|
||||||
|
if (
|
||||||
|
self.cfg.use_mlflow
|
||||||
|
and is_mlflow_available()
|
||||||
|
and self.cfg.eval_table_size > 0
|
||||||
|
):
|
||||||
|
LogPredictionCallback = log_prediction_callback_factory(
|
||||||
|
trainer, self.tokenizer, "mlflow"
|
||||||
)
|
)
|
||||||
callbacks.append(LogPredictionCallback(self.cfg))
|
callbacks.append(LogPredictionCallback(self.cfg))
|
||||||
|
|
||||||
@@ -1052,11 +1148,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.save_safetensors is not None:
|
if self.cfg.save_safetensors is not None:
|
||||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.cfg.sample_packing_eff_est:
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"sample_packing_efficiency"
|
|
||||||
] = self.cfg.sample_packing_eff_est
|
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
if self.cfg.dataloader_pin_memory is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"dataloader_pin_memory"
|
"dataloader_pin_memory"
|
||||||
@@ -1104,6 +1195,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_arguments_kwargs["save_strategy"] = "epoch"
|
training_arguments_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
|
training_arguments_kwargs["save_only_model"] = self.cfg.save_only_model
|
||||||
|
|
||||||
if self.cfg.do_bench_eval:
|
if self.cfg.do_bench_eval:
|
||||||
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||||
if self.cfg.bench_dataset:
|
if self.cfg.bench_dataset:
|
||||||
@@ -1182,11 +1275,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
False if self.cfg.ddp else None
|
False if self.cfg.ddp else None
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
report_to = None
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
|
report_to = []
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
report_to = "wandb"
|
report_to.append("wandb")
|
||||||
if self.cfg.use_mlflow:
|
if self.cfg.use_mlflow:
|
||||||
report_to = "mlflow"
|
report_to.append("mlflow")
|
||||||
|
if self.cfg.use_tensorboard:
|
||||||
|
report_to.append("tensorboard")
|
||||||
|
|
||||||
training_arguments_kwargs["report_to"] = report_to
|
training_arguments_kwargs["report_to"] = report_to
|
||||||
training_arguments_kwargs["run_name"] = (
|
training_arguments_kwargs["run_name"] = (
|
||||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||||
@@ -1226,20 +1323,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["weight_decay"] = (
|
training_arguments_kwargs["weight_decay"] = (
|
||||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["sample_packing"] = (
|
|
||||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||||
)
|
|
||||||
training_arguments_kwargs["multipack_real_batches"] = (
|
|
||||||
self.cfg.flash_attention is not True
|
|
||||||
)
|
|
||||||
training_arguments_kwargs["eval_sample_packing"] = (
|
|
||||||
self.cfg.sample_packing
|
|
||||||
if self.cfg.eval_sample_packing is not False
|
|
||||||
else False
|
|
||||||
)
|
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"sample_packing_seq_len_multiplier"
|
"multipack_real_batches"
|
||||||
] = self.cfg.micro_batch_size
|
] = not self.cfg.flash_attention
|
||||||
|
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||||
|
self.cfg.eval_sample_packing
|
||||||
|
)
|
||||||
|
if self.cfg.sample_packing_bin_size is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"sample_packing_bin_size"
|
||||||
|
] = self.cfg.sample_packing_bin_size
|
||||||
|
if self.cfg.sample_packing_group_size is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"sample_packing_group_size"
|
||||||
|
] = self.cfg.sample_packing_group_size
|
||||||
|
if self.cfg.sample_packing_eff_est:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"sample_packing_efficiency"
|
||||||
|
] = self.cfg.sample_packing_eff_est
|
||||||
|
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
@@ -1402,13 +1506,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HFDPOTrainerBuilder(TrainerBuilderBase):
|
class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||||
"""
|
"""
|
||||||
Trainer factory class for DPO Trainer
|
Trainer factory class for DPO Trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
|
callbacks.append(SaveModelCallback())
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -1444,9 +1550,12 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["evaluation_strategy"] = "no"
|
training_args_kwargs["evaluation_strategy"] = "no"
|
||||||
|
|
||||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||||
training_args_kwargs["bf16"] = True
|
training_args_kwargs["bf16"] = True
|
||||||
|
|
||||||
|
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||||
|
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
|
||||||
training_args_kwargs["lr_scheduler_type"] = (
|
training_args_kwargs["lr_scheduler_type"] = (
|
||||||
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||||
)
|
)
|
||||||
@@ -1495,12 +1604,40 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
training_args = TrainingArguments(
|
if self.cfg.orpo_alpha:
|
||||||
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
|
|
||||||
|
training_args_cls = AxolotlTrainingArguments
|
||||||
|
if self.cfg.rl == "orpo":
|
||||||
|
training_args_cls = AxolotlORPOConfig
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
if self.cfg.max_prompt_len:
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
|
if self.cfg.rl == "kto":
|
||||||
|
training_args_cls = AxolotlKTOConfig
|
||||||
|
|
||||||
|
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
||||||
|
training_args_kwargs["desirable_weight"] = (
|
||||||
|
self.cfg.kto_desirable_weight or 1.0
|
||||||
|
)
|
||||||
|
training_args_kwargs["undesirable_weight"] = (
|
||||||
|
self.cfg.kto_undesirable_weight or 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
if self.cfg.max_prompt_len:
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
|
output_dir=self.cfg.output_dir,
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
max_steps=self.cfg.max_steps or total_num_steps,
|
max_steps=self.cfg.max_steps or total_num_steps,
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||||
learning_rate=self.cfg.learning_rate,
|
learning_rate=self.cfg.learning_rate,
|
||||||
output_dir=self.cfg.output_dir,
|
|
||||||
warmup_steps=self.cfg.warmup_steps,
|
warmup_steps=self.cfg.warmup_steps,
|
||||||
logging_first_step=True,
|
logging_first_step=True,
|
||||||
logging_steps=1,
|
logging_steps=1,
|
||||||
@@ -1528,20 +1665,37 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
dpo_trainer = AxolotlDPOTrainer(
|
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||||
self.model,
|
trainer_cls = AxolotlDPOTrainer
|
||||||
self.model_ref,
|
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
||||||
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
|
# these aren't used for the ORPO trainer
|
||||||
|
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
dpo_trainer_kwargs["max_target_length"] = None
|
||||||
|
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
|
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||||
|
if self.cfg.rl == "dpo":
|
||||||
|
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
elif self.cfg.rl == "orpo":
|
||||||
|
trainer_cls = AxolotlORPOTrainer
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
elif self.cfg.rl == "kto":
|
||||||
|
trainer_cls = AxolotlKTOTrainer
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
dpo_trainer = trainer_cls(
|
||||||
|
*trainer_cls_args,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
beta=self.cfg.dpo_beta or 0.1,
|
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
max_length=self.cfg.sequence_len,
|
|
||||||
max_target_length=None,
|
|
||||||
max_prompt_length=self.cfg.sequence_len,
|
|
||||||
generate_during_eval=True,
|
|
||||||
callbacks=self.get_callbacks(),
|
callbacks=self.get_callbacks(),
|
||||||
**dpo_trainer_kwargs,
|
**dpo_trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
if self.cfg.fsdp:
|
||||||
|
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
|
||||||
|
|
||||||
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
||||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
||||||
dpo_trainer.add_callback(callback)
|
dpo_trainer.add_callback(callback)
|
||||||
|
|||||||
@@ -123,6 +123,25 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.LLAMA3:
|
||||||
|
if self.system_message:
|
||||||
|
# For llama3, the system message is NOT incorporated into the first human instruction
|
||||||
|
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message:
|
||||||
|
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
|
||||||
|
else:
|
||||||
|
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.GEMMA:
|
||||||
|
if self.system_message:
|
||||||
|
raise ValueError("Gemma chat template does not support system messages")
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
prefix = "<bos>" if i == 0 else ""
|
||||||
|
message_str = message if message else ""
|
||||||
|
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
|
||||||
|
return
|
||||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||||
|
|||||||
@@ -516,24 +516,18 @@ def mistral_model_forward(
|
|||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = (
|
||||||
def create_custom_forward(module):
|
self._gradient_checkpointing_func( # pylint: disable=protected-access
|
||||||
def custom_forward(*inputs):
|
decoder_layer.__call__,
|
||||||
# None for past_key_value
|
hidden_states,
|
||||||
return module(*inputs)
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
return custom_forward
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
None,
|
||||||
create_custom_forward(decoder_layer),
|
cu_seqlens,
|
||||||
hidden_states,
|
max_seqlen,
|
||||||
attention_mask,
|
)
|
||||||
position_ids,
|
|
||||||
past_key_value,
|
|
||||||
output_attentions,
|
|
||||||
None,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
|
|||||||
@@ -42,9 +42,9 @@ def patch_mixtral_moe_forward_zero3() -> None:
|
|||||||
return final_hidden_states, router_logits
|
return final_hidden_states, router_logits
|
||||||
|
|
||||||
from transformers.models.mixtral.modeling_mixtral import (
|
from transformers.models.mixtral.modeling_mixtral import (
|
||||||
MixtralBLockSparseTop2MLP,
|
MixtralBlockSparseTop2MLP,
|
||||||
MixtralSparseMoeBlock,
|
MixtralSparseMoeBlock,
|
||||||
)
|
)
|
||||||
|
|
||||||
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
MixtralBlockSparseTop2MLP.forward = mlp_forward
|
||||||
MixtralSparseMoeBlock.forward = moe_forward
|
MixtralSparseMoeBlock.forward = moe_forward
|
||||||
|
|||||||
267
src/axolotl/monkeypatch/unsloth_.py
Normal file
267
src/axolotl/monkeypatch/unsloth_.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""module for patching with unsloth optimizations"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import types
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from peft import PeftModelForCausalLM
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaFlashAttention2,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
|
||||||
|
|
||||||
|
ORIGINAL_CEL_CODE = """ if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_CEL_CODE = """ if labels is not None:
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
loss = fast_cross_entropy_loss(
|
||||||
|
logits = shift_logits,
|
||||||
|
labels = shift_labels,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
ORIGINAL_QKV_CODE = """
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
""".lstrip(
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
PATCHED_QKV_CODE = """
|
||||||
|
query_states, key_states, value_states = self.apply_qkv(self, hidden_states)
|
||||||
|
""".lstrip(
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
ORIGINAL_O_CODE = """
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
""".lstrip(
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
PATCHED_O_CODE = """
|
||||||
|
attn_output = self.apply_o(self, attn_output)
|
||||||
|
""".lstrip(
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def original_apply_qkv(self, hidden_states):
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
return query_states, key_states, value_states
|
||||||
|
|
||||||
|
|
||||||
|
def original_apply_o(self, hidden_states):
|
||||||
|
attn_output = self.o_proj(hidden_states)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
def get_forward_code() -> str:
|
||||||
|
forward = inspect.getsource(LlamaForCausalLM.forward)
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def test_cel_is_patchable() -> bool:
|
||||||
|
forward = get_forward_code()
|
||||||
|
return ORIGINAL_CEL_CODE in forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_self_attn_code() -> str:
|
||||||
|
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def test_self_attn_is_patchable() -> bool:
|
||||||
|
qkv = get_self_attn_code()
|
||||||
|
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
|
||||||
|
|
||||||
|
|
||||||
|
def integrate_cross_entropy_loss_patch():
|
||||||
|
forward = get_forward_code()
|
||||||
|
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||||
|
forward, _ = detab_code(forward)
|
||||||
|
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
||||||
|
|
||||||
|
forward = forward.replace(
|
||||||
|
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
||||||
|
)
|
||||||
|
forward = forward.replace(
|
||||||
|
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
||||||
|
forward = forward.replace(
|
||||||
|
"def forward(",
|
||||||
|
"def fast_cross_entropy_loss_forward(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.models.llama.modeling_llama):
|
||||||
|
if item in forward:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.models.llama.modeling_llama import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
print("patching unsloth fast_cross_entropy_loss")
|
||||||
|
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
|
||||||
|
|
||||||
|
def detab_code(code: str) -> Tuple[str, str]:
|
||||||
|
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||||
|
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
||||||
|
return code, spaces
|
||||||
|
|
||||||
|
|
||||||
|
def patch_self_attn_lora():
|
||||||
|
self_attn_forward = get_self_attn_code()
|
||||||
|
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
|
||||||
|
self_attn_forward
|
||||||
|
)
|
||||||
|
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||||
|
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found"
|
||||||
|
assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found"
|
||||||
|
|
||||||
|
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
||||||
|
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||||
|
self_attn_forward = self_attn_forward.replace(
|
||||||
|
"def forward(",
|
||||||
|
"def unsloth_attn_forward(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.models.llama.modeling_llama):
|
||||||
|
if item in self_attn_forward:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.models.llama.modeling_llama import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
print("patching unsloth attn lora")
|
||||||
|
LlamaFlashAttention2.forward = (
|
||||||
|
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
||||||
|
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
||||||
|
from unsloth.kernels import apply_lora_mlp_swiglu
|
||||||
|
|
||||||
|
apply_lora_mlp = apply_lora_mlp_swiglu
|
||||||
|
elif peft_model.base_model.config.model_type == "gemma":
|
||||||
|
from unsloth.kernels import apply_lora_mlp_geglu_approx
|
||||||
|
|
||||||
|
apply_lora_mlp = apply_lora_mlp_geglu_approx
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Model type {peft_model.base_model.config.model_type} not supported"
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, layer in enumerate(peft_model.model.model.layers):
|
||||||
|
layer_modules = [
|
||||||
|
getattr(layer.mlp, linear_proj)
|
||||||
|
for linear_proj in ["gate_proj", "up_proj", "down_proj"]
|
||||||
|
]
|
||||||
|
is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||||
|
mlp_no_bias = all(
|
||||||
|
getattr(module, "base_layer", module).bias is None
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
mlp_not_dora = all(
|
||||||
|
getattr(module, "lora_magnitude_vector", None) is None
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
||||||
|
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
||||||
|
else:
|
||||||
|
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
|
||||||
|
|
||||||
|
|
||||||
|
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||||
|
from unsloth.kernels import apply_lora_o, apply_lora_qkv
|
||||||
|
|
||||||
|
for idx, layer in enumerate(peft_model.model.model.layers):
|
||||||
|
if cfg.unsloth_lora_qkv:
|
||||||
|
layer_modules = [
|
||||||
|
getattr(layer.self_attn, linear_proj)
|
||||||
|
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||||
|
]
|
||||||
|
is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||||
|
qkv_no_bias = all(
|
||||||
|
getattr(module, "base_layer", module).bias is None
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
qkv_not_dora = all(
|
||||||
|
getattr(module, "lora_magnitude_vector", None) is None
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_qkv_lora and qkv_no_bias and qkv_not_dora:
|
||||||
|
layer.self_attn.apply_qkv = apply_lora_qkv
|
||||||
|
else:
|
||||||
|
layer.self_attn.apply_qkv = original_apply_qkv
|
||||||
|
logging.warning(
|
||||||
|
"unable to apply unsloth lora qkv patch to layer %d", idx
|
||||||
|
)
|
||||||
|
if cfg.unsloth_lora_o:
|
||||||
|
layer_modules = [
|
||||||
|
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||||
|
]
|
||||||
|
is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||||
|
o_no_bias = all(
|
||||||
|
getattr(module, "base_layer", module).bias is None
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
o_not_dora = all(
|
||||||
|
getattr(module, "lora_magnitude_vector", None) is None
|
||||||
|
for module in layer_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_o_lora and o_no_bias and o_not_dora:
|
||||||
|
layer.self_attn.apply_o = apply_lora_o
|
||||||
|
else:
|
||||||
|
layer.self_attn.apply_o = original_apply_o
|
||||||
|
logging.warning(
|
||||||
|
"unable to apply unsloth lora o_proj patch to layer %d", idx
|
||||||
|
)
|
||||||
@@ -1,24 +1,56 @@
|
|||||||
"""
|
"""
|
||||||
HF Chat Templates prompt strategy
|
HF Chat Templates prompt strategy
|
||||||
"""
|
"""
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import Prompter
|
from axolotl.prompters import Prompter
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplatePrompter(Prompter):
|
class ChatTemplatePrompter(Prompter):
|
||||||
"""prompter for HF chat templates"""
|
"""prompter for HF chat templates"""
|
||||||
|
|
||||||
def __init__(self, tokenizer, chat_template=None, max_length=2048):
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
max_length=2048,
|
||||||
|
message_field_role: str = "from",
|
||||||
|
message_field_content: str = "value",
|
||||||
|
roles: Optional[Dict[str, List[str]]] = None,
|
||||||
|
):
|
||||||
|
if roles:
|
||||||
|
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||||
|
else:
|
||||||
|
self.roles = {
|
||||||
|
"human": "user",
|
||||||
|
"user": "user",
|
||||||
|
"assistant": "assistant",
|
||||||
|
"gpt": "assistant",
|
||||||
|
"system": "system",
|
||||||
|
}
|
||||||
|
self.message_field_role = message_field_role
|
||||||
|
self.message_field_content = message_field_content
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
def build_prompt(self, conversation, add_generation_prompt=False):
|
||||||
|
turns = [
|
||||||
|
{
|
||||||
|
"role": self.roles[t[self.message_field_role]],
|
||||||
|
"content": t[self.message_field_content],
|
||||||
|
}
|
||||||
|
for t in conversation
|
||||||
|
]
|
||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
conversation,
|
turns,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
@@ -31,9 +63,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
Tokenizing strategy for instruction-based prompts.
|
Tokenizing strategy for instruction-based prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_messages = "conversations"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self):
|
||||||
|
return self._messages
|
||||||
|
|
||||||
|
@messages.setter
|
||||||
|
def messages(self, messages):
|
||||||
|
self._messages = messages
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
turns = self.get_conversation_thread(prompt)
|
turns = self.get_conversation_thread(prompt)
|
||||||
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
|
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
input_ids = self.prompter.build_prompt(turns)
|
||||||
|
|
||||||
if not self.train_on_inputs:
|
if not self.train_on_inputs:
|
||||||
@@ -51,28 +93,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
conversations = prompt["conversations"]
|
return prompt[self.messages]
|
||||||
# remap roles - allow for assistant turn
|
|
||||||
role_map = {
|
|
||||||
"human": "user",
|
|
||||||
"user": "user",
|
|
||||||
"assistant": "assistant",
|
|
||||||
"gpt": "assistant",
|
|
||||||
}
|
|
||||||
turns = [
|
|
||||||
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
|
|
||||||
]
|
|
||||||
return turns
|
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
chat_template = (
|
chat_template = (
|
||||||
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
||||||
)
|
)
|
||||||
|
message_field_role = (
|
||||||
|
ds_cfg["message_field_role"]
|
||||||
|
if ds_cfg and "message_field_role" in ds_cfg
|
||||||
|
else "from"
|
||||||
|
)
|
||||||
|
message_field_content = (
|
||||||
|
ds_cfg["message_field_content"]
|
||||||
|
if ds_cfg and "message_field_content" in ds_cfg
|
||||||
|
else "value"
|
||||||
|
)
|
||||||
|
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
||||||
|
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_templates(chat_template),
|
||||||
|
message_field_role=message_field_role,
|
||||||
|
message_field_content=message_field_content,
|
||||||
|
roles=roles,
|
||||||
|
),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||||
|
strategy.messages = ds_cfg["field_messages"]
|
||||||
return strategy
|
return strategy
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user