Compare commits

..

3 Commits

Author SHA1 Message Date
Wing Lian
d46d7dfe30 wip 2024-02-01 00:28:16 -05:00
Wing Lian
047d9e1d5b helper utils 2024-01-31 12:49:29 -05:00
Wing Lian
88a0c05d2c wip 2024-01-31 12:07:39 -05:00
151 changed files with 1472 additions and 7346 deletions

View File

@@ -59,7 +59,6 @@ body:
label: Config yaml label: Config yaml
description: | description: |
Please attach the config yaml! Please attach the config yaml!
render: yaml
- type: textarea - type: textarea
id: possible-solution id: possible-solution

View File

@@ -7,11 +7,16 @@ jobs:
build-base: build-base:
if: github.repository_owner == 'OpenAccess-AI-Collective' if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: axolotl-gpu-runner runs-on: self-hosted
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118" - cuda: "118"
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"

View File

@@ -17,6 +17,6 @@ jobs:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.9"
cache: 'pip' # caching pip dependencies cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0 - uses: pre-commit/action@v3.0.0

View File

@@ -9,16 +9,21 @@ on:
jobs: jobs:
build-axolotl: build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }} if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
# this job needs to be run on self-hosted GPU runners...
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.1.2
axolotl_extras: axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
is_latest: true is_latest: true
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
@@ -30,7 +35,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.1.2 pytorch: 2.1.2
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -51,17 +56,27 @@ jobs:
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
load: true
build-args: | build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }} PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
file: ./docker/Dockerfile file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: | tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }} ${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
- name: Unit Tests
run: |
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
- name: Push to Docker Hub
if: github.event_name != 'pull_request'
run: |
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
if [ -n "$latest_tag" ]; then
docker push "$latest_tag"
fi
build-axolotl-runpod: build-axolotl-runpod:
needs: build-axolotl needs: build-axolotl
@@ -70,6 +85,11 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
@@ -86,7 +106,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.1.2 pytorch: 2.1.2
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@@ -23,7 +23,7 @@ jobs:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.9"
cache: 'pip' # caching pip dependencies cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0 - uses: pre-commit/action@v3.0.0
@@ -33,7 +33,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.10", "3.11"] python_version: ["3.9", "3.10", "3.11"]
timeout-minutes: 10 timeout-minutes: 10
steps: steps:
@@ -58,8 +58,8 @@ jobs:
docker-e2e-tests: docker-e2e-tests:
if: github.repository_owner == 'OpenAccess-AI-Collective' if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal] runs-on: [self-hosted, gpu, docker]
timeout-minutes: 60 timeout-minutes: 30
needs: [pre-commit, pytest] needs: [pre-commit, pytest]
strategy: strategy:
@@ -69,32 +69,44 @@ jobs:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.0.1
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
num_gpus: 1
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.1.2
num_gpus: 1
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install Python - name: Docker metadata
uses: actions/setup-python@v5 id: metadata
uses: docker/metadata-action@v5
with: with:
python-version: "3.10" images: winglian/axolotl-tests
- name: Install Modal - name: Build Docker image
run: | run: |
python -m pip install --upgrade pip # Set up build arguments
pip install modal jinja2 BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
- name: Update env vars CUDA="${{ matrix.cuda }}"
PYTORCH_VERSION="${{ matrix.pytorch }}"
# Build the Docker image
docker build . \
--file ./docker/Dockerfile-tests \
--build-arg BASE_TAG=$BASE_TAG \
--build-arg CUDA=$CUDA \
--build-arg GITHUB_REF=$GITHUB_REF \
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
--tag ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} \
--no-cache
- name: Unit Tests w docker image
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV - name: GPU Unit Tests w docker image
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: | run: |
modal run cicd.tests docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
- name: GPU Unit Tests monkeypatched w docker image
run: |
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest /workspace/axolotl/tests/e2e/patched/
- name: Prune image from docker
if: github.ref != 'refs/heads/main'
run: |
docker rmi -f ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}

5
.gitignore vendored
View File

@@ -167,8 +167,3 @@ cython_debug/
# WandB # WandB
# wandb creates a folder to store logs for training runs # wandb creates a folder to store logs for training runs
wandb wandb
# Runs
lora-out/*
qlora-out/*
mlruns/*

View File

@@ -1,5 +1,5 @@
[mypy] [mypy]
plugins = pydantic.mypy
exclude = venv exclude = venv
[mypy-alpaca_lora_4bit.*] [mypy-alpaca_lora_4bit.*]
@@ -32,9 +32,6 @@ ignore_missing_imports = True
[mypy-bitsandbytes] [mypy-bitsandbytes]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-requests]
ignore_missing_imports = True
[mypy-datasets] [mypy-datasets]
ignore_missing_imports = True ignore_missing_imports = True

View File

@@ -31,7 +31,6 @@ repos:
additional_dependencies: additional_dependencies:
[ [
'types-PyYAML', 'types-PyYAML',
'pydantic>=2.5.3',
] ]
- repo: https://github.com/PyCQA/bandit - repo: https://github.com/PyCQA/bandit
rev: 1.7.5 rev: 1.7.5

285
README.md
View File

@@ -13,9 +13,6 @@ Features:
- Log results and optionally checkpoints to wandb or mlflow - Log results and optionally checkpoints to wandb or mlflow
- And more! - And more!
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
<table> <table>
<tr> <tr>
@@ -25,25 +22,21 @@ Features:
- [Introduction](#axolotl) - [Introduction](#axolotl)
- [Supported Features](#axolotl-supports) - [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-) - [Quickstart](#quickstart-)
- [Environment](#environment) - [Installation](#installation)
- [Docker](#docker) - [Docker](#docker)
- [Conda/Pip venv](#condapip-venv) - [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod - [Cloud GPU](#cloud-gpu) - Runpod, Latitude
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu) - [LambdaLabs](#lambdalabs)
- [Windows](#windows) - [Windows](#windows)
- [Mac](#mac)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot) - [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Dataset](#dataset) - [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts) - [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config) - [Config](#config)
- [Train](#train) - [Train](#train)
- [Inference](#inference-playground) - [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base) - [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens) - [Special Tokens](#special-tokens)
- Advanced Topics
- [Multipack](./docs/multipack.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [RLHF & DPO](./docs/rlhf.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Common Errors](#common-errors-) - [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training) - [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl) - [Debugging Axolotl](#debugging-axolotl)
@@ -91,18 +84,17 @@ Features:
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ | | phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | | RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ | | Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
❓: untested
## Quickstart ⚡ ## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task. Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Python >=3.10 and Pytorch >=2.1.1. **Requirements**: Python >=3.9 and Pytorch >=2.0.
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
### For developers
```bash ```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl cd axolotl
@@ -126,20 +118,15 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
# gradio # gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio --lora_model_dir="./lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
``` ```
## Advanced Setup ## Installation
### Environment ### Environment
#### Docker #### Docker
```bash ```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
``` ```
Or run on the current files for development: Or run on the current files for development:
@@ -158,7 +145,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAcc
A more powerful Docker command to run would be this: A more powerful Docker command to run would be this:
```bash ```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
``` ```
It additionally: It additionally:
@@ -173,7 +160,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
</details> </details>
#### Conda/Pip venv #### Conda/Pip venv
1. Install python >=**3.10** 1. Install python >=**3.9**
2. Install pytorch stable https://pytorch.org/get-started/locally/ 2. Install pytorch stable https://pytorch.org/get-started/locally/
@@ -192,14 +179,9 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags) For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz) - on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### Bare Metal Cloud GPU #### LambdaLabs
##### LambdaLabs
<details> <details>
<summary>Click to Expand</summary> <summary>Click to Expand</summary>
@@ -207,11 +189,11 @@ For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud
1. Install python 1. Install python
```bash ```bash
sudo apt update sudo apt update
sudo apt install -y python3.10 sudo apt install -y python3.9
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
sudo update-alternatives --config python # pick 3.10 if given option sudo update-alternatives --config python # pick 3.9 if given option
python -V # should be 3.10 python -V # should be 3.9
``` ```
@@ -243,46 +225,21 @@ For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud
``` ```
</details> </details>
##### GCP
<details>
<summary>Click to Expand</summary>
Use a Deeplearning linux OS with cuda and pytorch installed. Then follow instructions on quickstart.
Make sure to run the below to uninstall xla.
```bash
pip uninstall -y torch_xla[tpu]
```
</details>
#### Windows #### Windows
Please use WSL or Docker! Please use WSL or Docker!
#### Mac
Use the below instead of the install method in QuickStart.
```
pip3 install -e '.'
```
More info: [mac.md](/docs/mac.md)
#### Launching on public clouds via SkyPilot #### Launching on public clouds via SkyPilot
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html): To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
```bash ```bash
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
sky check sky check
``` ```
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`: Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
``` ```
git clone https://github.com/skypilot-org/skypilot.git git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/axolotl cd skypilot/llm/axolotl
``` ```
Use one command to launch: Use one command to launch:
```bash ```bash
# On-demand # On-demand
@@ -292,32 +249,31 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
``` ```
### Dataset ### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use. Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
Have dataset(s) in one of the following format (JSONL recommended): Have dataset(s) in one of the following format (JSONL recommended):
#### Pretraining
- `completion`: raw corpus
```json
{"text": "..."}
```
Note: Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
```yaml
pretraining_dataset: # hf path only
```
#### Supervised finetuning
##### Instruction
- `alpaca`: instruction; input(optional) - `alpaca`: instruction; input(optional)
```json ```json
{"instruction": "...", "input": "...", "output": "..."} {"instruction": "...", "input": "...", "output": "..."}
``` ```
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
```yml
datasets:
- path: <your-path>
type: sharegpt
conversation: llama-2
```
- `completion`: raw corpus
```json
{"text": "..."}
```
<details> <details>
@@ -395,37 +351,14 @@ pretraining_dataset: # hf path only
```json ```json
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."} {"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."}
``` ```
- `metharme`: instruction, adds additional eos tokens
```json
{"prompt": "...", "generation": "..."}
```
</details>
##### Template-Free
- `input_output`: template-free prompt construction
```json
{"segments": [{"label": true|false, "text": "..."}]}
```
This is a special format that allows you to construct prompts without using templates. This is for advanced users who want more freedom with prompt construction. See [these docs](docs/input_output.md) for more details.
##### Conversation
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
<details>
<summary>See other formats</summary>
- `pygmalion`: pygmalion - `pygmalion`: pygmalion
```json ```json
{"conversations": [{"role": "...", "value": "..."}]} {"conversations": [{"role": "...", "value": "..."}]}
``` ```
- `metharme`: instruction, adds additional eos tokens
```json
{"prompt": "...", "generation": "..."}
```
- `sharegpt.load_role`: conversations where `role` is used instead of `from` - `sharegpt.load_role`: conversations where `role` is used instead of `from`
```json ```json
{"conversations": [{"role": "...", "value": "..."}]} {"conversations": [{"role": "...", "value": "..."}]}
@@ -441,8 +374,6 @@ This is a special format that allows you to construct prompts without using temp
</details> </details>
Note: `type: sharegpt` opens a special config `conversation:` that enables conversions to many Conversation types. See dataset section under [all yaml options](#all-yaml-options).
#### How to add custom prompts #### How to add custom prompts
For a dataset that is preprocessed for instruction purposes: For a dataset that is preprocessed for instruction purposes:
@@ -464,16 +395,12 @@ datasets:
format: "[INST] {instruction} [/INST]" format: "[INST] {instruction} [/INST]"
no_input_format: "[INST] {instruction} [/INST]" no_input_format: "[INST] {instruction} [/INST]"
``` ```
See full config options under [all yaml options](#all-yaml-options).
#### How to use your custom pretokenized dataset #### How to use your custom pretokenized dataset
- Do not pass a `type:` - Do not pass a `type:`
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels` - Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
```yaml
- path: ...
```
### Config ### Config
@@ -487,18 +414,22 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- dataset - dataset
```yaml ```yaml
datasets: sequence_len: 2048 # max token length for prompt
# huggingface repo
- path: vicgalle/alpaca-gpt4
type: alpaca
# huggingface repo with specific configuration/subset # huggingface repo
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca # format from earlier
# huggingface repo with specific configuration/subset
datasets:
- path: EleutherAI/pile - path: EleutherAI/pile
name: enron_emails name: enron_emails
type: completion # format from earlier type: completion # format from earlier
field: text # Optional[str] default: text, field to use for completion data field: text # Optional[str] default: text, field to use for completion data
# huggingface repo with multiple named configurations/subsets # huggingface repo with multiple named configurations/subsets
datasets:
- path: bigcode/commitpackft - path: bigcode/commitpackft
name: name:
- ruby - ruby
@@ -506,42 +437,39 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript - typescript
type: ... # unimplemented custom format type: ... # unimplemented custom format
# fastchat conversation # fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
datasets:
- path: ... - path: ...
type: sharegpt type: sharegpt
conversation: chatml # default: vicuna_v1.1 conversation: chatml
# local # local
datasets:
- path: data.jsonl # or json - path: data.jsonl # or json
ds_type: json # see other options below ds_type: json # see other options below
type: alpaca type: alpaca
# dataset with splits, but no train split # dataset with splits, but no train split
dataset:
- path: knowrohit07/know_sql - path: knowrohit07/know_sql
type: context_qa.load_v2 type: context_qa.load_v2
train_on_split: validation train_on_split: validation
# loading from s3 or gcs # loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access # s3 creds will be loaded from the system default and gcs only supports public access
dataset:
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs. - path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
... ...
# Loading Data From a Public URL
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
ds_type: json # this is the default, see other options below.
``` ```
- loading - loading
```yaml ```yaml
load_in_4bit: true load_in_4bit: true
load_in_8bit: true load_in_8bit: true
bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically. bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32 fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
tf32: true # require >=ampere tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision) bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP float16: true # use instead of fp16 when you don't want AMP
``` ```
@@ -549,7 +477,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- lora - lora
```yaml ```yaml
adapter: lora # 'qlora' or leave blank for full finetune adapter: lora # qlora or leave blank for full finetune
lora_r: 8 lora_r: 8
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
@@ -558,9 +486,9 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- v_proj - v_proj
``` ```
<details id="all-yaml-options"> <details>
<summary>All yaml options (click to expand)</summary> <summary>All yaml options (click me)</summary>
```yaml ```yaml
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files # This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
@@ -572,8 +500,8 @@ base_model_ignore_patterns:
# You can set that here, or leave this empty to default to base_model # You can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf base_model_config: ./llama-7b-hf
# You can specify to choose a specific model revision from huggingface hub # You can specify to choose a specific model revision from huggingface hub
revision_of_model: model_revision:
# Optional tokenizer configuration path in case you want to use a different tokenizer # Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model # than the one defined in the base model
tokenizer_config: tokenizer_config:
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too # If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
@@ -590,16 +518,15 @@ tokenizer_legacy:
# This is reported to improve training speed on some models # This is reported to improve training speed on some models
resize_token_embeddings_to_32x: resize_token_embeddings_to_32x:
# (Internal use only)
# Used to identify which the model is based on # Used to identify which the model is based on
is_falcon_derived_model: is_falcon_derived_model:
is_llama_derived_model: is_llama_derived_model:
is_qwen_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default # Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model: is_mistral_derived_model:
is_qwen_derived_model:
# optional overrides to the base model configuration # optional overrides to the base model configuration
overrides_of_model_config: model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653 # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling: rope_scaling:
type: # linear | dynamic type: # linear | dynamic
@@ -616,6 +543,8 @@ bnb_config_kwargs:
# Whether you are training a 4-bit GPTQ quantized model # Whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true load_in_8bit: true
@@ -651,13 +580,9 @@ datasets:
train_on_split: train # Optional[str] name of dataset split to load from train_on_split: train # Optional[str] name of dataset split to load from
# Optional[str] fastchat conversation type, only used with type: sharegpt # Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation. field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation. field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].
# Custom user instruction prompt # Custom user instruction prompt
- path: repo - path: repo
@@ -682,10 +607,6 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column # For `completion` datsets only, uses the provided field instead of `text` column
field: field:
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
# A list of one or more datasets to eval the model with. # A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both. # You can use either test_datasets, or val_set_size, but not both.
test_datasets: test_datasets:
@@ -697,7 +618,7 @@ test_datasets:
data_files: data_files:
- /workspace/data/eval.jsonl - /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair' # use RL training: dpo, ipo, kto_pair
rl: rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing # Saves the desired chat template to the tokenizer_config.json for easier inferencing
@@ -717,7 +638,7 @@ dataset_processes: # defaults to os.cpu_count() if not set
# Only needed if cached dataset is taking too much storage # Only needed if cached dataset is taking too much storage
dataset_keep_in_memory: dataset_keep_in_memory:
# push checkpoints to hub # push checkpoints to hub
hub_model_id: # private repo path to push finetuned model hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub # how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy # https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy: hub_strategy:
@@ -796,8 +717,6 @@ peft:
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # Number of steps per ReLoRA restart relora_steps: # Number of steps per ReLoRA restart
relora_warmup_steps: # Number of per-restart warmup steps relora_warmup_steps: # Number of per-restart warmup steps
relora_anneal_steps: # Number of anneal steps for each relora cycle
relora_prune_ratio: # threshold for optimizer magnitude when pruning
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it # wandb configuration if you're using it
@@ -813,7 +732,6 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it # mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name mlflow_experiment_name: # Your experiment name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Where to save the full-finetuned model to # Where to save the full-finetuned model to
output_dir: ./completed-model output_dir: ./completed-model
@@ -847,8 +765,7 @@ save_total_limit: # Checkpoints saved at a time
max_steps: max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
@@ -867,7 +784,7 @@ group_by_length: false
gradient_checkpointing: false gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing # additional kwargs to pass to the trainer for gradient checkpointing
# gradient_checkpointing_kwargs: # gradient_checkpointing_kwargs:
# use_reentrant: true # use_reentrant: false
# Stop training after this many evaluation losses have increased in a row # Stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
@@ -877,11 +794,14 @@ early_stopping_patience: 3
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs: lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
# For one_cycle optim # For one_cycle optim
lr_div_factor: # Learning rate div factor lr_div_factor: # Learning rate div factor
# For log_sweep optim
log_sweep_min_lr:
log_sweep_max_lr:
# Specify optimizer # Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see: # Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134 # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
@@ -907,26 +827,7 @@ lr_div_factor: # Learning rate div factor
# - paged_adamw_8bit # - paged_adamw_8bit
# - paged_lion_32bit # - paged_lion_32bit
# - paged_lion_8bit # - paged_lion_8bit
# - galore_adamw
# - galore_adamw_8bit
# - galore_adafactor
# - galore_adamw_layerwise
# - galore_adamw_8bit_layerwise
# - galore_adafactor_layerwise
optimizer: optimizer:
# Dictionary of arguments to pass to the optimizer
optim_args:
# For Galore Optimizers the following optim_args are available
# rank: # type: int
# update_proj_gap # type: int
# scale # type: float
# proj_type: # type: str, default = std
# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
optim_target_modules:
# - self_attn # for llama
# - mlp
# Specify weight decay # Specify weight decay
weight_decay: weight_decay:
# adamw hyperparams # adamw hyperparams
@@ -1072,9 +973,6 @@ Run
accelerate launch -m axolotl.cli.train your_config.yml accelerate launch -m axolotl.cli.train your_config.yml
``` ```
> [!TIP]
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
#### Preprocess dataset #### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning. You can optionally pre-tokenize dataset with the following before finetuning.
@@ -1123,10 +1021,6 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
``` ```
##### FSDP + QLoRA
Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.md) for more information.
##### Weights & Biases Logging ##### Weights & Biases Logging
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`. Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
@@ -1188,7 +1082,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
### Merge LORA to base ### Merge LORA to base
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`. The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
```bash ```bash
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model" python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
@@ -1249,7 +1143,7 @@ If you decode a prompt constructed by axolotl, you might see spaces between toke
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer. 1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string. 2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly. 3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical. 4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example. Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
@@ -1258,11 +1152,9 @@ Having misalignment between your prompts during training and inference can cause
See [this debugging guide](docs/debugging.md) for tips on debugging Axolotl, along with an example configuration for debugging with VSCode. See [this debugging guide](docs/debugging.md) for tips on debugging Axolotl, along with an example configuration for debugging with VSCode.
## Need help? 🙋 ## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we our community members can help you. Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
Need dedicated support? Please contact us at [✉wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org) for dedicated support options.
## Badge ❤🏷️ ## Badge ❤🏷️
@@ -1296,28 +1188,13 @@ PRs are **greatly welcome**!
Please run below to setup env Please run below to setup env
```bash ```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install pre-commit install
# test # test
pytest tests/ pytest tests/
# optional: run against all files
pre-commit run --all-files
``` ```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/openaccess-ai-collective/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>
## Sponsors 🤝❤ ## Sponsors 🤝❤
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian), OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),
@@ -1346,6 +1223,4 @@ consider sponsoring the project via [GitHub Sponsors](https://github.com/sponsor
#### 🥉 Bronze Sponsors - $500/mo #### 🥉 Bronze Sponsors - $500/mo
- [JarvisLabs.ai](https://jarvislabs.ai)
--- ---

View File

@@ -1,39 +0,0 @@
FROM winglian/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV BNB_CUDA_VERSION="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image
RUN pip install pytest
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -1,5 +0,0 @@
#!/bin/bash
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/

View File

@@ -1,75 +0,0 @@
"""
modal application to run axolotl gpu tests in Modal
"""
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import Image, Stub
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.0.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.10-cu118-2.0.1"),
"CUDA": os.environ.get("CUDA", "118"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
stub = Stub("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=45 * 60,
cpu=8.0,
memory=131072,
)
def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
@stub.local_entrypoint()
def main():
cicd_pytest.remote()

View File

@@ -16,7 +16,6 @@
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false "wall_clock_breakdown": false

View File

@@ -20,7 +20,6 @@
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false "wall_clock_breakdown": false

View File

@@ -24,7 +24,6 @@
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false "wall_clock_breakdown": false

View File

@@ -24,7 +24,6 @@
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false "wall_clock_breakdown": false

View File

@@ -2,6 +2,7 @@
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -3,10 +3,9 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118" ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.1.2" ARG PYTORCH_VERSION="2.0.1"
ENV PYTORCH_VERSION=$PYTORCH_VERSION ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -21,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -7,8 +7,8 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION a
ENV PATH="/root/miniconda3/bin:${PATH}" ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.10" ARG PYTHON_VERSION="3.9"
ARG PYTORCH_VERSION="2.1.2" ARG PYTORCH_VERSION="2.0.1"
ARG CUDA="118" ARG CUDA="118"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"

View File

@@ -11,7 +11,6 @@ EXPOSE 8888
EXPOSE 22 EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \ RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean jupyter lab clean
@@ -19,7 +18,6 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \ mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \ chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \ chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh chmod +x /root/cloud-entrypoint.sh

View File

@@ -3,10 +3,9 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118" ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.1.2" ARG PYTORCH_VERSION="2.0.1"
ARG GITHUB_REF="main" ARG GITHUB_REF="main"
ENV PYTORCH_VERSION=$PYTORCH_VERSION ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -25,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
fi fi
# So we can test the Docker image # So we can test the Docker image

View File

@@ -74,6 +74,7 @@ pip3 install -e '.[flash-attn,deepspeed]'
If you developing on a remote host, you can easily use VSCode to debug remotely. To do so, you will need to follow this [remote - SSH guide](https://code.visualstudio.com/docs/remote/ssh). You can also see the video below on [Docker and Remote SSH debugging](#video---attaching-to-docker-on-remote-host). If you developing on a remote host, you can easily use VSCode to debug remotely. To do so, you will need to follow this [remote - SSH guide](https://code.visualstudio.com/docs/remote/ssh). You can also see the video below on [Docker and Remote SSH debugging](#video---attaching-to-docker-on-remote-host).
```bash
### Configuration ### Configuration

View File

@@ -1,37 +0,0 @@
# FDSP + QLoRA
## Background
Using FSDP with QLoRA is essential for **fine-tuning larger (70b+ parameter) LLMs on consumer GPUs.** For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs[^1].
Below, we describe how to use this feature in Axolotl.
## Usage
To enable `QLoRA` with `FSDP`, you need to perform the following steps:
> ![Tip]
> See the [example config](#example-config) file in addition to reading these instructions.
1. Set `adapter: qlora` in your axolotl config file.
2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Example Config
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
## References
- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
- Related HuggingFace PRs Enabling FDSP + QLoRA:
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
- Transformers [PR#29587](https://github.com/huggingface/transformers/pull/29587)
- TRL [PR#1416](https://github.com/huggingface/trl/pull/1416)
- PEFT [PR#1550](https://github.com/huggingface/peft/pull/1550)
[^1]: This was enabled by [this work](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the Answer.AI team.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 239 KiB

View File

@@ -1,260 +0,0 @@
# Template-free prompt construction with the `input_output` format
<!-- TOC -->
- [Background](#background)
- [Masking Inputs](#masking-inputs)
- [You may not want prompt templates](#you-may-not-want-prompt-templates)
- [The `input_output` format](#the-input_output-format)
- [Usage](#usage)
- [1. Prepare Data](#1-prepare-data)
- [2. Use `type: input_output`](#2-use-type-input_output)
- [3. Check the prompts](#3-check-the-prompts)
<!-- /TOC -->
<a id="markdown-background" name="background"></a>
## Background
<a id="markdown-masking-inputs" name="masking-inputs"></a>
### Masking Inputs
One of the most popular features of
[axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) is
setting the following configuration value:
```yaml
train_on_inputs: false
```
If you declare a [dataset formats](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset)
such as `alpaca` or `chatml`, axolotl knows what is an input
(i.e. human) vs. an output (i.e. the assistant) and masks the input
labels so that your model can focus on predicting the outputs only.
<a id="markdown-you-may-not-want-prompt-templates" name="you-may-not-want-prompt-templates"></a>
### You may not want prompt templates
However, there are many situations where you don't want to use one of
these formats or templates (I usually don't!). This is because they can:
- Add unnecessary boilerplate to your prompts.
- Create artifacts like special delimiters `<|im_start|>` that can
quickly become footguns if you don't include them correctly at
inference time.
- Enforce a *chat* interface when you do not want one. Sometimes you
just want to fine-tune a model to a very specific task and do NOT
want multi-turn conversations, roles, etc.
- Limit you to only certain roles that the template allows.
<a id="markdown-the-inputoutput-format" name="the-inputoutput-format"></a>
### The `input_output` format
You can construct your prompts without a template by using the
`input_output` format, by setting `type: input_output` in your
configuration file like this:
**config.yml**
```yaml
train_on_inputs: false # Mask segments of your data
datasets:
- path: output.jsonl
type: input_output # use template free prompt construction
```
Unlike `type: completion`, which is also template-free,
`type: input_output` allows you to mask segments of your text. More
details on how this works are described below.
<a id="markdown-usage" name="usage"></a>
## Usage
This is how you can use the `input_output` format:
<a id="markdown-1-prepare-data" name="1-prepare-data"></a>
### 1. Prepare Data
To use the `input_output` format, collect your data in the following
format into a jsonl file (below is the first row from the file
`output`.jsonl` pretty printed):
```bash
$ head -n1 output.jsonl | python -m json.tool
{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
```
Set `label:false` when you want to mask a segment of text so that the
model isn't trained on it. Some things to keep in mind:
> [!IMPORTANT]
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
concatenates all the segments as-is.** The tokenizer doesn't add
anything additional. Notice how I added spaces, newlines, `<s>`
(BOS), and `</s>` (EOS) myself.
> 2. Make sure you check the materialized output to validate that the
prompt is getting assembled how you like.
<a id="markdown-2-use-type-inputoutput" name="2-use-type-inputoutput"></a>
### 2. Use `type: input_output`
Let's materialize data with our `output.jsonl` file by setting
`type: input_output` in our axolotl config:
```yaml
# training_config.yaml
base_model: mistralai/Mistral-7B-v0.1
data_seed: 49
seed: 49
datasets:
- path: output.jsonl
type: input_output
val_set_size: 0.1
sequence_len: 896
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 3
eval_batch_size: 2
num_epochs: 1
learning_rate: 0.0002
train_on_inputs: false
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
```
You can use the following command to materialize your data. The
`--debug` flag will print the tokens, along with the labels so you can
verify that the correct items are being ignored:
```bash
$ python -m axolotl.cli.preprocess training_config.yaml --debug
...
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
```
The format is `decoded_token`(`label`, `token_id`), for example,
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
token_id is `1`. When the label is `-100` then that token is ignored for
training.
<a id="markdown-3-check-the-prompts" name="3-check-the-prompts"></a>
### 3. Check the prompts
Here is another way to check the materialized output:
```python
from transformers import AutoTokenizer
from datasets import load_from_disk
import yaml
directory = !ls last_run_prepared/
with open('training_config.yaml', 'r') as f:
cfg = yaml.safe_load(f)
model_id = cfg['base_model']
tok = AutoTokenizer.from_pretrained(model_id)
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
```
```python
>>> row = ds[0]
>>> print(tok.decode(row['input_ids']))
<s> Hello
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ingored by comparing the labels
to each token:
```python
import pandas as pd
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
zip(row['input_ids'], row['labels'])])
```
| token | label | id |
|-------|-------|-------|
| 0 | \<s\> | 1 |
| 1 | Hello | 22557 |
| 2 | \\n | 13 |
| 3 | hi | 12014 |
| 4 | there | 736 |
| 5 | ! | 28808 |
| 6 | . | 28723 |
| 7 | | 28705 |
| 8 | good | -100 |
| 9 | bye | -100 |
| 10 | | -100 |
| 11 | fare | 19111 |
| 12 | well | 5458 |
| 13 | \</s\>| 2 |
If we look at the input data, the above table seems correct! (The jsonl
version is repeated below for reference):
```bash
$ head -n1 output.jsonl | python -m json.tool
{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
```

View File

@@ -1,18 +0,0 @@
# Mac M series support
Currently Axolotl on Mac is partially usable, many of the dependencies of Axolotl including Pytorch do not support MPS or have incomplete support.
Current support:
- [x] Support for all models
- [x] Full training of models
- [x] LoRA training
- [x] Sample packing
- [ ] FP16 and BF16 (awaiting AMP support for MPS in Pytorch)
- [ ] Tri-dao's flash-attn (until it is supported use spd_attention as an alternative)
- [ ] xformers
- [ ] bitsandbytes (meaning no 4/8 bits loading and bnb optimizers)
- [ ] qlora
- [ ] DeepSpeed
Untested:
- FSDP

View File

@@ -1,11 +1,4 @@
# Multipack (Sample Packing) # Multipack
## Visualization of Multipack with Flash Attention
Because Flash Attention simply drops the attention mask, we do not need to
construct a 4d attention mask. We only need to concatenate the sequences into
a single batch and let flash attention know where each new sequence begins.
4k context, bsz =4, 4k context, bsz =4,
each character represents 256 tokens each character represents 256 tokens
@@ -56,18 +49,3 @@ w packing ( note it's the same effective number of tokens per step, but a true b
E E E E F F F F F G G G H H H H E E E E F F F F F G G G H H H H
I I I J J J J K K K K K L L L X ]] I I I J J J J K K K K K L L L X ]]
``` ```
cu_seqlens:
[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]]
## Multipack without Flash Attention
Multipack can still be achieved without Flash attention, but with lower packing
efficiency as we are not able to join multiple batches into a single batch due to
context length limits without flash attention. We can use either Pytorch's Scaled
Dot Product Attention implementation or native Pytorch attention implementation
along with [4d attention masks](https://github.com/huggingface/transformers/pull/27539)
to pack sequences together and avoid cross attention.
<img src="./images/4d-mask.png" alt="axolotl" width="800">

View File

@@ -1,29 +0,0 @@
# Optimizers
Optimizers are an important component when training LLMs. Optimizers are responsible for updating the model's weights (parameters) based on the gradients computed during backpropagation.
The goal of an optimizer is to minimize the loss function.
### Adam/AdamW Optimizers
```yaml
adam_beta1: 0.9
adam_beta2: 0.999
adam_epsilon: 1e-8
weight_decay: 0.0
```
### GaLore Optimizer
https://huggingface.co/papers/2403.03507
```yaml
optimizer: galore_adamw | galore_adamw_8bit | galore_adafactor
optim_args:
rank: 128
update_proj_gap: 200
scale: 0.25
proj_type: std
optim_target_modules:
- mlp
- attn
```

View File

@@ -12,8 +12,8 @@ feedback. Various methods include, but not limited to:
### RLHF using Axolotl ### RLHF using Axolotl
>[!IMPORTANT] [!IMPORTANT]
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality. This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
@@ -34,21 +34,6 @@ datasets:
rl: ipo rl: ipo
``` ```
#### ORPO
Paper: https://arxiv.org/abs/2403.07691
```yaml
rl: orpo
orpo_alpha: 0.1
remove_unused_columns: false
chat_template: chatml
datasets:
- path: argilla/ultrafeedback-binarized-preferences-cleaned
type: orpo.chat_template
```
#### Using local dataset files #### Using local dataset files
```yaml ```yaml
datasets: datasets:

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-13b-hf base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-13b-hf base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-34b-hf base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-34b-hf base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-7b-hf base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false

View File

@@ -1,6 +1,7 @@
base_model: codellama/CodeLlama-7b-hf base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -43,7 +43,6 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install torch==\"2.1.2\"\n",
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n", "!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n", "!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\"" "!pip install deepspeed==\"0.13.1\""
@@ -177,24 +176,6 @@
"# Buy using the ! the comand will be executed as a bash command\n", "# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml" "!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Play with inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
} }
], ],
"metadata": { "metadata": {

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false
gptq: false gptq: false

View File

@@ -5,7 +5,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false load_in_8bit: false
# enable 4bit for QLoRA # enable 4bit for QLoRA
load_in_4bit: true load_in_4bit: true

View File

@@ -2,7 +2,7 @@ base_model: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
gptq: false gptq: false

View File

@@ -1,65 +0,0 @@
# use google/gemma-7b if you have access
base_model: mhenrichsen/gemma-7b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
val_set_size: 0.1
output_dir: ./out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 3
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false

View File

@@ -1,4 +1,5 @@
base_model: TheBloke/Llama-2-7B-GPTQ base_model: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true gptq: true
gptq_disable_exllama: true gptq_disable_exllama: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
@@ -59,7 +60,7 @@ s2_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false
@@ -56,7 +57,7 @@ s2_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,70 +0,0 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -1,7 +1,7 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -49,7 +49,7 @@ flash_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -2,6 +2,7 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false
@@ -60,7 +61,7 @@ flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
#default deepspeed, can use more aggresive if needed like zero2, zero3 #default deepspeed, can use more aggresive if needed like zero2, zero3

View File

@@ -1,6 +1,7 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
@@ -48,7 +49,7 @@ flash_attention: true
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,79 +0,0 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./lora-out
eval_sample_packing: false
adapter: lora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
sdp_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,74 +0,0 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out
model_config:
output_router_logits: true
adapter: qlora
lora_model_dir:
sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
special_tokens:

View File

@@ -16,12 +16,12 @@ output_dir: ./qlora-out
## You can optionally freeze the entire model and unfreeze a subset of parameters ## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters: unfrozen_parameters:
# - ^lm_head.weight$ # - lm_head.*
# - ^model.embed_tokens.weight$[:32000] # - model.embed_tokens.*
# - model.layers.2[0-9]+.block_sparse_moe.gate # - model.layers.2[0-9]+.block_sparse_moe.gate.*
# - model.layers.2[0-9]+.block_sparse_moe.experts # - model.layers.2[0-9]+.block_sparse_moe.experts.*
# - model.layers.3[0-9]+.block_sparse_moe.gate # - model.layers.3[0-9]+.block_sparse_moe.gate.*
# - model.layers.3[0-9]+.block_sparse_moe.experts # - model.layers.3[0-9]+.block_sparse_moe.experts.*
model_config: model_config:
output_router_logits: true output_router_logits: true
@@ -81,7 +81,7 @@ loss_watchdog_patience: 3
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed_configs/zero2.json deepspeed: deepspeed_configs/zero2.json

View File

@@ -1,6 +1,7 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
@@ -67,7 +68,7 @@ loss_watchdog_patience: 3
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -2,6 +2,7 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true trust_remote_code: true
load_in_8bit: true load_in_8bit: true
@@ -57,7 +58,7 @@ flash_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -2,6 +2,7 @@ base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true trust_remote_code: true
load_in_8bit: false load_in_8bit: false
@@ -57,7 +58,7 @@ flash_attention:
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
saves_per_epoch: 1 saves_per_epoch: 1
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,69 +0,0 @@
base_model: stabilityai/stablelm-2-1_6b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 100
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,66 +0,0 @@
base_model: stabilityai/stablelm-2-1_6b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,36 +0,0 @@
# StableLM 2
This repository contains examples for training and processing using StableLM-2. It also includes a section to help you estimate the GPU requirements for your specific use case.
## Estimating GPU Requirements
| type | deepspeed | batch size | context length | vRAM GPU (GBs) |
|---------------|-----------|------------|----------------|----------------|
| full finetune | N/A | 1 | 4096 | ~21.5GBs |
| full finetune | zero2 | 1 | 4096 | ~20GBs |
| lora | N/A | 1 | 4096 | ~16.6GBs |
The above are estimates and might differ slight depending on the setup for example whether you pack your sequence lengths or not (the above assumes you do to length 4096).
This blog post from Hamel Husain was a great resource for estimating these numbers: https://hamel.dev/notes/llm/03_estimating_vram.html
## Training
We have example scripts here for both full finetuning and lora using the popular alpaca dataset:
```shell
# preprocess the dataset
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/stablelm-2/1.6b/lora.yml
```
Single GPU Training:
```shell
python -m axolotl.cli.train examples/stablelm-2/fft.yml --deepspeed deepspeed_configs/zero2.json
# OR
python -m axolotl.cli.train examples/stablelm-2/1.6b/lora.yml
```
Multinode GPU Training with `accelerate`:
```shell
# make sure you've configured accelerate properly
accelerate launch -m axolotl.cli.train examples/stablelm-2/1.6b/fft.yml --deepspeed deepspeed_configs/zero2.json
```

View File

@@ -1,69 +0,0 @@
base_model: bigcode/starcoder2-3b
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.2
output_dir: ./qlora
adapter: qlora
lora_model_dir:
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 20
evals_per_epoch: 4
eval_steps:
eval_table_size:
saves_per_epoch: 4
save_steps:
save_total_limit: 2
debug:
deepspeed:
weight_decay:
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,64 +0,0 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
warmup_steps: 10
evals_per_epoch: 0
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,6 +1,7 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true load_in_8bit: true
load_in_4bit: false load_in_4bit: false
@@ -15,7 +16,6 @@ output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora

View File

@@ -2,6 +2,7 @@ base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
@@ -11,7 +12,6 @@ max_steps: 200
pretraining_dataset: pretraining_dataset:
path: c4 path: c4
name: en name: en
type: pretrain
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./model-out output_dir: ./model-out

View File

@@ -1,6 +1,7 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -1,7 +1,8 @@
base_model: 01-ai/Yi-34B-Chat base_model: 01-ai/Yi-34B-Chat
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: false
is_llama_derived_model: true
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
strict: false strict: false
@@ -28,7 +29,7 @@ num_epochs: 1
val_set_size: 0.1 val_set_size: 0.1
evals_per_epoch: 5 evals_per_epoch: 5
eval_table_size: eval_table_size:
eval_max_new_tokens: 128 eval_table_max_new_tokens: 128
eval_sample_packing: false eval_sample_packing: false
eval_batch_size: 1 eval_batch_size: 1

View File

@@ -1,4 +1,3 @@
pre-commit pre-commit
black black
mypy mypy
types-requests

View File

@@ -1,18 +1,16 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.9.0 peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa transformers==4.37.0
tokenizers==0.15.0 tokenizers==0.15.0
bitsandbytes>=0.43.0 bitsandbytes>=0.41.1
accelerate==0.26.1 accelerate==0.26.1
deepspeed==0.13.1 deepspeed>=0.13.1
pydantic==2.6.3
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
requests
datasets>=2.15.0 datasets>=2.15.0
flash-attn==2.5.5 flash-attn==2.3.3
sentencepiece sentencepiece
wandb wandb
einops einops
@@ -22,13 +20,14 @@ hf_transfer
colorama colorama
numba numba
numpy>=1.24.4 numpy>=1.24.4
mlflow
# qlora things # qlora things
evaluate==0.4.1 evaluate==0.4.0
scipy scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.36 fschat==0.2.34
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
@@ -39,8 +38,4 @@ s3fs
gcsfs gcsfs
# adlfs # adlfs
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90 trl>=0.7.9
fastcore>=1.5.29
lpmm @ git+https://github.com/thu-ml/low-bit-optimizers.git@main
yacs

View File

@@ -1,17 +0,0 @@
dP dP dP
88 88 88
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
88' `88 `8bd8' 88' `88 88 88' `88 88 88
88. .88 .d88b. 88. .88 88 88. .88 88 88
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
```
cd /workspace
rm -rf /workspace/axolotl
git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
cd axolotl
pip install --no-deps -e .
```

View File

@@ -1,7 +1,5 @@
"""setup.py for axolotl""" """setup.py for axolotl"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from setuptools import find_packages, setup from setuptools import find_packages, setup
@@ -18,7 +16,6 @@ def parse_requirements():
or "flash-attention" in line or "flash-attention" in line
or "deepspeed" in line or "deepspeed" in line
or "mamba-ssm" in line or "mamba-ssm" in line
or "lion-pytorch" in line
) )
if line.startswith("--extra-index-url"): if line.startswith("--extra-index-url"):
# Handle custom index URLs # Handle custom index URLs
@@ -29,25 +26,11 @@ def parse_requirements():
_install_requires.append(line) _install_requires.append(line)
try: try:
if "Darwin" in platform.system(): torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
if torch_version.startswith("2.1."):
_install_requires.pop(_install_requires.index("xformers==0.0.22")) _install_requires.pop(_install_requires.index("xformers==0.0.22"))
else: _install_requires.append("xformers>=0.0.23")
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 1):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
except PackageNotFoundError: except PackageNotFoundError:
pass pass
@@ -68,13 +51,13 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn==2.5.5", "flash-attn==2.5.0",
], ],
"fused-dense-lib": [ "fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib", "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed==0.13.1", "deepspeed>=0.13.1",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [
@@ -83,14 +66,5 @@ setup(
"auto-gptq": [ "auto-gptq": [
"auto-gptq==0.5.1", "auto-gptq==0.5.1",
], ],
"mlflow": [
"mlflow",
],
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
"galore": [
"galore_torch",
],
}, },
) )

View File

@@ -1,19 +1,16 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib import importlib
import json
import logging import logging
import math import math
import os import os
import random import random
import sys import sys
import tempfile
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import requests import gradio as gr
import torch import torch
import yaml import yaml
@@ -23,7 +20,6 @@ from art import text2art
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
@@ -63,52 +59,6 @@ def print_axolotl_text_art(suffix=None):
print(ascii_art) print(ascii_art)
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL
filename = os.path.basename(urlparse(config).path)
temp_dir = tempfile.mkdtemp()
try:
response = requests.get(config, timeout=30)
response.raise_for_status() # Check for HTTP errors
content = response.content
try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
)
except json.JSONDecodeError:
# If it's not valid JSON, verify it's valid YAML
try:
yaml.safe_load(content)
except yaml.YAMLError as err:
raise ValueError(
f"Failed to parse the content at {config} as YAML: {err}"
) from err
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
output_path = Path(temp_dir) / filename
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
)
return output_path
except requests.RequestException as err:
# This catches all requests-related exceptions including HTTPError
raise RuntimeError(f"Failed to download {config}: {err}") from err
except Exception as err:
# Catch-all for any other exceptions
raise err
def get_multi_line_input() -> Optional[str]: def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to submit): ") print("Give me an instruction (Ctrl + D to submit): ")
instruction = "" instruction = ""
@@ -214,8 +164,6 @@ def do_inference_gradio(
cfg: DictDefault, cfg: DictDefault,
cli_args: TrainerCliArgs, cli_args: TrainerCliArgs,
): ):
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"} default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
@@ -322,14 +270,14 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
return not any(el in list2 for el in list1) return not any(el in list2 for el in list1)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): def load_cfg(config: Path = Path("examples/"), **kwargs):
config = check_remote_config(config)
if Path(config).is_dir(): if Path(config).is_dir():
config = choose_config(Path(config)) config = choose_config(config)
# load the config from the yaml file # load the config from the yaml file
with open(config, encoding="utf-8") as file: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file)) cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml, # if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value # then overwrite the value
cfg_keys = cfg.keys() cfg_keys = cfg.keys()
@@ -342,22 +290,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
else: else:
cfg[k] = kwargs[k] cfg[k] = kwargs[k]
cfg.axolotl_config_path = config validate_config(cfg)
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
cfg = validate_config(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": os.environ.get("WORLD_SIZE", 1),
"compute_capability": gpu_version,
},
)
prepare_optim_env(cfg) prepare_optim_env(cfg)

View File

@@ -3,7 +3,6 @@ CLI to run training on a model
""" """
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union
import fire import fire
import transformers import transformers
@@ -24,7 +23,7 @@ from axolotl.prompt_strategies.sharegpt import register_chatml_template
LOG = logging.getLogger("axolotl.cli.preprocess") LOG = logging.getLogger("axolotl.cli.preprocess")
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
@@ -54,7 +53,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg) LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
if parsed_cfg.rl and parsed_cfg.rl != "orpo": if parsed_cfg.rl:
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else: else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -3,7 +3,6 @@ CLI to shard a trained model into 10GiB chunks
""" """
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union
import fire import fire
import transformers import transformers
@@ -26,7 +25,7 @@ def shard(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)

View File

@@ -3,12 +3,11 @@ CLI to run training on a model
""" """
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import Tuple
import fire import fire
from transformers.hf_argparser import HfArgumentParser import transformers
from transformers.modeling_utils import PreTrainedModel from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.cli import ( from axolotl.cli import (
check_accelerate_default_config, check_accelerate_default_config,
@@ -25,10 +24,10 @@ from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train") LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
@@ -47,7 +46,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else: else:
register_chatml_template() register_chatml_template()
if cfg.rl and cfg.rl != "orpo": if cfg.rl:
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -6,7 +6,6 @@ import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer

View File

@@ -1,55 +0,0 @@
"""module for building the auto wrap policy for FSDP"""
import functools
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
from torch.distributed.fsdp.wrap import (
_or_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
]
def get_wrapping_policy_factory(model_type):
if model_type == "llama":
layer_to_wrap = LlamaDecoderLayer
elif model_type == "mistral":
layer_to_wrap = MistralDecoderLayer
elif model_type == "mixtral":
layer_to_wrap = MixtralDecoderLayer
def get_wrapping_policy():
"""This checks for lora layers (has weight and requires_grad)"""
def lambda_policy_fn(module):
return (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)
lambda_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
)
transformer_layer_name = layer_to_wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
transformer_layer_name,
),
)
policies = [lambda_policy, transformer_wrap_policy]
return functools.partial(_or_policy, policies=policies)
return get_wrapping_policy

File diff suppressed because it is too large Load Diff

View File

@@ -1,40 +0,0 @@
"""module for trainer helpers like OptimizerNames"""
from transformers.utils import ExplicitEnum
class OptimizerNames(ExplicitEnum):
"""
Stores the acceptable string identifiers for optimizers.
"""
ADAMW_HF = "adamw_hf"
ADAMW_TORCH = "adamw_torch"
ADAMW_TORCH_FUSED = "adamw_torch_fused"
ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision"
SGD = "sgd"
ADAGRAD = "adagrad"
ADAMW_BNB = "adamw_bnb_8bit"
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
LION_8BIT = "lion_8bit"
LION = "lion_32bit"
PAGED_ADAMW = "paged_adamw_32bit"
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
PAGED_LION = "paged_lion_32bit"
PAGED_LION_8BIT = "paged_lion_8bit"
RMSPROP = "rmsprop"
RMSPROP_BNB = "rmsprop_bnb"
RMSPROP_8BIT = "rmsprop_bnb_8bit"
RMSPROP_32BIT = "rmsprop_bnb_32bit"
GALORE_ADAMW = "galore_adamw"
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
GALORE_ADAFACTOR = "galore_adafactor"
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
LPMM_ADAMW_4BIT = "lmpp_adamw_4bit"
LPMM_ADAMW_4BIT_FUSED = "lmpp_adamw_4bit_fused"

View File

@@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
def __init__( # pylint: disable=super-init-not-called def __init__( # pylint: disable=super-init-not-called
self, self,
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset, dataset: IterableDataset,
process_count: Optional[int] = None, process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False, keep_in_memory: Optional[bool] = False,
**kwargs, **kwargs,

View File

@@ -30,7 +30,6 @@ class ColorfulFormatter(Formatter):
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"version": 1, "version": 1,
"disable_existing_loggers": False,
"formatters": { "formatters": {
"simple": { "simple": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s", "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",

View File

@@ -1,133 +0,0 @@
"""Module for LoRA+"""
# MIT License
#
# Copyright (c) 2024 nikhil-ghosh-berkeley
# https://github.com/nikhil-ghosh-berkeley/loraplus
import logging
from functools import reduce
from peft.tuners import lora
from torch import nn
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
LOG = logging.getLogger("axolotl.loraplus")
def get_module(name, opt_model):
"""
Retrieve a module from a model using its parameter name.
Args:
name (str): Full name of the parameter, typically including module path.
opt_model (torch.nn.Module): The model from which to retrieve the module.
Returns:
Module corresponding to the given name.
"""
parent_idx = 2 if "lora" in name else 1
module_names = name.split(sep=".")[:-parent_idx]
module = reduce(getattr, module_names, opt_model)
return module
def create_loraplus_optimizer(
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding=None,
):
"""
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
Args:
opt_model (torch.nn.Module): The model for which the optimizer is being created.
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
Returns:
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
"""
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
if loraplus_lr_embedding is None:
loraplus_lr_embedding = 1e-6
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
module = get_module(name, opt_model)
if isinstance(module, lora.Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param
assigned_param_groups = ""
for group, group_params in param_groups.items():
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
LOG.info(assigned_param_groups)
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
return optimizer

View File

@@ -1,46 +0,0 @@
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
# pylint: disable=protected-access
import torch
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
from torch.utils.data._utils.worker import _worker_loop
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if isinstance(possibly_batched_index[0], list):
data = [None for i in possibly_batched_index]
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
if self.auto_collation:
if (
hasattr(self.dataset, "__getitems__")
and self.dataset.__getitems__
):
data[i] = self.dataset.__getitems__(possibly_batched_index_)
else:
data[i] = [self.dataset[idx] for idx in possibly_batched_index_]
else:
data[i] = self.dataset[possibly_batched_index_]
else:
if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index)
else:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
def patch_fetchers():
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
def patched_worker_loop(*args, **kwargs):
patch_fetchers()
return _worker_loop(*args, **kwargs)
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
patch_fetchers()

View File

@@ -0,0 +1,12 @@
"""
Patches to support multipack for falcon
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_falcon_attn_with_multipack_flash_attn():
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -106,7 +106,7 @@ def get_turns( # pylint: disable=too-many-return-statements
if self.system_message: if self.system_message:
contains_sys_msg = True contains_sys_msg = True
if self.messages: if self.messages:
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline # There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
first_role, first_msg = self.messages[0] first_role, first_msg = self.messages[0]
if first_role == self.roles[0]: if first_role == self.roles[0]:
system_prompt = self.system_template.format( system_prompt = self.system_template.format(

View File

@@ -44,18 +44,6 @@ except ImportError:
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def is_xformers_swiglu_available() -> bool:
from xformers.ops.common import get_xformers_operator
try:
get_xformers_operator("swiglu_packedw")()
return True
except RuntimeError as exc:
if "No such operator xformers::swiglu_packedw " in str(exc):
return False
return True
def replace_llama_mlp_with_swiglu(model): def replace_llama_mlp_with_swiglu(model):
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, LlamaMLP): if isinstance(module, LlamaMLP):
@@ -287,9 +275,7 @@ def flashattn_forward_with_s2attn(
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb( query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
@@ -439,9 +425,7 @@ def flashattn_forward(
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb( query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
@@ -704,9 +688,6 @@ def llama_model_forward(
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = ( output_attentions = (
output_attentions output_attentions

View File

@@ -0,0 +1,142 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# sdp-attn start
#
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# sdp-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

View File

@@ -5,11 +5,38 @@ from typing import Optional
import torch import torch
from axolotl.monkeypatch.utils import mask_2d_to_4d
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len) """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
inverted_mask = 1.0 - masked_zero_one_mask inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill( return inverted_mask.masked_fill(

View File

@@ -1,26 +0,0 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
from axolotl.monkeypatch.utils import (
patched_prepare_4d_causal_attention_mask,
patched_prepare_4d_causal_attention_mask_for_sdpa,
)
def hijack_llama_prepare_4d_mask():
import transformers.modeling_attn_mask_utils
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)

View File

@@ -2,6 +2,9 @@
Patches to support multipack for mixtral Patches to support multipack for mixtral
""" """
import torch import torch
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def patch_mixtral_moe_forward_zero3() -> None: def patch_mixtral_moe_forward_zero3() -> None:
@@ -48,3 +51,11 @@ def patch_mixtral_moe_forward_zero3() -> None:
MixtralBLockSparseTop2MLP.forward = mlp_forward MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward MixtralSparseMoeBlock.forward = moe_forward
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if for_zero3:
patch_mixtral_moe_forward_zero3()

View File

@@ -1,61 +0,0 @@
"""multipack patching for v2 of sample packing"""
import importlib
import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"falcon",
"phi",
"gemma",
"gemmoe",
"starcoder2",
]
def patch_for_multipack(model_type, model_name=None):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_gemmoe to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_gemmoe", ".modeling_gemmoe"
)
modeling_gemmoe = importlib.import_module(module_name)
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -0,0 +1,12 @@
"""
Patches to support multipack for phi2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_phi_attn_with_multipack_flash_attn():
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -0,0 +1,12 @@
"""
Patches to support multipack for qwen2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_qwen2_attn_with_multipack_flash_attn():
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -4,16 +4,14 @@ import json
import logging import logging
import os.path import os.path
import shutil import shutil
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Dict, List, Sequence, Union from typing import Dict, List, Sequence
import bitsandbytes as bnb import bitsandbytes as bnb
import peft import peft
import safetensors.torch as st import safetensors.torch as st
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from transformers import ( from transformers import (
@@ -25,51 +23,23 @@ from transformers import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.relora") LOG = logging.getLogger("axolotl.relora")
@torch.no_grad() def reset_optimizer(optimizer: torch.optim.Optimizer):
def magnitude_pruning_(tensor, prune_ratio): for group in optimizer.param_groups:
tensor_magnitude = torch.abs(tensor) for param in group["params"]:
threshold = torch.quantile( param_state = optimizer.state[param]
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio for key in param_state:
).to(dtype=tensor.dtype) if "qmap" in key:
continue
mask = tensor_magnitude > threshold if key == "step" and isinstance(param_state[key], int):
tensor.mul_(mask.to(dtype=tensor.dtype)) param_state[key] = 0
else:
param_state[key] = torch.zeros_like(param_state[key])
def reset_optimizer(
optimizer: torch.optim.Optimizer,
*,
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
prune_ratio: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
n_zeros = 0
n_total = 0
optimizer_state = optimizer.state
if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state
for param in reset_params:
param_state = optimizer_state[param]
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
continue
for key in optimizer_state_keys:
pruning_fn(
param_state[key]
) # pruning fn has to be inplace to keep the same keys in the dict
n_total += param_state[key].numel()
n_zeros += torch.sum(param_state[key] == 0).item()
_zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
LOG.info(f"absolute n of optimizer states zeroed: {n_zeros}")
class ReLoRACallback(TrainerCallback): class ReLoRACallback(TrainerCallback):
@@ -127,25 +97,6 @@ class ReLoRACallback(TrainerCallback):
"relora", "relora",
) )
if "adam" in args.optim.lower():
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
else:
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
lora_params = [
n
for n, p in model.named_parameters()
if p.requires_grad and "lora_" in n
]
model.save_pretrained(
os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
"adapter",
),
safe_serialization=True,
)
with torch.no_grad(): with torch.no_grad():
merge_and_save( merge_and_save(
model, model,
@@ -156,12 +107,7 @@ class ReLoRACallback(TrainerCallback):
actually_save=is_main_process(), actually_save=is_main_process(),
cpu_offload=self.cpu_offload, cpu_offload=self.cpu_offload,
) )
reset_optimizer( reset_optimizer(optimizer)
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
prune_ratio=args.relora_prune_ratio,
)
if self.quantized: if self.quantized:
self.last_full_model = checkpoint_folder self.last_full_model = checkpoint_folder
@@ -251,13 +197,11 @@ class ReLoRAScheduler(LRScheduler):
inner_schedule: LRScheduler, inner_schedule: LRScheduler,
relora_steps: int, relora_steps: int,
warmup_steps: int, warmup_steps: int,
anneal_steps: int = 1,
min_lr_scale: float = 0.001, min_lr_scale: float = 0.001,
) -> None: ) -> None:
self.inner_schedule = inner_schedule self.inner_schedule = inner_schedule
self.relora_steps = relora_steps self.relora_steps = relora_steps
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose) super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
@@ -266,20 +210,10 @@ class ReLoRAScheduler(LRScheduler):
original = self.inner_schedule.get_lr() original = self.inner_schedule.get_lr()
step = self.last_epoch step = self.last_epoch
if step < self.relora_steps:
if step < self.relora_steps - self.warmup_steps:
scale = 1 scale = 1
else: else:
per_relora_progress = step % self.relora_steps cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
if per_relora_progress < self.warmup_steps:
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.relora_steps - per_relora_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence): if isinstance(original, Sequence):
@@ -304,11 +238,7 @@ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor: def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)): if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
adapter: Union[List[str], str] = layer.active_adapter adapter = layer.active_adapter
if isinstance(adapter, list):
if len(adapter) > 1:
raise ValueError("unhandled relora for multiple adapters")
adapter = adapter[0]
return ( return (
peft.utils.transpose( peft.utils.transpose(
layer.lora_B[adapter].weight.detach().to(device) layer.lora_B[adapter].weight.detach().to(device)
@@ -318,7 +248,7 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor
* layer.scaling[adapter] * layer.scaling[adapter]
) )
raise ValueError("unhandled lora layer type") return layer.get_delta_weight().to(device)
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]: def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
@@ -343,9 +273,9 @@ def update_weights(
): ):
if reinit: if reinit:
for adapter_name in target.lora_A: for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name, True) target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A: for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name, True) target.reset_lora_parameters(adapter_name)
if isinstance(target, peft.tuners.lora.Linear4bit): if isinstance(target, peft.tuners.lora.Linear4bit):
# This could be faster, but the quantization of Linear4bit weights occurs # This could be faster, but the quantization of Linear4bit weights occurs
@@ -356,9 +286,7 @@ def update_weights(
target.weight.data = new_weight.cpu() target.weight.data = new_weight.cpu()
target.to(device) target.to(device)
elif isinstance(target, peft.tuners.lora.Linear8bitLt): elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight.data = ( target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
)
else: else:
target.weight.data = new_weight.to(device) target.weight.data = new_weight.to(device)
@@ -376,17 +304,14 @@ def merge_and_save(
if not quantized: if not quantized:
for module_name, target in modules.items(): for module_name, target in modules.items():
active_adapter = target.active_adapter update = target.get_delta_weight(target.active_adapter).detach()
if isinstance(active_adapter, list):
active_adapter = active_adapter[0]
update = target.get_delta_weight(active_adapter).detach()
target.weight.data += update target.weight.data += update
if reinit: if reinit:
for adapter_name in target.lora_A: for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name, True) target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A: for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name, True) target.reset_lora_parameters(adapter_name)
return return
os.makedirs(model_dst, exist_ok=True) os.makedirs(model_dst, exist_ok=True)
@@ -438,7 +363,6 @@ def merge_and_save(
LOG.info(f"saving tensors to {shard_fn}") LOG.info(f"saving tensors to {shard_fn}")
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"}) st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
barrier()
del in_tensors del in_tensors
del out_tensors del out_tensors
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@@ -1,15 +1,8 @@
""" """
Shared utils for the monkeypatches Shared utils for the monkeypatches
""" """
from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script @torch.jit.script
@@ -96,6 +89,7 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
@torch.jit.script
def get_cu_seqlens_from_pos_ids(position_ids): def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids""" """generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1: if len(position_ids.shape) == 1:
@@ -141,18 +135,7 @@ def get_cu_seqlens_from_pos_ids(position_ids):
results.append(cu_seqlens) results.append(cu_seqlens)
max_seq_lens.append(max_seq_len) max_seq_lens.append(max_seq_len)
# Find the maximum value across all tensors return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
max_value = max(t.max() for t in results)
# Find the length of the longest tensor
max_length = max(t.size(0) for t in results)
# Pad each tensor to the same length and collect them in a list
padded_results = [
F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results
]
return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def set_module_name(model, name, value): def set_module_name(model, name, value):
@@ -166,62 +149,3 @@ def set_module_name(model, name, value):
child_name = name child_name = name
setattr(parent, child_name, value) setattr(parent, child_name, value)
def mask_2d_to_4d(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1, device=mask.device).to(dtype),
torch.tensor(0, device=mask.device).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
return masked_zero_one_mask
def patched_prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def patched_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)

View File

@@ -1,20 +0,0 @@
"""
module for base dataset transform strategies
"""
import importlib
import logging
LOG = logging.getLogger("axolotl")
def load(strategy, cfg, module_base=None, **kwargs):
try:
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", module_base)
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None

View File

@@ -1,78 +0,0 @@
"""
HF Chat Templates prompt strategy
"""
from typing import Any, Dict, Optional
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
class ChatTemplatePrompter(Prompter):
"""prompter for HF chat templates"""
def __init__(self, tokenizer, chat_template=None, max_length=2048):
self.tokenizer = tokenizer
self.chat_template = chat_template
self.max_length = max_length
def build_prompt(self, conversation, add_generation_prompt=False):
return self.tokenizer.apply_chat_template(
conversation,
truncation=True,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap roles - allow for assistant turn
role_map = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
}
turns = [
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
]
return turns
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = (
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy

View File

@@ -1,8 +1,21 @@
""" """
module for DPO style dataset transform strategies module for DPO style dataset transform strategies
""" """
from functools import partial
from ..base import load as load_base import importlib
import logging
load = partial(load_base, module="axolotl.prompt_strategies.dpo") LOG = logging.getLogger("axolotl")
def load(strategy, cfg):
try:
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
func = getattr(mod, load_fn)
load_kwargs = {}
return func(cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None

View File

@@ -5,7 +5,6 @@ DPO strategies for chatml
def argilla( def argilla(
cfg, cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument ): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample): def transform_fn(sample):
if "system" in sample and sample["system"]: if "system" in sample and sample["system"]:
@@ -24,28 +23,8 @@ def argilla(
return transform_fn return transform_fn
def argilla_chat(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
for argilla/dpo-mix-7k conversations
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample
return transform_fn
def icr( def icr(
cfg, cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument ): # pylint: disable=possibly-unused-variable,unused-argument
""" """
chatml transforms for datasets with system, input, chosen, rejected chatml transforms for datasets with system, input, chosen, rejected
@@ -69,7 +48,7 @@ def icr(
return transform_fn return transform_fn
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
""" """
For Intel Orca DPO Pairs For Intel Orca DPO Pairs
""" """
@@ -91,9 +70,7 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
return transform_fn return transform_fn
def prompt_pairs( def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample): def transform_fn(sample):
if "system" in sample and sample["system"]: if "system" in sample and sample["system"]:
sample["prompt"] = ( sample["prompt"] = (
@@ -111,7 +88,7 @@ def prompt_pairs(
return transform_fn return transform_fn
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
""" """
for ultrafeedback binarized conversations for ultrafeedback binarized conversations
""" """

View File

@@ -1,41 +0,0 @@
"""
User-defined DPO strategies
"""
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
ds_cfg = cfg["datasets"][dataset_idx]["type"]
if not isinstance(ds_cfg, dict):
raise ValueError(
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
)
field_prompt = ds_cfg.get("field_prompt", "prompt")
field_system = ds_cfg.get("field_system", "system")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
prompt_format = ds_cfg.get("prompt_format")
if not prompt_format:
prompt_format = "{" + field_prompt + "}"
chosen_format = ds_cfg.get("chosen_format")
if not chosen_format:
chosen_format = "{" + field_chosen + "}"
rejected_format = ds_cfg.get("rejected_format")
if not rejected_format:
rejected_format = "{" + field_rejected + "}"
def transform_fn(sample):
if (
"{" + field_system + "}" in prompt_format
and field_system in sample
and sample[field_system]
):
sample["prompt"] = prompt_format.format(
system=sample[field_system], prompt=sample[field_prompt]
)
else:
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
return sample
return transform_fn

Some files were not shown because too many files have changed in this diff Show More