Compare commits

..

9 Commits

Author SHA1 Message Date
Wing Lian
17330c05a3 shampoo checkpoint save workaround 2024-09-23 15:21:00 -04:00
Wing Lian
992ea517b7 setup precision config for bf16 2024-09-18 12:01:38 -07:00
Wing Lian
beaee36191 ddp shampoo 2024-09-18 10:50:46 -07:00
Wing Lian
69a29382e1 fix casting of optim args 2024-09-18 10:48:15 -07:00
Wing Lian
84dad0bd12 ensure epsilon is cast to float 2024-09-18 10:42:26 -07:00
Wing Lian
05f61a0ea5 remove accidental duplidcated line 2024-09-18 10:41:03 -07:00
Wing Lian
5334d0fc01 fixes 2024-09-18 10:38:59 -07:00
Wing Lian
52e6249d2e additional grafting config types and basic example doc 2024-09-18 08:16:11 -07:00
Wing Lian
eb3eab3450 wip shampoo optim support 2024-09-18 08:10:52 -07:00
161 changed files with 3814 additions and 7932 deletions

View File

@@ -24,41 +24,27 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.1 pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.10"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124" - cuda: "124"
cuda_version: 12.4.1 cuda_version: 12.4.1
cudnn_version: "" cudnn_version: ""
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Docker metadata - name: Docker metadata
id: metadata id: metadata
uses: docker/metadata-action@v5 uses: docker/metadata-action@v3
with: with:
images: | images: winglian/axolotl-base
winglian/axolotl-base
axolotlai/axolotl-base
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v2
- name: Build - name: Build
uses: docker/build-push-action@v4 uses: docker/build-push-action@v4
with: with:

View File

@@ -17,7 +17,7 @@ jobs:
- name: Set up Quarto - name: Set up Quarto
uses: quarto-dev/quarto-actions/setup@v2 uses: quarto-dev/quarto-actions/setup@v2
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v3
with: with:
python-version: '3.10' python-version: '3.10'
- name: install dependencies - name: install dependencies

View File

@@ -15,9 +15,9 @@ jobs:
name: pre-commit name: pre-commit
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: actions/setup-python@v5 - uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.10"
cache: 'pip' # caching pip dependencies cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1 - uses: pre-commit/action@v3.0.0

View File

@@ -4,8 +4,6 @@ on:
push: push:
branches: branches:
- "main" - "main"
tags:
- "v*"
workflow_dispatch: workflow_dispatch:
jobs: jobs:
@@ -29,12 +27,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
@@ -44,12 +37,7 @@ jobs:
id: metadata id: metadata
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5
with: with:
images: | images: winglian/axolotl
winglian/axolotl
axolotlai/axolotl
tags: |
type=ref,event=branch
type=semver,pattern={{version}}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub - name: Login to Docker Hub
@@ -63,7 +51,7 @@ jobs:
with: with:
context: . context: .
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }} PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }} AXOLOTL_ARGS=${{ matrix.axolotl_args }}
@@ -96,12 +84,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
@@ -111,25 +94,20 @@ jobs:
id: metadata id: metadata
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5
with: with:
images: | images: winglian/axolotl-cloud
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=ref,event=branch
type=semver,pattern={{version}}
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v2
- name: Build - name: Build
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud file: ./docker/Dockerfile-cloud
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
@@ -158,25 +136,20 @@ jobs:
id: metadata id: metadata
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5
with: with:
images: | images: winglian/axolotl-cloud-term
winglian/axolotl-cloud-term
axolotlai/axolotl-cloud-term
tags: |
type=ref,event=branch
type=semver,pattern={{version}}
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v2
- name: Build - name: Build
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
build-args: | build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-no-tmux file: ./docker/Dockerfile-cloud-no-tmux
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}

View File

@@ -8,11 +8,6 @@ on:
schedule: schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday - cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
# Cancel jobs on the same ref if a new one is triggered
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs: jobs:
test-axolotl-multigpu: test-axolotl-multigpu:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }} if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -26,17 +21,10 @@ jobs:
pytorch: 2.3.1 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
num_gpus: 2 num_gpus: 2
- cuda: 124 - cuda: 121
cuda_version: 12.4.1 cuda_version: 12.1.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: axolotl_extras:
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"

View File

@@ -26,12 +26,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
@@ -41,9 +36,7 @@ jobs:
id: metadata id: metadata
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5
with: with:
images: | images: winglian/axolotl
winglian/axolotl
axolotlai/axolotl
tags: | tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }} type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
@@ -90,12 +83,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
@@ -105,9 +93,7 @@ jobs:
id: metadata id: metadata
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5
with: with:
images: | images: winglian/axolotl-cloud
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: | tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }} type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
- name: Login to Docker Hub - name: Login to Docker Hub
@@ -116,7 +102,7 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v2
- name: Build - name: Build
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:

View File

@@ -3,31 +3,12 @@ name: publish pypi
on: on:
push: push:
tags: tags:
- 'v*' - '*'
workflow_dispatch:
jobs: jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
pypi-publish: pypi-publish:
name: Upload release to PyPI name: Upload release to PyPI
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [setup_release]
environment: environment:
name: pypi name: pypi
url: https://pypi.org/p/axolotl url: https://pypi.org/p/axolotl
@@ -35,10 +16,10 @@ jobs:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps: steps:
- name: Check out repository code - name: Check out repository code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.10"
@@ -46,7 +27,7 @@ jobs:
run: | run: |
pip3 install wheel packaging pip3 install wheel packaging
pip3 install -e . pip3 install -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-tests.txt
- name: Extract tag name - name: Extract tag name
id: tag id: tag

View File

@@ -9,12 +9,12 @@ jobs:
name: pre-commit name: pre-commit
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: actions/setup-python@v5 - uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.10"
cache: 'pip' # caching pip dependencies cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1 - uses: pre-commit/action@v3.0.0
env: env:
SKIP: no-commit-to-branch SKIP: no-commit-to-branch
@@ -25,15 +25,15 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.10", "3.11"] python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"] pytorch_version: ["2.3.1", "2.4.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
- name: Check out repository code - name: Check out repository code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python_version }} python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies cache: 'pip' # caching pip dependencies
@@ -47,15 +47,13 @@ jobs:
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging pip3 install --upgrade packaging
pip3 install -U -e . pip3 install -U -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-tests.txt
- name: Run tests - name: Run tests
run: | run: |
@@ -83,17 +81,17 @@ jobs:
num_gpus: 1 num_gpus: 1
axolotl_extras: mamba-ssm axolotl_extras: mamba-ssm
nightly_build: "true" nightly_build: "true"
- cuda: 124 - cuda: 121
cuda_version: 12.4.1 cuda_version: 12.1.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.3.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras: mamba-ssm
nightly_build: "true" nightly_build: "true"
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.4.0
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
nightly_build: "true" nightly_build: "true"

View File

@@ -15,22 +15,17 @@ on:
- '.github/workflows/*.yml' - '.github/workflows/*.yml'
workflow_dispatch: workflow_dispatch:
# Cancel jobs on the same ref if a new one is triggered
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs: jobs:
pre-commit: pre-commit:
name: pre-commit name: pre-commit
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- uses: actions/setup-python@v5 - uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: "3.10"
cache: 'pip' # caching pip dependencies cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1 - uses: pre-commit/action@v3.0.0
env: env:
SKIP: no-commit-to-branch SKIP: no-commit-to-branch
@@ -41,33 +36,29 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.10", "3.11"] python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"] pytorch_version: ["2.3.1", "2.4.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
- name: Check out repository code - name: Check out repository code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python_version }} python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch - name: Install PyTorch
run: | run: |
pip3 install torch==${{ matrix.pytorch_version }} pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 show torch pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e . pip3 install -U -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-tests.txt
- name: Run tests - name: Run tests
run: | run: |
@@ -77,52 +68,12 @@ jobs:
run: | run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest]
strategy:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests
docker-e2e-tests: docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud' if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 90 timeout-minutes: 60
needs: [pre-commit, pytest, docker-e2e-tests-1st] needs: [pre-commit, pytest]
strategy: strategy:
fail-fast: false fail-fast: false
@@ -134,10 +85,16 @@ jobs:
pytorch: 2.3.1 pytorch: 2.3.1
num_gpus: 1 num_gpus: 1
axolotl_extras: mamba-ssm axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.4.0
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
steps: steps:

View File

@@ -1,3 +1,3 @@
[settings] [settings]
profile=black profile=black
known_third_party=wandb,comet_ml known_third_party=wandb

View File

@@ -14,7 +14,7 @@ Features:
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking - Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed - Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud - Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet - Log results and optionally checkpoints to wandb or mlflow
- And more! - And more!
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25"> <a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
@@ -121,7 +121,7 @@ Features:
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task. Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1. **Requirements**: Python >=3.10 and Pytorch >=2.1.1.
```bash ```bash
git clone https://github.com/axolotl-ai-cloud/axolotl git clone https://github.com/axolotl-ai-cloud/axolotl
@@ -159,7 +159,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
#### Docker #### Docker
```bash ```bash
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
``` ```
Or run on the current files for development: Or run on the current files for development:
@@ -178,7 +178,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
A more powerful Docker command to run would be this: A more powerful Docker command to run would be this:
```bash ```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
``` ```
It additionally: It additionally:
@@ -210,7 +210,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
#### Cloud GPU #### Cloud GPU
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags) For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c) - on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl) - on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
@@ -319,7 +319,7 @@ Write a job description in YAML as below:
# dstack.yaml # dstack.yaml
type: task type: task
image: axolotlai/axolotl-cloud:main-latest image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
env: env:
- HUGGING_FACE_HUB_TOKEN - HUGGING_FACE_HUB_TOKEN
@@ -383,10 +383,11 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript - typescript
type: ... # unimplemented custom format type: ... # unimplemented custom format
# chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template # fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ... - path: ...
type: chat_template type: sharegpt
chat_template: chatml # defaults to tokenizer's chat_template conversation: chatml # default: vicuna_v1.1
# local # local
- path: data.jsonl # or json - path: data.jsonl # or json
@@ -514,22 +515,6 @@ wandb_name:
wandb_log_model: wandb_log_model:
``` ```
##### Comet Logging
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
- wandb options
```yaml
use_comet:
comet_api_key:
comet_workspace:
comet_project_name:
comet_experiment_key:
comet_mode:
comet_online:
comet_experiment_config:
```
##### Special Tokens ##### Special Tokens
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
@@ -561,8 +546,7 @@ plugins:
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true
liger_rms_norm: true liger_rms_norm: true
liger_glu_activation: true liger_swiglu: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
``` ```

View File

@@ -1,4 +1,4 @@
FROM axolotlai/axolotl-base:{{ BASE_TAG }} FROM winglian/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}" ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
@@ -23,12 +23,11 @@ RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \ sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \ sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \ sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
@@ -38,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
fi fi
# So we can test the Docker image # So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt RUN pip install -r requirements-tests.txt
# fix so that git fetch/pull from remote works # fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
set -e set -e
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/ pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -10,7 +10,7 @@ import tempfile
import jinja2 import jinja2
import modal import modal
from jinja2 import select_autoescape from jinja2 import select_autoescape
from modal import App, Image from modal import Image, Stub
cicd_path = pathlib.Path(__file__).parent.resolve() cicd_path = pathlib.Path(__file__).parent.resolve()
@@ -46,7 +46,7 @@ cicd_image = (
.pip_install("fastapi==0.110.0", "pydantic==2.6.3") .pip_install("fastapi==0.110.0", "pydantic==2.6.3")
) )
app = App("Axolotl CI/CD", secrets=[]) stub = Stub("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 2)) N_GPUS = int(os.environ.get("N_GPUS", 2))
@@ -61,10 +61,10 @@ def run_cmd(cmd: str, run_folder: str):
exit(exit_code) # pylint: disable=consider-using-sys-exit exit(exit_code) # pylint: disable=consider-using-sys-exit
@app.function( @stub.function(
image=cicd_image, image=cicd_image,
gpu=GPU_CONFIG, gpu=GPU_CONFIG,
timeout=60 * 60, timeout=45 * 60,
cpu=8.0, cpu=8.0,
memory=131072 * N_GPUS, memory=131072 * N_GPUS,
) )
@@ -72,6 +72,6 @@ def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl") run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
@app.local_entrypoint() @stub.local_entrypoint()
def main(): def main():
cicd_pytest.remote() cicd_pytest.remote()

View File

@@ -2,4 +2,4 @@
set -e set -e
# only run one test at a time so as not to OOM the GPU # only run one test at a time so as not to OOM the GPU
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ pytest -n1 /workspace/axolotl/tests/e2e/multigpu/

View File

@@ -10,7 +10,7 @@ import tempfile
import jinja2 import jinja2
import modal import modal
from jinja2 import select_autoescape from jinja2 import select_autoescape
from modal import App, Image from modal import Image, Stub
cicd_path = pathlib.Path(__file__).parent.resolve() cicd_path = pathlib.Path(__file__).parent.resolve()
@@ -47,7 +47,7 @@ cicd_image = (
.pip_install("fastapi==0.110.0", "pydantic==2.6.3") .pip_install("fastapi==0.110.0", "pydantic==2.6.3")
) )
app = App("Axolotl CI/CD", secrets=[]) stub = Stub("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 1)) N_GPUS = int(os.environ.get("N_GPUS", 1))
@@ -62,10 +62,10 @@ def run_cmd(cmd: str, run_folder: str):
exit(exit_code) # pylint: disable=consider-using-sys-exit exit(exit_code) # pylint: disable=consider-using-sys-exit
@app.function( @stub.function(
image=cicd_image, image=cicd_image,
gpu=GPU_CONFIG, gpu=GPU_CONFIG,
timeout=60 * 60, timeout=45 * 60,
cpu=8.0, cpu=8.0,
memory=131072, memory=131072,
) )
@@ -73,6 +73,6 @@ def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl") run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
@app.local_entrypoint() @stub.local_entrypoint()
def main(): def main():
cicd_pytest.remote() cicd_pytest.remote()

View File

@@ -14,6 +14,15 @@
"bf16": { "bf16": {
"enabled": true "enabled": true
}, },
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"gradient_clipping": "auto", "gradient_clipping": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",

View File

@@ -24,6 +24,15 @@
"bf16": { "bf16": {
"enabled": true "enabled": true
}, },
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"gradient_clipping": "auto", "gradient_clipping": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",

View File

@@ -20,6 +20,15 @@
"bf16": { "bf16": {
"enabled": true "enabled": true
}, },
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"gradient_clipping": "auto", "gradient_clipping": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",

View File

@@ -1,4 +1,4 @@
# Example config for debugging the chat_template prompt format # Example config for debugging the sharegpt prompt format
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
@@ -7,8 +7,8 @@ load_in_8bit: true
load_in_4bit: false load_in_4bit: false
datasets: datasets:
- path: fozziethebeat/alpaca_messages_2k_test - path: philschmid/guanaco-sharegpt-style
type: chat_template type: sharegpt
shards: 10 shards: 10
val_set_size: 0 val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model output_dir: temp_debug/axolotl_outputs/model

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main-base ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base:$BASE_TAG FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""
@@ -20,6 +20,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main-base ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base:$BASE_TAG FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""

View File

@@ -83,15 +83,22 @@ lora_on_cpu: true
datasets: datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
trust_remote_code: # Optional[bool] Trust remote code for untrusted source # Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].
# Custom user instruction prompt # Custom user instruction prompt
- path: repo - path: repo
@@ -116,48 +123,6 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column # For `completion` datsets only, uses the provided field instead of `text` column
field: field:
# Using chat template
- path: ...
# Set type to `chat_template` to use this strategy
type: chat_template
# Specify the name of the chat template to use
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
chat_template_jinja:
# The key in the data example that contains the messages. Default is "messages".
field_messages: messages
# The key in the message turn that contains the role. Default is "role".
message_field_role: role
# The key in the message turn that contains the content. Default is "content".
message_field_content: content
# Optional[Dict[str, List]]. Roles mapping for the messages.
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant", "ai"]
system: ["system"]
## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["gpt", "assistant"]
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
# - turn: train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
# See example at `docs/dataset-formats/conversation.qmd`
message_field_training_detail: train_detail
# If false, the datasets will not be shuffled and will keep their original order in `datasets`. # If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true. # The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true shuffle_merged_datasets: true
@@ -175,19 +140,10 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto' # use RL training: 'dpo', 'ipo', 'kto'
rl: rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
# The name of the chat template to use for training, following values are supported: # Saves the desired chat template to the tokenizer_config.json for easier inferencing
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. # Currently supports chatml and inst (mistral/mixtral)
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py chat_template: chatml
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message # Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml. default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so # Axolotl attempts to save the dataset as an arrow after packing the data together so
@@ -309,21 +265,8 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it # mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name mlflow_experiment_name: # Your experiment name
mlflow_run_name: # Your run name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Comet configuration if you're using it
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
use_comet: # Enable or disable Comet integration.
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
# Where to save the full-finetuned model to # Where to save the full-finetuned model to
output_dir: ./completed-model output_dir: ./completed-model
@@ -358,7 +301,7 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"] eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
@@ -406,7 +349,6 @@ lr_div_factor: # Learning rate div factor
# - adamw_torch_fused # - adamw_torch_fused
# - adamw_torch_xla # - adamw_torch_xla
# - adamw_apex_fused # - adamw_apex_fused
# - adopt_adamw (only for torch version >= 2.5.1)
# - adafactor # - adafactor
# - adamw_anyprecision # - adamw_anyprecision
# - sgd # - sgd

View File

@@ -6,8 +6,31 @@ order: 3
## sharegpt ## sharegpt
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below. conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```{.json filename="data.jsonl"}
{"conversations": [{"from": "...", "value": "..."}]}
```
Note: `type: sharegpt` opens special configs:
- `conversation`: enables conversions to many Conversation types. Refer to the 'name' [here](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) for options.
- `roles`: allows you to specify the roles for input and output. This is useful for datasets with custom roles such as `tool` etc to support masking.
- `field_human`: specify the key to use instead of `human` in the conversation.
- `field_model`: specify the key to use instead of `gpt` in the conversation.
```yaml
datasets:
path: ...
type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].
```
## pygmalion ## pygmalion
@@ -15,137 +38,34 @@ IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
{"conversations": [{"role": "...", "value": "..."}]} {"conversations": [{"role": "...", "value": "..."}]}
``` ```
## sharegpt.load_role
## chat_template conversations where `role` is used instead of `from`
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
```{.json filename="data.jsonl"} ```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "content": "..."}]} {"conversations": [{"role": "...", "value": "..."}]}
``` ```
See `config.qmd` for full configs and supported templates. ## sharegpt.load_guanaco
### Migrating from sharegpt conversations where `from` is `prompter` `assistant` instead of default sharegpt
Most configs can be adapted as follows:
```yaml
# old
chat_template: chatml
datasets:
- path: ...
type: sharegpt
conversation: chatml
# new (if using tokenizer's chat_template)
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
```
We recommend checking the below examples for other usecases.
### Examples
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
- path: ...
type: chat_template
```
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
```{.json filename="data.jsonl"} ```{.json filename="data.jsonl"}
{ {"conversations": [{"from": "...", "value": "..."}]}
"conversations": [
{"from": "system", "value": "You are an AI assistant.", "train": false},
{"from": "human", "value": "Hello", "train": false},
{"from": "assistant", "value": "Hello", "train": true},
{"from": "human", "value": "How are you?", "train": true},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": false},
{"begin_offset": 9, "end_offset": 18, "train": true},
{"begin_offset": 19, "end_offset": 30, "train": false},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": true,
},
{"from": "assistant", "value": "Hi there!", "train": true}
]
}
``` ```
The configuration would look like: ## sharegpt.load_ultrachat
```yaml conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'.
datasets:
- path: ... ```{.json filename="data.jsonl"}
type: chat_template {"messages": [{"user": "...", "assistant": "..."}]}
chat_template: tokenizer_default
field_messages: conversations
message_field_role: from
message_field_content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train
message_field_training_detail: train_detail
``` ```
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time. ## sharegpt_jokes
creates a chat where bot is asked to tell a joke, then explain why the joke is funny
```{.json filename="data.jsonl"}
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
```

View File

@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
### Background ### Background
The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config: The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
```yaml ```yaml
datasets: datasets:
- path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test - path: <path to your sharegpt formatted dataset> # example on HF Hub: philschmid/guanaco-sharegpt-style
type: chat_template type: sharegpt
``` ```
>[!Important] >[!Important]
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs. The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch. For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
```jsonc ```jsonc
// .vscode/launch.json // .vscode/launch.json
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{ {
"name": "Debug axolotl prompt - chat_template", "name": "Debug axolotl prompt - sharegpt",
"type": "python", "type": "python",
"module": "accelerate.commands.launch", "module": "accelerate.commands.launch",
"request": "launch", "request": "launch",
"args": [ "args": [
"-m", "axolotl.cli.train", "dev_chat_template.yml", "-m", "axolotl.cli.train", "dev_sharegpt.yml",
// The flags below simplify debugging by overriding the axolotl config // The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed. // with the debugging tips above. Modify as needed.
"--dataset_processes=1", // limits data preprocessing to one process "--dataset_processes=1", // limits data preprocessing to one process
@@ -185,7 +185,7 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
## Debugging With Docker ## Debugging With Docker
Using [official Axolotl Docker images](https://hub.docker.com/r/axolotlai/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps. Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
### Setup ### Setup
@@ -202,11 +202,11 @@ cd axolotl
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2] Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
```bash ```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1 docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
``` ```
>[!Tip] >[!Tip]
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml). > To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/winglian/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
You will now be in the container. Next, perform an editable install of Axolotl: You will now be in the container. Next, perform an editable install of Axolotl:
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
</div> </div>
<br> <br>
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing. [^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html). [^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).

View File

@@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
hi there!. goodbye farewell</s> hi there!. goodbye farewell</s>
``` ```
We can check that the right tokens are ignored by comparing the labels We can check that the right tokens are ingored by comparing the labels
to each token: to each token:
```python ```python

View File

@@ -1,28 +0,0 @@
# MultiModal / Vision Language Models (BETA)
### Supported Models
- Mllama, i.e. llama with vision models
### Usage
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
```yaml
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
skip_prepare_dataset: true
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
remove_unused_columns: false
sample_packing: false
# only finetune the Language model, leave the vision model and vision tower frozen
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
```

17
docs/optimizers.qmd Normal file
View File

@@ -0,0 +1,17 @@
# Optimizers
## Shampoo
```yaml
optimizer: shampoo
optim_shampoo_betas: [0.9, 0.999]
optim_args:
epsilon: 1e-12
max_preconditioner_dim: 8192
precondition_frequency: 100
use_decoupled_weight_decay: true
optim_shampoo_grafting_config_type: adam
optim_shampoo_grafting_config_kwargs:
beta2: 0.999
epsilon: 1e-12
```

View File

@@ -44,7 +44,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n", "!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.7.0.post2\"\n", "!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\"" "!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
] ]
}, },

View File

@@ -9,17 +9,14 @@ strict: false
plugins: plugins:
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rms_norm: true liger_rms_norm: true
liger_glu_activation: true liger_swiglu: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
chat_template: deepseek_v2 chat_template: deepseek_v2
datasets: datasets:
- path: mlabonne/FineTome-100k - path: mlabonne/FineTome-100k
type: chat_template type: chat_template
split: train[:20%] split: train
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0

View File

@@ -11,11 +11,8 @@ chat_template: gemma
datasets: datasets:
- path: cgato/SlimOrcaDedupCleaned - path: cgato/SlimOrcaDedupCleaned
type: chat_template type: chat_template
chat_template: gemma
drop_system_message: true drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out output_dir: ./outputs/out

View File

@@ -1,63 +0,0 @@
base_model: google/gemma-2-2b
model_type: AutoModelForSequenceClassification
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
reward_model: true
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -4,15 +4,11 @@ tokenizer_type: AutoTokenizer
load_in_4bit: true load_in_4bit: true
strict: false strict: false
use_tensorboard: true use_tensorboard: true
chat_template: jamba
datasets: datasets:
- path: cgato/SlimOrcaDedupCleaned - path: cgato/SlimOrcaDedupCleaned
type: chat_template type: chat_template
chat_template: jamba
drop_system_message: true drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.0 val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft output_dir: jamba-large-fsdp-qlora-ft

View File

@@ -1,63 +0,0 @@
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -4,7 +4,7 @@ plugins:
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true
liger_rms_norm: true liger_rms_norm: true
liger_glu_activation: true liger_swiglu: true
liger_fused_linear_cross_entropy: true liger_fused_linear_cross_entropy: true
strict: false strict: false
@@ -14,10 +14,6 @@ datasets:
- path: mlabonne/FineTome-100k - path: mlabonne/FineTome-100k
type: chat_template type: chat_template
split: train[:20%] split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
output_dir: ./outputs/out output_dir: ./outputs/out

View File

@@ -11,6 +11,7 @@ rl: dpo
datasets: datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test - path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default type: chat_template.default
chat_template: llama3
field_messages: conversation field_messages: conversation
field_chosen: chosen field_chosen: chosen
field_rejected: rejected field_rejected: rejected

View File

@@ -10,6 +10,7 @@ chat_template: llama3
datasets: datasets:
- path: fozziethebeat/alpaca_messages_2k_test - path: fozziethebeat/alpaca_messages_2k_test
type: chat_template type: chat_template
chat_template: llama3
field_messages: messages field_messages: messages
message_field_role: role message_field_role: role
message_field_content: content message_field_content: content

View File

@@ -1,77 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -1,93 +0,0 @@
#Note that we are switching from the regular chat template to chatml.
#If you experience problems with the special tokens, training for more epochs can help.
#After training, merge the model before inference otherwise you might
#face problems with the special tokens.
base_model: mistralai/Mistral-7B-Instruct-v0.2
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
chat_template: chatml
rl: dpo
datasets:
- path: olivermolenschot/alpaca_messages_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-qlora
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: true
adapter: qlora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.2
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 16
num_epochs: 6
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<|im_start|>"
eos_token: "<|im_end|>"

View File

@@ -10,6 +10,7 @@ chat_template: phi_3
datasets: datasets:
- path: fozziethebeat/alpaca_messages_2k_test - path: fozziethebeat/alpaca_messages_2k_test
type: chat_template type: chat_template
chat_template: phi_3
field_messages: messages field_messages: messages
message_field_role: role message_field_role: role
message_field_content: content message_field_content: content

View File

@@ -1,67 +0,0 @@
base_model: Qwen/Qwen2.5-0.5B
strict: false
chat_template: qwen_25
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/dpo-out
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -2,4 +2,3 @@ pre-commit
black black
mypy mypy
types-requests types-requests
tbparse

View File

@@ -1,3 +1,2 @@
pytest pytest
pytest-xdist pytest-xdist
pytest-retry

View File

@@ -1,22 +1,22 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.13.2 peft==0.12.0
transformers==4.46.2 transformers @ git+https://github.com/huggingface/transformers.git@0963229e287501bed52ae1dabc17922524de6992
tokenizers>=0.20.1 tokenizers>=0.19.1
bitsandbytes==0.44.1 bitsandbytes==0.43.3
accelerate==1.1.0 accelerate==0.34.2
datasets==3.1.0 datasets==2.21.0
deepspeed==0.15.3 deepspeed==0.14.4
pydantic==2.6.3 pydantic==2.6.3
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
requests requests
flash-attn==2.7.0.post2 flash-attn==2.6.3
sentencepiece sentencepiece
wandb wandb
einops einops
xformers>=0.0.23.post1 xformers==0.0.27
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
colorama colorama
@@ -28,12 +28,14 @@ scipy
scikit-learn==1.4.2 scikit-learn==1.4.2
pynvml pynvml
art art
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard
python-dotenv==1.0.1 python-dotenv==1.0.1
autoawq>=0.2.5 autoawq>=0.2.5
triton>=2.3.0 triton>=2.3.0
liger-kernel==0.4.1 liger-kernel==0.2.1
distributed_shampoo @ git+https://github.com/facebookresearch/optimizers.git@main
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
@@ -42,15 +44,6 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0 gcsfs>=2024.5.0
# adlfs # adlfs
trl==0.12.0 trl==0.9.6
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore
# lm eval harness
lm_eval==0.4.4
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.5.0
schedulefree==1.3.0

View File

@@ -1,315 +0,0 @@
accelerate==0.34.1
addict==2.4.0
aiofiles==23.2.1
aiohttp==3.9.0
aiosignal==1.3.1
aiostream==0.5.2
alembic==1.13.1
annotated-types==0.6.0
annoy==1.17.3
ansible==6.7.0
ansible-core==2.13.13
ansible-vault==2.1.0
anyio==3.7.1
appdirs==1.4.4
art==6.0
asgiref==3.7.2
async-timeout==4.0.2
attrdict==2.0.1
attrs==22.2.0
awscli==1.32.75
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
backoff==2.2.1
base58==2.1.1
beartype==0.17.2
bitnet==0.2.1
bitsandbytes==0.42.0
bittensor==6.7.0
black==23.7.0
blinker==1.7.0
boto3==1.34.75
botocore==1.34.75
cachetools==5.3.3
cachy==0.1.1
certifi==2023.7.22
cffi==1.16.0
cfgv==3.3.1
chai-guanaco==1.2.4
charset-normalizer==3.2.0
cleo==0.6.8
click==8.1.7
cloudpickle==2.0.0
cohere==4.11.2
colorama==0.4.4
coloredlogs==15.0.1
CoLT5-attention==0.10.20
contextlib2==21.6.0
contourpy==1.2.0
cryptography==41.0.3
cycler==0.12.1
cytoolz==0.12.3
databricks-cli==0.18.0
dataclasses-json==0.5.7
datasets==2.11.0
ddt==1.6.0
decorator==5.1.1
deepspeed==0.15.0
# Editable Git install with no remote (dialogpt==0.1)
-e /Users/wing/Projects/ml/dialogpt/src
dill==0.3.6
distlib==0.3.6
docker==7.0.0
docker-pycreds==0.4.0
docstring-parser==0.15
docutils==0.16
ecdsa==0.18.0
einops==0.7.0
einops-exts==0.0.4
einx==0.1.3
entrypoints==0.4
eth-hash==0.6.0
eth-keys==0.5.0
eth-typing==4.0.0
eth-utils==2.3.1
evaluate==0.4.0
exceptiongroup==1.1.1
fastapi==0.109.2
fastcore==1.5.29
ffmpy==0.4.0
filelock==3.12.2
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
fire==0.5.0
first==2.0.2
flake8==7.0.0
Flask==3.0.1
fonttools==4.47.2
frozendict==2.4.1
frozenlist==1.3.3
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
fsspec==2023.6.0
fuzzywuzzy==0.18.0
gitdb==4.0.10
GitPython==3.1.31
google-pasta==0.2.0
gradio==4.42.0
gradio_client==1.3.0
greenlet==2.0.2
grpclib==0.4.7
gunicorn==21.2.0
h11==0.14.0
h2==4.1.0
hpack==4.0.0
httpcore==0.17.3
httpx==0.24.1
huggingface-hub==0.23.4
humanfriendly==10.0
hyperframe==6.0.1
identify==2.5.24
idna==3.4
immutables==0.20
importlib-metadata==6.7.0
importlib-resources==6.1.1
inflection==0.5.1
iniconfig==2.0.0
itsdangerous==2.1.2
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.3.2
jsonlines==3.1.0
jsonschema==2.6.0
kiwisolver==1.4.5
langchain==0.0.144
Levenshtein==0.24.0
libcst==1.1.0
liger-kernel==0.0.0
lion-pytorch==0.1.2
llama-cpp-python==0.1.36
llvmlite==0.40.1
local-attention==1.9.0
loguru==0.7.0
Mako==1.3.2
Markdown==3.5.2
markdown-it-py==3.0.0
markdown2==2.4.10
MarkupSafe==2.1.2
marshmallow==3.19.0
marshmallow-enum==1.5.1
matplotlib==3.8.2
mccabe==0.7.0
mdurl==0.1.2
MEGABYTE-pytorch==0.0.7
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
mlflow==2.10.0
modal==0.62.77
more-itertools==10.2.0
mpmath==1.2.1
msgpack==1.0.7
msgpack-numpy-opentensor==0.5.0
multidict==6.0.4
multiprocess==0.70.14
munch==2.5.0
mypy==1.3.0
mypy-extensions==1.0.0
nest-asyncio==1.6.0
netaddr==0.10.1
networkx==3.0rc1
nh3==0.2.14
nodeenv==1.8.0
nomic==2.0.2
numba==0.57.1
numexpr==2.8.4
numpy==1.24.4
oauthlib==3.2.2
openai==0.27.4
openapi==1.1.0
openapi-schema-pydantic==1.2.4
optimum==1.8.6
orjson==3.10.7
packaging==23.1
pandas==2.0.0
parameterized==0.9.0
password-strength==0.0.3.post2
pastel==0.1.1
pathos==0.3.0
pathspec==0.11.1
pathtools==0.1.2
peft==0.11.1
pendulum==3.0.0
Pillow==9.5.0
pip-tools==1.11.0
platformdirs==3.2.0
pluggy==1.4.0
poetry==0.7.1
pox==0.3.2
ppft==1.7.6.6
pre-commit==3.3.2
prettytable==3.10.0
prompt-toolkit==3.0.39
protobuf==3.20.2
protobuf3-to-dict==0.1.5
psutil==5.9.5
psycopg==3.1.18
PuLP==2.8.0
py==1.11.0
py-bip39-bindings==0.1.11
py-cpuinfo==9.0.0
py-ed25519-zebra-bindings==1.0.1
py-sr25519-bindings==0.2.0
pyarrow==11.0.0
pyasn1==0.6.0
pycodestyle==2.11.1
pycparser==2.21
pycryptodome==3.20.0
pydantic==2.5.3
pydantic_core==2.14.6
pydub==0.25.1
pyfiglet==0.8.post1
pyflakes==3.2.0
Pygments==2.15.1
PyJWT==2.8.0
pylev==1.4.0
PyNaCl==1.5.0
pynvml==11.5.0
pyparsing==2.4.7
pyrsistent==0.14.11
pytest==8.0.2
pytest-asyncio==0.23.4
python-dateutil==2.8.2
python-dotenv==1.0.1
python-Levenshtein==0.24.0
python-multipart==0.0.9
pytz==2023.3
PyYAML==6.0.1
querystring-parser==1.2.4
rapidfuzz==3.6.1
regex==2023.6.3
requests==2.31.0
requests-toolbelt==0.8.0
resolvelib==0.8.1
responses==0.18.0
retry==0.9.2
rich==13.7.0
rsa==4.7.2
ruff==0.6.3
s3transfer==0.10.1
safetensors==0.4.5
sagemaker==2.148.0
scalecodec==1.2.7
schedulefree==1.2.1
schema==0.7.5
scikit-learn==1.4.0
scipy==1.9.3
seaborn==0.13.2
semantic-version==2.10.0
sentencepiece==0.2.0
sentry-sdk==1.19.1
setproctitle==1.3.2
shellingham==1.5.4
shortuuid==1.0.11
shtab==1.6.5
sigtools==4.0.1
six==1.16.0
skypilot==0.4.1
smdebug-rulesconfig==1.0.1
smmap==5.0.0
sniffio==1.3.0
SQLAlchemy==1.4.47
sqlparse==0.4.4
starlette==0.36.3
substrate-interface==1.5.2
svgwrite==1.4.3
sympy==1.11.1
synchronicity==0.6.7
tabulate==0.9.0
tblib==1.7.0
tenacity==8.2.2
tensor-parallel==2.0.0
termcolor==2.2.0
text2art==0.2.0
threadpoolctl==3.2.0
tiktoken==0.6.0
time-machine==2.14.1
timm==0.9.16
tokenizers==0.19.1
tokenmonster==1.1.12
toml==0.9.6
tomli==2.0.1
tomlkit==0.12.0
toolz==0.12.1
torch==2.2.0
torchdata==0.6.1
torchdiffeq==0.2.3
TorchFix==0.4.0
torchtext==0.15.2
torchvision==0.17.0
tqdm==4.66.2
transformers==4.44.2
trl==0.9.6
typer==0.12.5
types-certifi==2021.10.8.3
types-requests==2.31.0.20240125
types-setuptools==69.0.0.20240125
types-toml==0.10.8.7
typing==3.7.4.3
typing-inspect==0.8.0
typing_extensions==4.9.0
tyro==0.5.18
tzdata==2023.3
unique-names-generator==1.0.2
urllib3==2.2.2
uvicorn==0.22.0
vector_quantize_pytorch==1.14.1
virtualenv==20.23.0
voyager==2.0.2
wandb==0.16.2
watchfiles==0.21.0
wavedrom==2.0.3.post3
wcwidth==0.2.6
websocket-client==1.7.0
websockets==12.0
Werkzeug==3.0.1
wonderwords==2.2.0
xxhash==3.2.0
yarl==1.8.2
zetascale==2.2.7
zipp==3.15.0

View File

@@ -1,60 +0,0 @@
"""
helper script to parse chat datasets into a usable yaml
"""
import click
import yaml
from datasets import load_dataset
@click.command()
@click.argument("dataset", type=str)
@click.option("--split", type=str, default="train")
def parse_dataset(dataset=None, split="train"):
ds_cfg = {}
ds_cfg["path"] = dataset
ds_cfg["split"] = split
ds_cfg["type"] = "chat_template"
ds_cfg["chat_template"] = "<<<Replace based on your model>>>"
dataset = load_dataset(dataset, split=split)
features = dataset.features
feature_keys = features.keys()
field_messages = None
for key in ["conversation", "conversations", "messages"]:
if key in feature_keys:
field_messages = key
break
if not field_messages:
raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}'
)
ds_cfg["field_messages"] = field_messages
message_fields = features["conversations"][0].keys()
message_field_role = None
for key in ["from", "role"]:
if key in message_fields:
message_field_role = key
break
if not message_field_role:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role
message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_field_content = key
break
if not message_field_content:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_content"] = message_field_content
print(yaml.dump({"datasets": [ds_cfg]}))
if __name__ == "__main__":
parse_dataset()

View File

@@ -2,7 +2,7 @@
# Export specific ENV variables to /etc/rp_environment # Export specific ENV variables to /etc/rp_environment
echo "Exporting environment variables..." echo "Exporting environment variables..."
printenv | grep -E '^HF_|^BNB_|^CUDA_|^NCCL_|^NV|^RUNPOD_|^PATH=|^_=' | sed 's/^\([^=]*\)=\(.*\)$/export \1="\2"/' | grep -v 'printenv' >> /etc/rp_environment printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
echo 'source /etc/rp_environment' >> ~/.bashrc echo 'source /etc/rp_environment' >> ~/.bashrc
add_keys_to_authorized() { add_keys_to_authorized() {

View File

@@ -30,19 +30,13 @@ def parse_requirements():
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# don't install xformers on MacOS # don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
else: else:
# detect the version of torch already installed # detect the version of torch already installed
# and set it so dependencies don't clobber the torch version # and set it so dependencies don't clobber the torch version
try: torch_version = version("torch")
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
_install_requires.append(f"torch=={torch_version}") _install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version) version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -55,39 +49,20 @@ def parse_requirements():
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
if (major, minor) >= (2, 5): if (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0: if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1") _install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2): elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1") _install_requires.append("xformers>=0.0.25.post1")
else: else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1") _install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError: except PackageNotFoundError:
pass pass
return _install_requires, _dependency_links return _install_requires, _dependency_links
@@ -96,7 +71,7 @@ install_requires, dependency_links = parse_requirements()
setup( setup(
name="axolotl", name="axolotl",
version="0.5.0", version="0.4.1",
description="LLM Trainer", description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.", long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"}, package_dir={"": "src"},
@@ -105,7 +80,10 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn==2.7.0.post2", "flash-attn==2.6.3",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed==0.14.4", "deepspeed==0.14.4",
@@ -113,7 +91,6 @@ setup(
], ],
"mamba-ssm": [ "mamba-ssm": [
"mamba-ssm==1.2.0.post1", "mamba-ssm==1.2.0.post1",
"causal_conv1d",
], ],
"auto-gptq": [ "auto-gptq": [
"auto-gptq==0.5.1", "auto-gptq==0.5.1",

View File

@@ -30,8 +30,6 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import ( from axolotl.utils.config import (
normalize_cfg_datasets, normalize_cfg_datasets,
normalize_config, normalize_config,
@@ -41,7 +39,7 @@ from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -55,22 +53,8 @@ LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_axolotl_text_art(suffix=None):
def print_legacy_axolotl_text_art(suffix=None):
font = "nancyj" font = "nancyj"
ascii_text = " axolotl" ascii_text = " axolotl"
if suffix: if suffix:
@@ -83,13 +67,6 @@ def print_legacy_axolotl_text_art(suffix=None):
print_dep_versions() print_dep_versions()
def print_axolotl_text_art(
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
print(AXOLOTL_LOGO)
def print_dep_versions(): def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"] packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages) max_len = max(len(pkg) for pkg in packages)
@@ -190,15 +167,18 @@ def do_inference(
): ):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None prompter_module = None
chat_template_str = None
if prompter: if prompter:
prompter_module = getattr( prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter importlib.import_module("axolotl.prompters"), prompter
) )
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype) model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -208,31 +188,13 @@ def do_inference(
instruction = get_multi_line_input() instruction = get_multi_line_input()
if not instruction: if not instruction:
return return
if prompter_module: if prompter_module:
prompt: str = next( prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n")) prompter_module().build_prompt(instruction=instruction.strip("\n"))
) )
else: else:
prompt = instruction.strip() prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40) print("=" * 40)
model.eval() model.eval()
@@ -272,15 +234,18 @@ def do_inference_gradio(
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None prompter_module = None
chat_template_str = None
if prompter: if prompter:
prompter_module = getattr( prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter importlib.import_module("axolotl.prompters"), prompter
) )
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype) model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -294,24 +259,7 @@ def do_inference_gradio(
) )
else: else:
prompt = instruction.strip() prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@@ -334,7 +282,6 @@ def do_inference_gradio(
streamer = TextIteratorStreamer(tokenizer) streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = { generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device), "inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config, "generation_config": generation_config,
"streamer": streamer, "streamer": streamer,
} }
@@ -451,8 +398,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
setup_mlflow_env_vars(cfg) setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg return cfg
@@ -462,20 +407,12 @@ def load_datasets(
cli_args: TrainerCliArgs, cli_args: TrainerCliArgs,
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg, cfg, tokenizer
tokenizer,
processor=processor,
) )
if ( if cli_args.debug or cfg.debug:
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
check_dataset_labels( check_dataset_labels(
train_dataset.select( train_dataset.select(

View File

@@ -23,7 +23,10 @@ from axolotl.cli import (
) )
from axolotl.common.cli import PreprocessCliArgs from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.trainer import disable_datasets_caching from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
LOG = logging.getLogger("axolotl.cli.preprocess") LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -40,6 +43,23 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cfg.chat_template == "chatml":
if parsed_cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_chatml_template(parsed_cfg.default_system_message)
else:
register_chatml_template()
elif parsed_cfg.chat_template == "llama3":
if parsed_cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_llama3_template(parsed_cfg.default_system_message)
else:
register_llama3_template()
if not parsed_cfg.dataset_prepared_path: if not parsed_cfg.dataset_prepared_path:
msg = ( msg = (
Fore.RED Fore.RED
@@ -50,11 +70,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg) LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching(): if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) else:
else: load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.download: if parsed_cli_args.download:
model_name = parsed_cfg.base_model model_name = parsed_cfg.base_model

View File

@@ -3,11 +3,13 @@ CLI to run training on a model
""" """
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union from typing import Tuple, Union
import fire import fire
from dotenv import load_dotenv from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser from transformers.hf_argparser import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.cli import ( from axolotl.cli import (
check_accelerate_default_config, check_accelerate_default_config,
@@ -18,7 +20,10 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
from axolotl.train import train from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train") LOG = logging.getLogger("axolotl.cli.train")
@@ -34,23 +39,32 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return do_train(parsed_cfg, parsed_cli_args) return do_train(parsed_cfg, parsed_cli_args)
def do_train(cfg, cli_args) -> None: def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
if cfg.chat_template == "chatml" and cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {cfg.default_system_message}"
)
register_chatml_template(cfg.default_system_message)
else:
register_chatml_template()
if cfg.chat_template == "llama3" and cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
)
register_llama3_template(cfg.default_system_message)
else:
register_llama3_template()
if cfg.rl: # and cfg.rl != "orpo": if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -10,7 +10,6 @@ MOE_ARCH_BLOCK = {
"JetMoeMoE", "JetMoeMoE",
], ],
"mixtral": "MixtralSparseMoeBlock", "mixtral": "MixtralSparseMoeBlock",
"phimoe": "PhiMoESparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock", "qwen2_moe": "Qwen2MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE", "deepseek_v2": "DeepseekV2MoE",
} }

View File

@@ -23,7 +23,7 @@ class TrainerCliArgs:
debug: bool = field(default=False) debug: bool = field(default=False)
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0) debug_num_examples: int = field(default=5)
inference: bool = field(default=False) inference: bool = field(default=False)
merge_lora: bool = field(default=False) merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)

View File

@@ -1,34 +0,0 @@
"""
ChatML transformation functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages
from .shared import wrap_tools
def format_message(
message: Messages,
message_index: Optional[int] = None, # pylint: disable=unused-argument
) -> Messages:
if message.is_chat_formatted:
return message
# prepend the role prefix within a MessageContents to message.content
message.content.insert(
0,
MessageContents(
type="text",
value=f"<|im_start|>{message.role}\n",
weight=0,
),
)
message.content.append(
MessageContents(type="text", value="<|im_end|>", weight=message.weight)
)
message.content.append(MessageContents(type="text", value="\n", weight=0))
message = wrap_tools(message)
message.is_chat_formatted = True
return message

View File

@@ -1,45 +0,0 @@
"""
Llama 3.x chat formatting functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages
from .shared import wrap_tools
def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
if message.is_chat_formatted:
return message
message_role = message.role
if message.role == "tool":
message_role = "ipython"
# prepend the role prefix within a MessageContents to message.content
message.content.insert(
0,
MessageContents(
type="text",
value=f"<|start_header_id|>{message_role}<|end_header_id|>\n\n",
weight=0,
),
)
message.content.append(
MessageContents(type="text", value="<|eot_id|>", weight=message.weight)
)
message = wrap_tools(message)
if message_index == 0:
message.content.insert(
0,
MessageContents(
type="text",
value="<|begin_of_text|>",
weight=0,
),
)
message.is_chat_formatted = True
return message

View File

@@ -1,47 +0,0 @@
"""
shared functions for format transforms
"""
from axolotl.core.chat.messages import MessageContents, Messages
def wrap_tools(message: Messages):
# loop over message.content by index to find tool calls, we need to wrap each with tags,
# so be wary of indexing issues when changing the list while iterating.
# iterate over the range in reverse order to avoid index shifting
for i in range(len(message.content) - 1, -1, -1):
if message.content[i].type == "tool_call":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_call>\n", weight=message.weight
),
)
# make sure the actual tool call content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i,
MessageContents(
type="text", value="<tool_call>\n", weight=message.weight
),
)
elif message.content[i].type == "tool_response":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_response>\n", weight=message.weight
),
)
# make sure the actual tool response content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i,
MessageContents(
type="text", value="<tool_response>\n", weight=message.weight
),
)
return message

View File

@@ -1,230 +0,0 @@
"""
internal message representations of chat messages
"""
import json
from enum import Enum
from typing import Any, Callable, List, Optional, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizer
class MessageRoles(str, Enum):
"""
Message roles for the system, user, assistant, and tools
"""
system = "system" # pylint: disable=invalid-name
user = "user" # pylint: disable=invalid-name
assistant = "assistant" # pylint: disable=invalid-name
tool = "tool" # pylint: disable=invalid-name
ipython = ( # pylint: disable=invalid-name
# for responses from builtin tools
"ipython"
)
class MessageContentTypes(str, Enum):
"""
Message content types for text, image, audio, tool calls, and tool responses
"""
special_token = "special_token" # pylint: disable=invalid-name # nosec B105
text = "text" # pylint: disable=invalid-name
image = "image" # pylint: disable=invalid-name
audio = "audio" # pylint: disable=invalid-name
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
tool_response = "tool_response" # pylint: disable=invalid-name
class SpecialToken(str, Enum):
"""
Special tokens for beginning of string and end of string
"""
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
class ToolCallFunction(BaseModel):
"""
Tool call function with name and arguments
"""
name: str
arguments: dict[str, str]
class Tool(BaseModel):
"""
Tool with description, function, and parameters
"""
description: str
function: ToolCallFunction
parameters: dict[str, str] # .properties
class ToolCallContents(BaseModel):
"""
Tool call contents with name, arguments, and optional id
"""
name: str
arguments: dict[str, Union[str, int]]
id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "arguments": self.arguments}
if self.id is not None:
data["id"] = self.id
return json.dumps(data)
class ToolResponseContents(BaseModel):
"""
Tool response contents with name, content, and optional id
"""
name: str
content: Union[str, dict[str, Union[str, int, float]]]
id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "content": self.content}
if self.id is not None:
data["id"] = self.id
return json.dumps(data)
class MessageContents(BaseModel):
"""
Message contents with type, value, metadata, weight, newline, and end of contents
"""
type: Union[str, MessageContentTypes]
value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
weight: Optional[Union[int, float]] = None
has_newline: bool = False
eoc: bool = False # end of contents
def __str__(self) -> str:
str_val = str(self.value)
if self.has_newline and not str_val.endswith("\n"):
str_val += "\n"
return str_val
class Messages(BaseModel):
"""
Messages with role, content, metadata, weight, and chat formatting
"""
role: Union[MessageRoles, str] # allows for arbitrary roles
content: List["MessageContents"]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
weight: Optional[Union[int, float]] = None
is_chat_formatted: bool = False
def __str__(self) -> str:
return "".join(str(c) for c in self.content)
def tokenized(
self, tokenizer: PreTrainedTokenizer, ignore_index=-100
) -> dict[str, List[int]]:
# iterate over the contents, tokenizing the concatenated string values up to the current MessageContents
# returns a dictionary mapping w input_ids, attention_mask, and labels
input_ids: List[int] = []
labels: List[int] = []
pending_input_ids: List[int] = []
pending_weight = self.weight
running_content = ""
for _, msg_content in enumerate(self.content):
# TODO also handle non-text content types
if msg_content.type in [
MessageContentTypes.text.value,
MessageContentTypes.tool_call.value,
MessageContentTypes.tool_response.value,
]:
running_content += str(msg_content)
tok_results = tokenizer(running_content, add_special_tokens=False)
tok_input_ids = tok_results["input_ids"]
if pending_input_ids:
new_pending_inputs = tok_input_ids[
len(input_ids) : len(input_ids) + len(pending_input_ids)
]
if new_pending_inputs != pending_input_ids:
# logging.warning("tokenization mismatch from concatenation.")
pending_input_ids = new_pending_inputs
input_ids.extend(pending_input_ids)
if pending_weight:
labels.extend(pending_input_ids)
else:
labels.extend([ignore_index] * len(pending_input_ids))
pending_input_ids = tok_results["input_ids"][len(input_ids) :]
pending_weight = self.weight and msg_content.weight not in [0, 0.0]
input_ids.extend(pending_input_ids)
if pending_weight:
labels.extend(pending_input_ids)
else:
labels.extend([ignore_index] * len(pending_input_ids))
attention_mask = [1] * len(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class Chats(BaseModel):
"""
top level data structure for chat conversations
"""
conversation: List[Messages]
def __str__(self) -> str:
return "".join(str(c) for c in self.conversation)
def tokenized(
self, tokenizer: Callable[[str], dict[str, List[int]]], ignore_index=-100
) -> dict[str, List[int]]:
input_ids = []
attention_mask = []
labels = []
for msg in self.conversation:
msg_results = msg.tokenized(tokenizer, ignore_index)
input_ids.extend(msg_results["input_ids"])
attention_mask.extend(msg_results["attention_mask"])
labels.extend(msg_results["labels"])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class ChatFormattedChats(Chats):
"""
Chat formatted chats with formatter and optional train on inputs
"""
formatter: Callable # [[Union[dict, Chats]], Chats]
train_on_inputs: bool = False
def model_post_init(self, __context):
for i, msg in enumerate(self.conversation):
self.conversation[i] = self.formatter(msg, message_index=i)
if self.train_on_inputs:
self.conversation[i].weight = 1
class PreferenceChats(BaseModel):
"""
representation for preference data for chat
"""
prompt: List[Messages]
chosen: Messages
rejected: Messages

View File

@@ -1,55 +0,0 @@
"""
chat dataset module
"""
import os
from typing import Callable, Optional, Union
from datasets import Dataset
from transformers import PreTrainedTokenizer
from axolotl.core.chat.messages import ChatFormattedChats
class TokenizedChatDataset(Dataset):
"""
Tokenized chat dataset
"""
def __init__(
self,
data: Dataset,
model_transform: Union[PreTrainedTokenizer, Callable],
*args,
message_transform: Optional[Callable] = None,
formatter=None,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,
):
def map_fn(ex):
if message_transform is not None:
ex = message_transform(ex)
if formatter is not None:
ex = ChatFormattedChats(
formatter=formatter,
**ex,
)
else:
ex = ChatFormattedChats(
**ex,
)
return ex.tokenized(model_transform)
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(64, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,
num_proc=num_proc,
keep_in_memory=keep_in_memory,
remove_columns=features,
desc="Tokenizing Chats",
)
super().__init__(tokenized_data.data, *args, **kwargs)

View File

@@ -1,150 +0,0 @@
"""
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
"""
from typing import Any, Mapping, Union
def chat_message_transform_builder( # pylint: disable=dangerous-default-value
train_on_inputs=False,
conversations_field: str = "conversations",
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
message_field_content: Union[str, list[str]] = [
"value",
"text",
"content",
], # commonly "content"
message_field_training: Union[str, list[str]] = [
"train",
"weight",
], # commonly "weight"
):
"""Builds a transform that takes a row from the dataset and converts it to a Chat
Args:
train_on_inputs (bool, optional):
If True, the transform will train on the inputs. If False, the transform will train on the targets.
Defaults to False.
conversations_field (str, optional):
The field name of the conversations. Defaults to "conversations".
message_field_role (str | list[str], optional):
The field name of the role. Defaults to "role".
message_field_content (str | list[str], optional):
The field name of the message content. Defaults to "content".
message_field_training (str | list[str], optional):
The field name of the train/weight. Defaults to "weight".
Returns:
Callable:
A function that takes a list of conversations and returns a list of messages.
"""
message_field_role = (
[message_field_role]
if isinstance(message_field_role, str)
else message_field_role
)
message_field_content = (
[message_field_content]
if isinstance(message_field_content, str)
else message_field_content
)
message_weight_fields = (
[message_field_training]
if isinstance(message_field_training, str)
else message_field_training
)
role_value_mappings = {
"system": "system",
"user": "user",
"human": "user",
"assistant": "assistant",
"gpt": "assistant",
"tool": "tool",
"ipython": "ipython",
}
if train_on_inputs:
role_default_weights_mappings = {
"system": 1,
"user": 1,
"assistant": 1,
"tool": 1,
"ipython": 1,
}
else:
role_default_weights_mappings = {
"system": 0,
"user": 0,
"assistant": 1,
"tool": 0,
"ipython": 0,
}
def transform_builder(sample: Mapping[str, Any]):
if conversations_field not in sample:
raise ValueError(f"Field '{conversations_field}' not found in sample.")
# if none of the role fields are in the message, raise an error
if not any(
role in sample[conversations_field][0] for role in message_field_role
):
raise ValueError("No role field found in message.")
role_field = next(
role
for role in message_field_role
if role in sample[conversations_field][0]
)
if not any(
field in sample[conversations_field][0] for field in message_field_content
):
raise ValueError("No message_content field found in message.")
message_content_field = next(
field
for field in message_field_content
if field in sample[conversations_field][0]
)
if not any(
field in sample[conversations_field][0] for field in message_field_training
):
message_weight_field = None
else:
message_weight_field = next(
field
for field in message_weight_fields
if field in sample[conversations_field][0]
)
messages = []
for message in sample[conversations_field]:
role = role_value_mappings[message[role_field]]
weight = (
int(message[message_weight_field])
if message_weight_field
else role_default_weights_mappings[role]
)
# TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents
if isinstance(message[message_content_field], str):
messages.append(
{
"role": role,
"content": [
{
"type": "text",
"value": message[message_content_field],
}
],
"weight": weight,
}
)
else:
messages.append(
{
"role": role,
"content": message[message_content_field],
"weight": weight,
}
)
return {"conversation": messages}
return transform_builder

View File

@@ -7,7 +7,6 @@ import abc
import gc import gc
import importlib import importlib
import importlib.util import importlib.util
import inspect
import logging import logging
import math import math
import os import os
@@ -17,17 +16,17 @@ from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import ( from transformers import (
EarlyStoppingCallback, EarlyStoppingCallback,
PreTrainedModel,
Trainer, Trainer,
TrainerCallback, TrainerCallback,
TrainingArguments, TrainingArguments,
@@ -43,15 +42,13 @@ from trl import (
KTOTrainer, KTOTrainer,
ORPOConfig, ORPOConfig,
ORPOTrainer, ORPOTrainer,
RewardConfig,
RewardTrainer,
) )
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length from trl.trainer.utils import pad_to_length
from axolotl.integrations.base import PluginManager from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
@@ -64,14 +61,12 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory, log_prediction_callback_factory,
) )
from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import ( from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
MambaDataCollator, MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
) )
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
@@ -255,10 +250,11 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate lr scheduler to the HF trainer" "help": "workaround to pass an alternate lr scheduler to the HF trainer"
}, },
) )
chat_template: Optional[str] = field( optim_shampoo_grafting_config_type: Optional[
default=None, Literal["adam", "sgd", "adagrad"]
metadata={"help": "Chat template converting chat messages to text"}, ] = None
) optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
optim_shampoo_betas: Optional[Tuple[float, float]] = None
@dataclass @dataclass
@@ -304,13 +300,6 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
) )
@dataclass
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""
Reward config for Reward training
"""
class SchedulerMixin(Trainer): class SchedulerMixin(Trainer):
""" """
Mixin class for scheduler setup in CausalTrainer. Mixin class for scheduler setup in CausalTrainer.
@@ -408,10 +397,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
def __init__( def __init__(
self, self,
*_args, *_args,
num_epochs=1,
bench_data_collator=None, bench_data_collator=None,
eval_data_collator=None, eval_data_collator=None,
**kwargs, **kwargs,
): ):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs) super().__init__(*_args, **kwargs)
@@ -441,7 +432,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
"ao_adamw_8bit", "ao_adamw_8bit",
"ao_adamw_4bit", "ao_adamw_4bit",
"ao_adamw_fp8", "ao_adamw_fp8",
"adopt_adamw", "shampoo",
] ]
): ):
return super().create_optimizer() return super().create_optimizer()
@@ -476,14 +467,110 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if self.args.loraplus_lr_ratio is not None: if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr( loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6 self.args, "loraplus_lr_embedding", None
) )
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model, opt_model,
optimizer_cls, optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio, optimizer_kwargs,
loraplus_lr_embedding=loraplus_lr_embedding, loraplus_lr_ratio,
**optimizer_kwargs, loraplus_lr_embedding,
)
elif self.args.alternate_optimizer == "shampoo":
from distributed_shampoo.distributed_shampoo import DistributedShampoo
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdamGraftingConfig,
CommunicationDType,
DDPShampooConfig,
FSDPShampooConfig,
PrecisionConfig,
SGDGraftingConfig,
)
from distributed_shampoo.utils.shampoo_fsdp_utils import (
compile_fsdp_parameter_metadata,
)
# parse args.optim_args
optim_args = {}
if self.args.optim_args:
for mapping in self.args.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optim_args[key] = value
optim_args["betas"] = self.args.optim_shampoo_betas
if "max_preconditioner_dim" in optim_args:
optim_args["max_preconditioner_dim"] = int(
optim_args["max_preconditioner_dim"]
)
if "precondition_frequency" in optim_args:
optim_args["precondition_frequency"] = int(
optim_args["precondition_frequency"]
)
if "use_decoupled_weight_decay" in optim_args:
optim_args["use_decoupled_weight_decay"] = bool(
optim_args["use_decoupled_weight_decay"]
)
if isinstance(optim_args["epsilon"], str):
optim_args["epsilon"] = float(optim_args["epsilon"])
optim_args["lr"] = self.args.learning_rate
optim_args["weight_decay"] = self.args.weight_decay
if "epsilon" in self.args.optim_shampoo_grafting_config_kwargs:
if isinstance(
self.args.optim_shampoo_grafting_config_kwargs["epsilon"], str
):
self.args.optim_shampoo_grafting_config_kwargs[
"epsilon"
] = float(
self.args.optim_shampoo_grafting_config_kwargs["epsilon"]
)
if self.args.optim_shampoo_grafting_config_type == "adam":
grafting_config = AdamGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs
)
elif self.args.optim_shampoo_grafting_config_type == "sgd":
grafting_config = SGDGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs
)
elif self.args.optim_shampoo_grafting_config_type == "adagrad":
grafting_config = AdaGradGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs
)
distributed_config = None
if self.args.world_size > 1:
if self.args.fsdp and self.args.fsdp_config:
distributed_config = FSDPShampooConfig(
param_to_metadata=compile_fsdp_parameter_metadata(
self.model_wrapped
)
)
else:
distributed_config = DDPShampooConfig(
communication_dtype=CommunicationDType.BF16,
num_trainers_per_group=self.args.world_size,
communicate_params=False,
)
precision_config = None
if self.args.bf16:
precision_config = PrecisionConfig(
computation_dtype=torch.bfloat16,
factor_matrix_dtype=torch.bfloat16,
inv_factor_matrix_dtype=torch.bfloat16,
filtered_grad_dtype=torch.bfloat16,
momentum_dtype=torch.bfloat16,
grafting_state_dtype=torch.bfloat16,
)
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
DistributedShampoo(
optimizer_grouped_parameters,
grafting_config=grafting_config,
distributed_config=distributed_config,
precision_config=precision_config,
**optim_args,
)
) )
elif self.args.alternate_optimizer == "optimi_adamw": elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW from optimi import AdamW
@@ -511,14 +598,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs) AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
) )
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
)
)
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -681,9 +760,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return DataLoader(bench_dataset, **dataloader_params) return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss( def compute_loss(self, model, inputs, return_outputs=False):
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
# use one's weighted cross entropy loss calc # use one's weighted cross entropy loss calc
# if self.args.sample_packing: # if self.args.sample_packing:
# labels = inputs.pop("labels") # labels = inputs.pop("labels")
@@ -691,18 +768,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha: if self.args.orpo_alpha:
return self.orpo_compute_loss( return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
model, return super().compute_loss(model, inputs, return_outputs=return_outputs)
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
@staticmethod @staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
@@ -798,13 +865,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
).squeeze(2) ).squeeze(2)
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1) return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
def orpo_compute_loss( def orpo_compute_loss(self, model, inputs, return_outputs=False):
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs, inputs,
label_pad_token=-100, label_pad_token=-100,
@@ -910,13 +971,17 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items(): for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value) self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, **kwargs): def _save_checkpoint(self, model, trial, metrics=None):
# make sure the checkpoint dir exists, since trainer is flakey # make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial) run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder) output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs) try:
return super()._save_checkpoint(model, trial, metrics=metrics)
except NotImplementedError as exc:
LOG.warning(f"Failed to save checkpoint: {exc}")
return None
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):
@@ -931,7 +996,6 @@ class AxolotlMambaTrainer(AxolotlTrainer):
model, model,
inputs, inputs,
return_outputs=False, # pylint: disable=unused-argument return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
): ):
input_ids = inputs.pop("input_ids") input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits lm_logits = model(input_ids).logits
@@ -1016,9 +1080,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model, opt_model,
optimizer_cls, optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio, optimizer_kwargs,
loraplus_lr_embedding=loraplus_lr_embedding, loraplus_lr_ratio,
**optimizer_kwargs, loraplus_lr_embedding,
) )
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
@@ -1038,46 +1102,19 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs) return super().push_to_hub(*args, **kwargs)
@staticmethod
def tokenize_row( def tokenize_row(
features, self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict: ) -> Dict:
res = DPOTrainer.tokenize_row( res = super().tokenize_row(feature, model=model)
features, if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
for key in res.keys(): for key in res.keys():
res[key] = res[key][1:] res[key] = res[key][1:]
if processing_class.bos_token and processing_class.bos_token_id is not None:
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res return res
def training_step( def training_step(
self, self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor: ) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) loss: torch.Tensor = super().training_step(model, inputs)
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return loss return loss
@@ -1107,14 +1144,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"] tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """
Base class for trainer builder Base class for trainer builder
@@ -1125,11 +1154,10 @@ class TrainerBuilderBase(abc.ABC):
_model_ref = None _model_ref = None
_peft_config = None _peft_config = None
def __init__(self, cfg, model, tokenizer, processor=None): def __init__(self, cfg, model, tokenizer):
self.cfg = cfg self.cfg = cfg
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor = processor
# in case the model supports tagging, add the axolotl tag. # in case the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls # This makes sure the tag is correctly pushed even if a user calls
@@ -1175,49 +1203,26 @@ class TrainerBuilderBase(abc.ABC):
def get_callbacks(self) -> List[TrainerCallback]: def get_callbacks(self) -> List[TrainerCallback]:
callbacks = [] callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.use_wandb: if self.cfg.use_wandb:
callbacks.append( callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.use_mlflow and is_mlflow_available(): if self.cfg.use_mlflow and is_mlflow_available():
from transformers.integrations.integration_utils import MLflowCallback
from axolotl.utils.callbacks.mlflow_ import ( from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback, SaveAxolotlConfigtoMlflowCallback,
) )
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
MLflowCallback,
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append( callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
) )
return callbacks return callbacks
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
""" """
Callbacks added after the trainer is created, usually b/c these need access to the trainer Callbacks added after the trainer is created, usually b/c these need access to the trainer
""" """
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs): def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO # TODO
@@ -1263,7 +1268,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0: if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory( LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb" trainer, self.tokenizer, "wandb"
@@ -1278,11 +1283,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer, self.tokenizer, "mlflow" trainer, self.tokenizer, "mlflow"
) )
callbacks.append(LogPredictionCallback(self.cfg)) callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval: if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
@@ -1300,18 +1300,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer)) callbacks.append(lisa_callback_factory(trainer))
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks return callbacks
def _get_trainer_cls(self): def _get_trainer_cls(self):
@@ -1319,8 +1307,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return ReLoRATrainer return ReLoRATrainer
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
return AxolotlTrainer return AxolotlTrainer
def build(self, total_num_steps): def build(self, total_num_steps):
@@ -1429,15 +1415,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if not self.cfg.test_datasets and self.cfg.val_set_size == 0: if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval # no eval set, so don't eval
training_arguments_kwargs["eval_strategy"] = "no" training_arguments_kwargs["evaluation_strategy"] = "no"
elif self.cfg.eval_steps: elif self.cfg.eval_steps:
training_arguments_kwargs["eval_strategy"] = "steps" training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.eval_strategy: elif self.cfg.evaluation_strategy:
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy training_arguments_kwargs[
"evaluation_strategy"
] = self.cfg.evaluation_strategy
else: else:
# we have an eval set, but no steps defined, default to use epoch # we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["eval_strategy"] = "epoch" training_arguments_kwargs["evaluation_strategy"] = "epoch"
if self.cfg.save_steps: if self.cfg.save_steps:
training_arguments_kwargs["save_strategy"] = "steps" training_arguments_kwargs["save_strategy"] = "steps"
@@ -1540,22 +1528,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to = [] report_to = []
if self.cfg.use_wandb: if self.cfg.use_wandb:
report_to.append("wandb") report_to.append("wandb")
if self.cfg.wandb_name:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.use_mlflow: if self.cfg.use_mlflow:
report_to.append("mlflow") report_to.append("mlflow")
if self.cfg.use_tensorboard: if self.cfg.use_tensorboard:
report_to.append("tensorboard") report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_arguments_kwargs["report_to"] = report_to training_arguments_kwargs["report_to"] = report_to
if self.cfg.use_wandb: training_arguments_kwargs["run_name"] = (
training_arguments_kwargs["run_name"] = self.cfg.wandb_name self.cfg.wandb_name if self.cfg.use_wandb else None
elif self.cfg.use_mlflow: )
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = ( training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
) )
@@ -1571,6 +1552,21 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"optim_target_modules" "optim_target_modules"
] = self.cfg.optim_target_modules ] = self.cfg.optim_target_modules
# shampoo optimizer config
if self.cfg.optim_shampoo_betas:
training_arguments_kwargs[
"optim_shampoo_betas"
] = self.cfg.optim_shampoo_betas
if self.cfg.optim_shampoo_grafting_config_type:
training_arguments_kwargs[
"optim_shampoo_grafting_config_type"
] = self.cfg.optim_shampoo_grafting_config_type
if self.cfg.optim_shampoo_grafting_config_kwargs:
training_arguments_kwargs[
"optim_shampoo_grafting_config_kwargs"
] = self.cfg.optim_shampoo_grafting_config_kwargs
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[ training_arguments_kwargs[
"loraplus_lr_embedding" "loraplus_lr_embedding"
@@ -1643,11 +1639,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template,
tokenizer=self.tokenizer,
)
if self.cfg.rl == "orpo": if self.cfg.rl == "orpo":
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
@@ -1659,16 +1650,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {} trainer_kwargs = {}
if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len
# pylint: disable=duplicate-code
if self.cfg.optimizer in [ if self.cfg.optimizer in [
# pylint: disable=duplicate-code
"optimi_adamw", "optimi_adamw",
"ao_adamw_4bit", "ao_adamw_4bit",
"ao_adamw_8bit", "ao_adamw_8bit",
"ao_adamw_fp8", "ao_adamw_fp8",
"adopt_adamw", "shampoo",
]: ]:
# Set default so transformers doesn't throw # Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf" training_arguments_kwargs["optim"] = "adamw_hf"
@@ -1707,22 +1695,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"accelerator_config" "accelerator_config"
] = self.cfg.accelerator_config ] = self.cfg.accelerator_config
training_args_cls = ( training_args = (
AxolotlTrainingArguments AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
if not self.cfg.reward_model **training_arguments_kwargs,
else AxolotlRewardConfig )
)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
) )
training_args = self.hook_post_create_training_args(training_args) training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
data_collator_kwargs = { data_collator_kwargs = {
"padding": True, # True/"longest" is the default "padding": True, # True/"longest" is the default
} }
@@ -1735,37 +1714,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64 data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len
trainer_cls = self._get_trainer_cls() trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls trainer_kwargs, trainer_cls
) )
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not self.cfg.reward_model:
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not self.cfg.reward_model:
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
trainer = trainer_cls( trainer = trainer_cls(
model=self.model, model=self.model,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset, eval_dataset=self.eval_dataset,
args=training_args, args=training_args,
tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs), data_collator=self.build_collator(training_args, **data_collator_kwargs),
eval_data_collator=self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs,
**trainer_kwargs, **trainer_kwargs,
) )
trainer = self.hook_post_create_trainer(trainer) trainer = self.hook_post_create_trainer(trainer)
@@ -1799,14 +1768,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
RewardDataCollatorWithPadding,
] ]
] ]
if self.cfg.reward_model: if use_batch_sampler_collator:
collator = RewardDataCollatorWithPadding
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq collator = V2BatchSamplerDataCollatorForSeq2Seq
elif ( elif (
@@ -1817,12 +1781,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else: else:
collator = BatchSamplerDataCollatorForSeq2Seq collator = BatchSamplerDataCollatorForSeq2Seq
else: else:
if self.cfg.processor_type and self.processor: collator = DataCollatorForSeq2Seq
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
else:
collator = DataCollatorForSeq2Seq
return collator( return collator(
self.tokenizer, self.tokenizer,
@@ -1843,7 +1802,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) callbacks = []
return callbacks return callbacks
def build_training_arguments(self, total_num_steps): def build_training_arguments(self, total_num_steps):
@@ -1871,10 +1830,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.eval_dataset: if self.eval_dataset:
training_args_kwargs["eval_strategy"] = "steps" training_args_kwargs["evaluation_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps training_args_kwargs["eval_steps"] = self.cfg.eval_steps
else: else:
training_args_kwargs["eval_strategy"] = "no" training_args_kwargs["evaluation_strategy"] = "no"
if self.cfg.bf16 or self.cfg.bfloat16: if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True training_args_kwargs["bf16"] = True
@@ -1929,18 +1888,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined # default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch" training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.rl_beta: if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta training_args_kwargs["beta"] = self.cfg.rl_beta
if self.cfg.orpo_alpha: if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None: if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_cls = None
if self.cfg.rl == "simpo": if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo" training_args_kwargs["loss_type"] = "simpo"
@@ -1949,13 +1907,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None: if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl == "orpo": if self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len: if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "kto": if self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = ( training_args_kwargs["desirable_weight"] = (
@@ -1970,17 +1928,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.max_prompt_len: if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir, output_dir=self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size, per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -2001,6 +1948,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args = self.build_training_arguments(total_num_steps) training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {} dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo": if self.cfg.rl == "ipo":
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing: if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset: if self.eval_dataset:
@@ -2014,6 +1962,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl in ["dpo", "ipo"]: if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer trainer_cls = AxolotlDPOTrainer
trainer_cls_args = [self.model, self.model_ref] trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True
elif self.cfg.rl == "orpo": elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
@@ -2025,17 +1979,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
else: else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}") raise ValueError(f"Unsupported RL: {self.cfg.rl}")
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
dpo_trainer_kwargs["processing_class"] = self.tokenizer
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
dpo_trainer = trainer_cls( dpo_trainer = trainer_cls(
*trainer_cls_args, *trainer_cls_args,
args=training_args, args=training_args,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
tokenizer=self.tokenizer,
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
**dpo_trainer_kwargs, **dpo_trainer_kwargs,
) )
@@ -2057,11 +2005,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
""" """
def get_callbacks(self): def get_callbacks(self):
callbacks = super().get_callbacks() callbacks = []
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) callbacks = []
return callbacks return callbacks
def build(self, total_num_steps): def build(self, total_num_steps):

View File

@@ -18,10 +18,9 @@ Plugins can be used to integrate third-party models, modify the training process
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods. To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
""" """
import collections
import importlib import importlib
import logging import logging
from typing import OrderedDict from typing import List
class BasePlugin: class BasePlugin:
@@ -48,7 +47,7 @@ class BasePlugin:
Initializes the BasePlugin. Initializes the BasePlugin.
""" """
def register(self, cfg): # pylint: disable=unused-argument def register(self, cfg):
""" """
Registers the plugin with the given configuration. Registers the plugin with the given configuration.
@@ -64,7 +63,7 @@ class BasePlugin:
Returns a pydantic model for the plugin's input arguments. Returns a pydantic model for the plugin's input arguments.
""" """
def pre_model_load(self, cfg): # pylint: disable=unused-argument def pre_model_load(self, cfg):
""" """
Performs actions before the model is loaded. Performs actions before the model is loaded.
@@ -75,7 +74,7 @@ class BasePlugin:
None None
""" """
def post_model_load(self, cfg, model): # pylint: disable=unused-argument def post_model_load(self, cfg, model):
""" """
Performs actions after the model is loaded. Performs actions after the model is loaded.
@@ -87,7 +86,7 @@ class BasePlugin:
None None
""" """
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument def pre_lora_load(self, cfg, model):
""" """
Performs actions before LoRA weights are loaded. Performs actions before LoRA weights are loaded.
@@ -99,7 +98,7 @@ class BasePlugin:
None None
""" """
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument def post_lora_load(self, cfg, model):
""" """
Performs actions after LoRA weights are loaded. Performs actions after LoRA weights are loaded.
@@ -111,7 +110,7 @@ class BasePlugin:
None None
""" """
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument def create_optimizer(self, cfg, trainer):
""" """
Creates and returns an optimizer for training. Creates and returns an optimizer for training.
@@ -123,9 +122,7 @@ class BasePlugin:
object: The created optimizer. object: The created optimizer.
""" """
def create_lr_scheduler( def create_lr_scheduler(self, cfg, trainer, optimizer):
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
""" """
Creates and returns a learning rate scheduler. Creates and returns a learning rate scheduler.
@@ -138,9 +135,9 @@ class BasePlugin:
object: The created learning rate scheduler. object: The created learning rate scheduler.
""" """
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument def add_callbacks_pre_trainer(self, cfg, model):
""" """
setup callbacks before creating the trainer. Adds callbacks to the trainer before training.
Parameters: Parameters:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
@@ -149,45 +146,17 @@ class BasePlugin:
Returns: Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs List[callable]: A list of callback functions to be added to the TrainingArgs
""" """
return []
def add_callbacks_post_trainer( def add_callbacks_post_trainer(self, cfg, trainer):
self, cfg, trainer
): # pylint: disable=unused-argument
""" """
Adds callbacks to the trainer after creating the trainer. Adds callbacks to the trainer after training.
This is useful for callbacks that require access to the model or trainer.
Parameters: Parameters:
cfg (dict): The configuration for the plugin. cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training. trainer (object): The trainer object for training.
Returns: Returns:
List[callable]: A list of callback functions to be added List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def post_train(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after training is complete.
Parameters:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
"""
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
""" """
@@ -235,7 +204,7 @@ class PluginManager:
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
""" """
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict() plugins: List[BasePlugin] = []
_instance = None _instance = None
@@ -245,7 +214,7 @@ class PluginManager:
""" """
if cls._instance is None: if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls) cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins = collections.OrderedDict() cls._instance.plugins: List[BasePlugin] = []
return cls._instance return cls._instance
@staticmethod @staticmethod
@@ -273,7 +242,7 @@ class PluginManager:
""" """
try: try:
plugin = load_plugin(plugin_name) plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin self.plugins.append(plugin)
except ImportError: except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}") logging.error(f"Failed to load plugin: {plugin_name}")
@@ -285,7 +254,7 @@ class PluginManager:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.' list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
""" """
input_args = [] input_args = []
for plugin in self.plugins.values(): for plugin in self.plugins:
input_args_from_plugin = plugin.get_input_args() input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None: if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin) input_args.append(input_args_from_plugin)
@@ -301,7 +270,7 @@ class PluginManager:
Returns: Returns:
None None
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
plugin.pre_model_load(cfg) plugin.pre_model_load(cfg)
def post_model_load(self, cfg, model): def post_model_load(self, cfg, model):
@@ -315,7 +284,7 @@ class PluginManager:
Returns: Returns:
None None
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
plugin.post_model_load(cfg, model) plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model): def pre_lora_load(self, cfg, model):
@@ -329,7 +298,7 @@ class PluginManager:
Returns: Returns:
None None
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
plugin.pre_lora_load(cfg, model) plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model): def post_lora_load(self, cfg, model):
@@ -343,7 +312,7 @@ class PluginManager:
Returns: Returns:
None None
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
plugin.post_lora_load(cfg, model) plugin.post_lora_load(cfg, model)
def create_optimizer(self, cfg, trainer): def create_optimizer(self, cfg, trainer):
@@ -357,7 +326,7 @@ class PluginManager:
Returns: Returns:
object: The created optimizer, or None if none was found. object: The created optimizer, or None if none was found.
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
optimizer = plugin.create_optimizer(cfg, trainer) optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None: if optimizer is not None:
return optimizer return optimizer
@@ -375,7 +344,7 @@ class PluginManager:
Returns: Returns:
object: The created learning rate scheduler, or None if none was found. object: The created learning rate scheduler, or None if none was found.
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer) scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None: if scheduler is not None:
return scheduler return scheduler
@@ -393,10 +362,8 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs. List[callable]: A list of callback functions to be added to the TrainingArgs.
""" """
callbacks = [] callbacks = []
for plugin in self.plugins.values(): for plugin in self.plugins:
plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model) callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
if plugin_callbacks: # if the plugin returned a list of callbacks
callbacks.extend(plugin_callbacks)
return callbacks return callbacks
def add_callbacks_post_trainer(self, cfg, trainer): def add_callbacks_post_trainer(self, cfg, trainer):
@@ -411,22 +378,6 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs. List[callable]: A list of callback functions to be added to the TrainingArgs.
""" """
callbacks = [] callbacks = []
for plugin in self.plugins.values(): for plugin in self.plugins:
plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer) callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
if plugin_callbacks:
callbacks.extend(plugin_callbacks)
return callbacks return callbacks
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)

View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,13 +0,0 @@
# Grokfast Optimizer
See https://github.com/ironjr/grokfast
### Usage
```yaml
plugins:
- axolotl.integrations.grokfast.GrokfastPlugin
grokfast_alpha: 2.0
grokfast_lamb: 0.98
```

View File

@@ -1,50 +0,0 @@
"""
Grokfast plugin for Axolotl
"""
import logging
from transformers.trainer_callback import TrainerCallback
from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
from .optimizer import gradfilter_ema
LOG = logging.getLogger("axolotl.integrations.grokfast")
class GrokfastCallbackHandler(TrainerCallback):
"""
Transformer trainer callbacks for Grokfast
"""
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
super().__init__(*args_, **kwargs)
self.grads = None
self.alpha = alpha
self.lamb = lamb
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
self.grads = None
def on_pre_optimizer_step(
self, args_, state, control, **kwargs
): # pylint: disable=unused-argument
model = kwargs.pop("model")
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
return control
class GrokfastPlugin(BasePlugin):
"""
Plugin for Grokfast optimizer integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.grokfast.GrokfastArgs"
def add_callbacks_post_trainer(self, cfg, trainer):
LOG.info("Adding Grokfast callback to the trainer")
callback = GrokfastCallbackHandler(
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
)
return [callback]

View File

@@ -1,15 +0,0 @@
"""
config args for grokfast plugin
"""
from typing import Optional
from pydantic import BaseModel
class GrokfastArgs(BaseModel):
"""
Input args for Grokfast optimizer.
"""
grokfast_alpha: Optional[float] = 0.98
grokfast_lamb: Optional[float] = 2.0

View File

@@ -1,63 +0,0 @@
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
# Reference: https://github.com/ironjr/grokfast
# pylint: skip-file
from collections import deque
from typing import Dict, Literal, Optional
import torch
import torch.nn as nn
def gradfilter_ma(
m: nn.Module,
grads: Optional[Dict[str, deque]] = None,
window_size: int = 100,
lamb: float = 5.0,
filter_type: Literal["mean", "sum"] = "mean",
warmup: bool = True,
trigger: bool = False, # For ablation study.
) -> Dict[str, deque]:
if grads is None:
grads = {
n: deque(maxlen=window_size)
for n, p in m.named_parameters()
if p.requires_grad and p.grad is not None
}
for n, p in m.named_parameters():
if p.requires_grad and p.grad is not None:
grads[n].append(p.grad.data.detach()) # .cpu())
# Modify the gradients.
if not warmup or len(grads[n]) == window_size and not trigger:
if filter_type == "mean":
avg = sum(grads[n]) / len(grads[n])
elif filter_type == "sum":
avg = sum(grads[n])
else:
raise ValueError(f"Unrecognized filter_type {filter_type}")
p.grad.data = p.grad.data + avg * lamb
return grads
def gradfilter_ema(
m: nn.Module,
grads: Optional[Dict[str, torch.Tensor]] = None,
alpha: float = 0.98,
lamb: float = 2.0,
) -> Dict[str, torch.Tensor]:
if grads is None:
grads = {
n: p.grad.data.detach()
for n, p in m.named_parameters()
if p.requires_grad and p.grad is not None
}
for n, p in m.named_parameters():
if p.requires_grad and p.grad is not None:
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
p.grad.data = p.grad.data + grads[n] * lamb
return grads

View File

@@ -18,24 +18,20 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training. Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight. It is designed to be performant, correct, and light-weight.
""" """
import inspect
import logging import logging
import sys import sys
from functools import partial
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.liger")
class LigerPlugin(BasePlugin): class LigerPlugin(BasePlugin):
""" """
@@ -46,31 +42,59 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs" return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg): def pre_model_load(self, cfg):
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: if cfg.model_config_type == "llama":
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] from liger_kernel.transformers.model.llama import (
liger_fn_sig = inspect.signature(apply_liger_fn) lce_forward as llama_lce_forward,
kwargs = {} )
if "rope" in liger_fn_sig.parameters: from transformers.models.llama import modeling_llama
kwargs["rope"] = cfg.liger_rope
if "cross_entropy" in liger_fn_sig.parameters: if cfg.liger_rope:
kwargs["cross_entropy"] = cfg.liger_cross_entropy modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
if "fused_linear_cross_entropy" in liger_fn_sig.parameters: if cfg.liger_rms_norm:
kwargs[ modeling_llama.LlamaRMSNorm = LigerRMSNorm
"fused_linear_cross_entropy" if cfg.liger_swiglu:
] = cfg.liger_fused_linear_cross_entropy modeling_llama.LlamaMLP = LigerSwiGLUMLP
if "rms_norm" in liger_fn_sig.parameters: if cfg.liger_cross_entropy:
kwargs["rms_norm"] = cfg.liger_rms_norm modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
if "layer_norm" in liger_fn_sig.parameters: elif cfg.liger_fused_linear_cross_entropy:
kwargs["layer_norm"] = cfg.liger_layer_norm modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
if "geglu" in liger_fn_sig.parameters:
kwargs["geglu"] = cfg.liger_glu_activation elif cfg.model_config_type == "mistral":
elif "swiglu" in liger_fn_sig.parameters: from liger_kernel.transformers.model.mistral import (
kwargs["swiglu"] = cfg.liger_glu_activation lce_forward as mistral_lce_forward,
with zero_only(): )
LOG.info( from transformers.models.mistral import modeling_mistral
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
if cfg.liger_rope:
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma
if cfg.liger_rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
) )
apply_liger_fn(**kwargs) if cfg.liger_swiglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
elif cfg.model_config_type == "jamba": elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba from transformers.models.jamba import modeling_jamba
@@ -80,14 +104,30 @@ class LigerPlugin(BasePlugin):
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm: if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation: if cfg.liger_swiglu:
modeling_jamba.JambaMLP = LigerSwiGLUMLP modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy: if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "qwen2":
from liger_kernel.transformers.model.qwen2 import (
lce_forward as qwen2_lce_forward,
)
from transformers.models.qwen2 import modeling_qwen2
if cfg.liger_rope:
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
elif cfg.model_config_type == "deepseek_v2": elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
@@ -106,11 +146,44 @@ class LigerPlugin(BasePlugin):
logging.warning("Fused liger_rope is not supported for DeepseekV2.") logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm: if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation: if cfg.liger_swiglu:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy: if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "gemma2":
from transformers.models.gemma2 import modeling_gemma2
if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma 2."
)
elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3
if cfg.liger_rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

View File

@@ -15,12 +15,9 @@
""" """
Module for handling LIGER input arguments. Module for handling LIGER input arguments.
""" """
import logging
from typing import Optional from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel
LOG = logging.getLogger("axolotl.integrations.liger.args")
class LigerArgs(BaseModel): class LigerArgs(BaseModel):
@@ -30,24 +27,6 @@ class LigerArgs(BaseModel):
liger_rope: Optional[bool] = None liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None liger_fused_linear_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_deprecated_swiglu(cls, data):
if data.get("liger_swiglu") is not None:
if data.get("liger_glu_activation") is not None:
raise ValueError(
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
)
LOG.warning(
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
"Please use 'liger_glu_activation' instead."
)
data["liger_glu_activation"] = data.pop("liger_swiglu")
return data

View File

@@ -1,13 +0,0 @@
# LM Eval Harness
### Usage
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
```

View File

@@ -1,42 +0,0 @@
"""
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from datetime import datetime
from axolotl.integrations.base import BasePlugin
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
class LMEvalPlugin(BasePlugin):
"""
Plugin for LM Evaluation Harness integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg):
tasks = ",".join(cfg.lm_eval_tasks)
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
output_path = cfg.output_dir
output_path += "" if cfg.output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
subprocess.run( # nosec
[
"lm_eval",
"--model",
"hf",
"--model_args",
f"pretrained={cfg.output_dir}{fa2}{dtype}",
"--tasks",
tasks,
"--batch_size",
str(cfg.lm_eval_batch_size),
"--output_path",
output_path,
],
check=True,
)

View File

@@ -1,15 +0,0 @@
"""
Module for handling lm eval harness input arguments.
"""
from typing import List, Optional
from pydantic import BaseModel
class LMEvalArgs(BaseModel):
"""
Input args for lm eval harness
"""
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8

133
src/axolotl/loraplus.py Normal file
View File

@@ -0,0 +1,133 @@
"""Module for LoRA+"""
# MIT License
#
# Copyright (c) 2024 nikhil-ghosh-berkeley
# https://github.com/nikhil-ghosh-berkeley/loraplus
import logging
from functools import reduce
from peft.tuners import lora
from torch import nn
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
LOG = logging.getLogger("axolotl.loraplus")
def get_module(name, opt_model):
"""
Retrieve a module from a model using its parameter name.
Args:
name (str): Full name of the parameter, typically including module path.
opt_model (torch.nn.Module): The model from which to retrieve the module.
Returns:
Module corresponding to the given name.
"""
parent_idx = 2 if "lora" in name else 1
module_names = name.split(sep=".")[:-parent_idx]
module = reduce(getattr, module_names, opt_model)
return module
def create_loraplus_optimizer(
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding=None,
):
"""
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
Args:
opt_model (torch.nn.Module): The model for which the optimizer is being created.
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
Returns:
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
"""
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
if loraplus_lr_embedding is None:
loraplus_lr_embedding = 1e-6
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
module = get_module(name, opt_model)
if isinstance(module, lora.Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param
assigned_param_groups = ""
for group, group_params in param_groups.items():
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
LOG.info(assigned_param_groups)
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
return optimizer

View File

@@ -1,229 +0,0 @@
"""
Monkeypatch for Vision Llama for FA2 support
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple
import torch
from flash_attn.flash_attn_interface import flash_attn_func
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
from transformers.models.mllama.modeling_mllama import (
MllamaTextCrossAttention,
MllamaTextSelfAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import is_flash_attn_greater_or_equal_2_10
class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention):
"""
Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and
implements the forward pass using Flash Attention for improved performance.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check if flash attention version is greater or equal to 2.1
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[ # pylint: disable=unused-argument
torch.Tensor
] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
query_states = self.q_norm(query_states)
if cross_attention_states is not None:
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
key_states = self.k_norm(key_states)
if past_key_value is not None:
key_states, value_states = past_key_value.update(
key_states,
value_states,
self.layer_idx,
{"cache_position": cache_position},
)
elif cache_position[0] != 0:
key_states, value_states = (
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
)
else:
raise ValueError(
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Apply Flash Attention
dropout_rate = self.dropout if self.training else 0.0
output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=dropout_rate,
softmax_scale=None,
causal=False,
return_attn_probs=output_attentions,
)
attn_output = output.contiguous().view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class MllamaTextSelfFlashAttention2(MllamaTextSelfAttention):
"""
Mllama flash self-attention module. This module inherits from `MllamaTextSelfAttention` and
implements the forward pass using Flash Attention for improved performance.
"""
def __init__(self, config: MllamaTextConfig, layer_idx: int, *args, **kwargs):
super().__init__(config, layer_idx, *args, **kwargs)
# Check if flash attention version is greater or equal to 2.1
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
past_key_value=None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x num_heads x head_dim
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
# Handle potential silent casting to float32
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = (
self.config._pre_quantization_dtype # pylint: disable=protected-access
)
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=True,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def patch_mllama():
from transformers.models.mllama.modeling_mllama import (
MLLAMA_TEXT_ATTENTION_CLASSES,
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES,
MLLAMA_VISION_ATTENTION_CLASSES,
MllamaPreTrainedModel,
)
MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access
True
)
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
"flash_attention_2"
] = MllamaTextCrossFlashAttention2
# fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES[
"flash_attention_2"
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]

View File

@@ -0,0 +1,231 @@
"""
monkeypatch to add a get_turns method
"""
import logging
from typing import Generator, Tuple
from fastchat.conversation import SeparatorStyle
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
def get_prompt(self) -> str:
ret = ""
for role, msg in self.get_turns():
ret += role + msg
return ret
def get_turns( # pylint: disable=too-many-return-statements
self,
) -> Generator[Tuple[str, str], None, None]:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message + seps[i % 2]
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ": ", "" # must be end with a space
return
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
yield "", "" if system_prompt == "" else system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role, message + self.sep
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role, message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.RWKV:
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message.replace("\r\n", "\n").replace(
"\n\n", "\n"
) + "\n\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
if self.system_message:
if self.messages:
# For llama, the system message is incorporated into the first human instruction
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt += first_msg
self.messages.pop(0)
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
if (i % 2 == 0 and not self.system_message) or (
i % 2 != 0 and self.system_message
):
role = "<s> " + role
yield role + " ", message
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
contains_sys_msg = False
if self.system_message:
contains_sys_msg = True
if self.messages:
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt = self.system_template.format(
system_message=" " + self.system_message
)
system_prompt += first_msg
self.messages.pop(0)
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message and i == 0 and not contains_sys_msg:
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
elif message:
yield role + " ", message
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.LLAMA3:
if self.system_message:
# For llama3, the system message is NOT incorporated into the first human instruction
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
else:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
return
if self.sep_style == SeparatorStyle.GEMMA:
if self.system_message:
raise ValueError("Gemma chat template does not support system messages")
for i, (role, message) in enumerate(self.messages):
prefix = "<bos>" if i == 0 else ""
message_str = message if message else ""
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
return
if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self.name == "chatglm2" else 0
if system_prompt:
yield "", system_prompt + self.sep
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
if message:
yield f"{role}", f"{message}{self.sep}"
else:
yield f"{role}", ""
return
if self.sep_style == SeparatorStyle.CHATML:
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep + "\n"
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.CHATGLM3:
if self.system_message:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role + "\n", " " + message
else:
yield role
return
if self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
prefix = "<s>" if i % 2 == 0 else ""
if message:
yield prefix + role + ":", message + seps[i % 2] + "\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
suffix = "\n\n" if i % 2 == 1 else ""
yield role + ":\n", message + seps[i % 2] + suffix
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.PHOENIX:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role + ": ", "<s>" + message + "</s>"
else:
yield role + ": " + "<s>", ""
return
if self.sep_style == SeparatorStyle.ROBIN:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ":\n", message + self.sep
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.FALCON_CHAT:
if self.system_message:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def add_get_turns_to_conversation():
import fastchat.conversation
fastchat.conversation.Conversation.get_turns = get_turns
fastchat.conversation.Conversation.get_prompt = get_prompt

View File

@@ -22,6 +22,7 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
@@ -43,19 +44,7 @@ except ImportError:
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def is_xformers_available() -> bool:
try:
import xformers # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
def is_xformers_swiglu_available() -> bool: def is_xformers_swiglu_available() -> bool:
if not is_xformers_available():
return False
from xformers.ops.common import get_xformers_operator from xformers.ops.common import get_xformers_operator
try: try:
@@ -68,11 +57,6 @@ def is_xformers_swiglu_available() -> bool:
def replace_llama_mlp_with_swiglu(model): def replace_llama_mlp_with_swiglu(model):
if is_xformers_swiglu_available():
from axolotl.monkeypatch.xformers_ import FusedMLP
else:
raise RuntimeError("xformers SwiGLU not available for this environment")
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, LlamaMLP): if isinstance(module, LlamaMLP):
mlp = FusedMLP( mlp = FusedMLP(
@@ -197,6 +181,49 @@ class FusedAttention(LlamaAttention):
set_module_name(model, name, new_attn) set_module_name(model, name, new_attn)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask( def _prepare_decoder_attention_mask(

View File

@@ -1,5 +1,4 @@
"""multipack patching for v2 of sample packing""" """multipack patching for v2 of sample packing"""
import importlib import importlib
import transformers import transformers
@@ -11,7 +10,6 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mllama_text_model",
"llama", "llama",
"mistral", "mistral",
"mixtral", "mixtral",
@@ -20,7 +18,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"falcon", "falcon",
"phi", "phi",
"phi3", "phi3",
"phimoe",
"gemma", "gemma",
"gemma2", "gemma2",
"gemmoe", "gemmoe",
@@ -29,28 +26,71 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
] ]
def patch_for_multipack(model_type, model_name=None, has_remote_code=False): def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
if has_remote_code: if model_type == "gemmoe":
patch_remote(model_name) patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif hasattr(transformers, "modeling_flash_attention_utils"): elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
) )
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return
if model_type == "mixtral" and is_deepspeed_zero3_enabled(): # retain for legacy
patch_mixtral_moe_forward_zero3() if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mistral":
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2_moe":
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
def patch_remote(model_name): def patch_remote(model_name, config_name, modeling_name):
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_* to be available # we need to load the model here in order for modeling_* to be available
with init_empty_weights(): with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
parts = model_config.__class__.__module__.split(".") module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
parts[-1] = parts[-1].replace("configuration_", "modeling_", 1)
module_name = ".".join(parts)
modeling_arch = importlib.import_module(module_name) modeling_arch = importlib.import_module(module_name)
if hasattr(modeling_arch, "_get_unpad_data"): modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
modeling_arch._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
def reset_optimizer( def reset_optimizer(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
*, *,
reset_params: List[str], # where str is the key to a torch.nn.Parameter reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str], optimizer_state_keys: list[str],
prune_ratio: float = 0.9, prune_ratio: float = 0.9,
): ):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)

View File

@@ -16,7 +16,6 @@
# This code is based off the following work: # This code is based off the following work:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """ """ PyTorch StableLM Epoch model. """
import importlib import importlib
import math import math

View File

@@ -1,83 +0,0 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/34645
"""
import inspect
from accelerate.logging import get_logger
from transformers.trainer import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = get_logger("axolotl.monkeypatch.trainer_fsdp_grad_accumulation")
ORIGINAL_CONTEXT_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
else contextlib.nullcontext
)
"""
PATCHED_CONTEXT_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
train_loop = get_training_loop_code()
train_loop, _ = detab_code(train_loop)
return ORIGINAL_CONTEXT_CODE in train_loop
def patch_training_loop_for_fsdp_grad_accum():
"""
monkeypatch for fixing the training loop for FSDP gradient accumulation
"""
train_loop = get_training_loop_code()
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
train_loop
)
train_loop, _ = detab_code(train_loop)
assert (
ORIGINAL_CONTEXT_CODE in train_loop
), "Original _inner_training_loop code not found"
train_loop = train_loop.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
train_loop = train_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in train_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(train_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop", main_process_only=True)
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -16,6 +16,26 @@ from transformers.models.llama.modeling_llama import (
LOG = get_logger("axolotl.monkeypatch.unsloth") LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
"""
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
"""
ORIGINAL_QKV_CODE = """ ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
@@ -60,6 +80,12 @@ def get_forward_code() -> str:
return forward return forward
def check_cel_is_patchable() -> bool:
forward = get_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_CEL_CODE in forward
def get_self_attn_code() -> str: def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward) forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward return forward
@@ -72,31 +98,48 @@ def check_self_attn_is_patchable() -> bool:
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None: def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
def UnslothForCausalLMLoss( # pylint: disable=invalid-name
logits,
labels,
vocab_size: int, # pylint: disable=unused-argument
num_items_in_batch: int = None,
ignore_index: int = -100, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
)
return loss
if model_type == "llama": if model_type == "llama":
from transformers.loss import loss_utils forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment] forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
else: else:
raise ValueError("Unsupported model type") raise ValueError("Unsupported model type")

View File

@@ -1,51 +0,0 @@
"""
Fused MLP layer for incrementally improved training efficiency
"""
import torch
from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import set_module_name
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)

View File

@@ -9,12 +9,8 @@ from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies") LOG = logging.getLogger("axolotl.prompt_strategies")
def load(strategy, tokenizer, cfg, ds_cfg, processor=None): def load(strategy, tokenizer, cfg, ds_cfg):
try: try:
if strategy == "messages":
from .messages import load as messages_load
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
load_fn = "load" load_fn = "load"
if strategy.split(".")[-1].startswith("load_"): if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1] load_fn = strategy.split(".")[-1]
@@ -28,12 +24,9 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
sig = inspect.signature(func) sig = inspect.signature(func)
if "ds_cfg" in sig.parameters: if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs) return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError: except ModuleNotFoundError:
return None return None
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc return None
return None

View File

@@ -1,10 +0,0 @@
### example yaml
```yaml
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
```

View File

@@ -1,35 +0,0 @@
"""Module to load prompt strategies."""
import importlib
import inspect
import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
def load(strategy, tokenizer, cfg, ds_cfg):
# pylint: disable=duplicate-code
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(
f".{strategy}", "axolotl.prompt_strategies.bradley_terry"
)
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
else:
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
return None

View File

@@ -1,102 +0,0 @@
"""
Bradley-Terry model with chat template prompt strategy.
"""
import logging
from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
LOG.setLevel(logging.INFO)
class BTChatTemplateStrategy(ChatTemplateStrategy):
"""
Bradley-Terry reward model pairwise chat template prompt strategy.
"""
def tokenize_prompt(self, prompt):
"""
:param prompt: the actual row of data from the underlying dataset
:return:
"""
self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super().tokenize_prompt(prompt)
return {
"input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"],
"labels_chosen": 1.0,
"input_ids_rejected": rejected_tokenized["input_ids"],
"attention_mask_rejected": rejected_tokenized["attention_mask"],
"labels_rejected": 0.0,
}
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", None
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1
if not cfg.reward_model
else cfg.sequence_len,
}
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", None),
}
strategy = BTChatTemplateStrategy(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -1,27 +0,0 @@
"""
chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template
"""
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
prompt = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn

View File

@@ -5,11 +5,9 @@ HF Chat Templates prompt strategy
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.chat_templates import chat_templates
# Configure the logger # Configure the logger
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -22,7 +20,6 @@ class ChatTemplatePrompter(Prompter):
def __init__( def __init__(
self, self,
tokenizer, tokenizer,
processor=None,
chat_template=None, chat_template=None,
max_length=2048, max_length=2048,
message_field_role: str = "from", message_field_role: str = "from",
@@ -47,12 +44,11 @@ class ChatTemplatePrompter(Prompter):
self.message_field_training = message_field_training self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail self.message_field_training_detail = message_field_training_detail
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor: ProcessorMixin = processor
self.chat_template = chat_template self.chat_template = chat_template
self.max_length = max_length self.max_length = max_length
self.drop_system_message = drop_system_message self.drop_system_message = drop_system_message
def build_prompt(self, conversation, add_generation_prompt=False, images=None): def build_prompt(self, conversation, add_generation_prompt=False):
turns = [ turns = [
{ {
"role": self.roles[t[self.message_field_role]], "role": self.roles[t[self.message_field_role]],
@@ -65,28 +61,6 @@ class ChatTemplatePrompter(Prompter):
if self.drop_system_message and turns[0]["role"] == "system": if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:] turns = turns[1:]
if self.processor:
text = self.processor.apply_chat_template(
turns,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
batch = self.processor(
text=text,
images=images,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
)
# workaround since processor works in batches instead of single examples
for k, val in batch.items():
if k in ["pixel_values"]:
batch[k] = val.tolist()
else:
batch[k] = val.squeeze().tolist()
return batch
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
turns, turns,
truncation=True, truncation=True,
@@ -217,7 +191,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else [] self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.train_on_eos = train_on_eos self.train_on_eos = train_on_eos
self.images = "images"
@property @property
def messages(self): def messages(self):
@@ -236,21 +209,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
and not self.prompter.message_field_training_detail and not self.prompter.message_field_training_detail
): ):
turns = self.get_conversation_thread(prompt) turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt( prompt_ids = self.prompter.build_prompt(
turns[:-1], turns[:-1], add_generation_prompt=True
add_generation_prompt=True,
images=images,
) )
tokenized_res = self.prompter.build_prompt(turns, images=images) input_ids = self.prompter.build_prompt(turns)
tokenized_prompt = {}
if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
tokenized_prompt["input_ids"] = input_ids
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else:
input_ids = tokenized_res["input_ids"]
tokenized_prompt = tokenized_res
if not self.train_on_inputs: if not self.train_on_inputs:
user_prompt_len = len(prompt_ids) user_prompt_len = len(prompt_ids)
@@ -258,9 +220,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
else: else:
labels = input_ids labels = input_ids
tokenized_prompt["labels"] = labels tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt return tokenized_prompt
LOG.info(self.roles_to_train)
LOG.info(self.train_on_eos)
LOG.info(self.prompter.message_field_training)
LOG.info(self.prompter.message_field_training_detail)
turns = prompt[self.messages] turns = prompt[self.messages]
input_ids = self.prompter.build_prompt(turns) input_ids = self.prompter.build_prompt(turns)
@@ -398,23 +368,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
return prompt[self.messages] return prompt[self.messages]
def get_images(self, prompt):
return prompt.get(self.images, None)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {} ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = { prompter_params = {
"tokenizer": tokenizer, "tokenizer": tokenizer,
"chat_template": chat_template_string, "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "role"), "message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "content"), "message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", None), "message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get( "message_field_training_detail": ds_cfg.get(
"message_field_training_detail", "message_field_training_detail",
@@ -424,7 +386,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
"drop_system_message": ds_cfg.get("drop_system_message", False), "drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit. # we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1, "max_length": cfg.sequence_len + 1,
"processor": processor,
} }
strategy_params = { strategy_params = {

View File

@@ -2,16 +2,15 @@
DPO prompt strategies for using tokenizer chat templates. DPO prompt strategies for using tokenizer chat templates.
""" """
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template from axolotl.utils.chat_templates import chat_templates
def default( def default(
cfg, dataset_idx=0, **kwargs cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument ): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx] ds_cfg = cfg["datasets"][dataset_idx]
chat_template_choice, chat_template_jinja = extract_chat_template_args( chat_template_str = chat_templates(cfg.chat_template)
cfg=cfg, ds_cfg=ds_cfg
)
field_messages = ds_cfg.get("field_messages", "messages") field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen") field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected") field_rejected = ds_cfg.get("field_rejected", "rejected")
@@ -31,12 +30,6 @@ def default(
role_map[source] = target role_map[source] = target
def transform_fn(sample, tokenizer=None): def transform_fn(sample, tokenizer=None):
chat_template_string = get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
messages = sample[field_messages] messages = sample[field_messages]
messages = [ messages = [
{ {
@@ -53,29 +46,28 @@ def default(
"role": role_map[sample[field_rejected][field_message_role]], "role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content], "content": sample[field_rejected][field_message_content],
} }
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
result = {} result = {}
result["prompt"] = tokenizer.apply_chat_template( result["prompt"] = tokenizer.apply_chat_template(
messages, messages,
add_generation_prompt=True, add_generation_prompt=True,
chat_template=chat_template_string, chat_template=chat_template_str,
tokenize=False, tokenize=False,
) )
result["chosen"] = tokenizer.apply_chat_template( result["chosen"] = tokenizer.apply_chat_template(
[dummy_user_message, chosen], [chosen],
add_generation_prompt=False, add_generation_prompt=False,
chat_template=chat_template_string, chat_template=chat_template_str,
tokenize=False, tokenize=False,
) )
chosen_strip_index = result["chosen"].find(chosen["content"]) chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip() result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template( result["rejected"] = tokenizer.apply_chat_template(
[dummy_user_message, rejected], [rejected],
add_generation_prompt=False, add_generation_prompt=False,
chat_template=chat_template_string, chat_template=chat_template_str,
tokenize=False, tokenize=False,
) )
rejected_strip_index = result["rejected"].find(rejected["content"]) rejected_strip_index = result["rejected"].find(rejected["content"])

View File

@@ -0,0 +1,33 @@
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = InstructShareGPTPromptTokenizingStrategy(
# pylint: disable=duplicate-code
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
def get_conversation_thread(self, prompt):
return [
{"from": "human", "value": prompt["instruction"]},
{"from": "gpt", "value": prompt["output"]},
]

View File

@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
from typing import Generator, List, Sequence from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
@dataclass @dataclass
@@ -75,7 +75,7 @@ class Llama2ChatConversation:
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy): class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
""" """
Tokenizing strategy for Llama2 prompts. Tokenizing strategy for ShareGPT prompts.
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
""" """
@@ -191,7 +191,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] # pylint: disable=R0801 conv.messages = [] # pylint: disable=R0801
for j, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
assert role == conv.roles[j % 2], ALTERNATING_ASSERTION_FAILED_ROLE assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
if sentence["value"]: if sentence["value"]:
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
yield conv yield conv

View File

@@ -1,34 +0,0 @@
"""Module to load message prompt strategies."""
import importlib
import inspect
import logging
LOG = logging.getLogger("axolotl.prompt_strategies.messages")
def load(tokenizer, cfg, ds_cfg, processor=None):
try:
strategy = ds_cfg.get("input_transform", "chat")
# pylint: disable=duplicate-code
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(
f".{strategy}", "axolotl.prompt_strategies.messages"
)
func = getattr(mod, load_fn)
load_kwargs = {}
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None

View File

@@ -1,84 +0,0 @@
"""
Chat dataset wrapping strategy for new internal messages representations
"""
from typing import Any, Callable, Dict, Optional
from axolotl.core.datasets.chat import TokenizedChatDataset
from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy):
"""
Chat dataset wrapping strategy for new internal messages representations
"""
def __init__(
self,
processor,
message_transform=None,
formatter=None,
**kwargs, # pylint: disable=unused-argument
):
"""
:param processor: tokenizer or image processor
:param kwargs:
"""
self.processor = processor
self.dataset = None
self.message_transform = message_transform
self.formatter = formatter
def wrap_dataset(
self,
dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs, # pylint: disable=unused-argument
):
self.dataset = TokenizedChatDataset(
dataset,
message_transform=self.message_transform,
model_transform=self.processor,
formatter=self.formatter,
process_count=process_count,
keep_in_memory=keep_in_memory,
)
return self.dataset
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
field_messages = ds_cfg.get("field_messages")
message_field_role = ds_cfg.get("message_field_role")
message_field_content = ds_cfg.get("message_field_content")
message_field_training = ds_cfg.get("message_field_training")
builder_kwargs = {}
if field_messages:
builder_kwargs["conversations_field"] = field_messages
if message_field_role:
builder_kwargs["message_field_role"] = message_field_role
if message_field_content:
builder_kwargs["message_field_content"] = message_field_content
if message_field_training:
builder_kwargs["message_field_training"] = message_field_training
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
format_message = (
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
)
if chat_template == "chatml":
from axolotl.core.chat.format.chatml import format_message # noqa F811
if chat_template.startswith("llama3"):
from axolotl.core.chat.format.llama3x import format_message # noqa F811
message_transform: Callable = chat_message_transform_builder(
train_on_inputs=ds_cfg.get("train_on_inputs", False),
**builder_kwargs,
)
strategy = ChatMessageDatasetWrappingStrategy(
tokenizer, message_transform=message_transform, formatter=format_message
)
return strategy

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.chat_templates import chat_templates
class Message(BaseModel): class Message(BaseModel):
@@ -28,13 +28,18 @@ def load(
""" """
chatml transforms for datasets with system, input, chosen, rejected chatml transforms for datasets with system, input, chosen, rejected
""" """
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer chat_template = chat_templates("chatml")
) if ds_cfg and "chat_template" in ds_cfg:
tokenizer.chat_template = chat_template_string chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(chat_template)
except ValueError:
pass
tokenizer.chat_template = chat_template
return ORPOTokenizingStrategy( return ORPOTokenizingStrategy(
ORPOPrompter(chat_template_string, tokenizer), ORPOPrompter(chat_template, tokenizer),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
@@ -243,30 +248,28 @@ class ORPOPrompter(Prompter):
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy() dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None): def transform_fn(sample, tokenizer=None):
res = {} res = {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, tokenizer=tokenizer
)
res["prompt"] = tokenizer.apply_chat_template( res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages], [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True, add_generation_prompt=True,
chat_template=chat_template_string, chat_template=chat_template_str,
tokenize=False, tokenize=False,
) )
prompt_str_len = len(res["prompt"]) prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template( res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages], [msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False, add_generation_prompt=False,
chat_template=chat_template_string, chat_template=chat_template_str,
tokenize=False, tokenize=False,
)[prompt_str_len:] )[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template( res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages], [msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False, add_generation_prompt=False,
chat_template=chat_template_string, chat_template=chat_template_str,
tokenize=False, tokenize=False,
)[prompt_str_len:] )[prompt_str_len:]

View File

@@ -0,0 +1,220 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging
from typing import Any, Dict, Optional, Type
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
from axolotl.utils.tokenization import (
chatml_to_conversation,
merge_consecutive_messages,
)
LOG = logging.getLogger("axolotl")
def register_chatml_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
register_conv_template(
Conversation(
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)
register_conv_template(
Conversation(
name="chatml_glaive",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)
def register_llama3_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
register_conv_template(
Conversation(
name="llama3",
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
system_message=system_message,
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
sep="",
stop_str="<|eot_id|>",
stop_token_ids=[128001, 128009],
)
)
def build_loader(
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
prompter_cls: Type["ShareGPTPrompterV2"],
default_conversation: Optional[str] = None,
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else default_conversation
)
field_human = (
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
)
field_model = (
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
)
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = tokenization_strategy_cls(
prompter_cls(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy
return _load
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
_strict = False
_messages = "conversations"
@property
def strict(self):
return self._strict
@strict.setter
def strict(self, strict):
self._strict = strict
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
def get_conversation_thread(self, prompt):
conversations = prompt[self.messages]
if self.strict:
return conversations
role_key = "from"
if "role" in conversations[0].keys():
role_key = "role"
value_key = "value"
if "text" in conversations[0].keys():
value_key = "text"
elif "content" in conversations[0].keys():
value_key = "content"
# remap roles - allow for assistant turn"
role_map = {
"user": "human",
"human": "human",
"assistant": "gpt",
"gpt": "gpt",
"system": "system",
}
turns = [
{
"from": (
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
),
"value": t[value_key],
"weight": 1
if "weight" not in t or t["weight"] is None
else t["weight"],
}
for t in conversations
]
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(
SimpleShareGPTPromptTokenizingStrategy
):
"""
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
return turns
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps oasst data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"}
turns = [
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
]
return turns
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps ultrachat data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversations = prompt["messages"]
role_map = {"user": "human", "assistant": "gpt"}
turns = [
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps glaive data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversation = chatml_to_conversation(prompt)
conversation = merge_consecutive_messages(conversation)
return conversation
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_ultrachat = build_loader(
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
)
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_glaive = build_loader(
GlaiveShareGPTPromptTokenizingStrategy,
ShareGPTPrompterV2,
default_conversation="chatml_glaive",
)

Some files were not shown because too many files have changed in this diff Show More