Compare commits
10 Commits
NanoCode01
...
tensor-par
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87e8f13056 | ||
|
|
026172eaa8 | ||
|
|
b3689f73e3 | ||
|
|
c4664ba8ee | ||
|
|
75e4fc2825 | ||
|
|
e13c2fd6b1 | ||
|
|
8a21e14a21 | ||
|
|
9c52a83403 | ||
|
|
fb8ee37ca6 | ||
|
|
65f3a4f703 |
7
.github/workflows/base.yml
vendored
7
.github/workflows/base.yml
vendored
@@ -28,12 +28,7 @@ jobs:
|
|||||||
- cuda: "118"
|
- cuda: "118"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
|
||||||
- cuda: "121"
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.1
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
51
.github/workflows/main.yml
vendored
51
.github/workflows/main.yml
vendored
@@ -27,56 +27,38 @@ jobs:
|
|||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.0
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.1
|
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v3
|
||||||
- name: Docker metadata
|
- name: Docker metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v3
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl
|
images: winglian/axolotl
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
- name: Set up Docker Buildx
|
||||||
- name: Build and export to Docker
|
uses: docker/setup-buildx-action@v2
|
||||||
uses: docker/build-push-action@v5
|
- name: Build
|
||||||
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
load: true
|
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
file: ./docker/Dockerfile
|
file: ./docker/Dockerfile
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: |
|
tags: |
|
||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
- name: Unit Tests
|
|
||||||
run: |
|
|
||||||
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
|
||||||
- name: Push to Docker Hub
|
|
||||||
if: github.event_name != 'pull_request'
|
|
||||||
run: |
|
|
||||||
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
|
||||||
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
|
||||||
if [ -n "$latest_tag" ]; then
|
|
||||||
docker push "$latest_tag"
|
|
||||||
fi
|
|
||||||
|
|
||||||
build-axolotl-runpod:
|
build-axolotl-runpod:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||||
@@ -98,31 +80,26 @@ jobs:
|
|||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.0
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.1
|
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v3
|
||||||
- name: Docker metadata
|
- name: Docker metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v3
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-runpod
|
images: winglian/axolotl-runpod
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v2
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
|
|||||||
46
.github/workflows/tests-docker.yml
vendored
46
.github/workflows/tests-docker.yml
vendored
@@ -1,46 +0,0 @@
|
|||||||
name: e2e-docker-tests
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- '**.py'
|
|
||||||
- 'requirements.txt'
|
|
||||||
- '.github/workflows/*.yml'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build-axolotl:
|
|
||||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- cuda: 118
|
|
||||||
cuda_version: 11.8.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.0.1
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.1
|
|
||||||
runs-on: [self-hosted, gpu, docker]
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
- name: Build Docker image
|
|
||||||
run: |
|
|
||||||
# Set up build arguments
|
|
||||||
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
|
|
||||||
CUDA="${{ matrix.cuda }}"
|
|
||||||
PYTORCH_VERSION="${{ matrix.pytorch }}"
|
|
||||||
# Build the Docker image
|
|
||||||
docker build . \
|
|
||||||
--file ./docker/Dockerfile \
|
|
||||||
--build-arg BASE_TAG=$BASE_TAG \
|
|
||||||
--build-arg CUDA=$CUDA \
|
|
||||||
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
|
|
||||||
--tag test-axolotl
|
|
||||||
- name: Unit Tests w docker image
|
|
||||||
run: |
|
|
||||||
docker run --rm test-axolotl pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
|
||||||
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@@ -71,9 +71,8 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
|
||||||
pip3 uninstall -y transformers accelerate
|
pip3 uninstall -y transformers accelerate
|
||||||
pip3 install -U -e .[flash-attn,mamba-ssm]
|
pip3 install -U -e .[flash-attn]
|
||||||
pip3 install -r requirements-tests.txt
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run e2e tests
|
- name: Run e2e tests
|
||||||
|
|||||||
@@ -8,9 +8,6 @@ ignore_missing_imports = True
|
|||||||
[mypy-axolotl.monkeypatch.*]
|
[mypy-axolotl.monkeypatch.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-axolotl.models.mixtral.*]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-axolotl.models.phi.*]
|
[mypy-axolotl.models.phi.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|||||||
220
README.md
220
README.md
@@ -25,10 +25,8 @@ Features:
|
|||||||
- [Installation](#installation)
|
- [Installation](#installation)
|
||||||
- [Docker](#docker)
|
- [Docker](#docker)
|
||||||
- [Conda/Pip venv](#condapip-venv)
|
- [Conda/Pip venv](#condapip-venv)
|
||||||
- [Runpod](#runpod)
|
|
||||||
- [LambdaLabs](#lambdalabs)
|
- [LambdaLabs](#lambdalabs)
|
||||||
- [Windows](#windows)
|
- [Windows](#windows)
|
||||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
|
||||||
- [Dataset](#dataset)
|
- [Dataset](#dataset)
|
||||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||||
@@ -36,9 +34,7 @@ Features:
|
|||||||
- [Train](#train)
|
- [Train](#train)
|
||||||
- [Inference](#inference)
|
- [Inference](#inference)
|
||||||
- [Merge LORA to Base](#merge-lora-to-base)
|
- [Merge LORA to Base](#merge-lora-to-base)
|
||||||
- [Special Tokens](#special-tokens)
|
|
||||||
- [Common Errors](#common-errors-)
|
- [Common Errors](#common-errors-)
|
||||||
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
|
||||||
- [Need Help?](#need-help-)
|
- [Need Help?](#need-help-)
|
||||||
- [Badge](#badge-)
|
- [Badge](#badge-)
|
||||||
- [Community Showcase](#community-showcase)
|
- [Community Showcase](#community-showcase)
|
||||||
@@ -67,21 +63,17 @@ Features:
|
|||||||
|
|
||||||
## Axolotl supports
|
## Axolotl supports
|
||||||
|
|
||||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||||
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
|
||||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
|
||||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
|
||||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
|
||||||
|
|
||||||
|
|
||||||
## Quickstart ⚡
|
## Quickstart ⚡
|
||||||
@@ -90,29 +82,20 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
|||||||
|
|
||||||
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
||||||
|
|
||||||
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
|
|
||||||
|
|
||||||
### For developers
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging
|
pip3 install packaging
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install -e '.[flash-attn,deepspeed]'
|
||||||
```
|
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||||
|
|
||||||
### Usage
|
|
||||||
```bash
|
|
||||||
# finetune lora
|
# finetune lora
|
||||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||||
|
|
||||||
# inference
|
# inference
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--lora_model_dir="./lora-out"
|
--lora_model_dir="./lora-out"
|
||||||
|
|
||||||
# gradio
|
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|
||||||
--lora_model_dir="./lora-out" --gradio
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
@@ -123,6 +106,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
```bash
|
```bash
|
||||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||||
```
|
```
|
||||||
|
- `winglian/axolotl-runpod:main-latest`: for runpod or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||||
|
|
||||||
Or run on the current files for development:
|
Or run on the current files for development:
|
||||||
|
|
||||||
@@ -137,15 +121,13 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
A more powerful Docker command to run would be this:
|
A more powerful Docker command to run would be this:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
docker run --gpus '"all"' --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||||
```
|
```
|
||||||
|
|
||||||
It additionally:
|
It additionally:
|
||||||
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
||||||
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
||||||
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
||||||
* The `--privileged` flag gives all capabilities to the container.
|
|
||||||
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
|
|
||||||
|
|
||||||
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
||||||
|
|
||||||
@@ -167,10 +149,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
```
|
```
|
||||||
Get the token at huggingface.co/settings/tokens
|
Get the token at huggingface.co/settings/tokens
|
||||||
|
|
||||||
#### Runpod
|
|
||||||
|
|
||||||
Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
|
||||||
|
|
||||||
#### LambdaLabs
|
#### LambdaLabs
|
||||||
<details>
|
<details>
|
||||||
|
|
||||||
@@ -218,28 +196,6 @@ Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runp
|
|||||||
#### Windows
|
#### Windows
|
||||||
Please use WSL or Docker!
|
Please use WSL or Docker!
|
||||||
|
|
||||||
|
|
||||||
#### Launching on public clouds via SkyPilot
|
|
||||||
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
|
|
||||||
```bash
|
|
||||||
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
|
|
||||||
sky check
|
|
||||||
```
|
|
||||||
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
|
|
||||||
```
|
|
||||||
git clone https://github.com/skypilot-org/skypilot.git
|
|
||||||
cd skypilot/llm/axolotl
|
|
||||||
```
|
|
||||||
Use one command to launch:
|
|
||||||
```bash
|
|
||||||
# On-demand
|
|
||||||
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
|
||||||
|
|
||||||
# Managed spot (auto-recovery on preemption)
|
|
||||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### Dataset
|
### Dataset
|
||||||
|
|
||||||
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
|
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
|
||||||
@@ -249,17 +205,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"instruction": "...", "input": "...", "output": "..."}
|
{"instruction": "...", "input": "...", "output": "..."}
|
||||||
```
|
```
|
||||||
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
|
- `sharegpt`: conversations where `from` is `human`/`gpt`
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
|
|
||||||
```yml
|
|
||||||
datasets:
|
|
||||||
- path: <your-path>
|
|
||||||
type: sharegpt
|
|
||||||
conversation: llama-2
|
|
||||||
```
|
|
||||||
- `completion`: raw corpus
|
- `completion`: raw corpus
|
||||||
```json
|
```json
|
||||||
{"text": "..."}
|
{"text": "..."}
|
||||||
@@ -443,12 +392,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- path: knowrohit07/know_sql
|
- path: knowrohit07/know_sql
|
||||||
type: context_qa.load_v2
|
type: context_qa.load_v2
|
||||||
train_on_split: validation
|
train_on_split: validation
|
||||||
|
|
||||||
# loading from s3 or gcs
|
|
||||||
# s3 creds will be loaded from the system default and gcs only supports public access
|
|
||||||
dataset:
|
|
||||||
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
|
||||||
...
|
|
||||||
```
|
```
|
||||||
|
|
||||||
- loading
|
- loading
|
||||||
@@ -511,23 +454,6 @@ is_falcon_derived_model:
|
|||||||
is_llama_derived_model:
|
is_llama_derived_model:
|
||||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||||
is_mistral_derived_model:
|
is_mistral_derived_model:
|
||||||
is_qwen_derived_model:
|
|
||||||
|
|
||||||
# optional overrides to the base model configuration
|
|
||||||
model_config:
|
|
||||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
|
||||||
rope_scaling:
|
|
||||||
type: # linear | dynamic
|
|
||||||
factor: # float
|
|
||||||
|
|
||||||
# optional overrides to the bnb 4bit quantization configuration
|
|
||||||
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
|
||||||
bnb_config_kwargs:
|
|
||||||
# These are default values
|
|
||||||
llm_int8_has_fp16_weight: false
|
|
||||||
bnb_4bit_quant_type: nf4
|
|
||||||
bnb_4bit_use_double_quant: true
|
|
||||||
|
|
||||||
|
|
||||||
# Whether you are training a 4-bit GPTQ quantized model
|
# Whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq: true
|
||||||
@@ -550,14 +476,9 @@ tf32: true # require >=ampere
|
|||||||
bfloat16: true # require >=ampere
|
bfloat16: true # require >=ampere
|
||||||
float16: true
|
float16: true
|
||||||
|
|
||||||
# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
|
|
||||||
gpu_memory_limit: 20GiB
|
|
||||||
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
|
|
||||||
lora_on_cpu: true
|
|
||||||
|
|
||||||
# A list of one or more datasets to finetune the model with
|
# A list of one or more datasets to finetune the model with
|
||||||
datasets:
|
datasets:
|
||||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
|
||||||
- path: vicgalle/alpaca-gpt4
|
- path: vicgalle/alpaca-gpt4
|
||||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||||
@@ -565,12 +486,9 @@ datasets:
|
|||||||
data_files: # Optional[str] path to source data files
|
data_files: # Optional[str] path to source data files
|
||||||
shards: # Optional[int] number of shards to split data into
|
shards: # Optional[int] number of shards to split data into
|
||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
train_on_split: train # Optional[str] name of dataset split to load from
|
|
||||||
|
|
||||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
field_human: # Optional[str]. Human key to use for conversation.
|
|
||||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
|
||||||
|
|
||||||
# Custom user prompt
|
# Custom user prompt
|
||||||
- path: repo
|
- path: repo
|
||||||
@@ -594,9 +512,6 @@ datasets:
|
|||||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||||
field:
|
field:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
|
||||||
# Currently supports chatml and inst (mistral/mixtral)
|
|
||||||
chat_template: chatml
|
|
||||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
@@ -639,17 +554,10 @@ eval_sample_packing:
|
|||||||
sample_packing_eff_est:
|
sample_packing_eff_est:
|
||||||
total_num_tokens:
|
total_num_tokens:
|
||||||
|
|
||||||
# Passed through to transformers when loading the model when launched without accelerate
|
|
||||||
# Use `sequential` when training w/ model parallelism to limit memory
|
|
||||||
device_map:
|
|
||||||
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
|
||||||
max_memory:
|
|
||||||
|
|
||||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||||
adapter: lora
|
adapter: lora
|
||||||
# If you already have a lora model trained that you want to load, put that here.
|
# If you already have a lora model trained that you want to load, put that here.
|
||||||
# This means after training, if you want to test the model, you should set this to the value of `output_dir`.
|
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
||||||
# Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`.
|
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
# LoRA hyperparameters
|
# LoRA hyperparameters
|
||||||
@@ -676,6 +584,10 @@ lora_modules_to_save:
|
|||||||
# - embed_tokens
|
# - embed_tokens
|
||||||
# - lm_head
|
# - lm_head
|
||||||
|
|
||||||
|
# Once you complete training, the model will be saved to the following directory.
|
||||||
|
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
||||||
|
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
||||||
|
lora_out_dir:
|
||||||
lora_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
|
|
||||||
# ReLoRA configuration
|
# ReLoRA configuration
|
||||||
@@ -685,13 +597,11 @@ relora_warmup_steps: # Number of per-restart warmup steps
|
|||||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||||
|
|
||||||
# wandb configuration if you're using it
|
# wandb configuration if you're using it
|
||||||
# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
|
||||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||||
wandb_project: # Your wandb project name
|
wandb_project: # Your wandb project name
|
||||||
wandb_entity: # A wandb Team name if using a Team
|
wandb_entity: # A wandb Team name if using a Team
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name: # Set the name of your wandb run
|
wandb_run_id: # Set the name of your wandb run
|
||||||
wandb_run_id: # Set the ID of your wandb run
|
|
||||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||||
|
|
||||||
# Where to save the full-finetuned model to
|
# Where to save the full-finetuned model to
|
||||||
@@ -709,16 +619,13 @@ gradient_accumulation_steps: 1
|
|||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
eval_batch_size:
|
eval_batch_size:
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
warmup_steps: 100 # cannot use with warmup_ratio
|
warmup_steps: 100
|
||||||
warmup_ratio: 0.05 # cannot use with warmup_steps
|
|
||||||
learning_rate: 0.00003
|
learning_rate: 0.00003
|
||||||
lr_quadratic_warmup:
|
lr_quadratic_warmup:
|
||||||
logging_steps:
|
logging_steps:
|
||||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
|
||||||
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
|
||||||
save_strategy: # Set to `no` to skip checkpoint saves
|
save_strategy: # Set to `no` to skip checkpoint saves
|
||||||
save_steps: # Leave empty to save at each epoch
|
save_steps: # Leave empty to save at each epoch
|
||||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||||
save_total_limit: # Checkpoints saved at a time
|
save_total_limit: # Checkpoints saved at a time
|
||||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||||
# if both are set, num_epochs will not be guaranteed.
|
# if both are set, num_epochs will not be guaranteed.
|
||||||
@@ -728,9 +635,6 @@ max_steps:
|
|||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
|
||||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
|
||||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
|
||||||
|
|
||||||
# Save model as safetensors (require safetensors package)
|
# Save model as safetensors (require safetensors package)
|
||||||
save_safetensors:
|
save_safetensors:
|
||||||
|
|
||||||
@@ -743,9 +647,6 @@ group_by_length: false
|
|||||||
|
|
||||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
|
||||||
# gradient_checkpointing_kwargs:
|
|
||||||
# use_reentrant: false
|
|
||||||
|
|
||||||
# Stop training after this many evaluation losses have increased in a row
|
# Stop training after this many evaluation losses have increased in a row
|
||||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||||
@@ -800,7 +701,7 @@ max_grad_norm:
|
|||||||
# Augmentation techniques
|
# Augmentation techniques
|
||||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||||
# currently only supported on Llama and Mistral
|
# currently only supported on Llama and Mistral
|
||||||
neftune_noise_alpha:
|
noisy_embedding_alpha:
|
||||||
|
|
||||||
# Whether to bettertransformers
|
# Whether to bettertransformers
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
@@ -815,6 +716,15 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
|||||||
# Whether to use scaled-dot-product attention
|
# Whether to use scaled-dot-product attention
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
sdp_attention:
|
sdp_attention:
|
||||||
|
# Landmark attention (only llama)
|
||||||
|
landmark_attention:
|
||||||
|
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||||
|
# LLaMA only
|
||||||
|
xpos_rope:
|
||||||
|
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||||
|
rope_scaling:
|
||||||
|
type: # linear | dynamic
|
||||||
|
factor: # float
|
||||||
|
|
||||||
# Resume from a specific checkpoint dir
|
# Resume from a specific checkpoint dir
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
@@ -937,9 +847,8 @@ accelerate launch -m axolotl.cli.train your_config.yml
|
|||||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||||
This is recommended for large datasets.
|
This is recommended for large datasets.
|
||||||
|
|
||||||
- Set `dataset_prepared_path:` to a local folder for saving and loading pre-tokenized dataset.
|
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
||||||
- (Optional): Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
- Use `--debug` to see preprocessed examples.
|
||||||
- (Optional): Use `--debug` to see preprocessed examples.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m axolotl.cli.preprocess your_config.yml
|
python -m axolotl.cli.preprocess your_config.yml
|
||||||
@@ -982,40 +891,19 @@ fsdp_config:
|
|||||||
|
|
||||||
##### Weights & Biases Logging
|
##### Weights & Biases Logging
|
||||||
|
|
||||||
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
|
||||||
|
|
||||||
- wandb options
|
- wandb options
|
||||||
```yaml
|
```yaml
|
||||||
wandb_mode:
|
wandb_mode:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
```
|
```
|
||||||
|
|
||||||
##### Special Tokens
|
### Inference
|
||||||
|
|
||||||
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
Pass the appropriate flag to the train command:
|
||||||
|
|
||||||
```yml
|
|
||||||
special_tokens:
|
|
||||||
bos_token: "<s>"
|
|
||||||
eos_token: "</s>"
|
|
||||||
unk_token: "<unk>"
|
|
||||||
tokens: # these are delimiters
|
|
||||||
- "<|im_start|>"
|
|
||||||
- "<|im_end|>"
|
|
||||||
```
|
|
||||||
|
|
||||||
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
|
|
||||||
|
|
||||||
### Inference Playground
|
|
||||||
|
|
||||||
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.
|
|
||||||
The config file is the same config file used for training.
|
|
||||||
|
|
||||||
Pass the appropriate flag to the inference command, depending upon what kind of model was trained:
|
|
||||||
|
|
||||||
- Pretrained LORA:
|
- Pretrained LORA:
|
||||||
```bash
|
```bash
|
||||||
@@ -1030,10 +918,6 @@ Pass the appropriate flag to the inference command, depending upon what kind of
|
|||||||
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
||||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||||
```
|
```
|
||||||
-- With gradio hosting
|
|
||||||
```bash
|
|
||||||
python -m axolotl.cli.inference examples/your_config.yml --gradio
|
|
||||||
```
|
|
||||||
|
|
||||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
||||||
|
|
||||||
@@ -1041,20 +925,18 @@ Please use `--sample_packing False` if you have it on and receive the error simi
|
|||||||
|
|
||||||
### Merge LORA to base
|
### Merge LORA to base
|
||||||
|
|
||||||
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
|
Add below flag to train command above
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
|
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||||
```
|
```
|
||||||
|
|
||||||
You may need to use the `gpu_memory_limit` and/or `lora_on_cpu` config options to avoid running out of memory. If you still run out of CUDA memory, you can try to merge in system RAM with
|
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
|
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
|
||||||
```
|
```
|
||||||
|
|
||||||
although this will be very slow, and using the config options above are recommended instead.
|
|
||||||
|
|
||||||
## Common Errors 🧰
|
## Common Errors 🧰
|
||||||
|
|
||||||
See also the [FAQ's](./docs/faq.md).
|
See also the [FAQ's](./docs/faq.md).
|
||||||
@@ -1067,10 +949,6 @@ Please reduce any below
|
|||||||
- `gradient_accumulation_steps`
|
- `gradient_accumulation_steps`
|
||||||
- `sequence_len`
|
- `sequence_len`
|
||||||
|
|
||||||
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
|
|
||||||
|
|
||||||
Using adamw_bnb_8bit might also save you some memory.
|
|
||||||
|
|
||||||
> `failed (exitcode: -9)`
|
> `failed (exitcode: -9)`
|
||||||
|
|
||||||
Usually means your system has run out of system memory.
|
Usually means your system has run out of system memory.
|
||||||
@@ -1093,20 +971,6 @@ It's safe to ignore it.
|
|||||||
|
|
||||||
See the [NCCL](docs/nccl.md) guide.
|
See the [NCCL](docs/nccl.md) guide.
|
||||||
|
|
||||||
|
|
||||||
### Tokenization Mismatch b/w Inference & Training
|
|
||||||
|
|
||||||
For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks.
|
|
||||||
|
|
||||||
If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following:
|
|
||||||
|
|
||||||
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
|
|
||||||
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
|
|
||||||
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
|
|
||||||
4. As an additional troubleshooting step, you can look look at the token ids between 1 and 2 to make sure they are identical.
|
|
||||||
|
|
||||||
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
|
|
||||||
|
|
||||||
## Need help? 🙋♂️
|
## Need help? 🙋♂️
|
||||||
|
|
||||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||||
|
|||||||
@@ -24,6 +24,16 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"warmup_type": "linear",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -28,6 +28,16 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"warmup_type": "linear",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -32,6 +32,16 @@
|
|||||||
"weight_decay": "auto"
|
"weight_decay": "auto"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": "auto",
|
||||||
|
"warmup_max_lr": "auto",
|
||||||
|
"warmup_num_steps": "auto",
|
||||||
|
"warmup_type": "linear",
|
||||||
|
"total_num_steps": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
"gradient_accumulation_steps": "auto",
|
"gradient_accumulation_steps": "auto",
|
||||||
"train_batch_size": "auto",
|
"train_batch_size": "auto",
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
|||||||
@@ -1,39 +0,0 @@
|
|||||||
{
|
|
||||||
"zero_optimization": {
|
|
||||||
"stage": 3,
|
|
||||||
"overlap_comm": true,
|
|
||||||
"contiguous_gradients": true,
|
|
||||||
"sub_group_size": 0,
|
|
||||||
"reduce_bucket_size": "auto",
|
|
||||||
"stage3_prefetch_bucket_size": "auto",
|
|
||||||
"stage3_param_persistence_threshold": "auto",
|
|
||||||
"stage3_max_live_parameters": 0,
|
|
||||||
"stage3_max_reuse_distance": 0,
|
|
||||||
"stage3_gather_16bit_weights_on_model_save": true
|
|
||||||
},
|
|
||||||
"bf16": {
|
|
||||||
"enabled": true
|
|
||||||
},
|
|
||||||
"fp16": {
|
|
||||||
"enabled": "auto",
|
|
||||||
"auto_cast": false,
|
|
||||||
"loss_scale": 0,
|
|
||||||
"initial_scale_power": 32,
|
|
||||||
"loss_scale_window": 1000,
|
|
||||||
"hysteresis": 2,
|
|
||||||
"min_loss_scale": 1
|
|
||||||
},
|
|
||||||
"optimizer": {
|
|
||||||
"type": "AdamW",
|
|
||||||
"params": {
|
|
||||||
"lr": "auto",
|
|
||||||
"betas": "auto",
|
|
||||||
"eps": "auto",
|
|
||||||
"weight_decay": "auto"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gradient_accumulation_steps": "auto",
|
|
||||||
"train_batch_size": "auto",
|
|
||||||
"train_micro_batch_size_per_gpu": "auto",
|
|
||||||
"wall_clock_breakdown": false
|
|
||||||
}
|
|
||||||
@@ -10,7 +10,7 @@ ARG PYTORCH_VERSION="2.0.1"
|
|||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
apt-get install -y vim curl
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
@@ -19,15 +19,13 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
|||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
|
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn]; \
|
pip install -e .[flash-attn]; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
|
||||||
RUN pip install pytest
|
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
git config --get remote.origin.fetch
|
git config --get remote.origin.fetch
|
||||||
|
|||||||
@@ -10,10 +10,8 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
|
|||||||
ARG PYTHON_VERSION="3.9"
|
ARG PYTHON_VERSION="3.9"
|
||||||
ARG PYTORCH_VERSION="2.0.1"
|
ARG PYTORCH_VERSION="2.0.1"
|
||||||
ARG CUDA="118"
|
ARG CUDA="118"
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
|
||||||
|
|
||||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
||||||
@@ -29,9 +27,47 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||||
|
|
||||||
RUN git lfs install --skip-repo && \
|
FROM base-builder AS deepspeed-builder
|
||||||
pip3 install awscli && \
|
|
||||||
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||||
|
cd DeepSpeed && \
|
||||||
|
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel
|
||||||
|
|
||||||
|
FROM base-builder AS bnb-builder
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
ARG CUDA="118"
|
||||||
|
ENV CUDA=$CUDA
|
||||||
|
ARG MAX_JOBS="-1"
|
||||||
|
ENV MAX_JOBS=$MAX_JOBS
|
||||||
|
|
||||||
|
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
|
||||||
|
cd bitsandbytes && \
|
||||||
|
CUDA_VERSION=$CUDA make cuda11x && \
|
||||||
|
python setup.py bdist_wheel
|
||||||
|
|
||||||
|
FROM base-builder
|
||||||
|
|
||||||
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
|
RUN mkdir -p /workspace/builds
|
||||||
|
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
||||||
|
|
||||||
|
RUN mkdir -p /workspace/wheels/bitsandbytes
|
||||||
|
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
||||||
|
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
|
||||||
|
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
|
||||||
|
|
||||||
|
RUN pip3 install wheels/deepspeed-*.whl
|
||||||
|
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
||||||
|
RUN git lfs install --skip-repo
|
||||||
|
RUN pip3 install awscli && \
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ FROM winglian/axolotl:$BASE_TAG
|
|||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
|
||||||
|
|
||||||
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
||||||
|
|
||||||
|
|||||||
35
docs/rlhf.md
35
docs/rlhf.md
@@ -1,35 +0,0 @@
|
|||||||
# RLHF (Beta)
|
|
||||||
|
|
||||||
### Overview
|
|
||||||
|
|
||||||
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
|
||||||
feedback. Various methods include, but not limited to:
|
|
||||||
|
|
||||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
|
||||||
- Direct Preference Optimization (DPO)
|
|
||||||
- Identity Preference Optimization (IPO)
|
|
||||||
|
|
||||||
|
|
||||||
### RLHF using Axolotl
|
|
||||||
|
|
||||||
[!IMPORTANT]
|
|
||||||
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
|
||||||
|
|
||||||
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
|
|
||||||
|
|
||||||
#### DPO
|
|
||||||
```yaml
|
|
||||||
rl: true
|
|
||||||
datasets:
|
|
||||||
- path: Intel/orca_dpo_pairs
|
|
||||||
split: train
|
|
||||||
type: intel_apply_chatml
|
|
||||||
- path: argilla/ultrafeedback-binarized-preferences
|
|
||||||
split: train
|
|
||||||
type: argilla_apply_chatml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### IPO
|
|
||||||
```yaml
|
|
||||||
rl: ipo
|
|
||||||
```
|
|
||||||
@@ -14,7 +14,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_prepared_run
|
dataset_prepared_path: last_prepared_run
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
|
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
output_dir: btlm-out
|
output_dir: btlm-out
|
||||||
@@ -72,8 +72,8 @@ gptq_groupsize:
|
|||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
|
|
||||||
warmup_steps: 32
|
warmup_steps: 32
|
||||||
evals_per_epoch: 4
|
eval_steps:
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
save_total_limit:
|
save_total_limit:
|
||||||
|
|
||||||
debug:
|
debug:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
@@ -49,8 +49,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -54,8 +54,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -56,8 +56,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -54,8 +54,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -56,8 +56,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -54,8 +54,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -56,8 +56,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca:chat
|
type: alpaca:chat
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
@@ -51,8 +51,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 40
|
warmup_steps: 40
|
||||||
evals_per_epoch: 4
|
eval_steps: 5
|
||||||
saves_per_epoch: 1
|
save_steps: 43
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ datasets:
|
|||||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||||
type: "alpaca:chat"
|
type: "alpaca:chat"
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
# enable QLoRA
|
# enable QLoRA
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
@@ -80,8 +80,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 5
|
||||||
saves_per_epoch: 1
|
save_steps: 10
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.000001
|
weight_decay: 0.000001
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca:chat
|
type: alpaca:chat
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
adapter:
|
adapter:
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./falcon-7b
|
output_dir: ./falcon-7b
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
@@ -51,8 +51,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 40
|
warmup_steps: 40
|
||||||
evals_per_epoch: 4
|
eval_steps: 5
|
||||||
saves_per_epoch: 1
|
save_steps: 43
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
@@ -46,8 +46,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./jeopardy-bot-7b
|
output_dir: ./jeopardy-bot-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -42,8 +42,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
evals_per_epoch: 4
|
eval_steps: 110
|
||||||
saves_per_epoch: 1
|
save_steps: 660
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
output_dir: ./out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -58,9 +58,9 @@ flash_attn_fuse_qkv: false
|
|||||||
flash_attn_fuse_mlp: true
|
flash_attn_fuse_mlp: true
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
@@ -32,7 +32,7 @@ lora_target_linear:
|
|||||||
lora_fan_in_fan_out:
|
lora_fan_in_fan_out:
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./model-out
|
output_dir: ./model-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -62,8 +62,8 @@ flash_attention:
|
|||||||
sdp_attention:
|
sdp_attention:
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
evals_per_epoch: 4
|
eval_steps:
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -54,10 +54,10 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -56,9 +56,9 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
output_dir: ./relora-out
|
output_dir: ./relora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -35,7 +35,7 @@ relora_cpu_offload: false
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -60,8 +60,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps: 50
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
base_model: PY007/TinyLlama-1.1B-step-50K-105b
|
||||||
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
is_llama_derived_model: true
|
is_llama_derived_model: true
|
||||||
@@ -11,12 +12,11 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -54,11 +54,15 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
eval_table_size:
|
||||||
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
base_model: state-spaces/mamba-2.8b
|
|
||||||
model_type: MambaLMHeadModel
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
tokenizer_config: EleutherAI/gpt-neox-20b
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./out
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 2
|
|
||||||
optimizer: paged_adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 5e-5
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: true
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention:
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
tokens:
|
|
||||||
save_safetensors: False
|
|
||||||
@@ -11,18 +11,17 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
output_dir: ./out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
eval_sample_packing: false
|
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -47,10 +46,10 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,91 +0,0 @@
|
|||||||
base_model: mistralai/Mixtral-8x7B-v0.1
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./qlora-out
|
|
||||||
|
|
||||||
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
|
||||||
unfrozen_parameters:
|
|
||||||
# - lm_head.*
|
|
||||||
# - model.embed_tokens.*
|
|
||||||
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
|
|
||||||
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
|
|
||||||
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
|
||||||
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
|
||||||
|
|
||||||
model_config:
|
|
||||||
output_router_logits: true
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
#lora_target_modules:
|
|
||||||
# - gate
|
|
||||||
# - q_proj
|
|
||||||
# - k_proj
|
|
||||||
# - v_proj
|
|
||||||
# - o_proj
|
|
||||||
# - w1
|
|
||||||
# - w2
|
|
||||||
# - w3
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
|
||||||
loss_watchdog_patience: 3
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed: deepspeed/zero2.json
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.01
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -38,7 +38,7 @@ lora_target_modules:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
@@ -62,14 +62,11 @@ logging_steps: 1
|
|||||||
xformers_attention:
|
xformers_attention:
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
|
||||||
loss_watchdog_patience: 3
|
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project: mpt-alpaca-7b
|
wandb_project: mpt-alpaca-7b
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./mpt-alpaca-7b
|
output_dir: ./mpt-alpaca-7b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -44,8 +44,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
evals_per_epoch: 4
|
eval_steps: 110
|
||||||
saves_per_epoch: 1
|
save_steps: 660
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0001
|
weight_decay: 0.0001
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./openllama-out
|
output_dir: ./openllama-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -49,8 +49,8 @@ flash_attention: true
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -54,8 +54,8 @@ flash_attention: true
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ datasets:
|
|||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 1024
|
sequence_len: 1024
|
||||||
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -48,8 +48,8 @@ flash_attention: true
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
base_model: microsoft/phi-1_5
|
base_model: microsoft/phi-1_5
|
||||||
model_type: PhiForCausalLM
|
model_type: MixFormerSequentialForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
is_llama_derived_model: false
|
is_llama_derived_model: false
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -59,8 +59,8 @@ xformers_attention:
|
|||||||
flash_attention:
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -59,8 +59,8 @@ xformers_attention:
|
|||||||
flash_attention:
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./pythia-12b
|
output_dir: ./pythia-12b
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-alpaca-pythia
|
output_dir: ./lora-alpaca-pythia
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
@@ -33,5 +33,5 @@ early_stopping_patience:
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
evals_per_epoch: 4
|
eval_steps: 0.05
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
base_model: Qwen/Qwen-7B
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
is_qwen_derived_model: true
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./lora-out
|
|
||||||
|
|
||||||
sequence_len: 2048 # supports up to 8192
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len:
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention:
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
base_model: Qwen/Qwen-7B
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
is_qwen_derived_model: true
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./lora-out
|
|
||||||
|
|
||||||
sequence_len: 2048 # supports up to 8192
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len:
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention:
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
|
|||||||
wandb_project: redpajama-alpaca-3b
|
wandb_project: redpajama-alpaca-3b
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./redpajama-alpaca-3b
|
output_dir: ./redpajama-alpaca-3b
|
||||||
batch_size: 4
|
batch_size: 4
|
||||||
@@ -45,8 +45,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
evals_per_epoch: 4
|
eval_steps: 110
|
||||||
saves_per_epoch: 1
|
save_steps: 660
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0001
|
weight_decay: 0.0001
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project: lora-replit
|
wandb_project: lora-replit
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./lora-replit
|
output_dir: ./lora-replit
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
@@ -45,8 +45,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
evals_per_epoch: 4
|
eval_steps: 50
|
||||||
saves_per_epoch: 1
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
# Overview
|
|
||||||
|
|
||||||
This is a simple example of how to finetune TinyLlama1.1B using either lora or qlora:
|
|
||||||
|
|
||||||
LoRa:
|
|
||||||
|
|
||||||
```
|
|
||||||
accelerate launch -m axolotl.cli.train examples/tiny-llama/lora.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
qLoRa:
|
|
||||||
|
|
||||||
```
|
|
||||||
accelerate launch -m axolotl.cli.train examples/tiny-llama/qlora.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
Both take about 10 minutes to complete on a 4090.
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
|
||||||
|
|
||||||
model_type: LlamaForCausalLM
|
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
is_llama_derived_model: true
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
max_steps: 200
|
|
||||||
pretraining_dataset:
|
|
||||||
path: c4
|
|
||||||
name: en
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./model-out
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch:
|
|
||||||
eval_table_size:
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
|
||||||
model_type: LlamaForCausalLM
|
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
is_llama_derived_model: true
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./qlora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules:
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: paged_adamw_32bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
evals_per_epoch: 4
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -16,7 +16,7 @@ datasets:
|
|||||||
- openassistant_best_replies_train.jsonl
|
- openassistant_best_replies_train.jsonl
|
||||||
type: "completion"
|
type: "completion"
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.01
|
||||||
# enable QLoRA
|
# enable QLoRA
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
|
|||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_name:
|
wandb_run_id:
|
||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
@@ -78,8 +78,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
evals_per_epoch: 4
|
eval_steps: 50
|
||||||
saves_per_epoch: 1
|
save_steps: 50
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
# Overview
|
|
||||||
|
|
||||||
This is an example of a Yi-34B-Chat configuration. It demonstrates that it is possible to finetune a 34B model on a GPU with 24GB of VRAM.
|
|
||||||
|
|
||||||
Tested on an RTX 4090 with `python -m axolotl.cli.train examples/mistral/qlora.yml`, a single epoch of finetuning on the alpaca dataset using qlora runs in 47 mins, using 97% of available memory.
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
base_model: 01-ai/Yi-34B-Chat
|
|
||||||
model_type: LlamaForCausalLM
|
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
is_mistral_derived_model: false
|
|
||||||
is_llama_derived_model: true
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
sequence_len: 1024
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
flash_attention: true
|
|
||||||
special_tokens:
|
|
||||||
bos_token: "<|startoftext|>"
|
|
||||||
eos_token: "<|endoftext|>"
|
|
||||||
unk_token: "<unk>"
|
|
||||||
|
|
||||||
# Data
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
warmup_steps: 10
|
|
||||||
|
|
||||||
# Iterations
|
|
||||||
num_epochs: 1
|
|
||||||
|
|
||||||
# Evaluation
|
|
||||||
val_set_size: 0.1
|
|
||||||
evals_per_epoch: 5
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens: 128
|
|
||||||
eval_sample_packing: false
|
|
||||||
eval_batch_size: 1
|
|
||||||
|
|
||||||
# LoRA
|
|
||||||
output_dir: ./qlora-out
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
lora_target_modules:
|
|
||||||
|
|
||||||
# Sampling
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
# Batching
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
gradient_checkpointing: true
|
|
||||||
|
|
||||||
# wandb
|
|
||||||
wandb_project:
|
|
||||||
|
|
||||||
# Optimizer
|
|
||||||
optimizer: paged_adamw_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
# Misc
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
1
gitbook/README.md
Normal file
1
gitbook/README.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Page
|
||||||
4
gitbook/SUMMARY.md
Normal file
4
gitbook/SUMMARY.md
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# Table of contents
|
||||||
|
|
||||||
|
* [Page](README.md)
|
||||||
|
* [Small dev details](small-dev-details.md)
|
||||||
3
gitbook/small-dev-details.md
Normal file
3
gitbook/small-dev-details.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# Small dev details
|
||||||
|
|
||||||
|
/
|
||||||
@@ -1,22 +1,23 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
auto-gptq==0.5.1
|
torch==2.0.1
|
||||||
|
auto-gptq
|
||||||
packaging
|
packaging
|
||||||
peft==0.6.0
|
peft @ git+https://github.com/huggingface/peft.git
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
|
||||||
tokenizers==0.15.0
|
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate==0.24.1
|
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
||||||
deepspeed
|
deepspeed
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
datasets>=2.15.0
|
datasets
|
||||||
flash-attn==2.3.3
|
flash-attn>=2.3.0
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers==0.0.22
|
xformers>=0.0.22
|
||||||
optimum==1.13.2
|
optimum
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
@@ -29,13 +30,5 @@ scipy
|
|||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat==0.2.34
|
fschat==0.2.29
|
||||||
gradio==3.50.2
|
tensor_parallel
|
||||||
tensorboard
|
|
||||||
|
|
||||||
# remote filesystems
|
|
||||||
s3fs
|
|
||||||
gcsfs
|
|
||||||
# adlfs
|
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@main
|
|
||||||
|
|||||||
20
setup.py
20
setup.py
@@ -1,7 +1,5 @@
|
|||||||
"""setup.py for axolotl"""
|
"""setup.py for axolotl"""
|
||||||
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
@@ -24,13 +22,12 @@ def parse_requirements():
|
|||||||
# Handle standard packages
|
# Handle standard packages
|
||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
try:
|
# TODO(wing) remove once xformers release supports torch 2.1.0
|
||||||
torch_version = version("torch")
|
if "torch==2.1.0" in _install_requires:
|
||||||
if torch_version.startswith("2.1.1"):
|
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
|
||||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
_install_requires.append(
|
||||||
_install_requires.append("xformers==0.0.23")
|
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
|
||||||
except PackageNotFoundError:
|
)
|
||||||
pass
|
|
||||||
|
|
||||||
return _install_requires, _dependency_links
|
return _install_requires, _dependency_links
|
||||||
|
|
||||||
@@ -49,13 +46,10 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.3.3",
|
"flash-attn>=2.3.0",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed",
|
"deepspeed",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
|
||||||
"mamba-ssm==1.0.1",
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,25 +2,21 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
from accelerate.commands.config import config_args
|
from accelerate.commands.config import config_args
|
||||||
from art import text2art
|
from art import text2art
|
||||||
from datasets import concatenate_datasets, load_dataset
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
@@ -31,7 +27,6 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import prepare_optim_env
|
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -49,7 +44,7 @@ def print_axolotl_text_art(suffix=None):
|
|||||||
ascii_text = " axolotl"
|
ascii_text = " axolotl"
|
||||||
if suffix:
|
if suffix:
|
||||||
ascii_text += f" x {suffix}"
|
ascii_text += f" x {suffix}"
|
||||||
ascii_art = text2art(ascii_text, font=font)
|
ascii_art = text2art(" axolotl", font=font)
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
print(ascii_art)
|
print(ascii_art)
|
||||||
@@ -73,15 +68,14 @@ def do_merge_lora(
|
|||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
LOG.info("running merge of LoRA with base model")
|
LOG.info("running merge of LoRA with base model")
|
||||||
model = model.merge_and_unload(progressbar=True)
|
model = model.merge_and_unload()
|
||||||
model.to(dtype=cfg.torch_dtype)
|
model.to(dtype=torch.float16)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
str(Path(cfg.output_dir) / "merged"),
|
str(Path(cfg.output_dir) / "merged"),
|
||||||
safe_serialization=safe_serialization,
|
safe_serialization=safe_serialization,
|
||||||
progressbar=True,
|
|
||||||
)
|
)
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
|
|
||||||
@@ -106,7 +100,15 @@ def do_inference(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
if cfg.landmark_attention:
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||||
|
|
||||||
|
set_model_mem_id(model, tokenizer)
|
||||||
|
model.set_mem_cache_args(
|
||||||
|
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model.to(cfg.device)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -151,83 +153,6 @@ def do_inference(
|
|||||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||||
|
|
||||||
|
|
||||||
def do_inference_gradio(
|
|
||||||
*,
|
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs,
|
|
||||||
):
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
|
||||||
prompter = cli_args.prompter
|
|
||||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
|
||||||
|
|
||||||
for token, symbol in default_tokens.items():
|
|
||||||
# If the token isn't already specified in the config, add it
|
|
||||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
|
||||||
tokenizer.add_special_tokens({token: symbol})
|
|
||||||
|
|
||||||
prompter_module = None
|
|
||||||
if prompter:
|
|
||||||
prompter_module = getattr(
|
|
||||||
importlib.import_module("axolotl.prompters"), prompter
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
|
||||||
|
|
||||||
def generate(instruction):
|
|
||||||
if not instruction:
|
|
||||||
return
|
|
||||||
if prompter_module:
|
|
||||||
# pylint: disable=stop-iteration-return
|
|
||||||
prompt: str = next(
|
|
||||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = instruction.strip()
|
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
generation_config = GenerationConfig(
|
|
||||||
repetition_penalty=1.1,
|
|
||||||
max_new_tokens=1024,
|
|
||||||
temperature=0.9,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=40,
|
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
do_sample=True,
|
|
||||||
use_cache=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=False,
|
|
||||||
output_scores=False,
|
|
||||||
)
|
|
||||||
streamer = TextIteratorStreamer(tokenizer)
|
|
||||||
generation_kwargs = {
|
|
||||||
"inputs": batch["input_ids"].to(cfg.device),
|
|
||||||
"generation_config": generation_config,
|
|
||||||
"streamer": streamer,
|
|
||||||
}
|
|
||||||
|
|
||||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
all_text = ""
|
|
||||||
|
|
||||||
for new_text in streamer:
|
|
||||||
all_text += new_text
|
|
||||||
yield all_text
|
|
||||||
|
|
||||||
demo = gr.Interface(
|
|
||||||
fn=generate,
|
|
||||||
inputs="textbox",
|
|
||||||
outputs="text",
|
|
||||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
|
||||||
)
|
|
||||||
demo.queue().launch(show_api=False, share=True)
|
|
||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
def choose_config(path: Path):
|
||||||
yaml_files = list(path.glob("*.yml"))
|
yaml_files = list(path.glob("*.yml"))
|
||||||
|
|
||||||
@@ -284,8 +209,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
prepare_optim_env(cfg)
|
|
||||||
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
@@ -328,94 +251,6 @@ def load_datasets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_rl_datasets(
|
|
||||||
*,
|
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
|
|
||||||
) -> TrainDatasetMeta:
|
|
||||||
train_datasets: List[Any] = []
|
|
||||||
for i, ds_cfg in enumerate(cfg.datasets):
|
|
||||||
train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
|
|
||||||
# eval_dataset = load_dataset(
|
|
||||||
# cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
|
|
||||||
# )
|
|
||||||
eval_dataset = None
|
|
||||||
|
|
||||||
def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
|
||||||
if "system" in sample and sample["system"]:
|
|
||||||
sample["prompt"] = (
|
|
||||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
|
||||||
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample[
|
|
||||||
"prompt"
|
|
||||||
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
|
|
||||||
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
|
||||||
if "system" in sample and sample["system"]:
|
|
||||||
sample["prompt"] = (
|
|
||||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
|
||||||
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample[
|
|
||||||
"prompt"
|
|
||||||
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
|
||||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
def apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
|
||||||
if "system" in sample and sample["system"]:
|
|
||||||
sample["prompt"] = (
|
|
||||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
|
||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample[
|
|
||||||
"prompt"
|
|
||||||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
|
||||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
|
||||||
if "system" in sample and sample["system"]:
|
|
||||||
sample["prompt"] = (
|
|
||||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
|
||||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample[
|
|
||||||
"prompt"
|
|
||||||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
|
||||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
|
||||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
|
||||||
return sample
|
|
||||||
|
|
||||||
for i, data_set in enumerate(train_datasets):
|
|
||||||
_type = cfg.datasets[i]["type"]
|
|
||||||
ds_type_fn = locals()[_type]
|
|
||||||
train_datasets[i] = data_set.map(ds_type_fn)
|
|
||||||
train_dataset = concatenate_datasets(train_datasets)
|
|
||||||
|
|
||||||
# eval_dataset = eval_dataset.map(intel_apply_chatml)
|
|
||||||
|
|
||||||
total_num_steps = int(
|
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
total_num_steps=total_num_steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def check_accelerate_default_config():
|
def check_accelerate_default_config():
|
||||||
if Path(config_args.default_yaml_config_file).exists():
|
if Path(config_args.default_yaml_config_file).exists():
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
|
|||||||
@@ -6,16 +6,11 @@ from pathlib import Path
|
|||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
|
||||||
do_inference,
|
|
||||||
do_inference_gradio,
|
|
||||||
load_cfg,
|
|
||||||
print_axolotl_text_art,
|
|
||||||
)
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
@@ -26,10 +21,7 @@ def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
|||||||
)
|
)
|
||||||
parsed_cli_args.inference = True
|
parsed_cli_args.inference = True
|
||||||
|
|
||||||
if gradio:
|
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
|
||||||
else:
|
|
||||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -18,22 +18,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
parsed_cli_args.merge_lora = True
|
parsed_cli_args.merge_lora = True
|
||||||
|
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
|
||||||
parsed_cfg = load_cfg(
|
|
||||||
config,
|
|
||||||
merge_lora=True,
|
|
||||||
load_in_8bit=False,
|
|
||||||
load_in_4bit=False,
|
|
||||||
flash_attention=False,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not parsed_cfg.lora_model_dir and parsed_cfg.output_dir:
|
|
||||||
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
|
|
||||||
if not Path(parsed_cfg.lora_model_dir).exists():
|
|
||||||
raise ValueError(
|
|
||||||
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
|
|
||||||
)
|
|
||||||
|
|
||||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from axolotl.cli import (
|
|||||||
check_user_token,
|
check_user_token,
|
||||||
load_cfg,
|
load_cfg,
|
||||||
load_datasets,
|
load_datasets,
|
||||||
load_rl_datasets,
|
|
||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
@@ -23,18 +22,15 @@ LOG = logging.getLogger("axolotl.cli.train")
|
|||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
if parsed_cfg.rl:
|
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
|
||||||
else:
|
|
||||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
|
||||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,37 +6,35 @@ import abc
|
|||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import tensor_parallel as tp
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
|
||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_pt_utils import SequentialDistributedSampler
|
||||||
from trl import DPOTrainer
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
LossWatchDogCallback,
|
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import (
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
MambaDataCollator,
|
from axolotl.utils.distributed import is_distributed
|
||||||
)
|
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler
|
|
||||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -53,19 +51,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
Extend the base TrainingArguments for axolotl helpers
|
Extend the base TrainingArguments for axolotl helpers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: Optional[str] = field(
|
|
||||||
default=None, metadata={"help": "HF model configuration model_type."}
|
|
||||||
)
|
|
||||||
lr_quadratic_warmup: bool = field(
|
lr_quadratic_warmup: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||||
)
|
)
|
||||||
pretraining: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
sample_packing: bool = field(
|
sample_packing: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use sample packing for efficient training."},
|
metadata={"help": "Use sample packing for efficient training."},
|
||||||
@@ -115,9 +104,8 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
bench_source_max_len: int = field(
|
bench_source_max_len: int = field(
|
||||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||||
)
|
)
|
||||||
dataloader_prefetch_factor: Optional[int] = field(
|
tensor_parallel: bool = field(
|
||||||
default=None,
|
default=False, metadata={"help": "Use tensor parallelism to train"}
|
||||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -127,7 +115,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
args = None # type: AxolotlTrainingArguments
|
args = None # type: AxolotlTrainingArguments
|
||||||
tag_names = ["axolotl"]
|
|
||||||
|
|
||||||
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
||||||
self.num_epochs = num_epochs
|
self.num_epochs = num_epochs
|
||||||
@@ -163,102 +150,70 @@ class AxolotlTrainer(Trainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.world_size > 1 and self.args.sample_packing:
|
||||||
return MultipackBatchSampler(
|
return DistributedSampler(
|
||||||
RandomSampler(self.train_dataset),
|
self.train_dataset,
|
||||||
self.args.train_batch_size,
|
num_replicas=self.args.world_size,
|
||||||
drop_last=True,
|
rank=self.args.process_index,
|
||||||
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
seed=self.args.seed,
|
||||||
lengths=(
|
|
||||||
self.train_dataset.data.column("position_ids")
|
|
||||||
.to_pandas()
|
|
||||||
.apply(lambda x: x[-1] + 1)
|
|
||||||
.values
|
|
||||||
),
|
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
|
||||||
)
|
)
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def _get_eval_sampler(
|
def _get_eval_sampler(
|
||||||
self, eval_dataset: Dataset
|
self, eval_dataset: Dataset
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
if (
|
||||||
return MultipackBatchSampler(
|
self.args.world_size > 1
|
||||||
SequentialSampler(eval_dataset),
|
and self.args.sample_packing
|
||||||
self.args.per_device_eval_batch_size,
|
and self.args.eval_sample_packing is not False
|
||||||
drop_last=True,
|
):
|
||||||
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
return SequentialDistributedSampler(
|
||||||
lengths=(
|
eval_dataset,
|
||||||
eval_dataset.data.column("position_ids")
|
num_replicas=self.args.world_size,
|
||||||
.to_pandas()
|
rank=self.args.process_index,
|
||||||
.apply(lambda x: x[-1] + 1)
|
batch_size=self.args.per_device_eval_batch_size,
|
||||||
.values
|
|
||||||
),
|
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
|
||||||
)
|
)
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.sample_packing:
|
||||||
train_dataset = self.train_dataset
|
train_sampler = self._get_train_sampler()
|
||||||
train_dataset = train_dataset.remove_columns(["length"])
|
return self.accelerator.prepare(
|
||||||
data_collator = self.data_collator
|
MultipackDistributedDataloader(
|
||||||
dataloader_params = {
|
self.train_dataset,
|
||||||
"batch_size": self._train_batch_size,
|
batch_size=self._train_batch_size,
|
||||||
"collate_fn": data_collator,
|
seq_max_length=self.args.max_seq_length,
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
collate_fn=self.data_collator,
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
sampler=train_sampler,
|
||||||
}
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
if self.args.dataloader_prefetch_factor:
|
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||||
dataloader_params[
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
"prefetch_factor"
|
num_epochs=self.num_epochs,
|
||||||
] = self.args.dataloader_prefetch_factor
|
)
|
||||||
|
|
||||||
sampler = self._get_train_sampler()
|
|
||||||
if isinstance(sampler, BatchSampler):
|
|
||||||
dataloader_params["batch_sampler"] = sampler
|
|
||||||
del dataloader_params["batch_size"]
|
|
||||||
else:
|
|
||||||
dataloader_params["sampler"] = sampler
|
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
|
||||||
return self.accelerator.prepare_data_loader(
|
|
||||||
DataLoader(train_dataset, **dataloader_params)
|
|
||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
def get_eval_dataloader(
|
||||||
|
self, eval_dataset: Optional[Dataset] = None
|
||||||
|
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
eval_dataset = (
|
eval_dataset = (
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
return self.accelerator.prepare(
|
||||||
data_collator = self.data_collator
|
MultipackDistributedDataloader(
|
||||||
dataloader_params = {
|
eval_dataset,
|
||||||
"batch_size": self.args.eval_batch_size,
|
batch_size=self.args.eval_batch_size,
|
||||||
"collate_fn": data_collator,
|
seq_max_length=self.args.max_seq_length,
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
collate_fn=self.data_collator,
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
sampler=eval_sampler,
|
||||||
}
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
if self.args.dataloader_prefetch_factor:
|
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||||
dataloader_params[
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
"prefetch_factor"
|
num_epochs=self.num_epochs,
|
||||||
] = self.args.dataloader_prefetch_factor
|
)
|
||||||
|
|
||||||
if isinstance(eval_sampler, BatchSampler):
|
|
||||||
dataloader_params["batch_sampler"] = eval_sampler
|
|
||||||
del dataloader_params["batch_size"]
|
|
||||||
else:
|
|
||||||
dataloader_params["sampler"] = eval_sampler
|
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
|
||||||
return self.accelerator.prepare_data_loader(
|
|
||||||
DataLoader(eval_dataset, **dataloader_params)
|
|
||||||
)
|
)
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
@@ -272,15 +227,13 @@ class AxolotlTrainer(Trainer):
|
|||||||
def get_bench_dataloader(
|
def get_bench_dataloader(
|
||||||
self,
|
self,
|
||||||
bench_dataset: Dataset,
|
bench_dataset: Dataset,
|
||||||
) -> DataLoader:
|
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
dataloader_params = {
|
dataloader_params = {
|
||||||
"batch_size": self.args.eval_batch_size,
|
"batch_size": self.args.eval_batch_size,
|
||||||
"collate_fn": self.bench_data_collator,
|
"collate_fn": self.bench_data_collator,
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
}
|
}
|
||||||
if self.args.dataloader_prefetch_factor:
|
|
||||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
|
||||||
|
|
||||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||||
@@ -298,60 +251,13 @@ class AxolotlTrainer(Trainer):
|
|||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
if isinstance(tag_names, str):
|
if self.args.tensor_parallel:
|
||||||
tag_names = [tag_names]
|
model = tp.tensor_parallel(model, distributed=is_distributed())
|
||||||
|
model.hf_device_map = tp.infer_sharded_device_map(model)
|
||||||
if kwargs is not None:
|
else:
|
||||||
if "tags" not in kwargs:
|
model = super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
kwargs["tags"] = tag_names
|
return model
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
|
||||||
kwargs["tags"].extend(tag_names)
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
|
||||||
tag_names.append(kwargs["tags"])
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
@wraps(Trainer.push_to_hub)
|
|
||||||
def push_to_hub(self, *args, **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
|
||||||
"""
|
|
||||||
kwargs = self._sanitize_kwargs_for_tagging(
|
|
||||||
tag_names=self.tag_names, kwargs=kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
|
||||||
"""
|
|
||||||
Mamba specific trainer to handle loss calculation
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "mamba"]
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
inputs,
|
|
||||||
return_outputs=False, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
input_ids = inputs.pop("input_ids")
|
|
||||||
lm_logits = model(input_ids).logits
|
|
||||||
|
|
||||||
labels = input_ids.to(lm_logits.device)
|
|
||||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
|
||||||
labels = labels[:, 1:].contiguous()
|
|
||||||
|
|
||||||
loss_fct = torch.nn.CrossEntropyLoss()
|
|
||||||
lm_loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
return lm_loss
|
|
||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||||
@@ -359,8 +265,6 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
|||||||
Trainer subclass that uses the OneCycleLR scheduler
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "onecycle"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
@@ -390,8 +294,6 @@ class ReLoRATrainer(AxolotlTrainer):
|
|||||||
Trainer subclass that uses the OneCycleLR scheduler
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "relora"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
@@ -427,21 +329,12 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
_train_dataset = None
|
_train_dataset = None
|
||||||
_eval_dataset = None
|
_eval_dataset = None
|
||||||
_model_ref = None
|
|
||||||
|
|
||||||
def __init__(self, cfg, model, tokenizer):
|
def __init__(self, cfg, model, tokenizer):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
@property
|
|
||||||
def model_ref(self):
|
|
||||||
return self._model_ref
|
|
||||||
|
|
||||||
@model_ref.setter
|
|
||||||
def model_ref(self, model):
|
|
||||||
self._model_ref = model
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def train_dataset(self):
|
def train_dataset(self):
|
||||||
return self._train_dataset
|
return self._train_dataset
|
||||||
@@ -491,7 +384,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return trainer_kwargs, trainer_cls
|
return trainer_kwargs, trainer_cls
|
||||||
|
|
||||||
def hook_post_create_trainer(self, trainer):
|
def hook_post_create_trainer(self, trainer):
|
||||||
# TODO
|
if self.cfg.tensor_parallel:
|
||||||
|
trainer.model = trainer.accelerator.prepare_model(
|
||||||
|
trainer.model, device_placement=True
|
||||||
|
)
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
@@ -513,9 +409,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.loss_watchdog_threshold is not None:
|
|
||||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
@@ -544,19 +437,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return OneCycleLRSchedulerTrainer
|
return OneCycleLRSchedulerTrainer
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
return ReLoRATrainer
|
return ReLoRATrainer
|
||||||
if self.cfg.model_config_type == "mamba":
|
|
||||||
return AxolotlMambaTrainer
|
|
||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
warmup_steps = None
|
warmup_steps = (
|
||||||
if self.cfg.warmup_steps is not None:
|
self.cfg.warmup_steps
|
||||||
warmup_steps = self.cfg.warmup_steps
|
if self.cfg.warmup_steps is not None
|
||||||
elif self.cfg.warmup_ratio is not None:
|
else min(int(0.03 * total_num_steps), 100)
|
||||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
)
|
||||||
else:
|
|
||||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
|
||||||
|
|
||||||
logging_steps = (
|
logging_steps = (
|
||||||
self.cfg.logging_steps
|
self.cfg.logging_steps
|
||||||
if self.cfg.logging_steps is not None
|
if self.cfg.logging_steps is not None
|
||||||
@@ -582,14 +470,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"gradient_checkpointing"
|
"gradient_checkpointing"
|
||||||
] = self.cfg.gradient_checkpointing
|
] = self.cfg.gradient_checkpointing
|
||||||
if self.cfg.gradient_checkpointing_kwargs:
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"gradient_checkpointing_kwargs"
|
|
||||||
] = self.cfg.gradient_checkpointing_kwargs
|
|
||||||
else:
|
|
||||||
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
|
|
||||||
"use_reentrant": False
|
|
||||||
}
|
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
@@ -617,12 +497,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||||
training_arguments_kwargs["push_to_hub"] = True
|
training_arguments_kwargs["push_to_hub"] = True
|
||||||
training_arguments_kwargs["hub_private_repo"] = True
|
training_arguments_kwargs["hub_private_repo"] = True
|
||||||
training_arguments_kwargs["hub_always_push"] = True
|
|
||||||
|
|
||||||
if self.cfg.hub_strategy:
|
if self.cfg.hub_strategy:
|
||||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||||
|
|
||||||
if self.cfg.save_safetensors is not None:
|
if self.cfg.save_safetensors:
|
||||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.cfg.sample_packing_eff_est:
|
if self.cfg.sample_packing_eff_est:
|
||||||
@@ -630,29 +509,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
"sample_packing_efficiency"
|
"sample_packing_efficiency"
|
||||||
] = self.cfg.sample_packing_eff_est
|
] = self.cfg.sample_packing_eff_est
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
if self.cfg.eval_steps:
|
||||||
training_arguments_kwargs[
|
|
||||||
"dataloader_pin_memory"
|
|
||||||
] = self.cfg.dataloader_pin_memory
|
|
||||||
if self.cfg.dataloader_num_workers is not None:
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"dataloader_num_workers"
|
|
||||||
] = self.cfg.dataloader_num_workers
|
|
||||||
if self.cfg.dataloader_prefetch_factor is not None:
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"dataloader_prefetch_factor"
|
|
||||||
] = self.cfg.dataloader_prefetch_factor
|
|
||||||
|
|
||||||
if self.cfg.val_set_size == 0:
|
|
||||||
# no eval set, so don't eval
|
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
|
||||||
elif self.cfg.eval_steps:
|
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
elif self.cfg.evaluation_strategy:
|
elif self.cfg.evaluation_strategy:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"evaluation_strategy"
|
"evaluation_strategy"
|
||||||
] = self.cfg.evaluation_strategy
|
] = self.cfg.evaluation_strategy
|
||||||
|
elif self.cfg.val_set_size == 0:
|
||||||
|
# no eval set, so don't eval
|
||||||
|
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||||
else:
|
else:
|
||||||
# we have an eval set, but no steps defined, default to use epoch
|
# we have an eval set, but no steps defined, default to use epoch
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||||
@@ -740,7 +606,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
||||||
training_arguments_kwargs["run_name"] = (
|
training_arguments_kwargs["run_name"] = (
|
||||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
self.cfg.wandb_run_id if self.cfg.use_wandb else None
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["optim"] = (
|
training_arguments_kwargs["optim"] = (
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||||
@@ -751,9 +617,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||||
else "cosine"
|
else "cosine"
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
|
||||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
|
||||||
)
|
|
||||||
training_arguments_kwargs["weight_decay"] = (
|
training_arguments_kwargs["weight_decay"] = (
|
||||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||||
)
|
)
|
||||||
@@ -761,26 +624,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["eval_sample_packing"] = (
|
training_arguments_kwargs["eval_sample_packing"] = (
|
||||||
self.cfg.sample_packing
|
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||||
if self.cfg.eval_sample_packing is not False
|
|
||||||
else False
|
|
||||||
)
|
)
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"sample_packing_seq_len_multiplier"
|
"sample_packing_seq_len_multiplier"
|
||||||
] = self.cfg.micro_batch_size
|
] = self.cfg.micro_batch_size
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||||
|
training_arguments_kwargs["tensor_parallel"] = self.cfg.tensor_parallel is True
|
||||||
|
|
||||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
|
||||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
|
||||||
|
|
||||||
if self.cfg.neftune_noise_alpha is not None:
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"neftune_noise_alpha"
|
|
||||||
] = self.cfg.neftune_noise_alpha
|
|
||||||
|
|
||||||
training_args = (
|
training_args = (
|
||||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
@@ -806,6 +661,26 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||||
|
|
||||||
|
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||||
|
add_mem_tokens,
|
||||||
|
get_mem_id,
|
||||||
|
set_model_mem_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
set_model_mem_id(self.model, self.tokenizer)
|
||||||
|
|
||||||
|
LOG.info("Adding landmark attention tokens to dataset")
|
||||||
|
|
||||||
|
for dataset in [self.train_dataset, self.eval_dataset]:
|
||||||
|
dataset = dataset.map(
|
||||||
|
partial(
|
||||||
|
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
|
||||||
|
),
|
||||||
|
batched=False,
|
||||||
|
num_proc=32,
|
||||||
|
)
|
||||||
|
|
||||||
trainer_cls = self._get_trainer_cls()
|
trainer_cls = self._get_trainer_cls()
|
||||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||||
trainer_kwargs, trainer_cls
|
trainer_kwargs, trainer_cls
|
||||||
@@ -815,7 +690,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
data_collator=DataCollatorForSeq2Seq(
|
||||||
|
self.tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
**data_collator_kwargs,
|
||||||
|
),
|
||||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@@ -829,115 +708,4 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||||
trainer.add_callback(callback)
|
trainer.add_callback(callback)
|
||||||
|
|
||||||
if self.cfg.deepspeed and self.cfg.sample_packing:
|
|
||||||
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
||||||
"train_micro_batch_size_per_gpu"
|
|
||||||
] = self.cfg.micro_batch_size
|
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
|
|
||||||
if training_args.pretraining:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
|
||||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
|
||||||
|
|
||||||
return BatchSamplerDataCollatorForSeq2Seq(
|
|
||||||
self.tokenizer,
|
|
||||||
return_tensors="pt",
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
||||||
"""
|
|
||||||
Trainer factory class for DPO Trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_callbacks(self):
|
|
||||||
callbacks = []
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
|
||||||
callbacks = []
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def build_training_arguments(self, total_num_steps):
|
|
||||||
training_args_kwargs = {}
|
|
||||||
for arg in [
|
|
||||||
"adam_beta1",
|
|
||||||
"adam_beta2",
|
|
||||||
"adam_epsilon",
|
|
||||||
"dataloader_num_workers",
|
|
||||||
"dataloader_pin_memory",
|
|
||||||
]:
|
|
||||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
|
||||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
|
||||||
training_args = TrainingArguments(
|
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
|
||||||
max_steps=total_num_steps,
|
|
||||||
remove_unused_columns=False,
|
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
|
||||||
learning_rate=self.cfg.learning_rate,
|
|
||||||
evaluation_strategy="no",
|
|
||||||
# eval_steps=self.cfg.eval_steps,
|
|
||||||
save_strategy="steps",
|
|
||||||
save_steps=self.cfg.save_steps,
|
|
||||||
output_dir=self.cfg.output_dir,
|
|
||||||
warmup_steps=self.cfg.warmup_steps,
|
|
||||||
bf16=True,
|
|
||||||
gradient_checkpointing=self.cfg.gradient_checkpointing,
|
|
||||||
gradient_checkpointing_kwargs={"use_reentrant": False},
|
|
||||||
logging_first_step=True,
|
|
||||||
logging_steps=1,
|
|
||||||
optim=self.cfg.optimizer,
|
|
||||||
save_total_limit=self.cfg.save_total_limit or 5,
|
|
||||||
**training_args_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return training_args
|
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
|
||||||
training_args = self.build_training_arguments(total_num_steps)
|
|
||||||
dpo_trainer_kwargs = {}
|
|
||||||
if self.cfg.rl == "ipo":
|
|
||||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
|
||||||
if self.cfg.dpo_label_smoothing:
|
|
||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
|
||||||
|
|
||||||
dpo_trainer = DPOTrainer(
|
|
||||||
self.model,
|
|
||||||
self.model_ref,
|
|
||||||
args=training_args,
|
|
||||||
beta=self.cfg.dpo_beta or 0.1,
|
|
||||||
train_dataset=self.train_dataset,
|
|
||||||
# eval_dataset=self.eval_dataset,
|
|
||||||
eval_dataset=None,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
max_length=self.cfg.sequence_len,
|
|
||||||
max_target_length=None,
|
|
||||||
max_prompt_length=self.cfg.sequence_len,
|
|
||||||
generate_during_eval=True,
|
|
||||||
**dpo_trainer_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return dpo_trainer
|
|
||||||
|
|
||||||
|
|
||||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
|
||||||
"""
|
|
||||||
HF Factory class for PPO Trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_callbacks(self):
|
|
||||||
callbacks = []
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
|
||||||
callbacks = []
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
|
||||||
# build PPOConfig
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -1,66 +0,0 @@
|
|||||||
"""
|
|
||||||
module for TRL PPO training
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
from trl import PPOTrainer
|
|
||||||
|
|
||||||
|
|
||||||
class TRLPPOTrainer(PPOTrainer):
|
|
||||||
"""
|
|
||||||
wrapper for ppo trainer to handle customizations
|
|
||||||
"""
|
|
||||||
|
|
||||||
def train(
|
|
||||||
self,
|
|
||||||
reward_pipe,
|
|
||||||
resume_from_checkpoint=None, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
generation_kwargs = {
|
|
||||||
"min_length": -1,
|
|
||||||
"top_k": 0.0,
|
|
||||||
"top_p": 1.0,
|
|
||||||
"do_sample": True,
|
|
||||||
"pad_token_id": self.tokenizer.eos_token_id,
|
|
||||||
"max_new_tokens": 32,
|
|
||||||
}
|
|
||||||
sent_kwargs = {
|
|
||||||
"return_all_scores": True,
|
|
||||||
"function_to_apply": "none",
|
|
||||||
"batch_size": 16,
|
|
||||||
}
|
|
||||||
|
|
||||||
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
|
||||||
enumerate(self.dataloader)
|
|
||||||
):
|
|
||||||
query_tensors = batch["input_ids"]
|
|
||||||
|
|
||||||
# generate model response
|
|
||||||
response_tensors, ref_response_tensors = self.generate(
|
|
||||||
query_tensors,
|
|
||||||
return_prompt=False,
|
|
||||||
generate_ref_response=True,
|
|
||||||
**generation_kwargs
|
|
||||||
)
|
|
||||||
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
|
||||||
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
|
||||||
|
|
||||||
# Compute sentiment score
|
|
||||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
|
||||||
pipe_outputs = reward_pipe(texts, **sent_kwargs)
|
|
||||||
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
|
||||||
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
|
|
||||||
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
|
|
||||||
ref_rewards = [
|
|
||||||
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
|
|
||||||
]
|
|
||||||
batch["ref_rewards"] = ref_rewards
|
|
||||||
|
|
||||||
# Run PPO step
|
|
||||||
stats = self.step(query_tensors, response_tensors, rewards)
|
|
||||||
self.log_stats(
|
|
||||||
stats,
|
|
||||||
batch,
|
|
||||||
rewards,
|
|
||||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
|
||||||
)
|
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
@@ -30,20 +30,14 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
self,
|
self,
|
||||||
prompt_tokenizer: PromptTokenizingStrategy,
|
prompt_tokenizer: PromptTokenizingStrategy,
|
||||||
dataset: IterableDataset,
|
dataset: IterableDataset,
|
||||||
process_count: Optional[int] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.prompt_tokenizer = prompt_tokenizer
|
self.prompt_tokenizer = prompt_tokenizer
|
||||||
self.process_count = process_count
|
|
||||||
super().__init__(self.process(dataset).data, **kwargs)
|
super().__init__(self.process(dataset).data, **kwargs)
|
||||||
|
|
||||||
def process(self, dataset):
|
def process(self, dataset):
|
||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
num_proc = (
|
num_proc = min(64, os.cpu_count())
|
||||||
min(64, self.process_count)
|
|
||||||
if self.process_count
|
|
||||||
else min(64, os.cpu_count())
|
|
||||||
)
|
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
Modeling module for Mamba models
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def fix_mamba_attn_for_loss():
|
|
||||||
from mamba_ssm.models import mixer_seq_simple
|
|
||||||
|
|
||||||
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
|
||||||
|
|
||||||
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
|
|
||||||
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
"""
|
|
||||||
HF Transformers MambaConfig
|
|
||||||
"""
|
|
||||||
from transformers import PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
class MambaConfig(PretrainedConfig):
|
|
||||||
"""
|
|
||||||
modeling configuration for state space model/mamba
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_type = "mamba"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size=50280,
|
|
||||||
d_model=2560,
|
|
||||||
n_layer=64,
|
|
||||||
rms_norm=True,
|
|
||||||
residual_in_fp32=True,
|
|
||||||
fused_add_norm=True,
|
|
||||||
pad_vocab_size_multiple=8,
|
|
||||||
pad_token_id=50277,
|
|
||||||
bos_token_id=0,
|
|
||||||
eos_token_id=0,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.d_model = d_model
|
|
||||||
self.n_layer = n_layer
|
|
||||||
self.rms_norm = rms_norm
|
|
||||||
self.residual_in_fp32 = residual_in_fp32
|
|
||||||
self.fused_add_norm = fused_add_norm
|
|
||||||
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
|
||||||
super().__init__(
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
bos_token_id=bos_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
import os
|
|
||||||
from collections import namedtuple
|
|
||||||
from functools import partial
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
|
|
||||||
from mamba_ssm.utils.generation import GenerationMixin
|
|
||||||
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
|
|
||||||
from axolotl.models.mamba.configuration_mamba import MambaConfig
|
|
||||||
|
|
||||||
|
|
||||||
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
d_model: int,
|
|
||||||
n_layer: int,
|
|
||||||
vocab_size: int,
|
|
||||||
initializer_cfg=None,
|
|
||||||
pad_vocab_size_multiple: int = 1,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
**backbone_kwargs,
|
|
||||||
) -> None:
|
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
super().__init__()
|
|
||||||
if vocab_size % pad_vocab_size_multiple != 0:
|
|
||||||
vocab_size += pad_vocab_size_multiple - (
|
|
||||||
vocab_size % pad_vocab_size_multiple
|
|
||||||
)
|
|
||||||
self.config = MambaConfig(
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
d_model=d_model,
|
|
||||||
n_layer=n_layer,
|
|
||||||
pad_vocab_size_multiple=pad_vocab_size_multiple,
|
|
||||||
)
|
|
||||||
self.backbone = MixerModel(
|
|
||||||
d_model=d_model,
|
|
||||||
n_layer=n_layer,
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
initializer_cfg=initializer_cfg,
|
|
||||||
**backbone_kwargs,
|
|
||||||
**factory_kwargs,
|
|
||||||
)
|
|
||||||
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.apply(
|
|
||||||
partial(
|
|
||||||
_init_weights,
|
|
||||||
n_layer=n_layer,
|
|
||||||
**(initializer_cfg if initializer_cfg is not None else {}),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.tie_weights()
|
|
||||||
|
|
||||||
def tie_weights(self):
|
|
||||||
self.lm_head.weight = self.backbone.embedding.weight
|
|
||||||
|
|
||||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
|
||||||
return self.backbone.allocate_inference_cache(
|
|
||||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
position_ids=None,
|
|
||||||
inference_params=None,
|
|
||||||
num_last_tokens=0,
|
|
||||||
labels=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
|
||||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
|
||||||
"""
|
|
||||||
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
|
||||||
if num_last_tokens > 0:
|
|
||||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
|
||||||
return CausalLMOutput(logits=lm_logits)
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
logits = lm_logits
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
|
|
||||||
print(loss)
|
|
||||||
return CausalLMOutput(logits=lm_logits, loss=loss)
|
|
||||||
|
|
||||||
else:
|
|
||||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
|
||||||
return CausalLMOutput(logits=lm_logits)
|
|
||||||
|
|
||||||
def save_pretrained(
|
|
||||||
self,
|
|
||||||
save_directory: Union[str, os.PathLike],
|
|
||||||
state_dict: Optional[dict] = None,
|
|
||||||
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
if state_dict is None:
|
|
||||||
state_dict = self.state_dict()
|
|
||||||
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
|
||||||
config = load_config_hf(pretrained_model_name)
|
|
||||||
model = cls(**config, device=device, dtype=dtype, **kwargs)
|
|
||||||
model.load_state_dict(
|
|
||||||
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
@@ -3,6 +3,4 @@ MixFormers model architecture used for phi models
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
|
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
|
||||||
from .configuration_phi import PhiConfig # noqa
|
|
||||||
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
|
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
|
||||||
from .modeling_phi import PhiForCausalLM # noqa
|
|
||||||
|
|||||||
@@ -1,65 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT license.
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
class PhiConfig(PretrainedConfig):
|
|
||||||
"""Phi configuration."""
|
|
||||||
|
|
||||||
model_type = "phi"
|
|
||||||
attribute_map = {
|
|
||||||
"max_position_embeddings": "n_positions",
|
|
||||||
"hidden_size": "n_embd",
|
|
||||||
"num_attention_heads": "n_head",
|
|
||||||
"num_hidden_layers": "n_layer",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int = 50304,
|
|
||||||
n_positions: int = 2048,
|
|
||||||
n_embd: int = 1024,
|
|
||||||
n_layer: int = 20,
|
|
||||||
n_inner: Optional[int] = None,
|
|
||||||
n_head: int = 16,
|
|
||||||
n_head_kv: Optional[int] = None,
|
|
||||||
rotary_dim: Optional[int] = 32,
|
|
||||||
activation_function: Optional[str] = "gelu_new",
|
|
||||||
flash_attn: bool = False,
|
|
||||||
flash_rotary: bool = False,
|
|
||||||
fused_dense: bool = False,
|
|
||||||
attn_pdrop: float = 0.0,
|
|
||||||
embd_pdrop: float = 0.0,
|
|
||||||
resid_pdrop: float = 0.0,
|
|
||||||
layer_norm_epsilon: float = 1e-5,
|
|
||||||
initializer_range: float = 0.02,
|
|
||||||
tie_word_embeddings: bool = False,
|
|
||||||
pad_vocab_size_multiple: int = 64,
|
|
||||||
**kwargs
|
|
||||||
) -> None:
|
|
||||||
self.vocab_size = int(
|
|
||||||
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
|
||||||
)
|
|
||||||
self.n_positions = n_positions
|
|
||||||
self.n_embd = n_embd
|
|
||||||
self.n_layer = n_layer
|
|
||||||
self.n_inner = n_inner
|
|
||||||
self.n_head = n_head
|
|
||||||
self.n_head_kv = n_head_kv
|
|
||||||
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
|
||||||
self.activation_function = activation_function
|
|
||||||
self.flash_attn = flash_attn
|
|
||||||
self.flash_rotary = flash_rotary
|
|
||||||
self.fused_dense = fused_dense
|
|
||||||
self.attn_pdrop = attn_pdrop
|
|
||||||
self.embd_pdrop = embd_pdrop
|
|
||||||
self.resid_pdrop = resid_pdrop
|
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
|
|
||||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -82,44 +82,15 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role + ":", ""
|
yield role + ":", ""
|
||||||
return
|
return
|
||||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
|
if self.sep_style == SeparatorStyle.LLAMA2:
|
||||||
|
seps = [self.sep, self.sep2]
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
if self.messages:
|
|
||||||
# For llama, the system message is incorporated into the first human instruction
|
|
||||||
first_role, first_msg = self.messages[0]
|
|
||||||
if first_role == self.roles[0]:
|
|
||||||
system_prompt += first_msg
|
|
||||||
self.messages.pop(0)
|
|
||||||
yield "", system_prompt
|
yield "", system_prompt
|
||||||
for i, (role, message) in enumerate(self.messages):
|
else:
|
||||||
|
yield "", "[INST] "
|
||||||
|
for i, (role, message) in enumerate(self.messages[1:]):
|
||||||
if message:
|
if message:
|
||||||
if (i % 2 == 0 and not self.system_message) or (
|
yield role + " ", message + seps[i % 2]
|
||||||
i % 2 != 0 and self.system_message
|
|
||||||
):
|
|
||||||
role = "<s> " + role
|
|
||||||
yield role + " ", message
|
|
||||||
else:
|
|
||||||
yield role, ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
|
|
||||||
contains_sys_msg = False
|
|
||||||
if self.system_message:
|
|
||||||
contains_sys_msg = True
|
|
||||||
if self.messages:
|
|
||||||
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
|
|
||||||
first_role, first_msg = self.messages[0]
|
|
||||||
if first_role == self.roles[0]:
|
|
||||||
system_prompt = self.system_template.format(
|
|
||||||
system_message=" " + self.system_message
|
|
||||||
)
|
|
||||||
system_prompt += first_msg
|
|
||||||
self.messages.pop(0)
|
|
||||||
yield "", system_prompt
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message and i == 0 and not contains_sys_msg:
|
|
||||||
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
|
|
||||||
elif message:
|
|
||||||
yield role + " ", message
|
|
||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
@@ -147,15 +118,6 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role + "\n", ""
|
yield role + "\n", ""
|
||||||
return
|
return
|
||||||
if self.sep_style == SeparatorStyle.CHATGLM3:
|
|
||||||
if self.system_message:
|
|
||||||
yield "", system_prompt
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
yield role + "\n", " " + message
|
|
||||||
else:
|
|
||||||
yield role
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.CHATINTERN:
|
if self.sep_style == SeparatorStyle.CHATINTERN:
|
||||||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
||||||
seps = [self.sep, self.sep2]
|
seps = [self.sep, self.sep2]
|
||||||
|
|||||||
@@ -321,8 +321,6 @@ def flashattn_forward(
|
|||||||
# only on first autoregressive step q,k,v have same seqlen
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
is_causal = key_states.shape == query_states.shape
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
qkv = torch.stack(
|
qkv = torch.stack(
|
||||||
@@ -332,12 +330,7 @@ def flashattn_forward(
|
|||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv,
|
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
elif query_states.shape == key_states.shape:
|
elif query_states.shape == key_states.shape:
|
||||||
@@ -360,7 +353,7 @@ def flashattn_forward(
|
|||||||
qkv_unpad,
|
qkv_unpad,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
dropout_p=dropout_rate,
|
0.0,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
)
|
)
|
||||||
@@ -373,7 +366,6 @@ def flashattn_forward(
|
|||||||
output = flash_attn_kvpacked_func(
|
output = flash_attn_kvpacked_func(
|
||||||
query_states,
|
query_states,
|
||||||
torch.stack([key_states, value_states], 2),
|
torch.stack([key_states, value_states], 2),
|
||||||
dropout_p=dropout_rate,
|
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -406,7 +398,7 @@ def flashattn_forward(
|
|||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
dropout_p=dropout_rate,
|
0.0,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,8 +25,6 @@ def sdp_attention_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|||||||
@@ -29,8 +29,6 @@ def xformers_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|||||||
1249
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
1249
src/axolotl/monkeypatch/llama_landmark_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -201,8 +201,6 @@ def flashattn_forward(
|
|||||||
# only on first autoregressive step q,k,v have same seqlen
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
is_causal = key_states.shape == query_states.shape
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
qkv = torch.stack(
|
qkv = torch.stack(
|
||||||
@@ -215,7 +213,7 @@ def flashattn_forward(
|
|||||||
qkv,
|
qkv,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_seqlen,
|
max_seqlen,
|
||||||
dropout_p=dropout_rate,
|
0.0,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=True,
|
causal=True,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
@@ -241,7 +239,7 @@ def flashattn_forward(
|
|||||||
qkv_unpad,
|
qkv_unpad,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
dropout_p=dropout_rate,
|
0.0,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
@@ -255,7 +253,6 @@ def flashattn_forward(
|
|||||||
output = flash_attn_kvpacked_func(
|
output = flash_attn_kvpacked_func(
|
||||||
query_states,
|
query_states,
|
||||||
torch.stack([key_states, value_states], 2),
|
torch.stack([key_states, value_states], 2),
|
||||||
dropout_p=dropout_rate,
|
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
)
|
)
|
||||||
@@ -289,7 +286,7 @@ def flashattn_forward(
|
|||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
dropout_p=dropout_rate,
|
0.0,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
"""
|
|
||||||
Patches to support multipack for mixtral
|
|
||||||
"""
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
|
|
||||||
def replace_mixtral_attn_with_multipack_flash_attn():
|
|
||||||
from .modeling_mixtral import (
|
|
||||||
MixtralMultipackFlashAttention2,
|
|
||||||
mixtral_decoder_layer_forward,
|
|
||||||
mixtral_model_forward,
|
|
||||||
)
|
|
||||||
|
|
||||||
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
|
|
||||||
mixtral_decoder_layer_forward
|
|
||||||
)
|
|
||||||
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
|
||||||
mixtral_model_forward
|
|
||||||
)
|
|
||||||
transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
|
|
||||||
"flash_attention_2"
|
|
||||||
] = MixtralMultipackFlashAttention2
|
|
||||||
@@ -1,383 +0,0 @@
|
|||||||
"""
|
|
||||||
Mixtral modeling for multipack
|
|
||||||
"""
|
|
||||||
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
|
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from einops import rearrange
|
|
||||||
from flash_attn import flash_attn_varlen_qkvpacked_func
|
|
||||||
from transformers import Cache, DynamicCache
|
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
||||||
from transformers.modeling_outputs import MoeModelOutputWithPast
|
|
||||||
from transformers.models.mixtral.modeling_mixtral import (
|
|
||||||
MixtralFlashAttention2,
|
|
||||||
apply_rotary_pos_emb,
|
|
||||||
repeat_kv,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
|
|
||||||
"""
|
|
||||||
Custom multipack implementation w flash attention 2
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._flash_attn_uses_top_left_mask = True
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
if "padding_mask" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
|
||||||
)
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(
|
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
key_states = key_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
value_states = value_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
if self.layer_idx is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
|
||||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
|
||||||
"with a layer index."
|
|
||||||
)
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin, position_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
|
||||||
key_states, value_states = past_key_value.update(
|
|
||||||
key_states, value_states, self.layer_idx, cache_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
|
||||||
# special handling using sample packing
|
|
||||||
qkv = torch.stack(
|
|
||||||
[query_states, key_states, value_states], dim=2
|
|
||||||
) # [bsz, nh, 3, q_len, hd]
|
|
||||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
|
||||||
|
|
||||||
attn_output = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
dropout_p=self.attention_dropout,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
|
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
|
||||||
attn_weights = None
|
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
def mixtral_decoder_layer_forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
output_router_logits: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
||||||
if "padding_mask" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
||||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
||||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
|
||||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
output_router_logits (`bool`, *optional*):
|
|
||||||
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
|
||||||
should not be returned during inference.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
||||||
(see `past_key_values`).
|
|
||||||
"""
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
|
|
||||||
# Self Attention
|
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# Fully Connected
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
outputs += (self_attn_weights,)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
outputs += (present_key_value,)
|
|
||||||
|
|
||||||
if output_router_logits:
|
|
||||||
outputs += (router_logits,)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
def mixtral_model_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
output_router_logits: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, MoeModelOutputWithPast]:
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_router_logits = (
|
|
||||||
output_router_logits
|
|
||||||
if output_router_logits is not None
|
|
||||||
else self.config.output_router_logits
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
||||||
)
|
|
||||||
if input_ids is not None:
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
|
||||||
)
|
|
||||||
|
|
||||||
past_key_values_length = 0
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
|
||||||
if use_legacy_cache:
|
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
||||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
|
||||||
|
|
||||||
cu_seqlens = None
|
|
||||||
max_seqlen = None
|
|
||||||
if position_ids is None:
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_key_values_length,
|
|
||||||
seq_length + past_key_values_length,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
||||||
else:
|
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
|
||||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
|
||||||
cu_seqlens = cu_seqlens.squeeze()
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
if (
|
|
||||||
attention_mask is not None
|
|
||||||
and self._attn_implementation == "flash_attention_2"
|
|
||||||
and use_cache
|
|
||||||
):
|
|
||||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
|
||||||
if is_padding_right:
|
|
||||||
raise ValueError(
|
|
||||||
"You are attempting to perform batched generation with padding_side='right'"
|
|
||||||
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
|
||||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._attn_implementation == "flash_attention_2":
|
|
||||||
# 2d mask is passed through the layers
|
|
||||||
attention_mask = (
|
|
||||||
attention_mask
|
|
||||||
if (attention_mask is not None and 0 in attention_mask)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 4d mask is passed through the layers
|
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
sliding_window=self.config.sliding_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
LOG.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
all_router_logits = () if output_router_logits else None
|
|
||||||
next_decoder_cache = None
|
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
|
||||||
decoder_layer.__call__,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
position_ids,
|
|
||||||
past_key_values,
|
|
||||||
output_attentions,
|
|
||||||
output_router_logits,
|
|
||||||
use_cache,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_values,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_router_logits=output_router_logits,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
if output_router_logits:
|
|
||||||
all_router_logits += (layer_outputs[-1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
next_cache = None
|
|
||||||
if use_cache:
|
|
||||||
next_cache = (
|
|
||||||
next_decoder_cache.to_legacy_cache()
|
|
||||||
if use_legacy_cache
|
|
||||||
else next_decoder_cache
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [
|
|
||||||
hidden_states,
|
|
||||||
next_cache,
|
|
||||||
all_hidden_states,
|
|
||||||
all_self_attns,
|
|
||||||
all_router_logits,
|
|
||||||
]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
return MoeModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
router_logits=all_router_logits,
|
|
||||||
)
|
|
||||||
65
src/axolotl/monkeypatch/neft_embeddings.py
Normal file
65
src/axolotl/monkeypatch/neft_embeddings.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""
|
||||||
|
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
|
def patch_neft(alpha, model):
|
||||||
|
embeddings = None
|
||||||
|
if isinstance(model, PreTrainedModel):
|
||||||
|
embeddings = model.get_input_embeddings()
|
||||||
|
if isinstance(model, PeftModel):
|
||||||
|
embeddings = model.base_model.get_input_embeddings()
|
||||||
|
if not embeddings:
|
||||||
|
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
||||||
|
embeddings.noisy_embedding_alpha = alpha
|
||||||
|
old_forward = embeddings.forward
|
||||||
|
|
||||||
|
# This hack seems to be needed to properly use a custom forward pass
|
||||||
|
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
||||||
|
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
|
||||||
|
embeddings, embeddings.__class__
|
||||||
|
)
|
||||||
|
setattr(embeddings, "forward", bound_method)
|
||||||
|
|
||||||
|
embeddings._old_forward = old_forward # pylint: disable=protected-access
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def unpatch_neft(model):
|
||||||
|
embeddings = None
|
||||||
|
if isinstance(model, PreTrainedModel):
|
||||||
|
embeddings = model.get_input_embeddings()
|
||||||
|
if isinstance(model, PeftModel):
|
||||||
|
embeddings = model.base_model.get_input_embeddings()
|
||||||
|
if not embeddings:
|
||||||
|
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
||||||
|
if hasattr(embeddings, "_old_forward"):
|
||||||
|
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
|
||||||
|
del embeddings._old_forward # pylint: disable=protected-access
|
||||||
|
del embeddings.noisy_embedding_alpha
|
||||||
|
|
||||||
|
|
||||||
|
def neft_forward(self, inputs: torch.Tensor):
|
||||||
|
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
||||||
|
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
|
||||||
|
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
|
||||||
|
-mag_norm, mag_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def pretrain_hook(cfg, trainer):
|
||||||
|
if cfg.noisy_embedding_alpha:
|
||||||
|
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
|
||||||
|
|
||||||
|
|
||||||
|
def post_train_hook(cfg, trainer):
|
||||||
|
if cfg.noisy_embedding_alpha:
|
||||||
|
unpatch_neft(trainer.model)
|
||||||
94
src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py
Normal file
94
src/axolotl/monkeypatch/xpos_rope_llama_monkey_patch.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
"""
|
||||||
|
Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
class XposRotaryEmbedding(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
base=10000,
|
||||||
|
device=None,
|
||||||
|
scale_base=2048,
|
||||||
|
use_xpos=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.max_seq_len_cached = max_position_embeddings
|
||||||
|
self.scale_base = scale_base
|
||||||
|
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
|
||||||
|
freqs = torch.einsum("i , j -> i j", t, inv_freq)
|
||||||
|
freqs = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
self.register_buffer("freqs_cached", freqs, persistent=False)
|
||||||
|
|
||||||
|
if not use_xpos:
|
||||||
|
self.register_buffer("scale", None)
|
||||||
|
self.register_buffer("scale_cached", torch.ones(1))
|
||||||
|
return
|
||||||
|
|
||||||
|
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||||
|
power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
|
||||||
|
scale_cached = scale ** rearrange(power, "n -> n 1")
|
||||||
|
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
||||||
|
|
||||||
|
self.register_buffer("scale", scale, persistent=False)
|
||||||
|
self.register_buffer("scale_cached", scale_cached, persistent=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
seq_len,
|
||||||
|
):
|
||||||
|
if seq_len > self.max_seq_len_cached:
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
|
||||||
|
self.inv_freq
|
||||||
|
)
|
||||||
|
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
|
||||||
|
freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
self.register_buffer("freqs_cached", freqs)
|
||||||
|
|
||||||
|
if self.scale is None:
|
||||||
|
self.register_buffer(
|
||||||
|
"scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
|
||||||
|
|
||||||
|
power = (t - (seq_len // 2)) / self.scale_base
|
||||||
|
scale = self.scale ** rearrange(power, "n -> n 1")
|
||||||
|
scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
|
||||||
|
self.register_buffer("scale_cached", scale)
|
||||||
|
|
||||||
|
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
|
||||||
|
freqs = freqs[position_ids, :]
|
||||||
|
if scale.shape[-1] != 1:
|
||||||
|
scale = scale[position_ids, :]
|
||||||
|
|
||||||
|
q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
|
||||||
|
k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
|
||||||
|
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
def replace_llama_rope_with_xpos_rope():
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
|
||||||
|
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||||
@@ -81,9 +81,8 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.tokenizer.add_special_tokens(
|
self.sequence_len = 4096
|
||||||
{"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")}
|
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||||
)
|
|
||||||
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
|
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ register_conv_template(
|
|||||||
system_message="You are a helpful assistant.",
|
system_message="You are a helpful assistant.",
|
||||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||||
sep_style=SeparatorStyle.CHATML,
|
sep_style=SeparatorStyle.CHATML,
|
||||||
sep="<|im_end|>",
|
sep="<|im_end|>\n",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,23 +39,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
conversation = (
|
|
||||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
|
||||||
)
|
|
||||||
strategy = UltrachatShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation=conversation,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
if ds_cfg and "strict" in ds_cfg:
|
|
||||||
strategy.strict = ds_cfg["strict"]
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
|
|
||||||
def load_role(tokenizer, cfg):
|
def load_role(tokenizer, cfg):
|
||||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompterV2(),
|
ShareGPTPrompterV2(),
|
||||||
@@ -126,17 +109,3 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||||
]
|
]
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
|
||||||
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
sharegpt strategy that remaps ultrachat data to sharegpt format
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
conversations = prompt["messages"]
|
|
||||||
role_map = {"user": "human", "assistant": "gpt"}
|
|
||||||
turns = [
|
|
||||||
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
|
||||||
]
|
|
||||||
return turns
|
|
||||||
|
|||||||
@@ -22,19 +22,13 @@ class PromptStyle(Enum):
|
|||||||
CHATML = "chatml"
|
CHATML = "chatml"
|
||||||
|
|
||||||
|
|
||||||
class Prompter:
|
class AlpacaPrompter:
|
||||||
"""
|
|
||||||
Base prompter class for all prompters
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class AlpacaPrompter(Prompter):
|
|
||||||
"""
|
"""
|
||||||
Base class for alpaca prompters
|
Base class for alpaca prompters
|
||||||
"""
|
"""
|
||||||
|
|
||||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
|
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||||
system_format: str = "{system}"
|
system_format: str = "{system}"
|
||||||
turn_format: str
|
turn_format: str
|
||||||
turn_no_input_format: str
|
turn_no_input_format: str
|
||||||
@@ -75,7 +69,7 @@ class AlpacaPrompter(Prompter):
|
|||||||
else:
|
else:
|
||||||
res = (
|
res = (
|
||||||
self.system_format.format(system=self.system_no_input_prompt)
|
self.system_format.format(system=self.system_no_input_prompt)
|
||||||
if self.system_no_input_prompt
|
if self.system_prompt
|
||||||
else ""
|
else ""
|
||||||
) + self.turn_no_input_format.format(instruction=instruction)
|
) + self.turn_no_input_format.format(instruction=instruction)
|
||||||
if output:
|
if output:
|
||||||
@@ -165,7 +159,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ReflectAlpacaPrompter(Prompter):
|
class ReflectAlpacaPrompter:
|
||||||
"""
|
"""
|
||||||
Prompter for ReflectAlpaca
|
Prompter for ReflectAlpaca
|
||||||
"""
|
"""
|
||||||
@@ -260,7 +254,7 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||||
"""
|
"""
|
||||||
A prompter that generates prompts for the ShareGPT
|
A prompter that generates prompts for the ShareGPT
|
||||||
"""
|
"""
|
||||||
@@ -355,7 +349,7 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedPrompter(Prompter):
|
class UnsupportedPrompter:
|
||||||
"""
|
"""
|
||||||
A dummy class for custom prompters
|
A dummy class for custom prompters
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@@ -9,16 +10,14 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
from accelerate.logging import get_logger
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from pkg_resources import get_distribution # type: ignore
|
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
|
from axolotl.monkeypatch import neft_embeddings
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_parameters_except
|
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
@@ -27,7 +26,7 @@ src_dir = os.path.join(project_root, "src")
|
|||||||
sys.path.insert(0, src_dir)
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = get_logger("axolotl.train")
|
LOG = logging.getLogger("axolotl.train")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -45,10 +44,7 @@ def train(
|
|||||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||||
):
|
):
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.debug(
|
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
|
||||||
main_process_only=True,
|
|
||||||
)
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
train_dataset = dataset_meta.train_dataset
|
train_dataset = dataset_meta.train_dataset
|
||||||
@@ -56,17 +52,8 @@ def train(
|
|||||||
total_num_steps = dataset_meta.total_num_steps
|
total_num_steps = dataset_meta.total_num_steps
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
msg = "loading model"
|
LOG.info("loading model and (optionally) peft_config...")
|
||||||
if cfg.adapter:
|
|
||||||
msg += " and peft_config..."
|
|
||||||
LOG.debug(msg)
|
|
||||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
model_ref = None
|
|
||||||
if cfg.rl:
|
|
||||||
# load the model again for model_ref/baseline
|
|
||||||
model_ref, _ = load_model(
|
|
||||||
cfg, tokenizer, inference=cli_args.inference, reference_model=True
|
|
||||||
)
|
|
||||||
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
@@ -85,15 +72,11 @@ def train(
|
|||||||
)
|
)
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||||
|
|
||||||
if cfg.unfrozen_parameters:
|
|
||||||
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
|
||||||
|
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(model, "config"):
|
model.config.use_cache = False
|
||||||
model.config.use_cache = False
|
|
||||||
|
|
||||||
# go ahead and presave, so we have the adapter config available to inspect
|
# go ahead and presave, so we have the adapter config available to inspect
|
||||||
if peft_config:
|
if peft_config:
|
||||||
@@ -103,8 +86,7 @@ def train(
|
|||||||
if not Path(cfg.output_dir).is_dir():
|
if not Path(cfg.output_dir).is_dir():
|
||||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
if hasattr(model, "config"):
|
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
|
||||||
|
|
||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
@@ -122,12 +104,6 @@ def train(
|
|||||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||||
|
|
||||||
if getattr(cfg, "axolotl_config_path"):
|
|
||||||
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
|
||||||
version = get_distribution("axolotl").version
|
|
||||||
if raw_axolotl_cfg.is_file():
|
|
||||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
|
||||||
|
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
@@ -188,26 +164,25 @@ def train(
|
|||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||||
elif cfg.hub_model_id:
|
|
||||||
# defensively push to the hub to ensure the model card is updated
|
|
||||||
trainer.push_to_hub()
|
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def pretrain_hooks(_cfg, _trainer):
|
def pretrain_hooks(cfg, trainer):
|
||||||
"""
|
"""
|
||||||
Run hooks right before kicking off the training
|
Run hooks right before kicking off the training
|
||||||
:param cfg:
|
:param cfg:
|
||||||
:param trainer:
|
:param trainer:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
neft_embeddings.pretrain_hook(cfg, trainer)
|
||||||
|
|
||||||
|
|
||||||
def post_train_hooks(_cfg, _trainer):
|
def post_train_hooks(cfg, trainer):
|
||||||
"""
|
"""
|
||||||
Run hooks right after training completes
|
Run hooks right after training completes
|
||||||
:param cfg:
|
:param cfg:
|
||||||
:param trainer:
|
:param trainer:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
neft_embeddings.post_train_hook(cfg, trainer)
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
"""Benchmarking and measurement utilities"""
|
"""Benchmarking and measurement utilities"""
|
||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
|
|
||||||
import pynvml
|
import pynvml
|
||||||
import torch
|
import torch
|
||||||
from pynvml.nvml import NVMLError
|
from pynvml.nvml import NVMLError
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.utils.bench")
|
||||||
|
|
||||||
|
|
||||||
def check_cuda_device(default_value):
|
def check_cuda_device(default_value):
|
||||||
"""
|
"""
|
||||||
@@ -62,7 +65,14 @@ def gpu_memory_usage_smi(device=0):
|
|||||||
|
|
||||||
|
|
||||||
def log_gpu_memory_usage(log, msg, device):
|
def log_gpu_memory_usage(log, msg, device):
|
||||||
usage, cache, misc = gpu_memory_usage_all(device)
|
if not torch.cuda.is_available():
|
||||||
|
return (0, 0, 0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
usage, cache, misc = gpu_memory_usage_all(device)
|
||||||
|
except ValueError as exc:
|
||||||
|
LOG.exception(exc)
|
||||||
|
return (0, 0, 0)
|
||||||
extras = []
|
extras = []
|
||||||
if cache > 0:
|
if cache > 0:
|
||||||
extras.append(f"+{cache:.03f}GB cache")
|
extras.append(f"+{cache:.03f}GB cache")
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from shutil import copyfile
|
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
from typing import TYPE_CHECKING, Dict, List
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
@@ -126,36 +124,6 @@ class GPUStatsCallback(
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class LossWatchDogCallback(TrainerCallback):
|
|
||||||
"""Callback to track loss and stop training if loss is too high"""
|
|
||||||
|
|
||||||
def __init__(self, cfg):
|
|
||||||
self.cfg = cfg
|
|
||||||
self.logged = False
|
|
||||||
self.violations = 0
|
|
||||||
self.threshold = cfg.loss_watchdog_threshold
|
|
||||||
self.patience = cfg.loss_watchdog_patience or 3
|
|
||||||
|
|
||||||
def on_step_end(
|
|
||||||
self,
|
|
||||||
_args: TrainingArguments,
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
**_kwargs,
|
|
||||||
):
|
|
||||||
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
|
|
||||||
if state.log_history[-1]["loss"] > self.threshold:
|
|
||||||
self.violations += 1
|
|
||||||
if self.violations >= self.patience:
|
|
||||||
LOG.warning(
|
|
||||||
"Loss is too high, stopping training (loss_watchdog_threshold)"
|
|
||||||
)
|
|
||||||
control.should_training_stop = True
|
|
||||||
else:
|
|
||||||
self.violations = 0
|
|
||||||
return control
|
|
||||||
|
|
||||||
|
|
||||||
def bench_eval_callback_factory(trainer, tokenizer):
|
def bench_eval_callback_factory(trainer, tokenizer):
|
||||||
accuracy = evaluate.load("accuracy")
|
accuracy = evaluate.load("accuracy")
|
||||||
abcd_idx = [
|
abcd_idx = [
|
||||||
@@ -563,15 +531,10 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
):
|
):
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
try:
|
try:
|
||||||
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
artifact = wandb.Artifact(name="axolotl-config", type="config")
|
||||||
with NamedTemporaryFile(
|
artifact.add_file(local_path=self.axolotl_config_path)
|
||||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
wandb.run.log_artifact(artifact)
|
||||||
) as temp_file:
|
LOG.info("Axolotl config has been saved to WandB as an artifact.")
|
||||||
copyfile(self.axolotl_config_path, temp_file.name)
|
|
||||||
wandb.save(temp_file.name)
|
|
||||||
LOG.info(
|
|
||||||
"The Axolotl config has been saved to the WandB run under files."
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
return control
|
return control
|
||||||
|
|||||||
@@ -1,29 +0,0 @@
|
|||||||
"""
|
|
||||||
This module provides functionality for selecting chat templates based on user choices.
|
|
||||||
These templates are used for formatting messages in a conversation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def chat_templates(user_choice: str):
|
|
||||||
"""
|
|
||||||
Finds the correct chat_template for the tokenizer_config.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_choice (str): The user's choice of template.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The chosen template string.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the user_choice is not found in the templates.
|
|
||||||
"""
|
|
||||||
|
|
||||||
templates = {
|
|
||||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
|
||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
|
||||||
}
|
|
||||||
|
|
||||||
if user_choice in templates:
|
|
||||||
return templates[user_choice]
|
|
||||||
|
|
||||||
raise ValueError(f"Template '{user_choice}' not found.")
|
|
||||||
@@ -2,16 +2,12 @@
|
|||||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Sequence, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -123,79 +119,3 @@ class DataCollatorForSeq2Seq:
|
|||||||
features["decoder_input_ids"] = decoder_input_ids
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|
||||||
"""
|
|
||||||
Collator for multipack specific to the using the BatchSampler
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
|
||||||
chunked_data = {}
|
|
||||||
for feature in features[0].keys():
|
|
||||||
if feature == "length":
|
|
||||||
continue
|
|
||||||
if feature == "attention_mask":
|
|
||||||
arrays = [
|
|
||||||
(1) * np.array(item[feature])
|
|
||||||
for item in features
|
|
||||||
if feature in item
|
|
||||||
]
|
|
||||||
chunked_data[feature] = np.concatenate(arrays)
|
|
||||||
else:
|
|
||||||
arrays = [
|
|
||||||
np.array(item[feature]) for item in features if feature in item
|
|
||||||
]
|
|
||||||
chunked_data[feature] = np.concatenate(arrays)
|
|
||||||
features = [chunked_data]
|
|
||||||
return super().__call__(features, return_tensors=return_tensors)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MambaDataCollator:
|
|
||||||
"""
|
|
||||||
Collator for State Space Models (Mamba)
|
|
||||||
"""
|
|
||||||
|
|
||||||
tokenizer: transformers.PreTrainedTokenizer
|
|
||||||
|
|
||||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
|
||||||
input_ids, labels = tuple(
|
|
||||||
[torch.LongTensor(instance[key]) for instance in instances]
|
|
||||||
for key in ("input_ids", "labels")
|
|
||||||
)
|
|
||||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
|
||||||
input_ids,
|
|
||||||
batch_first=True,
|
|
||||||
padding_value=self.tokenizer.pad_token_id,
|
|
||||||
)
|
|
||||||
labels = torch.nn.utils.rnn.pad_sequence(
|
|
||||||
labels, batch_first=True, padding_value=IGNORE_INDEX
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"labels": labels,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|
||||||
"""
|
|
||||||
Collator for multipack specific to the using the BatchSampler
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
|
||||||
chunked_data = {}
|
|
||||||
for feature in features.keys():
|
|
||||||
if feature == "length":
|
|
||||||
continue
|
|
||||||
if feature == "attention_mask":
|
|
||||||
arrays = [(1) * np.array(item) for item in features[feature]]
|
|
||||||
chunked_data[feature] = np.concatenate(arrays)
|
|
||||||
else:
|
|
||||||
arrays = [np.array(item) for item in features[feature]]
|
|
||||||
chunked_data[feature] = np.concatenate(arrays)
|
|
||||||
features = [chunked_data]
|
|
||||||
return super().__call__(features, return_tensors=return_tensors)
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def choose_device(cfg):
|
|||||||
|
|
||||||
cfg.device = get_device()
|
cfg.device = get_device()
|
||||||
if cfg.world_size == 1:
|
if cfg.world_size == 1:
|
||||||
cfg.device_map = cfg.device_map or "auto"
|
cfg.device_map = "auto"
|
||||||
else:
|
else:
|
||||||
if cfg.device.startswith("cuda"):
|
if cfg.device.startswith("cuda"):
|
||||||
cfg.device_map = {"": torch.cuda.current_device()}
|
cfg.device_map = {"": torch.cuda.current_device()}
|
||||||
@@ -77,15 +77,6 @@ def normalize_config(cfg):
|
|||||||
else:
|
else:
|
||||||
cfg.torch_dtype = torch.float32
|
cfg.torch_dtype = torch.float32
|
||||||
|
|
||||||
if cfg.saves_per_epoch:
|
|
||||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
|
||||||
if save_steps < 1.0: # prevent saves on every step
|
|
||||||
cfg.save_steps = save_steps
|
|
||||||
if cfg.evals_per_epoch:
|
|
||||||
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
|
|
||||||
if eval_steps < 1.0: # prevent evals on every step
|
|
||||||
cfg.eval_steps = eval_steps
|
|
||||||
|
|
||||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
||||||
|
|
||||||
if not cfg.base_model_config:
|
if not cfg.base_model_config:
|
||||||
@@ -131,19 +122,6 @@ def normalize_config(cfg):
|
|||||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg.is_qwen_derived_model = (
|
|
||||||
(
|
|
||||||
hasattr(model_config, "model_type")
|
|
||||||
and model_config.model_type
|
|
||||||
in [
|
|
||||||
"qwen",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
or cfg.is_qwen_derived_model
|
|
||||||
or "qwen" in cfg.base_model.lower()
|
|
||||||
or (cfg.model_type and "qwen" in cfg.model_type.lower())
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(cfg.learning_rate, str):
|
if isinstance(cfg.learning_rate, str):
|
||||||
cfg.learning_rate = float(cfg.learning_rate)
|
cfg.learning_rate = float(cfg.learning_rate)
|
||||||
|
|
||||||
@@ -187,11 +165,7 @@ def validate_config(cfg):
|
|||||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||||
)
|
)
|
||||||
if (
|
if cfg.eval_batch_size != cfg.micro_batch_size:
|
||||||
cfg.eval_batch_size
|
|
||||||
and cfg.micro_batch_size
|
|
||||||
and cfg.eval_batch_size != cfg.micro_batch_size
|
|
||||||
):
|
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||||
)
|
)
|
||||||
@@ -361,27 +335,6 @@ def validate_config(cfg):
|
|||||||
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||||
"sharegpt_simple", "sharegpt"
|
"sharegpt_simple", "sharegpt"
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.saves_per_epoch and cfg.save_steps:
|
|
||||||
raise ValueError(
|
|
||||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
|
||||||
)
|
|
||||||
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
|
|
||||||
raise ValueError(
|
|
||||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
|
||||||
)
|
|
||||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
|
||||||
raise ValueError(
|
|
||||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
cfg.evals_per_epoch
|
|
||||||
and cfg.evaluation_strategy
|
|
||||||
and cfg.evaluation_strategy != "steps"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
|
||||||
)
|
|
||||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||||
@@ -416,52 +369,10 @@ def validate_config(cfg):
|
|||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.rope_scaling:
|
if cfg.tensor_parallel and cfg.gradient_checkpointing:
|
||||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
|
||||||
|
|
||||||
if cfg.warmup_steps and cfg.warmup_ratio:
|
|
||||||
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
|
||||||
|
|
||||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
|
||||||
cfg.wandb_name = cfg.wandb_run_id
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.noisy_embedding_alpha is not None:
|
|
||||||
# Deprecated, use neftune_noise_alpha
|
|
||||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
|
||||||
if cfg.neftune_noise_alpha is None:
|
|
||||||
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
|
||||||
else:
|
|
||||||
# User is providing both; bail and have them sort out their settings
|
|
||||||
raise ValueError(
|
|
||||||
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
|
||||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.adapter
|
|
||||||
and cfg.tokens
|
|
||||||
and (
|
|
||||||
not cfg.lora_modules_to_save
|
|
||||||
or not all(
|
|
||||||
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
|
"TensorParallelPreTrainedModel does not support gradient checkpointing"
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -2,9 +2,8 @@
|
|||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import (
|
from datasets import (
|
||||||
@@ -15,7 +14,6 @@ from datasets import (
|
|||||||
load_from_disk,
|
load_from_disk,
|
||||||
)
|
)
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from torch.utils.data import RandomSampler
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
@@ -36,19 +34,15 @@ from axolotl.prompters import (
|
|||||||
JeopardyPrompter,
|
JeopardyPrompter,
|
||||||
MultipleChoiceConcisePrompter,
|
MultipleChoiceConcisePrompter,
|
||||||
MultipleChoiceExplainPrompter,
|
MultipleChoiceExplainPrompter,
|
||||||
Prompter,
|
|
||||||
ReflectAlpacaPrompter,
|
ReflectAlpacaPrompter,
|
||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
UnsupportedPrompter,
|
UnsupportedPrompter,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
from axolotl.utils.samplers.multipack import MultipackBatchSampler
|
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
calculate_total_num_steps,
|
calculate_total_num_steps,
|
||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
process_pretraining_datasets_for_packing,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -69,17 +63,9 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = cfg.pretraining_dataset
|
|
||||||
name = None
|
|
||||||
if isinstance(cfg.pretraining_dataset, dict):
|
|
||||||
path = cfg.pretraining_dataset["path"]
|
|
||||||
name = cfg.pretraining_dataset["name"]
|
|
||||||
|
|
||||||
train_dataset = load_pretraining_dataset(
|
train_dataset = load_pretraining_dataset(
|
||||||
path,
|
cfg.pretraining_dataset,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg,
|
|
||||||
name=name,
|
|
||||||
max_tokens=cfg.sequence_len,
|
max_tokens=cfg.sequence_len,
|
||||||
seed=cfg.seed or 42,
|
seed=cfg.seed or 42,
|
||||||
)
|
)
|
||||||
@@ -92,27 +78,19 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
cfg, train_dataset, eval_dataset, tokenizer
|
cfg, train_dataset, eval_dataset, tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
|
||||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
|
||||||
if total_eval_steps == 0:
|
|
||||||
raise ValueError(
|
|
||||||
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.max_steps:
|
if cfg.max_steps:
|
||||||
total_num_steps = min(
|
total_num_steps = min(
|
||||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
||||||
)
|
)
|
||||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||||
else:
|
else:
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||||
|
|
||||||
|
|
||||||
def load_tokenized_prepared_datasets(
|
def load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
) -> DatasetDict:
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = tokenizer.__class__.__name__
|
||||||
ds_hash = str(
|
ds_hash = str(
|
||||||
md5(
|
md5(
|
||||||
@@ -120,12 +98,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
str(cfg.sequence_len)
|
str(cfg.sequence_len)
|
||||||
+ "@"
|
+ "@"
|
||||||
+ "|".join(
|
+ "|".join(
|
||||||
sorted(
|
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||||
[
|
|
||||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
|
||||||
for d in cfg.datasets
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
+ "|"
|
+ "|"
|
||||||
+ tokenizer_name
|
+ tokenizer_name
|
||||||
@@ -191,66 +164,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
except (FileNotFoundError, ConnectionError):
|
except (FileNotFoundError, ConnectionError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
ds_from_cloud = False
|
|
||||||
storage_options = {}
|
|
||||||
remote_file_system = None
|
|
||||||
if config_dataset.path.startswith("s3://"):
|
|
||||||
try:
|
|
||||||
import aiobotocore.session # type: ignore
|
|
||||||
import s3fs # type: ignore
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"s3:// paths require aiobotocore and s3fs to be installed"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
# Takes credentials from ~/.aws/credentials for default profile
|
|
||||||
s3_session = aiobotocore.session.AioSession(profile="default")
|
|
||||||
storage_options = {"session": s3_session}
|
|
||||||
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
|
||||||
elif config_dataset.path.startswith(
|
|
||||||
"gs://"
|
|
||||||
) or config_dataset.path.startswith("gcs://"):
|
|
||||||
try:
|
|
||||||
import gcsfs # type: ignore
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"gs:// or gcs:// paths require gcsfs to be installed"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
# gcsfs will use default credentials from the environment else anon
|
|
||||||
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
|
||||||
storage_options = {"token": None}
|
|
||||||
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
|
||||||
# TODO: Figure out how to get auth creds passed
|
|
||||||
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
|
||||||
# try:
|
|
||||||
# import adlfs
|
|
||||||
# except ImportError as exc:
|
|
||||||
# raise ImportError(
|
|
||||||
# "adl:// or abfs:// paths require adlfs to be installed"
|
|
||||||
# ) from exc
|
|
||||||
|
|
||||||
# # Gen 1
|
|
||||||
# storage_options = {
|
|
||||||
# "tenant_id": TENANT_ID,
|
|
||||||
# "client_id": CLIENT_ID,
|
|
||||||
# "client_secret": CLIENT_SECRET,
|
|
||||||
# }
|
|
||||||
# # Gen 2
|
|
||||||
# storage_options = {
|
|
||||||
# "account_name": ACCOUNT_NAME,
|
|
||||||
# "account_key": ACCOUNT_KEY,
|
|
||||||
# }
|
|
||||||
|
|
||||||
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
|
||||||
try:
|
|
||||||
if remote_file_system and remote_file_system.exists(
|
|
||||||
config_dataset.path
|
|
||||||
):
|
|
||||||
ds_from_cloud = True
|
|
||||||
except (FileNotFoundError, ConnectionError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# prefer local dataset, even if hub exists
|
# prefer local dataset, even if hub exists
|
||||||
local_path = Path(config_dataset.path)
|
local_path = Path(config_dataset.path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
@@ -264,8 +177,17 @@ def load_tokenized_prepared_datasets(
|
|||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds_type = get_ds_type(config_dataset)
|
ds_type = "json"
|
||||||
|
if config_dataset.ds_type:
|
||||||
|
ds_type = config_dataset.ds_type
|
||||||
|
elif ".parquet" in config_dataset.path:
|
||||||
|
ds_type = "parquet"
|
||||||
|
elif ".arrow" in config_dataset.path:
|
||||||
|
ds_type = "arrow"
|
||||||
|
elif ".csv" in config_dataset.path:
|
||||||
|
ds_type = "csv"
|
||||||
|
elif ".txt" in config_dataset.path:
|
||||||
|
ds_type = "text"
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
ds_type,
|
ds_type,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
@@ -285,22 +207,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
elif ds_from_cloud and remote_file_system:
|
|
||||||
if remote_file_system.isdir(config_dataset.path):
|
|
||||||
ds = load_from_disk(
|
|
||||||
config_dataset.path,
|
|
||||||
storage_options=storage_options,
|
|
||||||
)
|
|
||||||
elif remote_file_system.isfile(config_dataset.path):
|
|
||||||
ds_type = get_ds_type(config_dataset)
|
|
||||||
ds = load_dataset(
|
|
||||||
ds_type,
|
|
||||||
name=config_dataset.name,
|
|
||||||
data_files=config_dataset.path,
|
|
||||||
streaming=False,
|
|
||||||
split=None,
|
|
||||||
storage_options=storage_options,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if isinstance(config_dataset.data_files, str):
|
if isinstance(config_dataset.data_files, str):
|
||||||
fp = hf_hub_download(
|
fp = hf_hub_download(
|
||||||
@@ -392,29 +298,11 @@ def load_tokenized_prepared_datasets(
|
|||||||
return dataset, prompters
|
return dataset, prompters
|
||||||
|
|
||||||
|
|
||||||
def get_ds_type(config_dataset: DictDefault):
|
|
||||||
"""
|
|
||||||
Get the dataset type from the path if it's not specified
|
|
||||||
"""
|
|
||||||
ds_type = "json"
|
|
||||||
if config_dataset.ds_type:
|
|
||||||
ds_type = config_dataset.ds_type
|
|
||||||
elif ".parquet" in config_dataset.path:
|
|
||||||
ds_type = "parquet"
|
|
||||||
elif ".arrow" in config_dataset.path:
|
|
||||||
ds_type = "arrow"
|
|
||||||
elif ".csv" in config_dataset.path:
|
|
||||||
ds_type = "csv"
|
|
||||||
elif ".txt" in config_dataset.path:
|
|
||||||
ds_type = "text"
|
|
||||||
return ds_type
|
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_datasets(
|
def load_prepare_datasets(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
cfg,
|
cfg,
|
||||||
default_dataset_prepared_path,
|
default_dataset_prepared_path,
|
||||||
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
) -> Tuple[Dataset, Dataset, List[Any]]:
|
||||||
max_packed_sequence_len = (
|
max_packed_sequence_len = (
|
||||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||||
)
|
)
|
||||||
@@ -423,7 +311,7 @@ def load_prepare_datasets(
|
|||||||
) # make sure we don't accidentally set it larger than sequence_len
|
) # make sure we don't accidentally set it larger than sequence_len
|
||||||
|
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = tokenizer.__class__.__name__
|
||||||
prompters: List[Prompter] = []
|
prompters = []
|
||||||
if cfg.max_packed_sequence_len is not None:
|
if cfg.max_packed_sequence_len is not None:
|
||||||
# see if we can go ahead and load the stacked dataset
|
# see if we can go ahead and load the stacked dataset
|
||||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||||
@@ -557,13 +445,14 @@ def load_prepare_datasets(
|
|||||||
train_fingerprint = md5(to_hash_train)
|
train_fingerprint = md5(to_hash_train)
|
||||||
test_fingerprint = md5(to_hash_test)
|
test_fingerprint = md5(to_hash_test)
|
||||||
|
|
||||||
dataset = dataset.train_test_split(
|
with zero_first(is_main_process()):
|
||||||
test_size=cfg.val_set_size,
|
dataset = dataset.train_test_split(
|
||||||
shuffle=False,
|
test_size=cfg.val_set_size,
|
||||||
seed=cfg.seed or 42,
|
shuffle=False,
|
||||||
train_new_fingerprint=train_fingerprint,
|
seed=cfg.seed or 42,
|
||||||
test_new_fingerprint=test_fingerprint,
|
train_new_fingerprint=train_fingerprint,
|
||||||
)
|
test_new_fingerprint=test_fingerprint,
|
||||||
|
)
|
||||||
|
|
||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"]
|
eval_dataset = dataset["test"]
|
||||||
@@ -593,14 +482,10 @@ def get_dataset_wrapper(
|
|||||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||||
)
|
)
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
|
||||||
)
|
|
||||||
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
|
||||||
)
|
|
||||||
elif d_base_type == "alpaca":
|
elif d_base_type == "alpaca":
|
||||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||||
@@ -609,9 +494,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
|
||||||
)
|
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "explainchoice":
|
elif d_base_type == "explainchoice":
|
||||||
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
||||||
@@ -621,9 +504,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
|
||||||
)
|
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "concisechoice":
|
elif d_base_type == "concisechoice":
|
||||||
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
||||||
@@ -633,9 +514,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
|
||||||
)
|
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "summarizetldr":
|
elif d_base_type == "summarizetldr":
|
||||||
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
||||||
@@ -645,9 +524,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
|
||||||
)
|
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "jeopardy":
|
elif d_base_type == "jeopardy":
|
||||||
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
||||||
@@ -657,9 +534,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
|
||||||
)
|
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "oasst":
|
elif d_base_type == "oasst":
|
||||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||||
@@ -669,9 +544,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
|
||||||
)
|
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "gpteacher":
|
elif d_base_type == "gpteacher":
|
||||||
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
||||||
@@ -681,9 +554,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
|
||||||
)
|
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "reflection":
|
elif d_base_type == "reflection":
|
||||||
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
||||||
@@ -693,9 +564,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
|
||||||
)
|
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
else:
|
else:
|
||||||
suffix = ""
|
suffix = ""
|
||||||
@@ -819,27 +688,9 @@ def encode_pretraining(
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
|
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
||||||
if cfg.sample_packing:
|
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||||
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
dataset = load_dataset(path, streaming=True, split="train")
|
||||||
tokenizer,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
|
|
||||||
)
|
|
||||||
encode = functools.partial(
|
|
||||||
encode_packed_pretraining,
|
|
||||||
tokenizer,
|
|
||||||
collate_fn,
|
|
||||||
max_seq_length=max_tokens,
|
|
||||||
batch_size=cfg.micro_batch_size,
|
|
||||||
)
|
|
||||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
|
||||||
cfg.micro_batch_size = 1
|
|
||||||
else:
|
|
||||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
|
||||||
|
|
||||||
dataset = load_dataset(path, streaming=True, split="train", name=name)
|
|
||||||
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
encode,
|
encode,
|
||||||
@@ -850,63 +701,3 @@ def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, s
|
|||||||
remove_columns=dataset.features.keys(),
|
remove_columns=dataset.features.keys(),
|
||||||
)
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def encode_packed_pretraining(
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
collate_fn,
|
|
||||||
examples: List[str],
|
|
||||||
max_seq_length: int = 2048,
|
|
||||||
batch_size: int = 4,
|
|
||||||
) -> Dict[str, List]:
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
# tokenize all the examples
|
|
||||||
# rows get split with stride (overlap)
|
|
||||||
res = tokenizer(
|
|
||||||
examples,
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_seq_length - 1,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_overflowing_tokens=True,
|
|
||||||
stride=256,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
|
|
||||||
attention_mask = [seq + [1] for seq in res["attention_mask"]]
|
|
||||||
|
|
||||||
tokenized_examples = {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
}
|
|
||||||
|
|
||||||
train_dataset = Dataset.from_dict(tokenized_examples)
|
|
||||||
train_dataset = process_pretraining_datasets_for_packing(
|
|
||||||
train_dataset, max_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
sampler = MultipackBatchSampler(
|
|
||||||
RandomSampler(train_dataset),
|
|
||||||
batch_size=batch_size,
|
|
||||||
drop_last=True,
|
|
||||||
batch_max_len=batch_size * max_seq_length,
|
|
||||||
lengths=(
|
|
||||||
train_dataset.data.column("position_ids")
|
|
||||||
.to_pandas()
|
|
||||||
.apply(lambda x: x[-1] + 1)
|
|
||||||
.values
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
chunked_data = defaultdict(list)
|
|
||||||
|
|
||||||
for data in sampler:
|
|
||||||
features = train_dataset[data]
|
|
||||||
features["labels"] = features["input_ids"].copy()
|
|
||||||
collated_features = collate_fn(features)
|
|
||||||
|
|
||||||
for feature in features.keys():
|
|
||||||
if feature == "length":
|
|
||||||
continue
|
|
||||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
|
||||||
|
|
||||||
return chunked_data
|
|
||||||
|
|||||||
342
src/axolotl/utils/dataloader.py
Normal file
342
src/axolotl/utils/dataloader.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
import hashlib
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from queue import Queue
|
||||||
|
from threading import Thread
|
||||||
|
from typing import Any, Callable, List, Union
|
||||||
|
|
||||||
|
import numba
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.utils.dataloader")
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||||
|
# First-fit-decreasing bin packing
|
||||||
|
# Check if a[] could fit in n bins with capacity c
|
||||||
|
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||||
|
|
||||||
|
a = np.sort(a)[::-1]
|
||||||
|
bins = np.full((n,), c, dtype=a.dtype)
|
||||||
|
for size in a:
|
||||||
|
not_found = True
|
||||||
|
for idx in range(n):
|
||||||
|
if bins[idx] >= size:
|
||||||
|
bins[idx] -= size
|
||||||
|
not_found = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if not_found:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||||
|
# First-fit-decreasing bin packing (with result return)
|
||||||
|
|
||||||
|
indices = np.argsort(a)[::-1]
|
||||||
|
a = a[indices]
|
||||||
|
|
||||||
|
bins: List[Any] = []
|
||||||
|
bins_result: List[Any] = []
|
||||||
|
for a_id, size in enumerate(a):
|
||||||
|
add_new = True
|
||||||
|
for idx in range(len(bins)):
|
||||||
|
if bins[idx] >= size:
|
||||||
|
bins[idx] -= size
|
||||||
|
bins_result[idx].append(indices[a_id] + start_index)
|
||||||
|
add_new = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if add_new:
|
||||||
|
bins.append(c - size)
|
||||||
|
bins_result.append([indices[a_id] + start_index])
|
||||||
|
|
||||||
|
return bins_result, len(a)
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def allocate(
|
||||||
|
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param lengths: array of lengths of each sample
|
||||||
|
:param lengths_cumsum: cumulative sum of consecutive lengths
|
||||||
|
:param rank: rank for this process
|
||||||
|
:param c: length of tokens per batch
|
||||||
|
:param n: number of ranks
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# Dynamic batch allocator, similar to Multifit
|
||||||
|
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||||
|
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||||
|
|
||||||
|
s = 0
|
||||||
|
start_index = 0
|
||||||
|
result = []
|
||||||
|
result_totseqs = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# binary search [left, right)
|
||||||
|
left = 1
|
||||||
|
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||||
|
|
||||||
|
while right - left > 1:
|
||||||
|
mid = (left + right) // 2
|
||||||
|
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
||||||
|
left = mid
|
||||||
|
else:
|
||||||
|
right = mid
|
||||||
|
|
||||||
|
# use length left
|
||||||
|
batch, tot_seqs = ffd_with_result(
|
||||||
|
lengths[start_index : start_index + left], c, start_index
|
||||||
|
)
|
||||||
|
if len(batch) < n:
|
||||||
|
break
|
||||||
|
|
||||||
|
start_index += left
|
||||||
|
s = lengths_cumsum[start_index - 1]
|
||||||
|
|
||||||
|
# add local rank
|
||||||
|
result.append(batch[rank])
|
||||||
|
# add total seqs for all ranks
|
||||||
|
result_totseqs.append(tot_seqs)
|
||||||
|
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||||
|
return result, result_totseqs, s, len(result) * c * n
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(iterable, n):
|
||||||
|
"""
|
||||||
|
Chunk data into tuples of length n
|
||||||
|
"""
|
||||||
|
# batched('ABCDEFG', 3) --> ABC DEF G
|
||||||
|
if n < 1:
|
||||||
|
raise ValueError("n must be at least one")
|
||||||
|
it = iter(iterable)
|
||||||
|
while batch := tuple(itertools.islice(it, n)):
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def hash_indices(lst: List[int]) -> str:
|
||||||
|
# Convert the list of integers to a string representation
|
||||||
|
concatenated = ",".join(map(str, lst))
|
||||||
|
|
||||||
|
# Generate the hash
|
||||||
|
sha256 = hashlib.sha256()
|
||||||
|
sha256.update(concatenated.encode())
|
||||||
|
|
||||||
|
return sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class MultipackDistributedDataloader:
|
||||||
|
"""Unpadded data loading using Multipack.
|
||||||
|
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
||||||
|
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: Any,
|
||||||
|
collate_fn: Callable,
|
||||||
|
seq_max_length: int = 2048,
|
||||||
|
batch_size: int = 1,
|
||||||
|
sampler: Union[Sampler, DistributedSampler] = None,
|
||||||
|
packing_efficiency_estimate: float = 1.0,
|
||||||
|
sample_packing_seq_len_multiplier: int = 1,
|
||||||
|
device_count: int = 1,
|
||||||
|
prefetch_max: int = 1000,
|
||||||
|
num_epochs: int = 1,
|
||||||
|
):
|
||||||
|
# Dataset
|
||||||
|
self.dataset = dataset
|
||||||
|
self.lengths = (
|
||||||
|
dataset.data.column("position_ids")
|
||||||
|
.to_pandas()
|
||||||
|
.apply(lambda x: x[-1] + 1)
|
||||||
|
.values
|
||||||
|
)
|
||||||
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
|
assert batch_size % sample_packing_seq_len_multiplier == 0
|
||||||
|
assert batch_size >= sample_packing_seq_len_multiplier
|
||||||
|
self.sampler = sampler
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
|
||||||
|
self.seq_max_length = seq_max_length
|
||||||
|
self.batch_max_length = batch_size * seq_max_length
|
||||||
|
self.collate_fn = collate_fn
|
||||||
|
self.num_epochs = num_epochs
|
||||||
|
|
||||||
|
self.num_replicas = 1
|
||||||
|
self.rank = 0
|
||||||
|
|
||||||
|
# statistics
|
||||||
|
self.eff_total_used = 0
|
||||||
|
self.eff_total_slots = 0
|
||||||
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
|
self.device_count = device_count
|
||||||
|
|
||||||
|
# maxsize is maximum number of samples in queue
|
||||||
|
self.prefetch_max = prefetch_max
|
||||||
|
self.queue: Queue = Queue(maxsize=prefetch_max)
|
||||||
|
self.thread = None
|
||||||
|
|
||||||
|
def _worker(self):
|
||||||
|
LOG.info(
|
||||||
|
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
|
||||||
|
)
|
||||||
|
for epoch in range(self.num_epochs):
|
||||||
|
for sample in self._internal_batch_generator():
|
||||||
|
while True:
|
||||||
|
if self.queue.full():
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
self.queue.put(sample)
|
||||||
|
|
||||||
|
# stop the queue when epoch is done
|
||||||
|
self.queue.put(None)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if hasattr(self.sampler, "set_epoch"):
|
||||||
|
new_epoch = self.sampler.epoch + 1
|
||||||
|
self.sampler.set_epoch(new_epoch)
|
||||||
|
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||||
|
|
||||||
|
if self.thread is None:
|
||||||
|
self.thread = Thread(target=self._worker, daemon=True)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = self.queue.get()
|
||||||
|
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def generate_batches(self, set_stats=False):
|
||||||
|
LOG.info("generating packed batches")
|
||||||
|
if self.sampler:
|
||||||
|
indices = [idx for idx in self.sampler]
|
||||||
|
else:
|
||||||
|
indices = range(0, len(self.dataset))
|
||||||
|
|
||||||
|
LOG.info(hash_indices(indices))
|
||||||
|
lengths = self.lengths[indices]
|
||||||
|
lengths_cumsum = np.cumsum(lengths)
|
||||||
|
|
||||||
|
batches, totseqs, total_used, total_slots = allocate(
|
||||||
|
lengths=lengths,
|
||||||
|
lengths_cumsum=lengths_cumsum,
|
||||||
|
rank=self.rank,
|
||||||
|
# c=self.batch_max_length,
|
||||||
|
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||||
|
n=self.num_replicas,
|
||||||
|
)
|
||||||
|
|
||||||
|
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||||
|
|
||||||
|
# statistics
|
||||||
|
if set_stats:
|
||||||
|
self.eff_total_used += total_used
|
||||||
|
self.eff_total_slots += total_slots
|
||||||
|
|
||||||
|
return batches, totseqs
|
||||||
|
|
||||||
|
def _internal_batch_generator(self):
|
||||||
|
all_batches, _ = self.generate_batches(set_stats=True)
|
||||||
|
features = self.dataset.features.keys()
|
||||||
|
len_remaining = self._len_est()
|
||||||
|
for batches in chunk(
|
||||||
|
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
||||||
|
):
|
||||||
|
chunked_data = []
|
||||||
|
attn_mask_cum_idx = 0
|
||||||
|
for batch in batches:
|
||||||
|
concatenated = {}
|
||||||
|
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||||
|
for feature in features:
|
||||||
|
if feature == "length":
|
||||||
|
continue
|
||||||
|
if feature == "attention_mask":
|
||||||
|
arrays = [
|
||||||
|
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||||
|
for idx, item in enumerate(batched_data)
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
attn_mask_cum_idx += len(batched_data)
|
||||||
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
|
else:
|
||||||
|
arrays = [
|
||||||
|
np.array(item[feature])
|
||||||
|
for item in batched_data
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
|
chunked_data.append(concatenated)
|
||||||
|
yield self.collate_fn(chunked_data)
|
||||||
|
len_remaining -= 1
|
||||||
|
if not len_remaining:
|
||||||
|
return
|
||||||
|
# yield a no-op for cases where we don't have any data left to pack
|
||||||
|
for i in range(0, len_remaining):
|
||||||
|
yield self.collate_fn(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"input_ids": [0],
|
||||||
|
"labels": [-100],
|
||||||
|
"attention_mask": [True],
|
||||||
|
"position_ids": [0],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _len_est(self):
|
||||||
|
lengths_sum = np.sum(self.lengths)
|
||||||
|
lengths_sum_per_device = lengths_sum // self.device_count
|
||||||
|
LOG.info(
|
||||||
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
|
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||||
|
return (
|
||||||
|
math.floor(
|
||||||
|
0.99
|
||||||
|
* lengths_sum_per_device
|
||||||
|
/ self.packing_efficiency_estimate
|
||||||
|
// self.seq_max_length
|
||||||
|
// self.batch_size
|
||||||
|
)
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
|
||||||
|
# the same share of total tokens
|
||||||
|
# if not self.eff_total_used:
|
||||||
|
# batches, _ = self.generate_batches(set_stats=True)
|
||||||
|
# LOG.info(
|
||||||
|
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
|
# f"actual packing efficiency: {self.efficiency()}"
|
||||||
|
# )
|
||||||
|
return max(1, self._len_est())
|
||||||
|
|
||||||
|
def len_w_stats(self):
|
||||||
|
if not self.eff_total_used:
|
||||||
|
batches, _ = self.generate_batches(set_stats=True)
|
||||||
|
LOG.info(
|
||||||
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
|
f"actual packing efficiency: {self.efficiency()}"
|
||||||
|
)
|
||||||
|
return max(1, self._len_est())
|
||||||
|
|
||||||
|
def efficiency(self):
|
||||||
|
return self.eff_total_used / self.eff_total_slots
|
||||||
@@ -50,17 +50,6 @@ def get_world_size():
|
|||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def zero_only():
|
|
||||||
"""
|
|
||||||
Context manager that only runs the enclosed block on the main rank.
|
|
||||||
"""
|
|
||||||
if is_main_process():
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
yield None
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def zero_first(is_main):
|
def zero_first(is_main):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
"""
|
|
||||||
module to freeze/unfreeze parameters by name
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.freeze")
|
|
||||||
|
|
||||||
|
|
||||||
def freeze_parameters_except(model, regex_patterns):
|
|
||||||
"""
|
|
||||||
Freezes all layers of the given model except for the layers that match given regex patterns.
|
|
||||||
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- model (nn.Module): The PyTorch model to be modified.
|
|
||||||
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None; the model is modified in place.
|
|
||||||
"""
|
|
||||||
# Escape periods and compile the regex patterns
|
|
||||||
compiled_patterns = [
|
|
||||||
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
|
||||||
]
|
|
||||||
|
|
||||||
# First, freeze all parameters in the model
|
|
||||||
for param in model.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
# Unfreeze layers that match the regex patterns
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if any(pattern.match(name) for pattern in compiled_patterns):
|
|
||||||
if is_main_process():
|
|
||||||
LOG.debug(f"unfreezing {name}")
|
|
||||||
param.requires_grad = True
|
|
||||||
@@ -2,12 +2,12 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional, Tuple # noqa: F401
|
from typing import Optional, Tuple # noqa: F401
|
||||||
|
|
||||||
import addict
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
import transformers.utils.bitsandbytes
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from peft import PeftConfig, prepare_model_for_kbit_training
|
from peft import PeftConfig, prepare_model_for_kbit_training
|
||||||
from peft.tuners.lora import QuantLinear
|
from peft.tuners.lora import QuantLinear
|
||||||
@@ -18,65 +18,24 @@ from transformers import ( # noqa: F401
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
|
LlamaConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
|
||||||
|
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
|
||||||
quant_config_exists = hasattr(model_config, "quantization_config")
|
|
||||||
quant_config_method_is_gptq = (
|
|
||||||
quant_config_exists
|
|
||||||
and "quant_method" in model_config.quantization_config
|
|
||||||
and model_config.quantization_config["quant_method"] == "gptq"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.gptq and not quant_config_method_is_gptq:
|
|
||||||
raise ValueError(
|
|
||||||
"model_config.quantization_config is not set or quant_method is not set to gptq. "
|
|
||||||
"Please make sure to point to a GPTQ model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not cfg.gptq and quant_config_exists:
|
|
||||||
raise ValueError(
|
|
||||||
"model_config.quantization_config is set but `gptq` flag is not. "
|
|
||||||
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(cfg):
|
def load_model_config(cfg):
|
||||||
model_config_name = cfg.base_model_config or cfg.base_model
|
model_config_name = cfg.base_model_config or cfg.base_model
|
||||||
trust_remote_code = cfg.trust_remote_code is True
|
trust_remote_code = cfg.trust_remote_code is True
|
||||||
|
return AutoConfig.from_pretrained(
|
||||||
try:
|
model_config_name, trust_remote_code=trust_remote_code
|
||||||
model_config = AutoConfig.from_pretrained(
|
)
|
||||||
model_config_name, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
except ValueError as err:
|
|
||||||
if "mamba" in model_config_name:
|
|
||||||
return addict.Dict(
|
|
||||||
{
|
|
||||||
"model_type": "mamba",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
raise err
|
|
||||||
|
|
||||||
if cfg.model_config:
|
|
||||||
for key, val in cfg.model_config.items():
|
|
||||||
setattr(model_config, key, val)
|
|
||||||
|
|
||||||
check_model_config(cfg, model_config)
|
|
||||||
|
|
||||||
return model_config
|
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(cfg):
|
||||||
@@ -93,7 +52,7 @@ def load_tokenizer(cfg):
|
|||||||
if cfg.tokenizer_type:
|
if cfg.tokenizer_type:
|
||||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||||
|
|
||||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||||
tokenizer = tokenizer_cls.from_pretrained(
|
tokenizer = tokenizer_cls.from_pretrained(
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
@@ -107,7 +66,6 @@ def load_tokenizer(cfg):
|
|||||||
"LlamaTokenizer",
|
"LlamaTokenizer",
|
||||||
"LlamaTokenizerFast",
|
"LlamaTokenizerFast",
|
||||||
"CodeLlamaTokenizer",
|
"CodeLlamaTokenizer",
|
||||||
"CodeLlamaTokenizerFast",
|
|
||||||
]
|
]
|
||||||
and hasattr(tokenizer, "pad_token")
|
and hasattr(tokenizer, "pad_token")
|
||||||
and not tokenizer.pad_token
|
and not tokenizer.pad_token
|
||||||
@@ -123,57 +81,11 @@ def load_tokenizer(cfg):
|
|||||||
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
# Qwen base only has single token, so we need to set the special tokens
|
|
||||||
if cfg.is_qwen_derived_model:
|
|
||||||
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
|
||||||
for attr_name in token_ids:
|
|
||||||
if getattr(tokenizer, attr_name) is None:
|
|
||||||
setattr(tokenizer, attr_name, tokenizer.eod_id)
|
|
||||||
|
|
||||||
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
|
|
||||||
for attr_name in token_names:
|
|
||||||
if getattr(tokenizer, attr_name) is None:
|
|
||||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
|
||||||
|
|
||||||
if cfg.special_tokens:
|
if cfg.special_tokens:
|
||||||
for k, val in cfg.special_tokens.items():
|
for k, val in cfg.special_tokens.items():
|
||||||
# check if new special token is not already in tokenizer and
|
|
||||||
# is adapter training to make sure lora_modules_to_save is set
|
|
||||||
if (
|
|
||||||
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
|
||||||
and cfg.adapter
|
|
||||||
and (
|
|
||||||
not cfg.lora_modules_to_save
|
|
||||||
or not all(
|
|
||||||
x in cfg.lora_modules_to_save
|
|
||||||
for x in ["embed_tokens", "lm_head"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we add bos_token and eos_token, we need to update the post processor to
|
|
||||||
# handle them correctly.
|
|
||||||
# https://github.com/huggingface/transformers/pull/24132
|
|
||||||
bos_or_eos_in_special_tokens = (
|
|
||||||
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
tokenizer.__class__.__name__
|
|
||||||
in (
|
|
||||||
"LlamaTokenizerFast",
|
|
||||||
"CodeLlamaTokenizerFast",
|
|
||||||
)
|
|
||||||
and bos_or_eos_in_special_tokens
|
|
||||||
):
|
|
||||||
tokenizer.update_post_processor()
|
|
||||||
|
|
||||||
if cfg.tokens:
|
if cfg.tokens:
|
||||||
tokenizer.add_tokens(
|
tokenizer.add_tokens(
|
||||||
[
|
[
|
||||||
@@ -187,12 +99,6 @@ def load_tokenizer(cfg):
|
|||||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||||
|
|
||||||
if cfg.chat_template:
|
|
||||||
tokenizer.chat_template = chat_templates(cfg.chat_template)
|
|
||||||
else:
|
|
||||||
LOG.info(
|
|
||||||
"No Chat template selected. Consider adding a chat template for easier inference."
|
|
||||||
)
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@@ -200,12 +106,12 @@ def load_model(
|
|||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
inference: bool = False,
|
inference: bool = False,
|
||||||
reference_model: bool = False,
|
|
||||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||||
"""
|
"""
|
||||||
Load a model for a given configuration and tokenizer.
|
Load a model for a given configuration and tokenizer.
|
||||||
"""
|
"""
|
||||||
base_model = cfg.base_model
|
base_model = cfg.base_model
|
||||||
|
base_model_config = cfg.base_model_config
|
||||||
model_type = cfg.model_type
|
model_type = cfg.model_type
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
|
|
||||||
@@ -255,6 +161,17 @@ def load_model(
|
|||||||
|
|
||||||
LOG.info("patching with sdp attention")
|
LOG.info("patching with sdp attention")
|
||||||
hijack_llama_sdp_attention()
|
hijack_llama_sdp_attention()
|
||||||
|
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||||
|
MEM_TOKEN,
|
||||||
|
patch_llama_with_landmark_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching with landmark attention")
|
||||||
|
patch_llama_with_landmark_attn()
|
||||||
|
|
||||||
|
# Note: This might overwrite previous additional_special_tokens
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||||
|
|
||||||
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||||
@@ -264,17 +181,13 @@ def load_model(
|
|||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
|
|
||||||
if (
|
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||||
cfg.model_config_type == "mixtral"
|
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||||
and cfg.flash_attention
|
replace_llama_rope_with_xpos_rope,
|
||||||
and cfg.sample_packing
|
|
||||||
):
|
|
||||||
from axolotl.monkeypatch.mixtral import (
|
|
||||||
replace_mixtral_attn_with_multipack_flash_attn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with xpos rope")
|
||||||
replace_mixtral_attn_with_multipack_flash_attn()
|
replace_llama_rope_with_xpos_rope()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.is_llama_derived_model
|
cfg.is_llama_derived_model
|
||||||
@@ -288,50 +201,8 @@ def load_model(
|
|||||||
|
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
|
|
||||||
max_memory = cfg.max_memory
|
model_kwargs["device_map"] = cfg.device_map
|
||||||
device_map = cfg.device_map
|
|
||||||
|
|
||||||
if cfg.gpu_memory_limit:
|
|
||||||
gpu_memory_limit = (
|
|
||||||
str(cfg.gpu_memory_limit) + "GiB"
|
|
||||||
if isinstance(cfg.gpu_memory_limit, int)
|
|
||||||
else cfg.gpu_memory_limit
|
|
||||||
)
|
|
||||||
|
|
||||||
max_memory = {}
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
max_memory[i] = gpu_memory_limit
|
|
||||||
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
|
|
||||||
|
|
||||||
if max_memory is not None:
|
|
||||||
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
|
|
||||||
from accelerate import infer_auto_device_map, init_empty_weights
|
|
||||||
|
|
||||||
with init_empty_weights():
|
|
||||||
model_canvas = AutoModelForCausalLM.from_config(model_config)
|
|
||||||
model_canvas.tie_weights()
|
|
||||||
device_map = infer_auto_device_map(
|
|
||||||
model_canvas,
|
|
||||||
max_memory=max_memory,
|
|
||||||
dtype=cfg.torch_dtype,
|
|
||||||
)
|
|
||||||
# We can discard max_memory now as we have a device map set up for us
|
|
||||||
max_memory = None
|
|
||||||
|
|
||||||
model_kwargs["device_map"] = device_map
|
|
||||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||||
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
|
|
||||||
# if cfg.rl:
|
|
||||||
# if torch.cuda.device_count() > 1:
|
|
||||||
# if reference_model:
|
|
||||||
# model_kwargs["device_map"] = "cuda:" + str(
|
|
||||||
# torch.cuda.current_device() + 1
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
|
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
del model_kwargs["device_map"]
|
|
||||||
|
|
||||||
if cfg.model_revision:
|
if cfg.model_revision:
|
||||||
model_kwargs["revision"] = cfg.model_revision
|
model_kwargs["revision"] = cfg.model_revision
|
||||||
@@ -347,53 +218,42 @@ def load_model(
|
|||||||
**model_config.quantization_config
|
**model_config.quantization_config
|
||||||
)
|
)
|
||||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||||
bnb_config = {
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"llm_int8_threshold": 6.0,
|
|
||||||
"llm_int8_has_fp16_weight": False,
|
|
||||||
"bnb_4bit_compute_dtype": cfg.torch_dtype,
|
|
||||||
"bnb_4bit_use_double_quant": True,
|
|
||||||
"bnb_4bit_quant_type": "nf4",
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.bnb_config_kwargs:
|
|
||||||
bnb_config.update(cfg.bnb_config_kwargs)
|
|
||||||
|
|
||||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
load_in_4bit=True,
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
|
llm_int8_has_fp16_weight=False,
|
||||||
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
# sample packing uses custom FA2 patch
|
# sample packing uses custom FA2 patch
|
||||||
if cfg.flash_attention:
|
if cfg.flash_attention and not cfg.sample_packing:
|
||||||
if not cfg.sample_packing:
|
if (
|
||||||
if (
|
cfg.is_llama_derived_model
|
||||||
cfg.is_llama_derived_model
|
or cfg.is_falcon_derived_model
|
||||||
or cfg.is_falcon_derived_model
|
or cfg.is_mistral_derived_model
|
||||||
or cfg.is_mistral_derived_model
|
):
|
||||||
or model_config.model_type == "mixtral"
|
model_kwargs["use_flash_attention_2"] = True
|
||||||
):
|
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"flash_attention_2"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if model_config.model_type == "mixtral":
|
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"flash_attention_2"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_kwargs["attn_implementation"] = "eager"
|
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"eager"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
if (
|
||||||
|
cfg.is_llama_derived_model
|
||||||
|
and not cfg.trust_remote_code
|
||||||
|
and not cfg.gptq
|
||||||
|
and not cfg.tensor_parallel
|
||||||
|
):
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
|
config_kwargs = {}
|
||||||
|
if cfg.rope_scaling:
|
||||||
|
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
||||||
|
config = LlamaConfig.from_pretrained(
|
||||||
|
base_model_config,
|
||||||
|
**config_kwargs,
|
||||||
|
)
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=config,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -438,98 +298,106 @@ def load_model(
|
|||||||
# device=cfg.device,
|
# device=cfg.device,
|
||||||
# )
|
# )
|
||||||
# model.train() # sets to train instead of eval mode
|
# model.train() # sets to train instead of eval mode
|
||||||
elif model_type == "PhiForCausalLM":
|
elif model_type == "MixFormerSequentialForCausalLM":
|
||||||
from axolotl.models.phi import PhiForCausalLM
|
from axolotl.models.phi import MixFormerSequentialForCausalLM
|
||||||
|
|
||||||
model = PhiForCausalLM.from_pretrained(
|
model = MixFormerSequentialForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif model_type == "MambaLMHeadModel":
|
elif model_type and not cfg.trust_remote_code and not cfg.tensor_parallel:
|
||||||
# FIXME this is janky at best and hacked together to make it work
|
|
||||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
|
|
||||||
model_kwargs["device"] = torch.cuda.current_device()
|
|
||||||
del model_kwargs["torch_dtype"]
|
|
||||||
del model_kwargs["device_map"]
|
|
||||||
|
|
||||||
model = MambaLMHeadModel.from_pretrained(
|
|
||||||
base_model,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
elif model_type and not cfg.trust_remote_code:
|
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = getattr(transformers, model_type).from_pretrained(
|
model = getattr(transformers, model_type).from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
elif cfg.tensor_parallel:
|
||||||
|
model_kwargs.pop("device_map")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
offload_state_dict=True,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
)
|
||||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||||
# when training starts
|
# when training starts
|
||||||
if (
|
if (
|
||||||
hasattr(model_config, "max_seq_len")
|
hasattr(config, "max_seq_len")
|
||||||
and model_config.max_seq_len
|
and config.max_seq_len
|
||||||
and cfg.sequence_len > model_config.max_seq_len
|
and cfg.sequence_len > config.max_seq_len
|
||||||
):
|
):
|
||||||
model_config.max_seq_len = cfg.sequence_len
|
config.max_seq_len = cfg.sequence_len
|
||||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||||
elif (
|
elif (
|
||||||
hasattr(model_config, "max_sequence_length")
|
hasattr(config, "max_sequence_length")
|
||||||
and model_config.max_sequence_length
|
and config.max_sequence_length
|
||||||
and cfg.sequence_len > model_config.max_sequence_length
|
and cfg.sequence_len > config.max_sequence_length
|
||||||
):
|
):
|
||||||
model_config.max_sequence_length = cfg.sequence_len
|
config.max_sequence_length = cfg.sequence_len
|
||||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=config,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
except Exception as err: # pylint: disable=broad-exception-caught
|
except Exception as err: # pylint: disable=broad-exception-caught
|
||||||
|
LOG.error(
|
||||||
|
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
||||||
|
)
|
||||||
LOG.exception(err)
|
LOG.exception(err)
|
||||||
raise err
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
embeddings_len = (
|
try:
|
||||||
math.ceil(len(tokenizer) / 32) * 32
|
embeddings_len = (
|
||||||
if cfg.resize_token_embeddings_to_32x
|
math.ceil(len(tokenizer) / 32) * 32
|
||||||
else len(tokenizer)
|
if cfg.resize_token_embeddings_to_32x
|
||||||
)
|
else len(tokenizer)
|
||||||
if (
|
)
|
||||||
hasattr(model, "get_input_embeddings")
|
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||||
and model.get_input_embeddings().num_embeddings < embeddings_len
|
model.resize_token_embeddings(embeddings_len)
|
||||||
):
|
else:
|
||||||
model.resize_token_embeddings(embeddings_len)
|
model.tie_weights()
|
||||||
else:
|
except NotImplementedError:
|
||||||
model.tie_weights()
|
LOG.warning("`resize_token_embeddings` not implemented on model")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model, "config")
|
hasattr(model.config, "max_position_embeddings")
|
||||||
and hasattr(model.config, "max_position_embeddings")
|
|
||||||
and model.config.max_position_embeddings
|
and model.config.max_position_embeddings
|
||||||
and cfg.sequence_len > model.config.max_position_embeddings
|
and cfg.sequence_len > model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
@@ -539,22 +407,20 @@ def load_model(
|
|||||||
model.config.max_position_embeddings = cfg.sequence_len
|
model.config.max_position_embeddings = cfg.sequence_len
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model, "config")
|
hasattr(model.config, "bos_token_id")
|
||||||
and hasattr(model.config, "bos_token_id")
|
|
||||||
and model.config.bos_token_id
|
and model.config.bos_token_id
|
||||||
and model.config.bos_token_id != tokenizer.bos_token_id
|
and model.config.bos_token_id != tokenizer.bos_token_id
|
||||||
):
|
):
|
||||||
model.config.bos_token_id = tokenizer.bos_token_id
|
model.config.bos_token_id = tokenizer.bos_token_id
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model, "config")
|
hasattr(model.config, "eos_token_id")
|
||||||
and hasattr(model.config, "eos_token_id")
|
|
||||||
and model.config.eos_token_id
|
and model.config.eos_token_id
|
||||||
and model.config.eos_token_id != tokenizer.eos_token_id
|
and model.config.eos_token_id != tokenizer.eos_token_id
|
||||||
):
|
):
|
||||||
model.config.eos_token_id = tokenizer.eos_token_id
|
model.config.eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
if hasattr(model, "device") and model.device.type == "cuda":
|
if model.device.type == "cuda":
|
||||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||||
|
|
||||||
# make sure these are fp32 per Ramesh et al. (2021)
|
# make sure these are fp32 per Ramesh et al. (2021)
|
||||||
@@ -569,22 +435,15 @@ def load_model(
|
|||||||
module.to(torch.float32)
|
module.to(torch.float32)
|
||||||
|
|
||||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||||
skip_prepare_model_for_kbit_training = False
|
|
||||||
|
|
||||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
|
||||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
|
||||||
skip_prepare_model_for_kbit_training = True
|
|
||||||
|
|
||||||
if (cfg.adapter == "lora" and load_in_8bit) or (
|
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||||
cfg.adapter == "qlora" and cfg.load_in_4bit
|
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||||
):
|
):
|
||||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
if not skip_prepare_model_for_kbit_training:
|
model = prepare_model_for_kbit_training(
|
||||||
model = prepare_model_for_kbit_training(
|
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
)
|
||||||
)
|
|
||||||
needs_fa2_dtype = True
|
needs_fa2_dtype = True
|
||||||
|
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||||
@@ -598,14 +457,19 @@ def load_model(
|
|||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(cfg.torch_dtype)
|
module.to(cfg.torch_dtype)
|
||||||
|
|
||||||
lora_config = None
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
if not reference_model or cfg.lora_model_dir:
|
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
|
||||||
|
|
||||||
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
|
if cfg.ddp and not load_in_8bit:
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
if (
|
||||||
|
torch.cuda.device_count() > 1
|
||||||
|
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
||||||
|
and (cfg.load_in_4bit)
|
||||||
|
):
|
||||||
|
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
||||||
|
# so let's only set it for the 4bit, see
|
||||||
|
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
||||||
setattr(model, "is_parallelizable", True)
|
setattr(model, "is_parallelizable", True)
|
||||||
setattr(model, "model_parallel", True)
|
setattr(model, "model_parallel", True)
|
||||||
|
|
||||||
@@ -615,8 +479,7 @@ def load_model(
|
|||||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||||
if len(requires_grad) == 0:
|
if len(requires_grad) == 0:
|
||||||
LOG.warning("there are no parameters that require gradient updates")
|
LOG.warning("there are no parameters that require gradient updates")
|
||||||
if hasattr(model, "config"):
|
model.config.use_cache = False
|
||||||
model.config.use_cache = False
|
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.transform(model)
|
model = BetterTransformer.transform(model)
|
||||||
@@ -634,7 +497,12 @@ def load_adapter(model, cfg, adapter, inference=False):
|
|||||||
if adapter is None:
|
if adapter is None:
|
||||||
return model, None
|
return model, None
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
model.enable_input_require_grads()
|
try:
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
except NotImplementedError:
|
||||||
|
LOG.warning("enable_input_require_grads not implemented on model")
|
||||||
|
if adapter == "qlora" and cfg.tensor_parallel:
|
||||||
|
model, _ = load_tp_qlora(model)
|
||||||
if adapter in ["lora", "qlora"]:
|
if adapter in ["lora", "qlora"]:
|
||||||
return load_lora(model, cfg, inference=inference)
|
return load_lora(model, cfg, inference=inference)
|
||||||
if adapter == "llama-adapter":
|
if adapter == "llama-adapter":
|
||||||
@@ -686,6 +554,25 @@ def find_all_linear_names(model):
|
|||||||
return list(lora_module_names)
|
return list(lora_module_names)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tp_qlora(model):
|
||||||
|
from transformers.utils.bitsandbytes import replace_with_bnb_linear
|
||||||
|
|
||||||
|
model = replace_with_bnb_linear(
|
||||||
|
model,
|
||||||
|
quantization_config=BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
llm_int8_threshold=6.0,
|
||||||
|
llm_int8_has_fp16_weight=False,
|
||||||
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model.is_loaded_in_4bit = True
|
||||||
|
|
||||||
|
return model, None
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model, cfg, inference=False):
|
def load_lora(model, cfg, inference=False):
|
||||||
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
|
|
||||||
@@ -711,15 +598,10 @@ def load_lora(model, cfg, inference=False):
|
|||||||
|
|
||||||
if cfg.lora_model_dir:
|
if cfg.lora_model_dir:
|
||||||
LOG.debug("Loading pretained PEFT - LoRA")
|
LOG.debug("Loading pretained PEFT - LoRA")
|
||||||
model_kwargs: Any = {}
|
|
||||||
if cfg.lora_on_cpu:
|
|
||||||
model_kwargs["max_memory"] = {"cpu": "256GiB"}
|
|
||||||
model_kwargs["device_map"] = {"": "cpu"}
|
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.lora_model_dir,
|
cfg.lora_model_dir,
|
||||||
is_trainable=(not inference),
|
is_trainable=(not inference),
|
||||||
**model_kwargs,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
"""
|
|
||||||
axolotl samplers module
|
|
||||||
"""
|
|
||||||
from .multipack import MultipackBatchSampler # noqa: F401
|
|
||||||
@@ -1,196 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
"""
|
|
||||||
Multipack Batch Sampler
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from typing import Any, Iterable, List, Union
|
|
||||||
|
|
||||||
import numba
|
|
||||||
import numpy as np
|
|
||||||
from torch.utils.data import BatchSampler, Sampler
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
|
|
||||||
|
|
||||||
|
|
||||||
@numba.njit
|
|
||||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
|
||||||
# First-fit-decreasing bin packing
|
|
||||||
# Check if a[] could fit in n bins with capacity c
|
|
||||||
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
|
||||||
|
|
||||||
a = np.sort(a)[::-1]
|
|
||||||
bins = np.full((n,), c, dtype=a.dtype)
|
|
||||||
for size in a:
|
|
||||||
not_found = True
|
|
||||||
for idx in range(n):
|
|
||||||
if bins[idx] >= size:
|
|
||||||
bins[idx] -= size
|
|
||||||
not_found = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if not_found:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@numba.njit
|
|
||||||
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
|
||||||
# First-fit-decreasing bin packing (with result return)
|
|
||||||
|
|
||||||
indices = np.argsort(a)[::-1]
|
|
||||||
a = a[indices]
|
|
||||||
|
|
||||||
bins: List[Any] = []
|
|
||||||
bins_result: List[Any] = []
|
|
||||||
for a_id, size in enumerate(a):
|
|
||||||
add_new = True
|
|
||||||
for idx in range(len(bins)):
|
|
||||||
if bins[idx] >= size:
|
|
||||||
bins[idx] -= size
|
|
||||||
bins_result[idx].append(indices[a_id] + start_index)
|
|
||||||
add_new = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if add_new:
|
|
||||||
bins.append(c - size)
|
|
||||||
bins_result.append([indices[a_id] + start_index])
|
|
||||||
|
|
||||||
return bins_result
|
|
||||||
|
|
||||||
|
|
||||||
@numba.njit
|
|
||||||
def allocate(
|
|
||||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
|
||||||
):
|
|
||||||
# Dynamic batch allocator, similar to Multifit
|
|
||||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
|
||||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
|
||||||
|
|
||||||
s = 0
|
|
||||||
start_index = 0
|
|
||||||
result = []
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# binary search [l, r)
|
|
||||||
left = 1
|
|
||||||
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
|
||||||
|
|
||||||
while right - left > 1:
|
|
||||||
mid = (left + right) // 2
|
|
||||||
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
|
||||||
left = mid
|
|
||||||
else:
|
|
||||||
right = mid
|
|
||||||
|
|
||||||
# use length l
|
|
||||||
batch = ffd_with_result(
|
|
||||||
lengths[start_index : start_index + left], c, start_index
|
|
||||||
)
|
|
||||||
assert len(batch) <= n
|
|
||||||
if len(batch) < n:
|
|
||||||
break
|
|
||||||
|
|
||||||
start_index += left
|
|
||||||
s = lengths_cumsum[start_index - 1]
|
|
||||||
|
|
||||||
# add local rank
|
|
||||||
result.append(batch[rank])
|
|
||||||
|
|
||||||
return result, s, len(result) * c * n
|
|
||||||
|
|
||||||
|
|
||||||
class MultipackBatchSampler(BatchSampler):
|
|
||||||
"""
|
|
||||||
Batch Sampler class for multipack
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
sampler: Union[Sampler[int], Iterable[int]],
|
|
||||||
batch_size: int,
|
|
||||||
drop_last: bool,
|
|
||||||
batch_max_len: int,
|
|
||||||
lengths: np.ndarray,
|
|
||||||
packing_efficiency_estimate: float = 1.0,
|
|
||||||
):
|
|
||||||
super().__init__(sampler, batch_size, drop_last)
|
|
||||||
self.batch_size = None
|
|
||||||
self.batch_max_len = batch_max_len
|
|
||||||
self.lengths: np.ndarray = lengths
|
|
||||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
|
||||||
|
|
||||||
assert isinstance(self.lengths, np.ndarray)
|
|
||||||
|
|
||||||
self.epoch = 0
|
|
||||||
|
|
||||||
# statistics
|
|
||||||
self.eff_total_used = 0
|
|
||||||
self.eff_total_slots = 0
|
|
||||||
|
|
||||||
def set_epoch(self, epoch: int):
|
|
||||||
self.epoch = epoch
|
|
||||||
|
|
||||||
def generate_batches(self, set_stats=False):
|
|
||||||
indices = [idx for idx in self.sampler]
|
|
||||||
|
|
||||||
lengths = self.lengths[indices]
|
|
||||||
lengths_cumsum = np.cumsum(lengths)
|
|
||||||
|
|
||||||
batches, total_used, total_slots = allocate(
|
|
||||||
lengths=lengths,
|
|
||||||
lengths_cumsum=lengths_cumsum,
|
|
||||||
rank=0,
|
|
||||||
c=self.batch_max_len,
|
|
||||||
n=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
|
||||||
|
|
||||||
# statistics
|
|
||||||
if set_stats:
|
|
||||||
self.eff_total_used += total_used
|
|
||||||
self.eff_total_slots += total_slots
|
|
||||||
|
|
||||||
return batches
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
batches = self.generate_batches(set_stats=True)
|
|
||||||
return iter(batches)
|
|
||||||
|
|
||||||
def num_batches(self):
|
|
||||||
batches = self.generate_batches(set_stats=True)
|
|
||||||
return len(batches)
|
|
||||||
|
|
||||||
def efficiency(self):
|
|
||||||
return self.eff_total_used / self.eff_total_slots
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
self.num_batches()
|
|
||||||
return self._len_est()
|
|
||||||
|
|
||||||
def _len_est(self):
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
||||||
lengths_sum = np.sum(self.lengths)
|
|
||||||
lengths_sum_per_device = lengths_sum // world_size
|
|
||||||
LOG.info(
|
|
||||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
|
||||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
|
||||||
return max(
|
|
||||||
0,
|
|
||||||
(
|
|
||||||
world_size
|
|
||||||
* math.floor(
|
|
||||||
0.99
|
|
||||||
* lengths_sum_per_device
|
|
||||||
/ self.packing_efficiency_estimate
|
|
||||||
// self.batch_max_len
|
|
||||||
)
|
|
||||||
- 1
|
|
||||||
),
|
|
||||||
)
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user