Compare commits
1 Commits
pixtral_in
...
shampoo-lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1b4030cdd |
22
.github/workflows/base.yml
vendored
@@ -1,16 +1,6 @@
|
||||
name: ci-cd-base
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- 'Dockerfile-base'
|
||||
- '.github/workflows/base.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'Dockerfile-base'
|
||||
- '.github/workflows/base.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
@@ -37,7 +27,7 @@ jobs:
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.10"
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
@@ -54,21 +44,19 @@ jobs:
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@v3
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-base
|
||||
axolotlai/axolotl-base
|
||||
images: winglian/axolotl-base
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
|
||||
2
.github/workflows/docs.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Set up Quarto
|
||||
uses: quarto-dev/quarto-actions/setup@v2
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: install dependencies
|
||||
|
||||
6
.github/workflows/lint.yml
vendored
@@ -15,9 +15,9 @@ jobs:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
|
||||
43
.github/workflows/main.yml
vendored
@@ -4,13 +4,11 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
tags:
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -34,7 +32,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.5.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -44,12 +42,7 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl
|
||||
axolotlai/axolotl
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=pep440,pattern={{version}}
|
||||
images: winglian/axolotl
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Hub
|
||||
@@ -63,7 +56,7 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||
@@ -77,7 +70,7 @@ jobs:
|
||||
|
||||
build-axolotl-cloud:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -101,7 +94,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.5.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -111,25 +104,20 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-cloud
|
||||
axolotlai/axolotl-cloud
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=pep440,pattern={{version}}
|
||||
images: winglian/axolotl-cloud
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile-cloud
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
@@ -140,7 +128,7 @@ jobs:
|
||||
|
||||
build-axolotl-cloud-no-tmux:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -158,25 +146,20 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-cloud-term
|
||||
axolotlai/axolotl-cloud-term
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=pep440,pattern={{version}}
|
||||
images: winglian/axolotl-cloud-term
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile-cloud-no-tmux
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
|
||||
9
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,14 +8,9 @@ on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
jobs:
|
||||
test-axolotl-multigpu:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -36,7 +31,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.5.0
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
|
||||
18
.github/workflows/nightlies.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.5.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -41,9 +41,7 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl
|
||||
axolotlai/axolotl
|
||||
images: winglian/axolotl
|
||||
tags: |
|
||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||
- name: Set up Docker Buildx
|
||||
@@ -71,7 +69,7 @@ jobs:
|
||||
|
||||
build-axolotl-cloud:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -95,7 +93,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.5.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -105,9 +103,7 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-cloud
|
||||
axolotlai/axolotl-cloud
|
||||
images: winglian/axolotl-cloud
|
||||
tags: |
|
||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||
- name: Login to Docker Hub
|
||||
@@ -116,7 +112,7 @@ jobs:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
|
||||
22
.github/workflows/pypi.yml
vendored
@@ -3,24 +3,12 @@ name: publish pypi
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
setup_release:
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Create release
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: gh release create "$GITHUB_REF_NAME" # GITHUB_REF_NAME is the tag name in `on.push.tags` workflows
|
||||
pypi-publish:
|
||||
name: Upload release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: [setup_release]
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/axolotl
|
||||
@@ -28,10 +16,10 @@ jobs:
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
@@ -49,9 +37,9 @@ jobs:
|
||||
run: |
|
||||
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
|
||||
|
||||
- name: Build a source dist
|
||||
- name: Build a binary wheel
|
||||
run: |
|
||||
python setup.py sdist
|
||||
python setup.py sdist bdist_wheel
|
||||
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
||||
16
.github/workflows/tests-nightly.yml
vendored
@@ -9,12 +9,12 @@ jobs:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
@@ -25,15 +25,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
@@ -48,14 +48,12 @@ jobs:
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging
|
||||
pip3 install -U -e .
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
@@ -94,7 +92,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.5.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
|
||||
75
.github/workflows/tests.yml
vendored
@@ -8,33 +8,24 @@ on:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/*.yml'
|
||||
- 'requirements-tests.txt'
|
||||
- 'cicd/cicd.sh'
|
||||
pull_request:
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/*.yml'
|
||||
- 'requirements-tests.txt'
|
||||
- 'cicd/cicd.sh'
|
||||
workflow_dispatch:
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
@@ -45,15 +36,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
@@ -71,68 +62,22 @@ jobs:
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -n8 --ignore=tests/e2e/ tests/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
pytest-sdist:
|
||||
name: PyTest from Source Dist
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
python3 setup.py sdist
|
||||
pip3 install dist/axolotl*.tar.gz
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -n8 --ignore=tests/e2e/ tests/
|
||||
pytest --ignore=tests/e2e/ tests/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
docker-e2e-tests-1st:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
needs: [pre-commit, pytest]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -187,7 +132,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.5.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
|
||||
3
.gitignore
vendored
@@ -182,6 +182,3 @@ submit.sh
|
||||
|
||||
typings/
|
||||
out/
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
recursive-include axolotl *.py
|
||||
40
README.md
@@ -1,21 +1,8 @@
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="image/axolotl_logo_digital_white.svg">
|
||||
<source media="(prefers-color-scheme: light)" srcset="image/axolotl_logo_digital_black.svg">
|
||||
<img alt="Axolotl" src="image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
|
||||
</picture>
|
||||
</p>
|
||||
# Axolotl
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
|
||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
|
||||
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
|
||||
</p>
|
||||
<p align="center">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
||||
</p>
|
||||

|
||||

|
||||

|
||||
|
||||
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
|
||||
|
||||
@@ -88,7 +75,7 @@ Features:
|
||||
<td>
|
||||
|
||||
<div align="center">
|
||||
<img src="image/axolotl_symbol_digital_white.svg" alt="axolotl" width="160">
|
||||
<img src="image/axolotl.png" alt="axolotl" width="160">
|
||||
<div>
|
||||
<p>
|
||||
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
|
||||
@@ -147,7 +134,7 @@ pip3 install -e '.[flash-attn,deepspeed]'
|
||||
### Usage
|
||||
```bash
|
||||
# preprocess datasets - optional but recommended
|
||||
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
|
||||
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
|
||||
|
||||
# finetune lora
|
||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||
@@ -172,7 +159,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
|
||||
#### Docker
|
||||
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
|
||||
```
|
||||
|
||||
Or run on the current files for development:
|
||||
@@ -191,7 +178,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
|
||||
A more powerful Docker command to run would be this:
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
|
||||
```
|
||||
|
||||
It additionally:
|
||||
@@ -223,7 +210,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
|
||||
#### Cloud GPU
|
||||
|
||||
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
|
||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
||||
|
||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
||||
@@ -332,7 +319,7 @@ Write a job description in YAML as below:
|
||||
# dstack.yaml
|
||||
type: task
|
||||
|
||||
image: axolotlai/axolotl-cloud:main-latest
|
||||
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
|
||||
|
||||
env:
|
||||
- HUGGING_FACE_HUB_TOKEN
|
||||
@@ -396,10 +383,11 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- typescript
|
||||
type: ... # unimplemented custom format
|
||||
|
||||
# chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template
|
||||
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
|
||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
- path: ...
|
||||
type: chat_template
|
||||
chat_template: chatml # defaults to tokenizer's chat_template
|
||||
type: sharegpt
|
||||
conversation: chatml # default: vicuna_v1.1
|
||||
|
||||
# local
|
||||
- path: data.jsonl # or json
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||
FROM winglian/axolotl-base:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
@@ -28,7 +28,6 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
@@ -37,9 +36,6 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
RUN python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
pytest -v --durations=10 -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -10,7 +10,7 @@ import tempfile
|
||||
import jinja2
|
||||
import modal
|
||||
from jinja2 import select_autoescape
|
||||
from modal import App, Image
|
||||
from modal import Image, Stub
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
@@ -46,7 +46,7 @@ cicd_image = (
|
||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
||||
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
||||
@@ -61,7 +61,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
|
||||
@app.function(
|
||||
@stub.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=60 * 60,
|
||||
@@ -72,6 +72,6 @@ def cicd_pytest():
|
||||
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
@stub.local_entrypoint()
|
||||
def main():
|
||||
cicd_pytest.remote()
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
set -e
|
||||
|
||||
# only run one test at a time so as not to OOM the GPU
|
||||
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
|
||||
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/
|
||||
|
||||
@@ -10,7 +10,7 @@ import tempfile
|
||||
import jinja2
|
||||
import modal
|
||||
from jinja2 import select_autoescape
|
||||
from modal import App, Image
|
||||
from modal import Image, Stub
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
@@ -40,7 +40,6 @@ with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
cicd_image = (
|
||||
Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
context_mount=None,
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
)
|
||||
@@ -48,7 +47,7 @@ cicd_image = (
|
||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
||||
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
@@ -63,7 +62,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
|
||||
@app.function(
|
||||
@stub.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=60 * 60,
|
||||
@@ -74,6 +73,6 @@ def cicd_pytest():
|
||||
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
@stub.local_entrypoint()
|
||||
def main():
|
||||
cicd_pytest.remote()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Example config for debugging the chat_template prompt format
|
||||
# Example config for debugging the sharegpt prompt format
|
||||
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG BASE_TAG=main-base
|
||||
FROM axolotlai/axolotl-base:$BASE_TAG
|
||||
FROM winglian/axolotl-base:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
@@ -26,9 +26,6 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
RUN python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN pip install pytest
|
||||
|
||||
|
||||
@@ -29,11 +29,13 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
RUN if [ "$PYTHON_VERSION" != "2.5.1" ] ; then \
|
||||
pip3 install flash-attn==2.6.3; \
|
||||
fi
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG BASE_TAG=main
|
||||
FROM axolotlai/axolotl:$BASE_TAG
|
||||
FROM winglian/axolotl:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG BASE_TAG=main
|
||||
FROM axolotlai/axolotl:$BASE_TAG
|
||||
FROM winglian/axolotl:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG BASE_TAG=main-base
|
||||
FROM axolotlai/axolotl-base:$BASE_TAG
|
||||
FROM winglian/axolotl-base:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
|
||||
@@ -83,7 +83,7 @@ lora_on_cpu: true
|
||||
datasets:
|
||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||
data_files: # Optional[str] path to source data files
|
||||
@@ -91,7 +91,15 @@ datasets:
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
|
||||
|
||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
field_human: # Optional[str]. Human key to use for conversation.
|
||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||
# Add additional keys from your dataset as input or output roles
|
||||
roles:
|
||||
input: # Optional[List[str]]. These will be masked based on train_on_input
|
||||
output: # Optional[List[str]].
|
||||
|
||||
# Custom user instruction prompt
|
||||
- path: repo
|
||||
@@ -162,9 +170,6 @@ datasets:
|
||||
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
||||
shuffle_merged_datasets: true
|
||||
|
||||
Deduplicates datasets and test_datasets with identical entries.
|
||||
dataset_exact_deduplication: true
|
||||
|
||||
# A list of one or more datasets to eval the model with.
|
||||
# You can use either test_datasets, or val_set_size, but not both.
|
||||
test_datasets:
|
||||
@@ -178,8 +183,6 @@ test_datasets:
|
||||
|
||||
# use RL training: 'dpo', 'ipo', 'kto'
|
||||
rl:
|
||||
# whether to perform weighting if doing DPO training. Boolean.
|
||||
dpo_use_weighting:
|
||||
|
||||
# The name of the chat template to use for training, following values are supported:
|
||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
||||
@@ -409,7 +412,6 @@ lr_div_factor: # Learning rate div factor
|
||||
# - adamw_torch_fused
|
||||
# - adamw_torch_xla
|
||||
# - adamw_apex_fused
|
||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||
# - adafactor
|
||||
# - adamw_anyprecision
|
||||
# - sgd
|
||||
|
||||
@@ -6,8 +6,33 @@ order: 3
|
||||
|
||||
## sharegpt
|
||||
|
||||
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
|
||||
UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.
|
||||
|
||||
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
Note: `type: sharegpt` opens special configs:
|
||||
- `conversation`: enables conversions to many Conversation types. Refer to the 'name' [here](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) for options.
|
||||
- `roles`: allows you to specify the roles for input and output. This is useful for datasets with custom roles such as `tool` etc to support masking.
|
||||
- `field_human`: specify the key to use instead of `human` in the conversation.
|
||||
- `field_model`: specify the key to use instead of `gpt` in the conversation.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
path: ...
|
||||
type: sharegpt
|
||||
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
field_human: # Optional[str]. Human key to use for conversation.
|
||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||
# Add additional keys from your dataset as input or output roles
|
||||
roles:
|
||||
input: # Optional[List[str]]. These will be masked based on train_on_input
|
||||
output: # Optional[List[str]].
|
||||
```
|
||||
|
||||
## pygmalion
|
||||
|
||||
@@ -15,6 +40,38 @@ IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
## sharegpt.load_role
|
||||
|
||||
conversations where `role` is used instead of `from`
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
## sharegpt.load_guanaco
|
||||
|
||||
conversations where `from` is `prompter` `assistant` instead of default sharegpt
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
## sharegpt.load_ultrachat
|
||||
|
||||
conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'.
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"messages": [{"user": "...", "assistant": "..."}]}
|
||||
```
|
||||
|
||||
## sharegpt_jokes
|
||||
|
||||
creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
||||
```
|
||||
|
||||
|
||||
## chat_template
|
||||
|
||||
|
||||
@@ -185,7 +185,7 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
||||
|
||||
## Debugging With Docker
|
||||
|
||||
Using [official Axolotl Docker images](https://hub.docker.com/r/axolotlai/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
|
||||
Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
|
||||
|
||||
### Setup
|
||||
|
||||
@@ -202,11 +202,11 @@ cd axolotl
|
||||
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
|
||||
>[!Tip]
|
||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/winglian/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||
|
||||
You will now be in the container. Next, perform an editable install of Axolotl:
|
||||
|
||||
|
||||
@@ -11,10 +11,12 @@ standard industry baselines.
|
||||
|
||||
### Installation
|
||||
|
||||
The following will install the correct unsloth and extras from source.
|
||||
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
|
||||
to date libraries.
|
||||
|
||||
```bash
|
||||
python scripts/unsloth_install.py | sh
|
||||
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
|
||||
pip install --no-deps --force-reinstall xformers==0.0.26.post1
|
||||
```
|
||||
|
||||
### Using unsloth w Axolotl
|
||||
|
||||
@@ -2,15 +2,19 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "AKjdG7tbTb-n"
|
||||
},
|
||||
"source": [
|
||||
"## Setting up"
|
||||
"# Example notebook for running Axolotl on google colab"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "RcbNpOgWRcii"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
@@ -18,76 +22,82 @@
|
||||
"assert (torch.cuda.is_available()==True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "h3nLav8oTRA5"
|
||||
},
|
||||
"source": [
|
||||
"## Install Axolotl and dependencies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "3c3yGAwnOIdi",
|
||||
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install axolotl[deepspeed]"
|
||||
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
||||
"!pip install flash-attn==\"2.5.0\"\n",
|
||||
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "BW2MFr7HTjub"
|
||||
},
|
||||
"source": [
|
||||
"## Hugging Face login (optional)"
|
||||
"## Create an yaml config file"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from huggingface_hub import notebook_login\n",
|
||||
"notebook_login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example configuration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "9pkF2dSoQEUN"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"# Your YAML string\n",
|
||||
"yaml_string = \"\"\"\n",
|
||||
"base_model: NousResearch/Meta-Llama-3.1-8B\n",
|
||||
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
|
||||
"model_type: LlamaForCausalLM\n",
|
||||
"tokenizer_type: LlamaTokenizer\n",
|
||||
"\n",
|
||||
"load_in_8bit: false\n",
|
||||
"load_in_4bit: true\n",
|
||||
"strict: false\n",
|
||||
"\n",
|
||||
"datasets:\n",
|
||||
" - path: tatsu-lab/alpaca\n",
|
||||
" - path: mhenrichsen/alpaca_2k_test\n",
|
||||
" type: alpaca\n",
|
||||
"dataset_prepared_path: last_run_prepared\n",
|
||||
"dataset_prepared_path:\n",
|
||||
"val_set_size: 0.05\n",
|
||||
"output_dir: ./outputs/lora-out\n",
|
||||
"\n",
|
||||
"sequence_len: 2048\n",
|
||||
"sample_packing: true\n",
|
||||
"eval_sample_packing: true\n",
|
||||
"pad_to_sequence_len: true\n",
|
||||
"output_dir: ./outputs/qlora-out\n",
|
||||
"\n",
|
||||
"adapter: qlora\n",
|
||||
"lora_model_dir:\n",
|
||||
"\n",
|
||||
"sequence_len: 4096\n",
|
||||
"sample_packing: true\n",
|
||||
"eval_sample_packing: false\n",
|
||||
"pad_to_sequence_len: true\n",
|
||||
"\n",
|
||||
"lora_r: 32\n",
|
||||
"lora_alpha: 16\n",
|
||||
"lora_dropout: 0.05\n",
|
||||
"lora_target_modules:\n",
|
||||
"lora_target_linear: true\n",
|
||||
"lora_fan_in_fan_out:\n",
|
||||
"lora_modules_to_save:\n",
|
||||
" - embed_tokens\n",
|
||||
" - lm_head\n",
|
||||
"\n",
|
||||
"wandb_project:\n",
|
||||
"wandb_entity:\n",
|
||||
@@ -95,12 +105,12 @@
|
||||
"wandb_name:\n",
|
||||
"wandb_log_model:\n",
|
||||
"\n",
|
||||
"gradient_accumulation_steps: 2\n",
|
||||
"micro_batch_size: 1\n",
|
||||
"num_epochs: 1\n",
|
||||
"optimizer: paged_adamw_8bit\n",
|
||||
"gradient_accumulation_steps: 4\n",
|
||||
"micro_batch_size: 2\n",
|
||||
"num_epochs: 4\n",
|
||||
"optimizer: paged_adamw_32bit\n",
|
||||
"lr_scheduler: cosine\n",
|
||||
"learning_rate: 2e-5\n",
|
||||
"learning_rate: 0.0002\n",
|
||||
"\n",
|
||||
"train_on_inputs: false\n",
|
||||
"group_by_length: false\n",
|
||||
@@ -111,15 +121,13 @@
|
||||
"gradient_checkpointing: true\n",
|
||||
"early_stopping_patience:\n",
|
||||
"resume_from_checkpoint:\n",
|
||||
"local_rank:\n",
|
||||
"logging_steps: 1\n",
|
||||
"xformers_attention:\n",
|
||||
"flash_attention: false\n",
|
||||
"sdp_attention: true\n",
|
||||
"flash_attention: true\n",
|
||||
"\n",
|
||||
"warmup_steps: 1\n",
|
||||
"max_steps: 25\n",
|
||||
"evals_per_epoch: 1\n",
|
||||
"eval_table_size:\n",
|
||||
"warmup_steps: 10\n",
|
||||
"evals_per_epoch: 4\n",
|
||||
"saves_per_epoch: 1\n",
|
||||
"debug:\n",
|
||||
"deepspeed:\n",
|
||||
@@ -127,9 +135,8 @@
|
||||
"fsdp:\n",
|
||||
"fsdp_config:\n",
|
||||
"special_tokens:\n",
|
||||
" pad_token: <|end_of_text|>\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"# Convert the YAML string to a Python dictionary\n",
|
||||
"yaml_dict = yaml.safe_load(yaml_string)\n",
|
||||
@@ -139,124 +146,31 @@
|
||||
"\n",
|
||||
"# Write the YAML file\n",
|
||||
"with open(file_path, 'w') as file:\n",
|
||||
" yaml.dump(yaml_dict, file)"
|
||||
" yaml.dump(yaml_dict, file)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "bidoj8YLTusD"
|
||||
},
|
||||
"source": [
|
||||
"Above we have a configuration file with base LLM model and datasets specified, among many other things. Axolotl can automatically detect whether the specified datasets are on HuggingFace repo or local machine.\n",
|
||||
"\n",
|
||||
"The Axolotl configuration options encompass model and dataset selection, data pre-processing, and training. Let's go through them line by line:\n",
|
||||
"\n",
|
||||
"* \"base model\": String value, specifies the underlying pre-trained LLM that will be used for finetuning\n",
|
||||
"\n",
|
||||
"Next we have options for model weights quantization. Quantization allows for reduction in occupied memory on GPUs.\n",
|
||||
"\n",
|
||||
"* \"load_in_8bit\": Boolean value, whether to quantize the model weights into 8-bit integer.\n",
|
||||
"\n",
|
||||
"* \"load_in_4bit\": Boolean value, whether to quantize the model weights into 4-bit integer.\n",
|
||||
"\n",
|
||||
"* \"strict\": Boolean value. If false, it allows for overriding established configuration options in the yaml file when executing in command-line interface.\n",
|
||||
"\n",
|
||||
"* \"datasets\": a list of dicts that contain path and type of data sets as well as other optional configurations where datasets are concerned. Supports multiple datasets.\n",
|
||||
"\n",
|
||||
"* \"val_set_size\": Either a float value less than one or an integer less than the total size of dataset. Sets the size of validation set from the whole dataset. If float, sets the proportion of the dataset assigned for validation. If integer, sets the direct size of validation set.\n",
|
||||
"\n",
|
||||
"* \"output_dir\": String value. Path of trained model.\n",
|
||||
"\n",
|
||||
"For data preprocessing:\n",
|
||||
"\n",
|
||||
"* \"sequence_len\": Integer. Specifies the maximum sequence length of the input. Typically 2048 or less.\n",
|
||||
"\n",
|
||||
"* \"pad_to_sequence_len\": Boolean. Padding input to maximum sequence length.\n",
|
||||
"\n",
|
||||
"* \"sample_packing\": Boolean. Specifies whether to use multi-packing with block diagonal attention.\n",
|
||||
"\n",
|
||||
"* \"special_tokens\": Python dict, optional. Allows users to specify the additional special tokens to be ignored by the tokenizer.\n",
|
||||
"\n",
|
||||
"For LoRA configuration and its hyperparamters:\n",
|
||||
"\n",
|
||||
"* \"adapter\": String. Either \"lora\" or \"qlora\", depending on user's choice.\n",
|
||||
"\n",
|
||||
"* \"lora_model_dir\": String, Optional. Path to directory that contains LoRA model, if there is already a trained LoRA model the user would like to use.\n",
|
||||
"\n",
|
||||
"* \"lora_r\": Integer. Refers to the rank of LoRA decomposition matrices. Higher value will reduce LoRA efficiency. Recommended to be set to 8.\n",
|
||||
"\n",
|
||||
"* \"lora_alpha\": Integer. Scale the weight matrices by $\\frac{\\text{lora_alpha}}{\\text{lora_r}}$Recommended to be fixed at 16.\n",
|
||||
"\n",
|
||||
"* \"lora_dropout\": Float that is 1 or less. The dropout probability of a lora layer.\n",
|
||||
"\n",
|
||||
"* \"lora_target_linear\": Boolean. If true, lora will target all linear modules in the transformers architecture.\n",
|
||||
"\n",
|
||||
"* \"lora_modules_to_save\": If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.\n",
|
||||
"\n",
|
||||
"See [LoRA](https://arxiv.org/abs/2106.09685) for detailed explanation of LoRA implementation.\n",
|
||||
"\n",
|
||||
"For the training configurations:\n",
|
||||
"\n",
|
||||
"* \"gradient_accumulation_steps\": Integer. The number of steps over which to accumulate gradient for batch training. E.g. if 2, backprop is performed every two steps.\n",
|
||||
"\n",
|
||||
"* \"micro_batch_size\": Integer. Batch size per gpu / gradient_accumulation_steps\n",
|
||||
"\n",
|
||||
"* \"num_epochs\": Integer. Number of epochs. One epoch is when training has looped over every batch in the whole data set once.\n",
|
||||
"\n",
|
||||
"* \"optimizer\": The optimizer to use for the training.\n",
|
||||
"\n",
|
||||
"* \"learning_rate\": The learning rate.\n",
|
||||
"\n",
|
||||
"* \"lr_scheduler\": The learning rate scheduler to use for adjusting learning rate during training.\n",
|
||||
"\n",
|
||||
"* \"train_on_inputs\": Boolean. Whether to ignore or include the user's prompt from the training labels.\n",
|
||||
"\n",
|
||||
"* \"group_by_length\": Boolean. Whether to group similarly sized data to minimize padding.\n",
|
||||
"\n",
|
||||
"* \"bf16\": Either \"auto\", \"true\", or \"false\". Whether to use CUDA bf16 floating point format. If set to \"auto\", will automatically apply bf16 should the gpu supports it.\n",
|
||||
"\n",
|
||||
"* \"fp16\": Optional. Specifies whether to use CUDA fp16. Automatically set to true if \"bf16\" is set to true. Otherwise false.\n",
|
||||
"\n",
|
||||
"* \"tf32\": Boolean. Whether to use CUDA tf32. Will override bf16.\n",
|
||||
"\n",
|
||||
"* \"gradient_checkpointing\": Boolean. Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing\n",
|
||||
"\n",
|
||||
"* \"gradient_checkpointing_kwargs\": Python Dict. Fed into the trainer.\n",
|
||||
"\n",
|
||||
"* \"logging_steps\": Integer. Log training information over every specified number of steps.\n",
|
||||
"\n",
|
||||
"* \"flash_attention\": Boolean. Whether to use the [flash attention](https://github.com/Dao-AILab/flash-attention) mechanism.\n",
|
||||
"\n",
|
||||
"* \"sdp_attention\": Boolean. Whether to use the Scaled Dot Product attention mechanism (the attention mechanism in the [original implementation](https://arxiv.org/abs/1706.03762) of transformers.)\n",
|
||||
"\n",
|
||||
"* \"warmup_steps\": Integer. The number of pre-training steps where a very low learning rate is used.\n",
|
||||
"\n",
|
||||
"* \"evals_per_epoch\": Integer. Number of evaluations to be performed within one training epoch.\n",
|
||||
"\n",
|
||||
"* \"saves_per_epoch\": Integer. Number of times the model is saved in one training epoch.\n",
|
||||
"\n",
|
||||
"* \"weight_decay\": Positive Float. Sets the \"strength\" of weight decay (i.e. setting the coefficient of L2 regularization)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The above is but a snippet aiming to get users familiarized with the types of streamlined configuration options axolotl provides. For a full list of configuration options, see [here](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Train the model"
|
||||
"## Launch the training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "ydTI2Jk2RStU",
|
||||
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# By using the ! the comand will be executed as a bash command\n",
|
||||
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
||||
]
|
||||
},
|
||||
@@ -264,7 +178,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Predict with trained model"
|
||||
"## Play with inference"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -273,85 +187,36 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# By using the ! the comand will be executed as a bash command\n",
|
||||
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
||||
" --lora_model_dir=\"./outputs/lora-out\" --gradio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Deeper Dive"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"It is also helpful to gain some familiarity over some of the core inner workings of axolotl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Configuration Normalization"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Axolotl uses a custom Dict class, called ```DictDefault```\n",
|
||||
"to store configurations specified in the yaml configuration file (into a Python variable named ```cfg```). The definition for this custom Dict can be found in the [utils/dict.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/dict.py)\n",
|
||||
"\n",
|
||||
"```DictDefault``` is amended such that calling a missing key from it will result in a ```None``` return type. This is important because if some configuration options aren't specified by the user, the ```None``` type allows Axolotl to perform boolean operations to determine the default settings for missing configurations. For more examples on how this is done, check out [utils/config/__init__.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/__init__.py)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Loading Models, Tokenizers, and Trainer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we inspect [cli.train.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/cli/train.py), we will find that most of the heavy lifting were done by the function ```train()``` which is itself imported from [src/axolotl/train.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/train.py).\n",
|
||||
"\n",
|
||||
"```train()``` takes care of loading the appropriate tokenizer and pre-trained model through ```load_model()``` and ```load_tokenizer()``` from [src/axolotl/utils/models.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/models.py) respectively.\n",
|
||||
"\n",
|
||||
"```load_tokenizer()``` loads in the appropriate tokenizer given the desired model, as well as chat templates.\n",
|
||||
"\n",
|
||||
"```ModelLoader``` class follows after tokenizer has been selected. It will automatically discern the base model type, load in the desired model, as well as applying model-appropriate attention mechanism modifications (e.g. flash attention). Depending on which base model the user chooses in the configuration, ```ModelLoader``` will utilize the corresponding \"attention hijacking\" script. For example, if the user specified the base model to be ```NousResearch/Meta-Llama-3.1-8B```, which is of llama type, and set ```flash_attn``` to ```True```, ```ModelLoader``` will load in [llama_attn_hijack_flash.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/llama_attn_hijack_flash.py). For a list of supported attention hijacking, please refer to the directory [/src/axolotl/monkeypatch/](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/monkeypatch)\n",
|
||||
"\n",
|
||||
"Another important operation encompassed in ```train()``` is setting up the training that takes into account of user-specified traning configurations (e.g. num_epochs, optimizer) through the use of ```setup_trainer()``` from [/src/axolotl/utils/trainer.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/trainer.py), which in turn relies on modules from [/src/axolotl/core/trainer_builder.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/core/trainer_builder.py).\n",
|
||||
"```trainer_builder.py``` provides a list of trainer object options bespoke for the task type (Causal or Reinforcement learning ('dpo', 'ipo', 'kto') )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Monkey patch\n",
|
||||
"\n",
|
||||
"The [Monkey patch directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/monkeypatch) is where model architecture/optimization patching scripts are stored (these are modifications that are not implemented in the official releases, hence the name monkey patch). It includes attention jacking, ReLoRA, and unsloth optimization."
|
||||
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"gpuType": "T4",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"version": "3.9.6"
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
chat_template: llama3
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_exact_deduplication: true
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
@@ -1,76 +0,0 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
dataset_exact_deduplication: true
|
||||
test_value: true
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_modules_to_save:
|
||||
- embed_tokens
|
||||
- lm_head
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
@@ -1,63 +0,0 @@
|
||||
base_model: llava-hf/llava-1.5-7b-hf
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: llava
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
@@ -1,93 +0,0 @@
|
||||
#Note that we are switching from the regular chat template to chatml.
|
||||
#If you experience problems with the special tokens, training for more epochs can help.
|
||||
#After training, merge the model before inference otherwise you might
|
||||
#face problems with the special tokens.
|
||||
|
||||
base_model: mistralai/Mistral-7B-Instruct-v0.2
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
chat_template: chatml
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: olivermolenschot/alpaca_messages_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/dpo-qlora
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.2
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
lora_modules_to_save:
|
||||
- embed_tokens
|
||||
- lm_head
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 16
|
||||
num_epochs: 6
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0001
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: false
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<|im_start|>"
|
||||
eos_token: "<|im_end|>"
|
||||
@@ -1,65 +0,0 @@
|
||||
base_model: mistral-community/pixtral-12b
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: pixtral
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
@@ -1,63 +0,0 @@
|
||||
base_model: Qwen/Qwen2-VL-7B-Instruct
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen2_vl
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
@@ -1,67 +0,0 @@
|
||||
base_model: Qwen/Qwen2.5-0.5B
|
||||
|
||||
strict: false
|
||||
|
||||
chat_template: qwen_25
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/dpo-out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
Before Width: | Height: | Size: 11 KiB |
|
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 11 KiB |
@@ -1,19 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1113 283.5">
|
||||
<path fill="#141310" d="M435,234.3l-12.1-48.8h-54.4l-12.1,48.8h-24.7l48.2-185.1h31.6l47.9,185.1h-24.5ZM417.7,164.9l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
|
||||
<path fill="#141310" d="M568.2,234.3l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.6-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.4,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.2,49.3,71.7h-28.5Z"/>
|
||||
<path fill="#141310" d="M658.6,236.3c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM658.6,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path fill="#141310" d="M860.6,236.3c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM860.6,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path fill="#141310" d="M773.9,234c-18,0-32.6-14.6-32.6-32.6V48.8h24.1v152.6c0,4.7,3.8,8.5,8.5,8.5h16.8v24.1h-16.8Z"/>
|
||||
<path fill="#141310" d="M1036.2,234.3V81.4c0-4.7-3.8-8.5-8.5-8.5h-16.8v-24.1h16.8c18,0,32.6,14.6,32.6,32.6v152.9h-24.1Z"/>
|
||||
<path fill="#141310" d="M978.6,234.3c-18,0-32.6-14.6-32.6-32.6v-85.1h-20.3v-22.1h20.3v-45.3h24.1v45.3h30.2v22.1h-30.2v85.1c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
|
||||
<path fill="#141310" d="M51.5,49h12.2v-20.6h-12.2c-16,0-29,13-29,29v32.8h20.6v-32.8c0-4.7,3.8-8.4,8.4-8.4Z"/>
|
||||
<path fill="#141310" d="M92.8,49h12.2v-20.6h-12.2c-16,0-29,13-29,29v12.2h20.6v-12.2c0-4.7,3.8-8.4,8.4-8.4Z"/>
|
||||
<path fill="#141310" d="M249.3,57.4c0-16-13-29-29-29h-12.2v20.6h12.2c4.7,0,8.4,3.8,8.4,8.4v32.8h20.6v-32.8Z"/>
|
||||
<path fill="#141310" d="M187.4,90.2v-20.6h-103.1v20.6h-41.2v20.6h-20.6v41.2c0,11.4,9.2,20.6,20.6,20.6h185.5c11.4,0,20.6-9.2,20.6-20.6v-41.2h-20.6v-20.6h-41.2ZM166.8,141.7c0-5.7-4.6-10.3-10.3-10.3s-10.3,4.6-10.3,10.3v10.3h-20.6v-20.6c0-11.4,9.2-20.6,20.6-20.6s20.6,9.2,20.6,20.6v10.3ZM228.7,141.7c0-5.7-4.6-10.3-10.3-10.3s-10.3,4.6-10.3,10.3v10.3h-20.6v-20.6c0-11.4,9.2-20.6,20.6-20.6s20.6,9.2,20.6,20.6v10.3Z"/>
|
||||
<path fill="#141310" d="M208,57.4c0-16-13-29-29-29h-12.2v20.6h12.2c4.7,0,8.4,3.8,8.4,8.4v12.2h20.6v-12.2Z"/>
|
||||
<rect fill="#141310" x="22.5" y="234.5" width="41.2" height="20.6"/>
|
||||
<rect fill="#141310" x="84.3" y="234.5" width="164.9" height="20.6"/>
|
||||
<rect fill="#141310" x="208" y="193.3" width="41.2" height="20.6"/>
|
||||
<rect fill="#141310" x="22.5" y="193.3" width="164.9" height="20.6"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 3.2 KiB |
@@ -1,11 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1113 283.5">
|
||||
<path fill="#fff" d="M462.9,234.2l-12.1-48.8h-54.4l-12.1,48.8h-24.7l48.2-185h31.6l47.9,185h-24.4ZM445.7,164.8l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
|
||||
<path fill="#fff" d="M596.1,234.2l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.5-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.3,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.1,49.3,71.6h-28.5Z"/>
|
||||
<path fill="#fff" d="M686.4,236.2c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.6,14.8-41.4,9.8-9.7,23.4-14.7,40.2-14.7s30.4,4.9,40.2,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.4-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM686.4,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.8v-36.7c0-10.5-2.8-18.5-8.2-23.8-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path fill="#fff" d="M888.3,236.2c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.6,14.8-41.4,9.8-9.7,23.4-14.7,40.2-14.7s30.4,4.9,40.2,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.4-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM888.3,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.8v-36.7c0-10.5-2.8-18.5-8.2-23.8-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path fill="#fff" d="M801.7,234c-18,0-32.6-14.6-32.6-32.6V48.8h24.1v152.5c0,4.7,3.8,8.5,8.5,8.5h16.7v24.1h-16.7Z"/>
|
||||
<path fill="#fff" d="M1063.8,234.2V81.4c0-4.7-3.8-8.5-8.5-8.5h-16.7v-24.1h16.7c18,0,32.6,14.6,32.6,32.6v152.8h-24.1Z"/>
|
||||
<path fill="#fff" d="M1006.2,234.2c-18,0-32.6-14.6-32.6-32.6v-85h-20.3v-22.1h20.3v-45.2h24.1v45.2h30.2v22.1h-30.2v85c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
|
||||
<path fill="#fff" d="M160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM277.3,57.4c0-23.8-19.3-43.1-43.1-43.1h-12.2c-3.9,0-7.6,1.6-10.2,4.4-5.9-2.9-12.3-4.4-18.9-4.4h-12.2c-7.7,0-14.1,6.3-14.1,14.1v20.6c0,2.4.6,4.6,1.6,6.6h-37c1-2,1.6-4.2,1.6-6.6v-20.6c0-7.7-6.3-14.1-14.1-14.1h-12.2c-6.5,0-13,1.5-18.9,4.4-2.6-2.8-6.3-4.4-10.2-4.4h-12.2c-23.8,0-43.1,19.3-43.1,43.1v32.8c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v41.2c0,11,5.2,20.8,13.2,27.2-7.3.4-13.2,6.6-13.2,14v20.6c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v20.6c0,7.7,6.3,14.1,14.1,14.1h41.2c4.1,0,7.7-1.7,10.3-4.5,2.6,2.8,6.2,4.5,10.3,4.5h164.9c7.7,0,14.1-6.3,14.1-14.1v-20.6c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-20.6c0-7.5-5.8-13.6-13.2-14,8-6.4,13.2-16.2,13.2-27.2v-41.2c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-32.8ZM77.8,255.1h-41.2v-20.6h41.2v20.6ZM36.5,213.9v-20.6h164.9v20.6H36.5ZM263.3,255.1H98.4v-20.6h164.9v20.6ZM263.3,213.9h-41.2v-20.6h41.2v20.6ZM263.3,90.2h-20.6v20.6h20.6v41.2c0,11.4-9.2,20.6-20.6,20.6H57.2c-11.4,0-20.6-9.2-20.6-20.6v-41.2h20.6v-20.6h-20.6v-32.8c0-16,13-29,29-29h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v32.8h41.2v-20.6h-20.6v-12.2c0-16,13-29,29-29h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v12.2h103.1v-12.2c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16,0,29,13,29,29v12.2h-20.6v20.6h41.2v-32.8c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16,0,29,13,29,29v32.8ZM201.4,152h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6s-20.6,9.2-20.6,20.6v20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6Z"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 6.6 KiB |
@@ -1,26 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 283.5 283.5">
|
||||
<defs>
|
||||
<style>
|
||||
.cls-1 {
|
||||
fill: #141310;
|
||||
}
|
||||
</style>
|
||||
</defs>
|
||||
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
|
||||
<g>
|
||||
<g id="Layer_1">
|
||||
<g>
|
||||
<path class="cls-1" d="M46.9,37.4h13.7V14.2h-13.7c-18,0-32.7,14.6-32.7,32.7v36.9h23.2v-36.9c0-5.2,4.2-9.5,9.5-9.5Z"/>
|
||||
<path class="cls-1" d="M93.2,37.4h13.7V14.2h-13.7c-18,0-32.7,14.6-32.7,32.7v13.7h23.2v-13.7c0-5.2,4.2-9.5,9.5-9.5Z"/>
|
||||
<path class="cls-1" d="M269.3,46.9c0-18-14.6-32.7-32.7-32.7h-13.7v23.2h13.7c5.2,0,9.5,4.2,9.5,9.5v36.9h23.2v-36.9Z"/>
|
||||
<path class="cls-1" d="M199.7,83.8v-23.2h-116v23.2h-46.4v23.2H14.2v46.4c0,12.8,10.4,23.2,23.2,23.2h208.7c12.8,0,23.2-10.4,23.2-23.2v-46.4h-23.2v-23.2h-46.4ZM176.5,141.7c0-6.4-5.2-11.6-11.6-11.6s-11.6,5.2-11.6,11.6v11.6h-23.2v-23.2c0-12.8,10.4-23.2,23.2-23.2s23.2,10.4,23.2,23.2v11.6ZM246.1,141.7c0-6.4-5.2-11.6-11.6-11.6s-11.6,5.2-11.6,11.6v11.6h-23.2v-23.2c0-12.8,10.4-23.2,23.2-23.2s23.2,10.4,23.2,23.2v11.6Z"/>
|
||||
<path class="cls-1" d="M222.9,46.9c0-18-14.6-32.7-32.7-32.7h-13.7v23.2h13.7c5.2,0,9.5,4.2,9.5,9.5v13.7h23.2v-13.7Z"/>
|
||||
<rect class="cls-1" x="14.2" y="246.1" width="46.4" height="23.2"/>
|
||||
<rect class="cls-1" x="83.8" y="246.1" width="185.5" height="23.2"/>
|
||||
<rect class="cls-1" x="222.9" y="199.7" width="46.4" height="23.2"/>
|
||||
<rect class="cls-1" x="14.2" y="199.7" width="185.5" height="23.2"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.6 KiB |
@@ -1,16 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 283.5 283.5">
|
||||
<defs>
|
||||
<style>
|
||||
.cls-1 {
|
||||
fill: #fff;
|
||||
}
|
||||
</style>
|
||||
</defs>
|
||||
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
|
||||
<g>
|
||||
<g id="Layer_1">
|
||||
<path class="cls-1" d="M152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM269.3,57.3c0-23.8-19.4-43.1-43.1-43.1h-12.2c-3.9,0-7.6,1.6-10.2,4.4-5.9-2.9-12.3-4.4-18.9-4.4h-12.2c-7.8,0-14.1,6.3-14.1,14.1v20.6c0,2.4.6,4.6,1.6,6.6h-37c1-2,1.6-4.2,1.6-6.6v-20.6c0-7.8-6.3-14.1-14.1-14.1h-12.2c-6.6,0-13,1.5-18.9,4.4-2.6-2.8-6.3-4.4-10.2-4.4h-12.2c-23.8,0-43.1,19.4-43.1,43.1v32.8c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v41.3c0,11,5.2,20.9,13.2,27.2-7.4.4-13.2,6.6-13.2,14v20.6c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v20.6c0,7.8,6.3,14.1,14.1,14.1h41.3c4.1,0,7.7-1.7,10.3-4.5,2.6,2.8,6.2,4.5,10.3,4.5h165.1c7.8,0,14.1-6.3,14.1-14.1v-20.6c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-20.6c0-7.5-5.9-13.6-13.2-14,8-6.4,13.2-16.2,13.2-27.2v-41.3c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-32.8ZM69.5,255.2H28.2v-20.6h41.3v20.6ZM28.2,214v-20.6h165.1v20.6H28.2ZM255.2,255.2H90.1v-20.6h165.1v20.6ZM255.2,214h-41.3v-20.6h41.3v20.6ZM255.2,90.1h-20.6v20.6h20.6v41.3c0,11.4-9.2,20.6-20.6,20.6H48.9c-11.4,0-20.6-9.2-20.6-20.6v-41.3h20.6v-20.6h-20.6v-32.8c0-16.1,13-29.1,29.1-29.1h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v32.8h41.3v-20.6h-20.6v-12.2c0-16.1,13-29.1,29.1-29.1h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v12.2h103.2v-12.2c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16.1,0,29.1,13,29.1,29.1v12.2h-20.6v20.6h41.3v-32.8c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16.1,0,29.1,13,29.1,29.1v32.8ZM193.3,152h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6s-20.6,9.2-20.6,20.6v20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6Z"/>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 5.0 KiB |
@@ -1,17 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 765.4 212.6">
|
||||
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
|
||||
<g>
|
||||
<g id="Layer_1">
|
||||
<g>
|
||||
<path d="M121.6,198.1l-12.1-48.8h-54.4l-12.1,48.8h-24.7L66.6,12.9h31.6l47.9,185.1h-24.5ZM104.4,128.6l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
|
||||
<path d="M254.9,198.1l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.6-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.4,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.2,49.3,71.7h-28.5Z"/>
|
||||
<path d="M345.2,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM345.2,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path d="M547.3,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM547.3,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path d="M460.6,197.8c-18,0-32.6-14.6-32.6-32.6V12.5h24.1v152.6c0,4.7,3.8,8.5,8.5,8.5h16.8v24.1h-16.8Z"/>
|
||||
<path d="M722.8,198.1V45.2c0-4.7-3.8-8.5-8.5-8.5h-16.8V12.5h16.8c18,0,32.6,14.6,32.6,32.6v152.9h-24.1Z"/>
|
||||
<path d="M665.2,198.1c-18,0-32.6-14.6-32.6-32.6v-85.1h-20.3v-22.1h20.3V12.9h24.1v45.3h30.2v22.1h-30.2v85.1c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 2.1 KiB |
@@ -1,24 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 765.4 212.6">
|
||||
<defs>
|
||||
<style>
|
||||
.cls-1 {
|
||||
fill: #fff;
|
||||
}
|
||||
</style>
|
||||
</defs>
|
||||
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
|
||||
<g>
|
||||
<g id="Layer_1">
|
||||
<g>
|
||||
<path class="cls-1" d="M121.6,198.1l-12.1-48.8h-54.4l-12.1,48.8h-24.7L66.6,12.9h31.6l47.9,185.1h-24.5ZM104.4,128.6l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
|
||||
<path class="cls-1" d="M254.9,198.1l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.6-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.4,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.2,49.3,71.7h-28.5Z"/>
|
||||
<path class="cls-1" d="M345.2,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM345.2,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path class="cls-1" d="M547.3,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM547.3,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path class="cls-1" d="M460.6,197.8c-18,0-32.6-14.6-32.6-32.6V12.5h24.1v152.6c0,4.7,3.8,8.5,8.5,8.5h16.8v24.1h-16.8Z"/>
|
||||
<path class="cls-1" d="M722.8,198.1V45.2c0-4.7-3.8-8.5-8.5-8.5h-16.8V12.5h16.8c18,0,32.6,14.6,32.6,32.6v152.9h-24.1Z"/>
|
||||
<path class="cls-1" d="M665.2,198.1c-18,0-32.6-14.6-32.6-32.6v-85.1h-20.3v-22.1h20.3V12.9h24.1v45.3h30.2v22.1h-30.2v85.1c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 2.3 KiB |
@@ -2,3 +2,4 @@ pre-commit
|
||||
black
|
||||
mypy
|
||||
types-requests
|
||||
tbparse
|
||||
|
||||
@@ -1,5 +1,2 @@
|
||||
pytest
|
||||
pytest-xdist
|
||||
pytest-retry
|
||||
pytest-sugar
|
||||
tbparse
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.13.2
|
||||
transformers==4.46.3
|
||||
transformers==4.46.1
|
||||
tokenizers>=0.20.1
|
||||
bitsandbytes==0.44.1
|
||||
accelerate==1.1.0
|
||||
datasets==3.1.0
|
||||
deepspeed==0.15.4
|
||||
datasets==3.0.1
|
||||
deepspeed==0.15.3
|
||||
pydantic==2.6.3
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
flash-attn==2.7.0.post2
|
||||
flash-attn==2.6.3
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
@@ -26,14 +26,15 @@ numpy>=1.24.4,<=2.0.1
|
||||
evaluate==0.4.1
|
||||
scipy
|
||||
scikit-learn==1.4.2
|
||||
nvidia-ml-py==12.560.30
|
||||
pynvml
|
||||
art
|
||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
python-dotenv==1.0.1
|
||||
autoawq==0.2.7.post2
|
||||
autoawq>=0.2.5
|
||||
triton>=2.3.0
|
||||
liger-kernel==0.4.2
|
||||
liger-kernel==0.4.0
|
||||
|
||||
mamba-ssm==1.2.0.post1
|
||||
|
||||
@@ -42,7 +43,7 @@ s3fs>=2024.5.0
|
||||
gcsfs>=2024.5.0
|
||||
# adlfs
|
||||
|
||||
trl==0.12.0
|
||||
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
@@ -53,4 +54,3 @@ immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.5.0
|
||||
schedulefree==1.3.0
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
# Export specific ENV variables to /etc/rp_environment
|
||||
echo "Exporting environment variables..."
|
||||
printenv | grep -E '^HF_|^BNB_|^CUDA_|^NCCL_|^NV|^RUNPOD_|^PATH=|^_=' | sed 's/^\([^=]*\)=\(.*\)$/export \1="\2"/' | grep -v 'printenv' >> /etc/rp_environment
|
||||
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
||||
echo 'source /etc/rp_environment' >> ~/.bashrc
|
||||
|
||||
add_keys_to_authorized() {
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Script to output the correct installation command for cut-cross-entropy."""
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError as exc:
|
||||
raise ImportError("Install torch via `pip install torch`") from exc
|
||||
from packaging.version import Version as V
|
||||
|
||||
v = V(torch.__version__)
|
||||
|
||||
# no cut-cross-entropy support for torch < 2.4.0
|
||||
if v < V("2.4.0"):
|
||||
print("")
|
||||
sys.exit(0)
|
||||
|
||||
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
||||
cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers")
|
||||
|
||||
UNINSTALL_PREFIX = ""
|
||||
if cce_spec and not cce_spec_transformers:
|
||||
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
# noqa
|
||||
# pylint: skip-file
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("Install torch via `pip install torch`")
|
||||
from packaging.version import Version as V
|
||||
|
||||
v = V(torch.__version__)
|
||||
cuda = str(torch.version.cuda)
|
||||
try:
|
||||
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
||||
except RuntimeError:
|
||||
is_ampere = False
|
||||
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
|
||||
raise RuntimeError(f"CUDA = {cuda} not supported!")
|
||||
if v <= V("2.1.0"):
|
||||
raise RuntimeError(f"Torch = {v} too old!")
|
||||
elif v <= V("2.1.1"):
|
||||
x = "cu{}{}-torch211"
|
||||
elif v <= V("2.1.2"):
|
||||
x = "cu{}{}-torch212"
|
||||
elif v < V("2.3.0"):
|
||||
x = "cu{}{}-torch220"
|
||||
elif v < V("2.4.0"):
|
||||
x = "cu{}{}-torch230"
|
||||
elif v < V("2.5.0"):
|
||||
x = "cu{}{}-torch240"
|
||||
elif v < V("2.6.0"):
|
||||
x = "cu{}{}-torch250"
|
||||
else:
|
||||
raise RuntimeError(f"Torch = {v} too new!")
|
||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||
print(
|
||||
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
|
||||
)
|
||||
20
setup.py
@@ -39,10 +39,7 @@ def parse_requirements():
|
||||
else:
|
||||
# detect the version of torch already installed
|
||||
# and set it so dependencies don't clobber the torch version
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.5.1"
|
||||
torch_version = version("torch")
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
@@ -57,10 +54,6 @@ def parse_requirements():
|
||||
|
||||
if (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
elif (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
@@ -96,19 +89,22 @@ install_requires, dependency_links = parse_requirements()
|
||||
|
||||
setup(
|
||||
name="axolotl",
|
||||
version="0.5.2",
|
||||
version="0.4.1",
|
||||
description="LLM Trainer",
|
||||
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
packages=find_packages(),
|
||||
install_requires=install_requires,
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.7.0.post2",
|
||||
"flash-attn==2.6.3",
|
||||
],
|
||||
"fused-dense-lib": [
|
||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.15.4",
|
||||
"deepspeed==0.14.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
|
||||
@@ -27,17 +27,14 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.chat_templates import (
|
||||
get_chat_template,
|
||||
get_chat_template_from_config,
|
||||
)
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
prepare_plugins,
|
||||
validate_config,
|
||||
)
|
||||
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
|
||||
@@ -100,8 +97,8 @@ def print_dep_versions():
|
||||
print("*" * 40)
|
||||
print("**** Axolotl Dependency Versions *****")
|
||||
for pkg in packages:
|
||||
pkg_version = _is_package_available(pkg, return_version=True)
|
||||
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
|
||||
version = _is_package_available(pkg, return_version=True)
|
||||
print(f"{pkg: >{max_len}}: {version[1]: <15}")
|
||||
print("*" * 40)
|
||||
|
||||
|
||||
@@ -139,7 +136,7 @@ def check_remote_config(config: Union[str, Path]):
|
||||
with open(output_path, "wb") as file:
|
||||
file.write(content)
|
||||
LOG.info(
|
||||
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
|
||||
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
|
||||
)
|
||||
return output_path
|
||||
|
||||
@@ -193,19 +190,18 @@ def do_inference(
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
chat_template_str = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template)
|
||||
elif cfg.datasets[0].type == "chat_template":
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
@@ -215,31 +211,13 @@ def do_inference(
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
|
||||
if prompter_module:
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
|
||||
if chat_template_str:
|
||||
batch = tokenizer.apply_chat_template(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
else:
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
model.eval()
|
||||
@@ -279,6 +257,13 @@ def do_inference_gradio(
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
default_tokens: Dict[str, str] = {}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
chat_template_str = None
|
||||
@@ -426,6 +411,11 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
|
||||
cfg.axolotl_config_path = config
|
||||
|
||||
if cfg.get("plugins"):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
for plugin_name in cfg["plugins"]:
|
||||
plugin_manager.register(plugin_name)
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||
@@ -439,13 +429,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
},
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
prepare_opinionated_env(cfg)
|
||||
|
||||
@@ -19,7 +19,7 @@ from axolotl.common.cli import TrainerCliArgs
|
||||
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, inference=True, **kwargs)
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
|
||||
@@ -23,6 +23,10 @@ from axolotl.cli import (
|
||||
)
|
||||
from axolotl.common.cli import PreprocessCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
)
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
@@ -40,6 +44,23 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
if parsed_cfg.chat_template == "chatml":
|
||||
if parsed_cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||
)
|
||||
register_chatml_template(parsed_cfg.default_system_message)
|
||||
else:
|
||||
register_chatml_template()
|
||||
elif parsed_cfg.chat_template == "llama3":
|
||||
if parsed_cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||
)
|
||||
register_llama3_template(parsed_cfg.default_system_message)
|
||||
else:
|
||||
register_llama3_template()
|
||||
|
||||
if not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
|
||||
@@ -19,6 +19,10 @@ from axolotl.cli import (
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
)
|
||||
from axolotl.train import train
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
@@ -38,6 +42,21 @@ def do_train(cfg, cli_args) -> None:
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
if cfg.chat_template == "chatml" and cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"ChatML set. Adding default system message: {cfg.default_system_message}"
|
||||
)
|
||||
register_chatml_template(cfg.default_system_message)
|
||||
else:
|
||||
register_chatml_template()
|
||||
|
||||
if cfg.chat_template == "llama3" and cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
|
||||
)
|
||||
register_llama3_template(cfg.default_system_message)
|
||||
else:
|
||||
register_llama3_template()
|
||||
|
||||
if cfg.rl: # and cfg.rl != "orpo":
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -107,22 +107,6 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
return kwargs
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||
if isinstance(dataset_tags, str):
|
||||
dataset_tags = [dataset_tags]
|
||||
|
||||
if (dataset_tags is not None) and (kwargs is not None):
|
||||
if "dataset_tags" not in kwargs:
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||
kwargs["dataset_tags"].extend(dataset_tags)
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||
dataset_tags.append(kwargs["dataset_tags"])
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
@@ -236,14 +220,6 @@ class AxolotlTrainingMixins:
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
@@ -410,7 +386,7 @@ class SchedulerMixin(Trainer):
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
@@ -434,12 +410,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
@@ -461,75 +435,38 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
def create_optimizer(self):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.alternate_optimizer
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]
|
||||
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
}
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in opt_model.named_parameters()
|
||||
if (n in decay_parameters and p.requires_grad)
|
||||
],
|
||||
"weight_decay": self.args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in opt_model.named_parameters()
|
||||
if (n not in decay_parameters and p.requires_grad)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if name.endswith("modules_to_save.default.weight") or any(
|
||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||
):
|
||||
params["embeddings"][name] = param
|
||||
elif name in decay_parameters:
|
||||
params["to_weight_decay"][name] = param
|
||||
else:
|
||||
params["no_weight_decay"][name] = param
|
||||
optimizer_grouped_parameters = []
|
||||
if params["to_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["to_weight_decay"].values()),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr,
|
||||
}
|
||||
)
|
||||
if params["no_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["no_weight_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
|
||||
if self.args.loraplus_lr_ratio is not None:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(
|
||||
@@ -542,13 +479,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
elif (
|
||||
self.args.embedding_lr_scale is not None
|
||||
or self.args.embedding_lr is not None
|
||||
):
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
@@ -575,16 +505,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "adopt_adamw":
|
||||
from axolotl.utils.optimizers.adopt import ADOPT
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
ADOPT(
|
||||
optimizer_grouped_parameters,
|
||||
decouple=True,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
@@ -937,9 +857,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
@@ -1063,9 +980,8 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
|
||||
def create_optimizer(self):
|
||||
@@ -1104,44 +1020,28 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
self,
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
res = super().tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
@@ -1285,17 +1185,11 @@ class TrainerBuilderBase(abc.ABC):
|
||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||
"""
|
||||
callbacks = []
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
[
|
||||
cb
|
||||
for cb in plugin_manager.add_callbacks_post_trainer(
|
||||
self.cfg, trainer
|
||||
)
|
||||
if cb
|
||||
]
|
||||
)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
|
||||
)
|
||||
return callbacks
|
||||
|
||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||
@@ -1342,7 +1236,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "wandb"
|
||||
@@ -1379,8 +1273,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
callbacks.append(lisa_callback_factory(trainer))
|
||||
|
||||
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
@@ -1498,15 +1390,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["eval_strategy"] = "no"
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_arguments_kwargs["eval_strategy"] = "steps"
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
elif self.cfg.eval_strategy:
|
||||
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||
elif self.cfg.evaluation_strategy:
|
||||
training_arguments_kwargs[
|
||||
"evaluation_strategy"
|
||||
] = self.cfg.evaluation_strategy
|
||||
else:
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["eval_strategy"] = "epoch"
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.save_steps:
|
||||
training_arguments_kwargs["save_strategy"] = "steps"
|
||||
@@ -1644,9 +1538,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs[
|
||||
"loraplus_lr_embedding"
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
training_arguments_kwargs[
|
||||
@@ -1734,13 +1625,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.reward_model:
|
||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
if self.cfg.optimizer in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]:
|
||||
# Set default so transformers doesn't throw
|
||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||
@@ -1831,10 +1720,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
@@ -1897,7 +1782,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator = MultiModalChatDataCollator
|
||||
kwargs["processor"] = self.processor
|
||||
kwargs["chat_template"] = training_args.chat_template
|
||||
kwargs["chat_template_type"] = self.cfg.chat_template
|
||||
else:
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
@@ -1948,10 +1832,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.eval_dataset:
|
||||
training_args_kwargs["eval_strategy"] = "steps"
|
||||
training_args_kwargs["evaluation_strategy"] = "steps"
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
else:
|
||||
training_args_kwargs["eval_strategy"] = "no"
|
||||
training_args_kwargs["evaluation_strategy"] = "no"
|
||||
|
||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||
training_args_kwargs["bf16"] = True
|
||||
@@ -2006,18 +1890,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||
if self.cfg.orpo_alpha:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rpo_alpha is not None:
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
training_args_cls = None
|
||||
if self.cfg.rl == "simpo":
|
||||
training_args_cls = AxolotlCPOConfig
|
||||
training_args_kwargs["loss_type"] = "simpo"
|
||||
@@ -2026,13 +1909,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
elif self.cfg.rl == "orpo":
|
||||
if self.cfg.rl == "orpo":
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl == "kto":
|
||||
if self.cfg.rl == "kto":
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
|
||||
training_args_kwargs["desirable_weight"] = (
|
||||
@@ -2047,17 +1930,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
else:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rl == "ipo":
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
@@ -2078,6 +1950,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args = self.build_training_arguments(total_num_steps)
|
||||
dpo_trainer_kwargs = {}
|
||||
if self.cfg.rl == "ipo":
|
||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
if self.eval_dataset:
|
||||
@@ -2091,6 +1964,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
|
||||
# these aren't used for the ORPO trainer
|
||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["max_target_length"] = None
|
||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
@@ -2109,10 +1988,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||
dpo_trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
dpo_trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
|
||||
@@ -40,7 +40,7 @@ class TRLPPOTrainer(PPOTrainer):
|
||||
query_tensors,
|
||||
return_prompt=False,
|
||||
generate_ref_response=True,
|
||||
**generation_kwargs,
|
||||
**generation_kwargs
|
||||
)
|
||||
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
||||
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
||||
|
||||
@@ -140,7 +140,7 @@ class BasePlugin:
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
setup callbacks before creating the trainer.
|
||||
Adds callbacks to the trainer before training.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
@@ -155,15 +155,14 @@ class BasePlugin:
|
||||
self, cfg, trainer
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Adds callbacks to the trainer after creating the trainer.
|
||||
This is useful for callbacks that require access to the model or trainer.
|
||||
Adds callbacks to the trainer after training.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
"""
|
||||
return []
|
||||
|
||||
@@ -394,9 +393,7 @@ class PluginManager:
|
||||
"""
|
||||
callbacks = []
|
||||
for plugin in self.plugins.values():
|
||||
plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
|
||||
if plugin_callbacks: # if the plugin returned a list of callbacks
|
||||
callbacks.extend(plugin_callbacks)
|
||||
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
||||
return callbacks
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
@@ -412,9 +409,7 @@ class PluginManager:
|
||||
"""
|
||||
callbacks = []
|
||||
for plugin in self.plugins.values():
|
||||
plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer)
|
||||
if plugin_callbacks:
|
||||
callbacks.extend(plugin_callbacks)
|
||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
||||
return callbacks
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
|
||||
@@ -1,325 +0,0 @@
|
||||
Acknowledgements
|
||||
|
||||
Portions of this Cut Cross Entropy Software may utilize the following copyrighted
|
||||
material, the use of which is hereby acknowledged.
|
||||
|
||||
|
||||
------
|
||||
|
||||
|
||||
PyTorch
|
||||
|
||||
From PyTorch:
|
||||
|
||||
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
||||
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
||||
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
||||
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
||||
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
||||
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
||||
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
||||
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
||||
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
||||
|
||||
From Caffe2:
|
||||
|
||||
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
||||
|
||||
All contributions by Facebook:
|
||||
Copyright (c) 2016 Facebook Inc.
|
||||
|
||||
All contributions by Google:
|
||||
Copyright (c) 2015 Google Inc.
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Yangqing Jia:
|
||||
Copyright (c) 2015 Yangqing Jia
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Kakao Brain:
|
||||
Copyright 2019-2020 Kakao Brain
|
||||
|
||||
All contributions by Cruise LLC:
|
||||
Copyright (c) 2022 Cruise LLC.
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Arm:
|
||||
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
|
||||
|
||||
All contributions from Caffe:
|
||||
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
All other contributions:
|
||||
Copyright(c) 2015, 2016 the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
||||
copyright over their contributions to Caffe2. The project versioning records
|
||||
all such contribution and copyright details. If a contributor wants to further
|
||||
mark their specific copyright on a particular contribution, they should
|
||||
indicate their copyright solely in the commit message of the change when it is
|
||||
committed.
|
||||
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||
and IDIAP Research Institute nor the names of its contributors may be
|
||||
used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
Triton
|
||||
|
||||
/*
|
||||
* Copyright 2018-2020 Philippe Tillet
|
||||
* Copyright 2020-2022 OpenAI
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
|
||||
Transformers
|
||||
|
||||
Copyright 2018- The Hugging Face team. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -1,47 +0,0 @@
|
||||
Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
IMPORTANT: This Apple software is supplied to you by Apple
|
||||
Inc. ("Apple") in consideration of your agreement to the following
|
||||
terms, and your use, installation, modification or redistribution of
|
||||
this Apple software constitutes acceptance of these terms. If you do
|
||||
not agree with these terms, please do not use, install, modify or
|
||||
redistribute this Apple software.
|
||||
|
||||
In consideration of your agreement to abide by the following terms, and
|
||||
subject to these terms, Apple grants you a personal, non-exclusive
|
||||
license, under Apple's copyrights in this original Apple software (the
|
||||
"Apple Software"), to use, reproduce, modify and redistribute the Apple
|
||||
Software, with or without modifications, in source and/or binary forms;
|
||||
provided that if you redistribute the Apple Software in its entirety and
|
||||
without modifications, you must retain this notice and the following
|
||||
text and disclaimers in all such redistributions of the Apple Software.
|
||||
Neither the name, trademarks, service marks or logos of Apple Inc. may
|
||||
be used to endorse or promote products derived from the Apple Software
|
||||
without specific prior written permission from Apple. Except as
|
||||
expressly stated in this notice, no other rights or licenses, express or
|
||||
implied, are granted by Apple herein, including but not limited to any
|
||||
patent rights that may be infringed by your derivative works or by other
|
||||
works in which the Apple Software may be incorporated.
|
||||
|
||||
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
|
||||
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
|
||||
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
|
||||
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
|
||||
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
|
||||
|
||||
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
|
||||
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
|
||||
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
|
||||
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
|
||||
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
SOFTWARE DISTRIBUTED WITH CUT CROSS ENTROPY:
|
||||
|
||||
The Cut Cross Entropy software includes a number of subcomponents with separate
|
||||
copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.md.
|
||||
-------------------------------------------------------------------------------
|
||||
@@ -1,10 +0,0 @@
|
||||
# Cut Cross Entropy
|
||||
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
cut_cross_entropy: true
|
||||
```
|
||||
@@ -1,83 +0,0 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Module for the Plugin for Cut Cross Entropy integration with Axolotl.
|
||||
|
||||
Cut Cross Entropy is an optimized implementation of cross entropy loss
|
||||
from Apple's ML team.
|
||||
"""
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils import get_pytorch_version
|
||||
|
||||
from ...utils.distributed import zero_only
|
||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
|
||||
)
|
||||
|
||||
|
||||
class CutCrossEntropyPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for Cut Cross Entropy integration with Axolotl.
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.cut_cross_entropy.CutCrossEntropyArgs"
|
||||
|
||||
def _check_requirements(self):
|
||||
"""Check if all requirements are met."""
|
||||
# Check PyTorch version
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
raise ImportError(
|
||||
"Cut Cross Entropy requires PyTorch >= 2.4.0. "
|
||||
f"Current version: {torch.__version__}"
|
||||
)
|
||||
|
||||
# Check if cut_cross_entropy is installed
|
||||
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
||||
if cce_spec is None:
|
||||
raise ImportError(_CCE_INSTALL_MESSAGE)
|
||||
|
||||
cce_spec_transformers = importlib.util.find_spec(
|
||||
"cut_cross_entropy.transformers"
|
||||
)
|
||||
if cce_spec_transformers is None:
|
||||
raise ImportError(_CCE_INSTALL_MESSAGE)
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply cut cross entropy before model loading if enabled."""
|
||||
if cfg.cut_cross_entropy:
|
||||
self._check_requirements()
|
||||
|
||||
from cut_cross_entropy.transformers import cce_patch
|
||||
|
||||
with zero_only():
|
||||
LOG.info(
|
||||
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
||||
)
|
||||
|
||||
# The patch checks model_type internally
|
||||
cce_patch(cfg.model_config_type)
|
||||
@@ -1,42 +0,0 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Module for handling Cut Cross Entropy input arguments.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args")
|
||||
|
||||
|
||||
class CutCrossEntropyArgs(BaseModel):
|
||||
"""
|
||||
Input args for Cut Cross Entropy.
|
||||
"""
|
||||
|
||||
cut_cross_entropy: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_dtype_is_half(cls, data):
|
||||
if not (data.get("bf16") or data.get("fp16")):
|
||||
raise ValueError(
|
||||
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
|
||||
"Please set `bf16` or `fp16` to `True`."
|
||||
)
|
||||
|
||||
return data
|
||||
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -1,13 +0,0 @@
|
||||
# Grokfast Optimizer
|
||||
|
||||
See https://github.com/ironjr/grokfast
|
||||
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.grokfast.GrokfastPlugin
|
||||
|
||||
grokfast_alpha: 2.0
|
||||
grokfast_lamb: 0.98
|
||||
```
|
||||
@@ -1,50 +0,0 @@
|
||||
"""
|
||||
Grokfast plugin for Axolotl
|
||||
"""
|
||||
import logging
|
||||
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
from ..base import BasePlugin
|
||||
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
|
||||
from .optimizer import gradfilter_ema
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.grokfast")
|
||||
|
||||
|
||||
class GrokfastCallbackHandler(TrainerCallback):
|
||||
"""
|
||||
Transformer trainer callbacks for Grokfast
|
||||
"""
|
||||
|
||||
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
|
||||
super().__init__(*args_, **kwargs)
|
||||
self.grads = None
|
||||
self.alpha = alpha
|
||||
self.lamb = lamb
|
||||
|
||||
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
|
||||
self.grads = None
|
||||
|
||||
def on_pre_optimizer_step(
|
||||
self, args_, state, control, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
model = kwargs.pop("model")
|
||||
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
|
||||
return control
|
||||
|
||||
|
||||
class GrokfastPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for Grokfast optimizer integraton with Axolotl.
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.grokfast.GrokfastArgs"
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
LOG.info("Adding Grokfast callback to the trainer")
|
||||
callback = GrokfastCallbackHandler(
|
||||
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
|
||||
)
|
||||
return [callback]
|
||||
@@ -1,15 +0,0 @@
|
||||
"""
|
||||
config args for grokfast plugin
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GrokfastArgs(BaseModel):
|
||||
"""
|
||||
Input args for Grokfast optimizer.
|
||||
"""
|
||||
|
||||
grokfast_alpha: Optional[float] = 0.98
|
||||
grokfast_lamb: Optional[float] = 2.0
|
||||
@@ -1,63 +0,0 @@
|
||||
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||
# Reference: https://github.com/ironjr/grokfast
|
||||
|
||||
# pylint: skip-file
|
||||
from collections import deque
|
||||
from typing import Dict, Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def gradfilter_ma(
|
||||
m: nn.Module,
|
||||
grads: Optional[Dict[str, deque]] = None,
|
||||
window_size: int = 100,
|
||||
lamb: float = 5.0,
|
||||
filter_type: Literal["mean", "sum"] = "mean",
|
||||
warmup: bool = True,
|
||||
trigger: bool = False, # For ablation study.
|
||||
) -> Dict[str, deque]:
|
||||
if grads is None:
|
||||
grads = {
|
||||
n: deque(maxlen=window_size)
|
||||
for n, p in m.named_parameters()
|
||||
if p.requires_grad and p.grad is not None
|
||||
}
|
||||
|
||||
for n, p in m.named_parameters():
|
||||
if p.requires_grad and p.grad is not None:
|
||||
grads[n].append(p.grad.data.detach()) # .cpu())
|
||||
|
||||
# Modify the gradients.
|
||||
if not warmup or len(grads[n]) == window_size and not trigger:
|
||||
if filter_type == "mean":
|
||||
avg = sum(grads[n]) / len(grads[n])
|
||||
elif filter_type == "sum":
|
||||
avg = sum(grads[n])
|
||||
else:
|
||||
raise ValueError(f"Unrecognized filter_type {filter_type}")
|
||||
p.grad.data = p.grad.data + avg * lamb
|
||||
|
||||
return grads
|
||||
|
||||
|
||||
def gradfilter_ema(
|
||||
m: nn.Module,
|
||||
grads: Optional[Dict[str, torch.Tensor]] = None,
|
||||
alpha: float = 0.98,
|
||||
lamb: float = 2.0,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if grads is None:
|
||||
grads = {
|
||||
n: p.grad.data.detach()
|
||||
for n, p in m.named_parameters()
|
||||
if p.requires_grad and p.grad is not None
|
||||
}
|
||||
|
||||
for n, p in m.named_parameters():
|
||||
if p.requires_grad and p.grad is not None:
|
||||
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
|
||||
p.grad.data = p.grad.data + grads[n] * lamb
|
||||
|
||||
return grads
|
||||
@@ -23,7 +23,6 @@ import logging
|
||||
import sys
|
||||
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||
@@ -83,9 +82,7 @@ class LigerPlugin(BasePlugin):
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||
elif cfg.model_config_type == "deepseek_v2":
|
||||
@@ -109,8 +106,6 @@ class LigerPlugin(BasePlugin):
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||
if cfg.liger_cross_entropy:
|
||||
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||
# nn.CrossEntropyLoss in the forward method.
|
||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||
|
||||
231
src/axolotl/monkeypatch/fastchat_conversation_turns.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
monkeypatch to add a get_turns method
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from fastchat.conversation import SeparatorStyle
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
|
||||
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
ret = ""
|
||||
for role, msg in self.get_turns():
|
||||
ret += role + msg
|
||||
return ret
|
||||
|
||||
|
||||
def get_turns( # pylint: disable=too-many-return-statements
|
||||
self,
|
||||
) -> Generator[Tuple[str, str], None, None]:
|
||||
"""Get the prompt for generation."""
|
||||
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + ": ", message + seps[i % 2]
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ": ", "" # must be end with a space
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
||||
yield "", "" if system_prompt == "" else system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", message + self.sep
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role, message + self.sep
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role, message + seps[i % 2]
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.RWKV:
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + ": ", message.replace("\r\n", "\n").replace(
|
||||
"\n\n", "\n"
|
||||
) + "\n\n"
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
|
||||
if self.system_message:
|
||||
if self.messages:
|
||||
# For llama, the system message is incorporated into the first human instruction
|
||||
first_role, first_msg = self.messages[0]
|
||||
if first_role == self.roles[0]:
|
||||
system_prompt += first_msg
|
||||
self.messages.pop(0)
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
if (i % 2 == 0 and not self.system_message) or (
|
||||
i % 2 != 0 and self.system_message
|
||||
):
|
||||
role = "<s> " + role
|
||||
yield role + " ", message
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
|
||||
contains_sys_msg = False
|
||||
if self.system_message:
|
||||
contains_sys_msg = True
|
||||
if self.messages:
|
||||
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline
|
||||
first_role, first_msg = self.messages[0]
|
||||
if first_role == self.roles[0]:
|
||||
system_prompt = self.system_template.format(
|
||||
system_message=" " + self.system_message
|
||||
)
|
||||
system_prompt += first_msg
|
||||
self.messages.pop(0)
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message and i == 0 and not contains_sys_msg:
|
||||
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
|
||||
elif message:
|
||||
yield role + " ", message
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA3:
|
||||
if self.system_message:
|
||||
# For llama3, the system message is NOT incorporated into the first human instruction
|
||||
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
|
||||
else:
|
||||
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.GEMMA:
|
||||
if self.system_message:
|
||||
raise ValueError("Gemma chat template does not support system messages")
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
prefix = "<bos>" if i == 0 else ""
|
||||
message_str = message if message else ""
|
||||
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||
round_add_n = 1 if self.name == "chatglm2" else 0
|
||||
if system_prompt:
|
||||
yield "", system_prompt + self.sep
|
||||
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if i % 2 == 0:
|
||||
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
|
||||
|
||||
if message:
|
||||
yield f"{role}:", f"{message}{self.sep}"
|
||||
else:
|
||||
yield f"{role}:", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATML:
|
||||
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", message + self.sep + "\n"
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATGLM3:
|
||||
if self.system_message:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", " " + message
|
||||
else:
|
||||
yield role
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATINTERN:
|
||||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
prefix = "<s>" if i % 2 == 0 else ""
|
||||
if message:
|
||||
yield prefix + role + ":", message + seps[i % 2] + "\n"
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.DOLLY:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
suffix = "\n\n" if i % 2 == 1 else ""
|
||||
yield role + ":\n", message + seps[i % 2] + suffix
|
||||
else:
|
||||
yield role + ":\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.PHOENIX:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", "<s>" + message + "</s>"
|
||||
else:
|
||||
yield role + ": " + "<s>", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ROBIN:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ":\n", message + self.sep
|
||||
else:
|
||||
yield role + ":\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.FALCON_CHAT:
|
||||
if self.system_message:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ":", ""
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
|
||||
def add_get_turns_to_conversation():
|
||||
import fastchat.conversation
|
||||
|
||||
fastchat.conversation.Conversation.get_turns = get_turns
|
||||
fastchat.conversation.Conversation.get_prompt = get_prompt
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -93,32 +94,13 @@ def replace_llama_qkv_with_fused(model):
|
||||
set_module_name(model, name, qkv)
|
||||
|
||||
|
||||
def patch_fa_llama_cross_entropy():
|
||||
LOG.info(
|
||||
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
|
||||
)
|
||||
from flash_attn.ops.triton.cross_entropy import (
|
||||
cross_entropy_loss as flash_attn_cross_entropy_loss,
|
||||
)
|
||||
def patch_llama_cross_entropy():
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
def fa2_fixed_cross_entropy(
|
||||
source,
|
||||
target,
|
||||
num_items_in_batch: int = None,
|
||||
ignore_index: int = -100,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||
loss, _ = flash_attn_cross_entropy_loss(
|
||||
source, target, ignore_index=ignore_index
|
||||
)
|
||||
if reduction == "sum":
|
||||
loss = loss.sum() / num_items_in_batch
|
||||
else:
|
||||
loss = loss.sum() / (target != ignore_index).sum()
|
||||
return loss
|
||||
|
||||
transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
|
||||
|
||||
def patch_llama_rms_norm():
|
||||
@@ -165,7 +147,7 @@ def replace_llama_attn_with_flash_attn(
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if cross_entropy:
|
||||
patch_fa_llama_cross_entropy()
|
||||
patch_llama_cross_entropy()
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if rms_norm:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""multipack patching for v2 of sample packing"""
|
||||
|
||||
import importlib
|
||||
|
||||
import transformers
|
||||
@@ -28,28 +27,71 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||
if has_remote_code:
|
||||
patch_remote(model_name)
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
|
||||
if model_type == "gemmoe":
|
||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||
elif model_type == "deepseek_v2":
|
||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
return
|
||||
|
||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
# retain for legacy
|
||||
if model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
elif model_type == "llama":
|
||||
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "mistral":
|
||||
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "qwen2_moe":
|
||||
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "phi":
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma":
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma2":
|
||||
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "starcoder2":
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
|
||||
def patch_remote(model_name):
|
||||
def patch_remote(model_name, config_name, modeling_name):
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_* to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
parts = model_config.__class__.__module__.split(".")
|
||||
parts[-1] = parts[-1].replace("configuration_", "modeling_", 1)
|
||||
module_name = ".".join(parts)
|
||||
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
|
||||
modeling_arch = importlib.import_module(module_name)
|
||||
if hasattr(modeling_arch, "_get_unpad_data"):
|
||||
modeling_arch._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
||||
|
||||
@@ -46,10 +46,9 @@ def reset_optimizer(
|
||||
*,
|
||||
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: List[str],
|
||||
optimizer_magnitude_pruning: float = 0.9,
|
||||
prune_ratio: float = 0.9,
|
||||
):
|
||||
# pylint:disable=unused-argument
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
||||
n_zeros = 0
|
||||
n_total = 0
|
||||
|
||||
@@ -57,22 +56,16 @@ def reset_optimizer(
|
||||
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
||||
optimizer_state = optimizer.optim.state
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
state = optimizer_state[param]
|
||||
for key, value in state.items():
|
||||
if key not in optimizer_state_keys:
|
||||
continue
|
||||
if torch.is_tensor(value):
|
||||
try:
|
||||
pruning_fn(value)
|
||||
n_total += value.numel()
|
||||
n_zeros += torch.sum(value == 0).item()
|
||||
except RuntimeError as exc:
|
||||
if "quantile() input tensor is too large" in str(exc):
|
||||
pass
|
||||
else:
|
||||
raise exc
|
||||
for param in reset_params:
|
||||
param_state = optimizer_state[param]
|
||||
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
|
||||
continue
|
||||
for key in optimizer_state_keys:
|
||||
pruning_fn(
|
||||
param_state[key]
|
||||
) # pruning fn has to be inplace to keep the same keys in the dict
|
||||
n_total += param_state[key].numel()
|
||||
n_zeros += torch.sum(param_state[key] == 0).item()
|
||||
|
||||
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
||||
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
||||
@@ -136,9 +129,6 @@ class ReLoRACallback(TrainerCallback):
|
||||
|
||||
if "adam" in args.optim.lower():
|
||||
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
|
||||
if "8bit" in args.optim.lower():
|
||||
optimizer_state_keys.append("state1")
|
||||
optimizer_state_keys.append("state2")
|
||||
else:
|
||||
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
||||
|
||||
@@ -170,7 +160,7 @@ class ReLoRACallback(TrainerCallback):
|
||||
optimizer,
|
||||
reset_params=lora_params,
|
||||
optimizer_state_keys=optimizer_state_keys,
|
||||
optimizer_magnitude_pruning=args.relora_prune_ratio,
|
||||
prune_ratio=args.relora_prune_ratio,
|
||||
)
|
||||
|
||||
if self.quantized:
|
||||
|
||||
@@ -188,7 +188,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
||||
for module in layer_modules
|
||||
)
|
||||
mlp_not_dora = all(
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
getattr(module, "lora_magnitude_vector", None) is None
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
@@ -213,7 +213,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
for module in layer_modules
|
||||
)
|
||||
qkv_not_dora = all(
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
getattr(module, "lora_magnitude_vector", None) is None
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
@@ -232,7 +232,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
for module in layer_modules
|
||||
)
|
||||
o_not_dora = all(
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
getattr(module, "lora_magnitude_vector", None) is None
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
|
||||
33
src/axolotl/prompt_strategies/instruct.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
strategy = InstructShareGPTPromptTokenizingStrategy(
|
||||
# pylint: disable=duplicate-code
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
return strategy
|
||||
|
||||
|
||||
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return [
|
||||
{"from": "human", "value": prompt["instruction"]},
|
||||
{"from": "gpt", "value": prompt["output"]},
|
||||
]
|
||||
@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Generator, List, Sequence
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -75,7 +75,7 @@ class Llama2ChatConversation:
|
||||
|
||||
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for Llama2 prompts.
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
|
||||
"""
|
||||
|
||||
@@ -191,7 +191,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
|
||||
conv.messages = [] # pylint: disable=R0801
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
assert role == conv.roles[j % 2], ALTERNATING_ASSERTION_FAILED_ROLE
|
||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
if sentence["value"]:
|
||||
conv.append_message(role, sentence["value"])
|
||||
yield conv
|
||||
|
||||
223
src/axolotl/prompt_strategies/sharegpt.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
from axolotl.utils.tokenization import (
|
||||
chatml_to_conversation,
|
||||
merge_consecutive_messages,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def register_chatml_template(system_message=None):
|
||||
system_message = system_message or "You are a helpful assistant."
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="chatml",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message=system_message,
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="chatml_glaive",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message=system_message,
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def register_llama3_template(system_message=None):
|
||||
system_message = system_message or "You are a helpful assistant."
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="llama3",
|
||||
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
||||
system_message=system_message,
|
||||
roles=("user", "assistant"),
|
||||
sep_style=SeparatorStyle.LLAMA3,
|
||||
sep="",
|
||||
stop_str="<|eot_id|>",
|
||||
stop_token_ids=[128001, 128009],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_loader(
|
||||
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
||||
prompter_cls: Type["ShareGPTPrompterV2"],
|
||||
default_conversation: Optional[str] = None,
|
||||
):
|
||||
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
LOG.warning(
|
||||
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
|
||||
)
|
||||
conversation = (
|
||||
ds_cfg["conversation"]
|
||||
if ds_cfg and "conversation" in ds_cfg
|
||||
else default_conversation
|
||||
)
|
||||
field_human = (
|
||||
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
)
|
||||
field_model = (
|
||||
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
)
|
||||
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
||||
strategy = tokenization_strategy_cls(
|
||||
prompter_cls(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
role_key_human=field_human,
|
||||
roles=roles,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
|
||||
strategy.strict = ds_cfg["strict"]
|
||||
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
return strategy
|
||||
|
||||
return _load
|
||||
|
||||
|
||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
_strict = False
|
||||
_messages = "conversations"
|
||||
|
||||
@property
|
||||
def strict(self):
|
||||
return self._strict
|
||||
|
||||
@strict.setter
|
||||
def strict(self, strict):
|
||||
self._strict = strict
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
return self._messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, messages):
|
||||
self._messages = messages
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt[self.messages]
|
||||
if self.strict:
|
||||
return conversations
|
||||
role_key = "from"
|
||||
if "role" in conversations[0].keys():
|
||||
role_key = "role"
|
||||
value_key = "value"
|
||||
if "text" in conversations[0].keys():
|
||||
value_key = "text"
|
||||
elif "content" in conversations[0].keys():
|
||||
value_key = "content"
|
||||
# remap roles - allow for assistant turn"
|
||||
role_map = {
|
||||
"user": "human",
|
||||
"human": "human",
|
||||
"assistant": "gpt",
|
||||
"gpt": "gpt",
|
||||
"system": "system",
|
||||
}
|
||||
turns = [
|
||||
{
|
||||
"from": (
|
||||
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
|
||||
),
|
||||
"value": t[value_key],
|
||||
"weight": 1
|
||||
if "weight" not in t or t["weight"] is None
|
||||
else t["weight"],
|
||||
}
|
||||
for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||
SimpleShareGPTPromptTokenizingStrategy
|
||||
):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
||||
return turns
|
||||
|
||||
|
||||
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps oasst data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps ultrachat data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["messages"]
|
||||
role_map = {"user": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps glaive data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversation = chatml_to_conversation(prompt)
|
||||
conversation = merge_consecutive_messages(conversation)
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||
load_ultrachat = build_loader(
|
||||
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
|
||||
)
|
||||
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||
load_glaive = build_loader(
|
||||
GlaiveShareGPTPromptTokenizingStrategy,
|
||||
ShareGPTPrompterV2,
|
||||
default_conversation="chatml_glaive",
|
||||
)
|
||||
28
src/axolotl/prompt_strategies/sharegpt_jokes.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Module for Jokes prompts using sharegpt style """
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenization strategy for asking bot to tell a joke and then explain why its funny
|
||||
"""
|
||||
|
||||
# title, text, explanation
|
||||
def get_conversation_thread(self, prompt):
|
||||
title = "" if not prompt["title"] else prompt["title"] + " "
|
||||
return [
|
||||
{"from": "human", "value": "Tell me a joke."},
|
||||
{"from": "gpt", "value": title + prompt["text"]},
|
||||
{"from": "human", "value": "Why is that joke funny?"},
|
||||
{"from": "gpt", "value": prompt["explanation"]},
|
||||
]
|
||||
@@ -1,12 +1,17 @@
|
||||
"""Module containing PromptTokenizingStrategy and Prompter classes"""
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
||||
add_get_turns_to_conversation,
|
||||
)
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -16,6 +21,8 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
||||
|
||||
add_get_turns_to_conversation()
|
||||
|
||||
|
||||
class InvalidDataException(Exception):
|
||||
"""
|
||||
@@ -324,6 +331,154 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# Initial values. We will append to these as we go through the conversation.
|
||||
result, current_len = tokenize_prompt_default()
|
||||
conversation: Conversation = (
|
||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||
)
|
||||
|
||||
input_roles = {conversation.roles[0]}
|
||||
output_roles = {conversation.roles[1]}
|
||||
|
||||
if len(conversation.roles) == 3:
|
||||
tool_role_label = conversation.roles[2]
|
||||
input_roles.add(tool_role_label)
|
||||
|
||||
# Add roles from the config
|
||||
if self.prompter.roles:
|
||||
if "input" in self.prompter.roles and self.prompter.roles["input"]:
|
||||
for role in self.prompter.roles["input"]:
|
||||
input_roles.add(role)
|
||||
|
||||
if "output" in self.prompter.roles and self.prompter.roles["output"]:
|
||||
for role in self.prompter.roles["output"]:
|
||||
output_roles.add(role)
|
||||
|
||||
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
||||
role_remap = []
|
||||
if (
|
||||
conversation.name == "vicuna_v1.1"
|
||||
and "roles" in prompt
|
||||
and len(prompt["roles"]) >= 2
|
||||
):
|
||||
role_remap = [
|
||||
{"from": conversation.roles[0], "to": prompt["roles"][0]},
|
||||
{"from": conversation.roles[1], "to": prompt["roles"][1]},
|
||||
]
|
||||
|
||||
try:
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||
):
|
||||
if not isinstance(part, tuple):
|
||||
LOG.warning(f"expected tuple, got {part}")
|
||||
continue
|
||||
|
||||
if len(part) <= 2:
|
||||
role, content = part
|
||||
weight = 1
|
||||
else:
|
||||
role, content, weight = part
|
||||
|
||||
# Uses "in" because role contains extra characters
|
||||
input_turn = any(r.lower() in role.lower() for r in input_roles)
|
||||
output_turn = any(r.lower() in role.lower() for r in output_roles)
|
||||
empty_role = role.strip() == ""
|
||||
|
||||
if not any([input_turn, output_turn, empty_role]):
|
||||
LOG.warning(f"unhandled role: {role}")
|
||||
continue
|
||||
|
||||
if input_turn:
|
||||
role = (
|
||||
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||
if role_remap
|
||||
else role
|
||||
)
|
||||
turn = role + content
|
||||
# this is still the user query, we should
|
||||
if not content.strip():
|
||||
LOG.warning(f"user turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if self.train_on_inputs and weight == 1:
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
else:
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif output_turn:
|
||||
role = (
|
||||
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||
if role_remap
|
||||
else role
|
||||
)
|
||||
turn = role + content
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not content.strip():
|
||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||
add_eos_token = not (
|
||||
conversation.name == "chatml"
|
||||
and conversation.sep == self.tokenizer.eos_token
|
||||
)
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=add_eos_token,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
role_res = self._tokenize(
|
||||
role.rstrip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
if not self.train_on_inputs:
|
||||
# mask out role tokens from the labels
|
||||
len_role = len(role_res["input_ids"])
|
||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
||||
len_role, len(labels)
|
||||
)
|
||||
if weight == 0:
|
||||
# everything from this is masked out from the labels
|
||||
# (role is masked out too because it makes no sense if contents is masked out)
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
|
||||
elif empty_role:
|
||||
turn = content
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
if self.train_on_inputs and weight == 1:
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
else:
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
return result
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
|
||||
|
||||
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
||||
"""
|
||||
Returns the default values for the tokenize prompt function
|
||||
|
||||
@@ -5,6 +5,7 @@ from enum import Enum
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from colorama import Fore
|
||||
from fastchat.conversation import Conversation, get_conv_template
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
IGNORE_TOKEN_ID = -100
|
||||
@@ -261,10 +262,166 @@ class ReflectAlpacaPrompter(Prompter):
|
||||
)
|
||||
|
||||
|
||||
ALTERNATING_ASSERTION_FAILED_ROLE = (
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||
)
|
||||
|
||||
CONVERSATION_ROLE_FORMAT = {
|
||||
"chatml": "<|im_start|>{ROLE}",
|
||||
"zephyr": "<|{ROLE}|>",
|
||||
"vicuna_v1.1": "{ROLE}",
|
||||
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
|
||||
}
|
||||
|
||||
|
||||
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
role_key_human = "human"
|
||||
role_key_model = "gpt"
|
||||
# Optional, only used for tool usage datasets.
|
||||
role_key_tool: Optional[str] = None
|
||||
# Optional, role input/output mapping
|
||||
roles: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_style=None, # pylint: disable=unused-argument
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
role_key_tool: Optional[str] = None,
|
||||
roles: Optional[dict] = None,
|
||||
):
|
||||
if conversation:
|
||||
if isinstance(conversation, Conversation):
|
||||
self._conversation = conversation
|
||||
else:
|
||||
self._conversation = get_conv_template(conversation)
|
||||
else:
|
||||
self._conversation = get_conv_template("vicuna_v1.1")
|
||||
if role_key_human:
|
||||
self.role_key_human = role_key_human
|
||||
if role_key_model:
|
||||
self.role_key_model = role_key_model
|
||||
if role_key_tool:
|
||||
self.role_key_tool = role_key_tool
|
||||
if roles:
|
||||
self.roles = roles
|
||||
|
||||
def _build_result(self, source):
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
raise IndexError(
|
||||
f"A conversation entry has less than 2 messages :\n{source}"
|
||||
)
|
||||
|
||||
conv = self._conversation.copy()
|
||||
|
||||
original_source = source.copy()
|
||||
# Add the conversation system prompt if provided, otherwise use the default one
|
||||
if source[0]["from"] == "system":
|
||||
conv.set_system_message(source[0]["value"])
|
||||
source.pop(0)
|
||||
|
||||
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
||||
if self.role_key_tool:
|
||||
roles[self.role_key_tool] = conv.roles[2]
|
||||
|
||||
try:
|
||||
# Apply prompt templates
|
||||
if source[0]["from"] not in roles:
|
||||
# Skip the first one if it is not from human
|
||||
source = source[1:]
|
||||
except IndexError as err:
|
||||
# sometimes there is a bing or system chat
|
||||
raise err
|
||||
|
||||
conv.messages = []
|
||||
for _, sentence in enumerate(source):
|
||||
from_role = sentence["from"]
|
||||
if from_role in roles:
|
||||
role = roles[from_role]
|
||||
else:
|
||||
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
|
||||
raise NotImplementedError(
|
||||
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
|
||||
"Please help us by creating an Issue to add support for this conversation type."
|
||||
)
|
||||
|
||||
if self._conversation.name in ["llama3"]:
|
||||
role = from_role
|
||||
else:
|
||||
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
|
||||
ROLE=from_role
|
||||
)
|
||||
|
||||
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
||||
if (
|
||||
role != "assistant"
|
||||
): # back to back assistant calls may be okay for tool calls
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
|
||||
conv.append_message(role, sentence["value"])
|
||||
turns = list(conv.get_turns())
|
||||
original_source_length = len(original_source)
|
||||
assert len(turns) in [
|
||||
original_source_length - 1,
|
||||
original_source_length,
|
||||
original_source_length + 1,
|
||||
]
|
||||
if len(turns) == original_source_length + 1:
|
||||
original_source = [{"weight": None}] + original_source
|
||||
elif len(turns) == original_source_length - 1:
|
||||
original_source = original_source[1:]
|
||||
return [
|
||||
(*turn, weight)
|
||||
for turn, weight in zip(
|
||||
turns,
|
||||
[
|
||||
1 if "weight" not in e or e["weight"] is None else e["weight"]
|
||||
for e in original_source
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
turns = self._build_result(source)
|
||||
|
||||
for part in turns:
|
||||
if part[0] and not part[1]:
|
||||
LOG.warning(f"role with empty message: {part[0]}")
|
||||
yield part
|
||||
|
||||
def __repr__(self) -> str:
|
||||
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
||||
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
||||
|
||||
|
||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
"""
|
||||
A V2 prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
role_key_tool: Optional[str] = None,
|
||||
roles: Optional[dict] = None,
|
||||
):
|
||||
super().__init__(
|
||||
conversation=conversation,
|
||||
role_key_human=role_key_human,
|
||||
role_key_model=role_key_model,
|
||||
role_key_tool=role_key_tool,
|
||||
roles=roles,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedPrompter(Prompter):
|
||||
"""
|
||||
|
||||
@@ -259,31 +259,11 @@ def train(
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
try:
|
||||
# Check to make sure the base model is from HuggingFace not a local directory
|
||||
hf_api = HfApi()
|
||||
hf_api.model_info(cfg.base_model)
|
||||
|
||||
model_card_kwarg = {
|
||||
"model_name": cfg.output_dir.lstrip("./")
|
||||
.encode("utf-8")
|
||||
.decode("utf-8")
|
||||
}
|
||||
if cfg.datasets is not None:
|
||||
if cfg.rl is not None or cfg.reward_model:
|
||||
model_card_kwarg["dataset_name"] = [
|
||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
else:
|
||||
model_card_kwarg["dataset_tags"] = [
|
||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
|
||||
trainer.create_model_card(**model_card_kwarg)
|
||||
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
|
||||
trainer.create_model_card(
|
||||
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
|
||||
)
|
||||
except (AttributeError, UnicodeDecodeError):
|
||||
pass
|
||||
elif cfg.hub_model_id:
|
||||
# defensively push to the hub to ensure the model card is updated
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
"""
|
||||
Basic utils for Axolotl
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def is_mlflow_available():
|
||||
@@ -14,23 +10,3 @@ def is_mlflow_available():
|
||||
|
||||
def is_comet_available():
|
||||
return importlib.util.find_spec("comet_ml") is not None
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def get_pytorch_version() -> tuple[int, int, int]:
|
||||
"""
|
||||
Get Pytorch version as a tuple of (major, minor, patch).
|
||||
"""
|
||||
torch_version = torch.__version__
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
|
||||
if not version_match:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
major, minor, patch = version_match.groups()
|
||||
major, minor = int(major), int(minor)
|
||||
patch = int(patch) if patch is not None else 0 # Default patch to 0 if not present
|
||||
return major, minor, patch
|
||||
|
||||
|
||||
# pylint: enable=duplicate-code
|
||||
|
||||
@@ -1,23 +1,9 @@
|
||||
"""Benchmarking and measurement utilities"""
|
||||
import functools
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.distributed import get_device_type
|
||||
|
||||
try:
|
||||
from pynvml import (
|
||||
NVMLError,
|
||||
nvmlDeviceGetHandleByIndex,
|
||||
nvmlDeviceGetMemoryInfo,
|
||||
nvmlInit,
|
||||
)
|
||||
except ImportError:
|
||||
NVMLError = None
|
||||
nvmlDeviceGetHandleByIndex = None
|
||||
nvmlDeviceGetMemoryInfo = None
|
||||
nvmlInit = None
|
||||
from pynvml.nvml import NVMLError
|
||||
|
||||
|
||||
def check_cuda_device(default_value):
|
||||
@@ -67,35 +53,24 @@ def mps_memory_usage_all():
|
||||
return usage, reserved - usage, 0
|
||||
|
||||
|
||||
def npu_memory_usage_all(device=0):
|
||||
usage = torch.npu.memory_allocated(device) / 1024.0**3
|
||||
reserved = torch.npu.memory_reserved(device) / 1024.0**3
|
||||
return usage, reserved - usage, 0
|
||||
|
||||
|
||||
@check_cuda_device(0.0)
|
||||
def gpu_memory_usage_smi(device=0):
|
||||
if isinstance(device, torch.device):
|
||||
device = device.index
|
||||
if isinstance(device, str) and device.startswith("cuda:"):
|
||||
device = int(device[5:])
|
||||
if not nvmlInit:
|
||||
return 0.0
|
||||
try:
|
||||
nvmlInit()
|
||||
handle = nvmlDeviceGetHandleByIndex(device)
|
||||
info = nvmlDeviceGetMemoryInfo(handle)
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
return info.used / 1024.0**3
|
||||
except NVMLError:
|
||||
return 0.0
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
cur_device = get_device_type()
|
||||
if torch.backends.mps.is_available():
|
||||
usage, cache, misc = mps_memory_usage_all()
|
||||
elif "npu" in str(cur_device) and is_torch_npu_available():
|
||||
usage, cache, misc = npu_memory_usage_all(device)
|
||||
else:
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
extras = []
|
||||
@@ -104,7 +79,6 @@ def log_gpu_memory_usage(log, msg, device):
|
||||
if misc > 0:
|
||||
extras.append(f"+{misc:.03f}GB misc")
|
||||
log.info(
|
||||
f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})",
|
||||
stacklevel=2,
|
||||
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
|
||||
)
|
||||
return usage, cache, misc
|
||||
|
||||
@@ -28,7 +28,6 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
from trl.models import unwrap_model_for_generation
|
||||
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
@@ -47,7 +46,6 @@ from axolotl.utils.distributed import (
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
@@ -380,10 +378,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
for metric in self.cfg.eval_causal_lm_metrics:
|
||||
if metric == "perplexity":
|
||||
max_seq_len = self.cfg.eval_max_new_tokens
|
||||
metrics[metric] = Perplexity(
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
|
||||
else:
|
||||
try:
|
||||
metrics[metric] = evaluate.load(metric)
|
||||
@@ -400,11 +395,8 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
eval_dataloader,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
trainer.model_wrapped.eval()
|
||||
|
||||
device = torch.device(
|
||||
self.cfg.device
|
||||
) # Use this instead of trainer.model_wrapped.device as it may return cpu if fsdp offloaded
|
||||
trainer.model.eval()
|
||||
device = torch.device(self.cfg.device)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
generation_config = GenerationConfig(
|
||||
@@ -441,10 +433,6 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
for k in metric._feature_names() # pylint: disable=protected-access
|
||||
if k in kwargs
|
||||
}
|
||||
|
||||
if isinstance(metric, Perplexity):
|
||||
metric_kwargs["model"] = trainer.model_wrapped
|
||||
|
||||
metric_score = metric.compute(**metric_kwargs)
|
||||
return (
|
||||
metric_score["score"]
|
||||
@@ -480,97 +468,89 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
def predict_with_generate():
|
||||
eval_src, eval_pred, eval_ref = [], [], []
|
||||
|
||||
with unwrap_model_for_generation(
|
||||
trainer.model_wrapped, trainer.accelerator
|
||||
) as unwrapped_model:
|
||||
for batch in tqdm(eval_dataloader, disable=not is_main_process()):
|
||||
batch_labels = batch["labels"].to(device)
|
||||
batch_input_ids = batch["input_ids"].to(device)
|
||||
for batch in tqdm(eval_dataloader):
|
||||
batch_labels = batch["labels"].to(device)
|
||||
batch_input_ids = batch["input_ids"].to(device)
|
||||
|
||||
if "position_ids" in batch:
|
||||
batch_pos_ids = batch["position_ids"].tolist()
|
||||
if "position_ids" in batch:
|
||||
batch_pos_ids = batch["position_ids"].tolist()
|
||||
else:
|
||||
batch_pos_ids = [None] * len(batch["input_ids"])
|
||||
|
||||
prompt_token_ids_list = []
|
||||
completion_token_ids_list = []
|
||||
|
||||
for input_ids_all, labels_all, pos_ids in zip(
|
||||
batch_input_ids,
|
||||
batch_labels,
|
||||
batch_pos_ids,
|
||||
):
|
||||
if pos_ids is None:
|
||||
pos_ranges = [(0, len(input_ids_all) - 1)]
|
||||
else:
|
||||
batch_pos_ids = [None] * len(batch["input_ids"])
|
||||
pos_ranges = find_ranges(pos_ids)
|
||||
|
||||
prompt_token_ids_list = []
|
||||
completion_token_ids_list = []
|
||||
for pos_range in pos_ranges:
|
||||
start, end = pos_range
|
||||
if start == end:
|
||||
continue
|
||||
|
||||
for input_ids_all, labels_all, pos_ids in zip(
|
||||
batch_input_ids,
|
||||
batch_labels,
|
||||
batch_pos_ids,
|
||||
):
|
||||
if pos_ids is None:
|
||||
pos_ranges = [(0, len(input_ids_all) - 1)]
|
||||
else:
|
||||
pos_ranges = find_ranges(pos_ids)
|
||||
input_ids = input_ids_all[start : end + 1]
|
||||
labels = labels_all[start : end + 1]
|
||||
|
||||
for pos_range in pos_ranges:
|
||||
start, end = pos_range
|
||||
if start == end:
|
||||
continue
|
||||
|
||||
input_ids = input_ids_all[start : end + 1]
|
||||
labels = labels_all[start : end + 1]
|
||||
|
||||
tokens_without_loss = labels == IGNORE_INDEX
|
||||
tokens_with_loss = labels != IGNORE_INDEX
|
||||
tokens_exclude_padding = (
|
||||
input_ids != tokenizer.pad_token_id
|
||||
)
|
||||
prompt_token_includes = (
|
||||
tokens_without_loss & tokens_exclude_padding
|
||||
)
|
||||
|
||||
prompt_token_ids = input_ids[prompt_token_includes]
|
||||
prompt_token_ids_list.append(prompt_token_ids)
|
||||
|
||||
completion_token_ids = input_ids[tokens_with_loss]
|
||||
completion_token_ids_list.append(completion_token_ids)
|
||||
|
||||
prompt_texts = tokenizer.batch_decode(
|
||||
prompt_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
completion_texts = tokenizer.batch_decode(
|
||||
completion_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_encoding = tokenizer(
|
||||
prompt_texts, padding=True, return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
predictions = unwrapped_model.generate(
|
||||
**prompt_encoding, generation_config=generation_config
|
||||
tokens_without_loss = labels == IGNORE_INDEX
|
||||
tokens_with_loss = labels != IGNORE_INDEX
|
||||
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
|
||||
prompt_token_includes = (
|
||||
tokens_without_loss & tokens_exclude_padding
|
||||
)
|
||||
|
||||
del prompt_encoding
|
||||
prompt_token_ids = input_ids[prompt_token_includes]
|
||||
prompt_token_ids_list.append(prompt_token_ids)
|
||||
|
||||
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
||||
prediction_without_prompt_tokens_list = []
|
||||
for prompt_token_ids, prediction_tokens in zip(
|
||||
prompt_token_ids_list, prediction_all_tokens
|
||||
):
|
||||
prediction_without_prompt_tokens = prediction_tokens[
|
||||
len(prompt_token_ids) :
|
||||
]
|
||||
prediction_without_prompt_tokens_list.append(
|
||||
prediction_without_prompt_tokens
|
||||
)
|
||||
completion_token_ids = input_ids[tokens_with_loss]
|
||||
completion_token_ids_list.append(completion_token_ids)
|
||||
|
||||
predicted_texts = tokenizer.batch_decode(
|
||||
prediction_without_prompt_tokens_list,
|
||||
skip_special_tokens=True,
|
||||
prompt_texts = tokenizer.batch_decode(
|
||||
prompt_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
completion_texts = tokenizer.batch_decode(
|
||||
completion_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_encoding = tokenizer(
|
||||
prompt_texts, padding=True, return_tensors="pt"
|
||||
).to(self.cfg.device)
|
||||
predictions = trainer.model.generate(
|
||||
**prompt_encoding, generation_config=generation_config
|
||||
)
|
||||
|
||||
eval_src.extend(prompt_texts)
|
||||
eval_pred.extend(predicted_texts)
|
||||
eval_ref.extend(completion_texts)
|
||||
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
||||
prediction_without_prompt_tokens_list = []
|
||||
for prompt_token_ids, prediction_tokens in zip(
|
||||
prompt_token_ids_list, prediction_all_tokens
|
||||
):
|
||||
prediction_without_prompt_tokens = prediction_tokens[
|
||||
len(prompt_token_ids) :
|
||||
]
|
||||
prediction_without_prompt_tokens_list.append(
|
||||
prediction_without_prompt_tokens
|
||||
)
|
||||
|
||||
predicted_texts = tokenizer.batch_decode(
|
||||
prediction_without_prompt_tokens_list, skip_special_tokens=True
|
||||
)
|
||||
|
||||
eval_src.extend(prompt_texts)
|
||||
eval_pred.extend(predicted_texts)
|
||||
eval_ref.extend(completion_texts)
|
||||
|
||||
return eval_src, eval_pred, eval_ref
|
||||
|
||||
eval_preds = predict_with_generate()
|
||||
trainer.log(evaluate_preds(*eval_preds))
|
||||
if is_main_process():
|
||||
eval_preds = predict_with_generate()
|
||||
trainer.log(evaluate_preds(*eval_preds))
|
||||
|
||||
return control
|
||||
|
||||
|
||||
@@ -8,8 +8,6 @@ from transformers.modeling_outputs import CausalLMOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
|
||||
class Perplexity:
|
||||
"""
|
||||
@@ -19,13 +17,16 @@ class Perplexity:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_seq_len: int,
|
||||
stride: int = 512,
|
||||
) -> None:
|
||||
self.max_seq_len = max_seq_len
|
||||
self.stride = stride
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = model.device
|
||||
self.name = "perplexity"
|
||||
|
||||
def _feature_names(self) -> List[str]:
|
||||
@@ -33,7 +34,6 @@ class Perplexity:
|
||||
|
||||
def compute(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
references: Optional[List[str]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
@@ -41,21 +41,17 @@ class Perplexity:
|
||||
"""
|
||||
assert references is not None, "Missing parameter: references"
|
||||
|
||||
model.eval()
|
||||
|
||||
references_tokenized = self.tokenizer(
|
||||
references, return_tensors="pt", padding=True, truncation=True
|
||||
)
|
||||
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
|
||||
input_ids = input_ids.to(model.device)
|
||||
input_ids = input_ids.to(self.device)
|
||||
|
||||
sequence_length = input_ids.size(1)
|
||||
|
||||
losses = []
|
||||
prev_end_loc = 0
|
||||
for begin_loc in tqdm(
|
||||
range(0, sequence_length, self.stride), disable=not is_main_process()
|
||||
):
|
||||
for begin_loc in tqdm(range(0, sequence_length, self.stride)):
|
||||
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
|
||||
trg_len = end_loc - prev_end_loc
|
||||
input_ids_slice = input_ids[:, begin_loc:end_loc]
|
||||
@@ -63,7 +59,7 @@ class Perplexity:
|
||||
labels_slice[:, :-trg_len] = -100
|
||||
|
||||
with torch.no_grad():
|
||||
outputs: CausalLMOutput = model(
|
||||
outputs: CausalLMOutput = self.model(
|
||||
input_ids=input_ids_slice, labels=labels_slice
|
||||
)
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""
|
||||
Collators for multi-modal chat messages and packing
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||
@@ -22,7 +20,6 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
processor: ProcessorMixin
|
||||
return_tensors: str = "pt"
|
||||
chat_template: Optional[str] = None
|
||||
chat_template_type: Optional[str] = None
|
||||
packing: bool = False
|
||||
max_images: int = -1
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
@@ -33,190 +30,38 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
raise ValueError("Packing is currently not supported.")
|
||||
|
||||
def torch_call(
|
||||
self, examples: list[Union[list[int], Any, dict[str, Any]]]
|
||||
) -> dict[str, Any]:
|
||||
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
|
||||
) -> Dict[str, Any]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
|
||||
return self.__class__.process_rows(
|
||||
examples,
|
||||
self.processor,
|
||||
self.chat_template,
|
||||
self.max_images,
|
||||
chat_template_type=self.chat_template_type,
|
||||
examples, self.processor, self.chat_template, self.max_images
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def preprocess(examples: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Preprocess conversation examples to ensure consistent format.
|
||||
Converts different conversation formats to OpenAI format with 'messages'.
|
||||
Supports two formats:
|
||||
1. OpenAI format with 'messages'
|
||||
2. Legacy format with 'conversations'
|
||||
|
||||
Args:
|
||||
examples: list of conversation dictionaries
|
||||
Returns:
|
||||
dict in OpenAI format with 'messages' key
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation format is not supported
|
||||
"""
|
||||
role_mapping = {
|
||||
"human": "user",
|
||||
"gpt": "assistant",
|
||||
}
|
||||
|
||||
def normalize_role(role: str) -> str:
|
||||
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
||||
return role_mapping.get(role, role)
|
||||
|
||||
def convert_legacy_format(example: dict) -> dict:
|
||||
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
||||
messages = [
|
||||
{
|
||||
"role": normalize_role(convo["from"]),
|
||||
"content": convo["value"],
|
||||
}
|
||||
for convo in example["conversations"]
|
||||
]
|
||||
|
||||
# Create new dict without 'conversations' key
|
||||
result = deepcopy(example)
|
||||
result.pop("conversations")
|
||||
return {"messages": messages, **result}
|
||||
|
||||
processed_examples = []
|
||||
for example in examples:
|
||||
# OpenAI format
|
||||
if "messages" in example:
|
||||
processed_examples.append(example)
|
||||
|
||||
# Legacy format
|
||||
elif "conversations" in example:
|
||||
processed_examples.append(convert_legacy_format(example))
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only `messages` and `conversations` message keys are currently supported."
|
||||
)
|
||||
|
||||
return processed_examples
|
||||
|
||||
@staticmethod
|
||||
def process_images(examples, max_images):
|
||||
"""
|
||||
Process images from examples, ensuring consistency in image presence and applying max_images limit.
|
||||
|
||||
Args:
|
||||
examples: List of dictionaries that may contain 'images' key
|
||||
max_images: Maximum number of images to keep per example (0 means no limit)
|
||||
|
||||
Returns:
|
||||
Either None (if no images) or List[Image objects] (if all examples have images)
|
||||
|
||||
Raises:
|
||||
ValueError: If there's a mix of None and non-None images
|
||||
"""
|
||||
|
||||
def get_image(example):
|
||||
if "images" not in example:
|
||||
return None
|
||||
images = example["images"]
|
||||
if isinstance(images, str):
|
||||
return Image.open(images)
|
||||
return images
|
||||
|
||||
images = [get_image(example) for example in examples]
|
||||
|
||||
# Count None and non-None images
|
||||
none_count = sum(1 for img in images if img is None)
|
||||
|
||||
# All images are None
|
||||
if none_count == len(images):
|
||||
return None
|
||||
|
||||
# Mix of None and non-None images
|
||||
if none_count > 0:
|
||||
raise ValueError(
|
||||
"All images should be either None or not None. "
|
||||
"Please provide images for all examples or None."
|
||||
)
|
||||
|
||||
# Apply max_images limit if specified
|
||||
if max_images > 0:
|
||||
images = [
|
||||
(
|
||||
img_batch[:max_images]
|
||||
if isinstance(img_batch, (list, tuple))
|
||||
else img_batch
|
||||
)
|
||||
for img_batch in images
|
||||
]
|
||||
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def pixtral_chat_conversion(messages):
|
||||
is_single_message = not isinstance(messages, list)
|
||||
if is_single_message:
|
||||
messages = [messages]
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if message["role"] == "user":
|
||||
for j, content in enumerate(message["content"]):
|
||||
if "type" in content and content["type"] == "text":
|
||||
messages[i]["content"][j] = {
|
||||
"type": "text",
|
||||
"content": content["text"],
|
||||
}
|
||||
|
||||
if message["role"] == "assistant":
|
||||
messages[i]["content"] = message["content"][0]["text"]
|
||||
|
||||
if is_single_message:
|
||||
return messages[0]
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def process_rows(
|
||||
examples,
|
||||
processor,
|
||||
chat_template,
|
||||
max_images,
|
||||
length_only=False,
|
||||
chat_template_type=None,
|
||||
):
|
||||
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
||||
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||
# see also DataCollatorWithFlattening and DefaultDataCollator
|
||||
|
||||
# *** This is COPIED from the trl example sft_vlm.py code ***
|
||||
# use this as a starting point
|
||||
|
||||
# Preprocess the examples
|
||||
examples = __class__.preprocess(examples)
|
||||
|
||||
# Get the texts and images, and apply the chat template
|
||||
if chat_template_type == "pixtral":
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
__class__.pixtral_chat_conversion(example["messages"]),
|
||||
chat_template=chat_template,
|
||||
tokenize=False,
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
else:
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
example["messages"], chat_template=chat_template, tokenize=False
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
example["messages"], chat_template=chat_template, tokenize=False
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
images = [
|
||||
Image.open(example["images"])
|
||||
if isinstance(example["images"], str)
|
||||
else example["images"]
|
||||
for example in examples
|
||||
]
|
||||
|
||||
images = __class__.process_images(examples, max_images=max_images)
|
||||
if chat_template_type == "llava":
|
||||
# LLava1.5 does not support multiple images
|
||||
images = [image[0] for image in images]
|
||||
if max_images > 0:
|
||||
images = [img_batch[:max_images] for img_batch in images]
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
@@ -225,12 +70,9 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||
# Ignore the image token index in the loss computation (model specific)
|
||||
if chat_template_type == "qwen2_vl":
|
||||
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
|
||||
else:
|
||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.image_token
|
||||
)
|
||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.image_token
|
||||
)
|
||||
labels[labels == image_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
"""Module for working with config dicts"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.config import merge_input_args
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
)
|
||||
@@ -31,10 +32,7 @@ def choose_device(cfg):
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
if is_torch_npu_available():
|
||||
return f"npu:{cfg.local_rank}"
|
||||
|
||||
raise SystemError("No CUDA/mps/npu device found")
|
||||
raise SystemError("No CUDA/mps device found")
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return "cpu"
|
||||
|
||||
@@ -44,8 +42,6 @@ def choose_device(cfg):
|
||||
else:
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": torch.cuda.current_device()}
|
||||
elif cfg.device.startswith("npu"):
|
||||
cfg.device_map = {"npu": torch.npu.current_device()}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
|
||||
@@ -132,7 +128,7 @@ def normalize_config(cfg):
|
||||
|
||||
cfg.is_multimodal = (
|
||||
hasattr(model_config, "model_type")
|
||||
and model_config.model_type in ["llava", "mllama", "qwen2_vl", "qwen2_5_vl"]
|
||||
and model_config.model_type in ["llava", "mllama"]
|
||||
or any(
|
||||
multimodal_name in cfg.base_model.lower()
|
||||
for multimodal_name in [
|
||||
@@ -145,12 +141,7 @@ def normalize_config(cfg):
|
||||
cfg.processor_config = (
|
||||
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
||||
)
|
||||
|
||||
try:
|
||||
model_config = model_config.text_config
|
||||
except AttributeError:
|
||||
# for qwen2_vl
|
||||
model_config = model_config.get_text_config()
|
||||
model_config = model_config.text_config
|
||||
|
||||
cfg.model_config_type = model_config.model_type
|
||||
|
||||
@@ -224,6 +215,11 @@ def normalize_cfg_datasets(cfg):
|
||||
if cfg.chat_template:
|
||||
if cfg.datasets:
|
||||
for idx, ds_cfg in enumerate(cfg.datasets):
|
||||
if ds_cfg.type == "sharegpt" and not ds_cfg.conversation:
|
||||
LOG.info(
|
||||
f"updating dataset {ds_cfg.path} with `conversation: {cfg.chat_template}` to match your chat_template"
|
||||
)
|
||||
cfg.datasets[idx].conversation = cfg.chat_template
|
||||
if (
|
||||
ds_cfg.type in ["orpo.chat_template", "chat_template"]
|
||||
and not ds_cfg.chat_template
|
||||
@@ -235,11 +231,7 @@ def normalize_cfg_datasets(cfg):
|
||||
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
|
||||
|
||||
|
||||
def validate_config(
|
||||
cfg: DictDefault,
|
||||
capabilities: Optional[dict] = None,
|
||||
env_capabilities: Optional[dict] = None,
|
||||
):
|
||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
|
||||
AxolotlInputConfig = AxolotlInputConfigBase
|
||||
|
||||
@@ -249,35 +241,402 @@ def validate_config(
|
||||
AxolotlInputConfig, # pylint: disable=invalid-name
|
||||
) = merge_input_args()
|
||||
|
||||
if capabilities or env_capabilities:
|
||||
if (capabilities and not env_capabilities) or (
|
||||
env_capabilities and not capabilities
|
||||
):
|
||||
raise ValueError(
|
||||
"Both capabilities and env_capabilities must be provided or not provided."
|
||||
)
|
||||
|
||||
if capabilities:
|
||||
return DictDefault(
|
||||
dict(
|
||||
AxolotlConfigWCapabilities(
|
||||
**cfg.to_dict(),
|
||||
capabilities=capabilities,
|
||||
env_capabilities=env_capabilities,
|
||||
**cfg.to_dict(), capabilities=capabilities
|
||||
).model_dump(exclude_none=True)
|
||||
)
|
||||
)
|
||||
|
||||
return DictDefault(
|
||||
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
||||
)
|
||||
|
||||
|
||||
def prepare_plugins(cfg):
|
||||
def legacy_validate_config(cfg):
|
||||
"""
|
||||
Prepare the plugins for the configuration
|
||||
This is a "pre-validation" step that handles the yaml configuration before we have any
|
||||
information about the model architecture
|
||||
"""
|
||||
if is_torch_bf16_gpu_available():
|
||||
if not cfg.bf16 and not cfg.bfloat16:
|
||||
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
||||
else:
|
||||
if (
|
||||
not cfg.merge_lora
|
||||
and not cfg.is_preprocess
|
||||
and (cfg.bf16 is True or cfg.bfloat16 is True)
|
||||
):
|
||||
raise ValueError(
|
||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||
)
|
||||
if (
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
not (cfg.bf16 or cfg.bfloat16)
|
||||
and (cfg.fp16 or cfg.float16)
|
||||
and not cfg.adapter
|
||||
and not cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
LOG.warning(
|
||||
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
|
||||
)
|
||||
# ValueError: Attempting to unscale FP16 gradients.
|
||||
# OR
|
||||
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
||||
if cfg.max_packed_sequence_len:
|
||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||
|
||||
if cfg.get("plugins"):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
for plugin_name in cfg["plugins"]:
|
||||
plugin_manager.register(plugin_name)
|
||||
if cfg.sample_packing and cfg.rl:
|
||||
raise ValueError("`sample_packing: true` does not work with RLHF training")
|
||||
|
||||
if cfg.sample_packing and not cfg.pad_to_sequence_len:
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
)
|
||||
|
||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||
raise ValueError(
|
||||
"please set only one of gradient_accumulation_steps or batch_size"
|
||||
)
|
||||
if cfg.batch_size:
|
||||
LOG.warning(
|
||||
"%s\n%s",
|
||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||
)
|
||||
if (
|
||||
cfg.eval_batch_size
|
||||
and cfg.micro_batch_size
|
||||
and cfg.eval_batch_size != cfg.micro_batch_size
|
||||
):
|
||||
LOG.warning(
|
||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||
)
|
||||
|
||||
if cfg.adapter == "qlora":
|
||||
if cfg.merge_lora:
|
||||
# can't merge qlora if loaded in 8bit or 4bit
|
||||
if cfg.load_in_8bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||
|
||||
if cfg.gptq:
|
||||
raise ValueError("Can't merge qlora if gptq")
|
||||
|
||||
if cfg.load_in_4bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||
|
||||
else:
|
||||
if cfg.load_in_8bit:
|
||||
raise ValueError("Can't load qlora in 8bit")
|
||||
|
||||
if cfg.gptq:
|
||||
raise ValueError("Can't load qlora if gptq")
|
||||
|
||||
if not cfg.load_in_4bit:
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with QLoRA")
|
||||
|
||||
loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
||||
raise ValueError("Fused modules are not supported with LoRA")
|
||||
|
||||
if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
|
||||
raise ValueError(
|
||||
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
||||
)
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
if cfg.fsdp:
|
||||
raise ValueError("fsdp not supported with ReLoRA")
|
||||
|
||||
if cfg.deepspeed:
|
||||
raise ValueError("deepspeed not supported with ReLoRA")
|
||||
|
||||
if cfg.lr_scheduler == "one_cycle":
|
||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||
|
||||
if cfg.trust_remote_code:
|
||||
LOG.warning(
|
||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||
)
|
||||
|
||||
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
|
||||
raise ValueError(
|
||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||
)
|
||||
|
||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||
raise ValueError("FSDP is not supported for falcon models")
|
||||
|
||||
if (
|
||||
cfg.base_model and "mpt" in cfg.base_model.lower()
|
||||
) and cfg.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
|
||||
if cfg.flash_optimum is True:
|
||||
if cfg.adapter:
|
||||
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
||||
if cfg.fp16 or cfg.bf16:
|
||||
raise ValueError("AMP is not supported with BetterTransformer")
|
||||
if cfg.float16 is not True and cfg.bfloat16 is not True:
|
||||
LOG.warning(
|
||||
"You should probably set bfloat16 or float16 to true to "
|
||||
"load the model in float16 for BetterTransformers"
|
||||
)
|
||||
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
|
||||
LOG.warning("torch>=2.0.0 required")
|
||||
raise ValueError(
|
||||
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
||||
)
|
||||
|
||||
if cfg.pretraining_dataset and cfg.group_by_length:
|
||||
LOG.warning(
|
||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||
)
|
||||
if cfg.pretraining_dataset and not cfg.max_steps:
|
||||
raise ValueError(
|
||||
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
|
||||
)
|
||||
|
||||
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||
):
|
||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
|
||||
if cfg.push_to_hub_model_id:
|
||||
raise ValueError(
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
||||
LOG.warning(
|
||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.revision_of_model:
|
||||
raise ValueError(
|
||||
"revision_of_model is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove revision_of_model from the config."
|
||||
)
|
||||
|
||||
# if cfg.sample_packing and cfg.sdp_attention:
|
||||
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
# raise ValueError(
|
||||
# "sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
# )
|
||||
|
||||
if cfg.sample_packing and cfg.xformers_attention:
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
|
||||
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
||||
LOG.warning(
|
||||
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
||||
"This may work on H100s."
|
||||
)
|
||||
|
||||
if cfg.early_stopping_patience:
|
||||
if not cfg.save_steps or not cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
||||
)
|
||||
if cfg.save_steps % cfg.eval_steps != 0:
|
||||
raise ValueError(
|
||||
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
||||
)
|
||||
|
||||
if cfg.datasets:
|
||||
for idx, ds_cfg in enumerate(cfg.datasets):
|
||||
if not ds_cfg.type:
|
||||
continue
|
||||
if ds_cfg.type == "sharegpt:chat":
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
cfg.datasets[idx].type = "sharegpt"
|
||||
if "sharegpt_simple" in ds_cfg.type:
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||
"sharegpt_simple", "sharegpt"
|
||||
)
|
||||
|
||||
if cfg.saves_per_epoch and cfg.save_steps:
|
||||
raise ValueError(
|
||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||
)
|
||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
)
|
||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if (
|
||||
cfg.evals_per_epoch
|
||||
and cfg.evaluation_strategy
|
||||
and cfg.evaluation_strategy != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
if (
|
||||
cfg.evaluation_strategy
|
||||
and cfg.eval_steps
|
||||
and cfg.evaluation_strategy != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.val_set_size == 0
|
||||
and (cfg.eval_steps or cfg.evaluation_strategy)
|
||||
and not cfg.test_datasets
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.sample_packing
|
||||
and cfg.eval_table_size
|
||||
and cfg.eval_sample_packing is not False
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
||||
)
|
||||
|
||||
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
|
||||
raise ValueError(
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
|
||||
if cfg.rope_scaling:
|
||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||
|
||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||
cfg.wandb_name = cfg.wandb_run_id
|
||||
|
||||
LOG.warning(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
)
|
||||
|
||||
if cfg.noisy_embedding_alpha is not None:
|
||||
# Deprecated, use neftune_noise_alpha
|
||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||
if cfg.neftune_noise_alpha is None:
|
||||
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
||||
else:
|
||||
# User is providing both; bail and have them sort out their settings
|
||||
raise ValueError(
|
||||
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
||||
)
|
||||
|
||||
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||
|
||||
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
||||
raise ValueError(
|
||||
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.unfrozen_parameters
|
||||
and cfg.gradient_checkpointing_kwargs
|
||||
and cfg.gradient_checkpointing_kwargs.use_reentrant is True
|
||||
):
|
||||
# https://github.com/huggingface/transformers/issues/21381
|
||||
raise ValueError(
|
||||
"`use_reentrant` must be false when used with partially frozen model."
|
||||
)
|
||||
|
||||
if cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
||||
with open(cfg.deepspeed, encoding="utf-8") as file:
|
||||
contents = file.read()
|
||||
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
||||
if cfg.flash_attention:
|
||||
if (
|
||||
deepspeed_cfg.zero_optimization
|
||||
and deepspeed_cfg.zero_optimization.stage == 3
|
||||
):
|
||||
if not (
|
||||
(
|
||||
deepspeed_cfg.bf16
|
||||
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
||||
is True
|
||||
)
|
||||
or (
|
||||
deepspeed_cfg.fp16
|
||||
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
||||
is True
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
||||
)
|
||||
if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
|
||||
LOG.warning(
|
||||
f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
|
||||
)
|
||||
|
||||
if cfg.test_datasets and cfg.val_set_size:
|
||||
raise ValueError(
|
||||
"non-zero val_set_size should not be used with test_datasets configuration"
|
||||
)
|
||||
|
||||
if cfg.fsdp and "bnb" in cfg.optimizer:
|
||||
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
||||
|
||||
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
|
||||
raise ValueError(
|
||||
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
||||
)
|
||||
|
||||
if cfg.eval_causal_lm_metrics:
|
||||
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
||||
raise ValueError("eval_causal_lm_metrics must be a list")
|
||||
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
||||
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
|
||||
raise ValueError(
|
||||
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
||||
)
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
# no 8bit adaAmw w bf16
|
||||
|
||||
# GPT-NeoX
|
||||
# evals broken when extending context len
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
||||
# attention_mask = causal_mask + attention_mask
|
||||
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
||||
|
||||
@@ -7,9 +7,9 @@ Module for pydantic models for configuration
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from importlib.metadata import version
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from packaging import version
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
@@ -20,9 +20,8 @@ from pydantic import (
|
||||
)
|
||||
from transformers import SchedulerType
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||
from axolotl.utils.config.models.internals import GPUCapabilities
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||
|
||||
@@ -51,7 +50,6 @@ class ChatTemplate(str, Enum):
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||
llava = "llava" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||
@@ -60,9 +58,6 @@ class ChatTemplate(str, Enum):
|
||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
exaone = "exaone" # pylint: disable=invalid-name
|
||||
metharme = "metharme" # pylint: disable=invalid-name
|
||||
pixtral = "pixtral" # pylint: disable=invalid-name
|
||||
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DeprecatedParameters(BaseModel):
|
||||
@@ -72,7 +67,6 @@ class DeprecatedParameters(BaseModel):
|
||||
rope_scaling: Optional[Any] = None
|
||||
noisy_embedding_alpha: Optional[float] = None
|
||||
dpo_beta: Optional[float] = None
|
||||
evaluation_strategy: Optional[str] = None
|
||||
|
||||
@field_validator("max_packed_sequence_len")
|
||||
@classmethod
|
||||
@@ -104,13 +98,6 @@ class DeprecatedParameters(BaseModel):
|
||||
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||
return dpo_beta
|
||||
|
||||
@field_validator("evaluation_strategy")
|
||||
@classmethod
|
||||
def validate_evaluation_strategy(cls, evaluation_strategy):
|
||||
if evaluation_strategy is not None:
|
||||
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||
return evaluation_strategy
|
||||
|
||||
|
||||
class RemappedParameters(BaseModel):
|
||||
"""parameters that have been remapped to other names"""
|
||||
@@ -254,10 +241,8 @@ class KTODataset(BaseModel):
|
||||
class LoftQConfig(BaseModel):
|
||||
"""LoftQ configuration subset"""
|
||||
|
||||
loftq_bits: int = Field(
|
||||
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
||||
)
|
||||
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
||||
loftq_bits: int = Field(default=4, metadata={"help": "Quantization bits for LoftQ"})
|
||||
# loftq_iter: int = Field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
|
||||
|
||||
|
||||
class PeftConfig(BaseModel):
|
||||
@@ -300,8 +285,8 @@ class LoraConfig(BaseModel):
|
||||
|
||||
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||
metadata={
|
||||
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||
},
|
||||
)
|
||||
lora_on_cpu: Optional[bool] = None
|
||||
@@ -310,15 +295,13 @@ class LoraConfig(BaseModel):
|
||||
|
||||
loraplus_lr_ratio: Optional[float] = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
||||
metadata={
|
||||
"help": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
||||
},
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = Field(
|
||||
default=1e-6,
|
||||
json_schema_extra={
|
||||
"description": "loraplus learning rate for lora embedding layers."
|
||||
},
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
|
||||
merge_lora: Optional[bool] = None
|
||||
@@ -326,13 +309,11 @@ class LoraConfig(BaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_adapter(cls, data):
|
||||
if (
|
||||
not data.get("adapter")
|
||||
and not data.get("inference")
|
||||
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
||||
if not data.get("adapter") and (
|
||||
data.get("load_in_8bit") or data.get("load_in_4bit")
|
||||
):
|
||||
raise ValueError(
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
return data
|
||||
@@ -390,10 +371,10 @@ class ModelInputConfig(BaseModel):
|
||||
tokenizer_use_fast: Optional[bool] = None
|
||||
tokenizer_legacy: Optional[bool] = None
|
||||
tokenizer_type: Optional[str] = Field(
|
||||
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
||||
default=None, metadata={"help": "transformers tokenizer class"}
|
||||
)
|
||||
processor_type: Optional[str] = Field(
|
||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||
default=None, metadata={"help": "transformers processor class"}
|
||||
)
|
||||
trust_remote_code: Optional[bool] = None
|
||||
|
||||
@@ -415,18 +396,18 @@ class HyperparametersConfig(BaseModel):
|
||||
gradient_accumulation_steps: Optional[int] = Field(default=1)
|
||||
micro_batch_size: Optional[int] = Field(
|
||||
default=1,
|
||||
json_schema_extra={"description": "per gpu micro batch size for training"},
|
||||
metadata={"help": "per gpu micro batch size for training"},
|
||||
)
|
||||
batch_size: Optional[int] = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Total batch size, we do not recommended setting this manually"
|
||||
metadata={
|
||||
"help": "Total batch size, we do not recommended setting this manually"
|
||||
},
|
||||
)
|
||||
eval_batch_size: Optional[int] = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
||||
metadata={
|
||||
"help": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -436,8 +417,6 @@ class HyperparametersConfig(BaseModel):
|
||||
group_by_length: Optional[bool] = None
|
||||
|
||||
learning_rate: Union[str, float]
|
||||
embedding_lr: Optional[float] = None
|
||||
embedding_lr_scale: Optional[float] = None
|
||||
weight_decay: Optional[float] = 0.0
|
||||
optimizer: Optional[
|
||||
Union[
|
||||
@@ -448,18 +427,16 @@ class HyperparametersConfig(BaseModel):
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
],
|
||||
]
|
||||
] = OptimizerNames.ADAMW_HF.value
|
||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
||||
)
|
||||
optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||
metadata={
|
||||
"help": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||
},
|
||||
)
|
||||
torchdistx_path: Optional[str] = None
|
||||
@@ -519,15 +496,15 @@ class LISAConfig(BaseModel):
|
||||
|
||||
lisa_n_layers: Optional[int] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "the number of activate layers in LISA"},
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "how often to switch layers in LISA"},
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = Field(
|
||||
default="model.layers",
|
||||
json_schema_extra={"description": "path under the model to access the layers"},
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
|
||||
|
||||
@@ -611,9 +588,6 @@ class AxolotlInputConfig(
|
||||
|
||||
rl: Optional[RLType] = None
|
||||
reward_model: Optional[bool] = None
|
||||
dpo_use_weighting: Optional[
|
||||
bool
|
||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
@@ -626,11 +600,9 @@ class AxolotlInputConfig(
|
||||
pretraining_dataset: Optional[ # type: ignore
|
||||
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
|
||||
] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
|
||||
)
|
||||
dataset_processes: Optional[int] = Field(default=os.cpu_count())
|
||||
dataset_exact_deduplication: Optional[bool] = None
|
||||
dataset_keep_in_memory: Optional[bool] = None
|
||||
dataloader_pin_memory: Optional[bool] = None
|
||||
dataloader_num_workers: Optional[int] = None
|
||||
@@ -688,8 +660,7 @@ class AxolotlInputConfig(
|
||||
sequence_len: int = Field(default=512)
|
||||
min_sample_len: Optional[int] = None
|
||||
max_prompt_len: int = Field(
|
||||
default=512,
|
||||
json_schema_extra={"description": "maximum prompt length for RL training"},
|
||||
default=512, metadata={"help": "maximum prompt length for RL training"}
|
||||
)
|
||||
sample_packing: Optional[bool] = None
|
||||
sample_packing_group_size: Optional[int] = 100_000
|
||||
@@ -708,8 +679,8 @@ class AxolotlInputConfig(
|
||||
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
||||
pretrain_multipack_attn: Optional[bool] = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "whether to prevent cross attention for packed sequences during pretraining",
|
||||
metadata={
|
||||
"help": "whether to prevent cross attention for packed sequences during pretraining",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -755,7 +726,7 @@ class AxolotlInputConfig(
|
||||
warmup_ratio: Optional[float] = None
|
||||
eval_steps: Optional[Union[int, float]] = None
|
||||
evals_per_epoch: Optional[Union[int]] = None
|
||||
eval_strategy: Optional[str] = None
|
||||
evaluation_strategy: Optional[str] = None
|
||||
save_steps: Optional[Union[int, float]] = None
|
||||
saves_per_epoch: Optional[int] = None
|
||||
save_strategy: Optional[str] = None
|
||||
@@ -807,25 +778,28 @@ class AxolotlInputConfig(
|
||||
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
||||
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
||||
|
||||
plugins: Optional[List[str]] = Field(default=None)
|
||||
|
||||
@field_validator("datasets", mode="before")
|
||||
@classmethod
|
||||
def deprecate_sharegpt_datasets(cls, datasets):
|
||||
for _, ds_cfg in enumerate(datasets):
|
||||
if not ds_cfg.get("type"):
|
||||
def fix_sharegpt_datasets(cls, datasets):
|
||||
for idx, ds_cfg in enumerate(datasets):
|
||||
if not ds_cfg["type"]:
|
||||
continue
|
||||
|
||||
ds_type = ds_cfg["type"]
|
||||
# skip if it's a dict (for custom user instruction prompt)
|
||||
if isinstance(ds_type, dict):
|
||||
continue
|
||||
|
||||
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
|
||||
raise ValueError(
|
||||
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
|
||||
if ds_cfg["type"] == "sharegpt:chat":
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
datasets[idx]["type"] = "sharegpt"
|
||||
if "sharegpt_simple" in ds_cfg["type"]:
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
datasets[idx]["type"] = datasets[idx]["type"].replace(
|
||||
"sharegpt_simple", "sharegpt"
|
||||
)
|
||||
|
||||
return datasets
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1057,21 +1031,21 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def check_evals(cls, data):
|
||||
if (
|
||||
data.get("eval_strategy")
|
||||
data.get("evaluation_strategy")
|
||||
and data.get("eval_steps")
|
||||
and data.get("eval_strategy") != "steps"
|
||||
and data.get("evaluation_strategy") != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
|
||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
data.get("val_set_size") == 0
|
||||
and (data.get("eval_steps") or data.get("eval_strategy"))
|
||||
and (data.get("eval_steps") or data.get("evaluation_strategy"))
|
||||
and not data.get("test_datasets")
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
||||
raise ValueError(
|
||||
@@ -1079,11 +1053,11 @@ class AxolotlInputConfig(
|
||||
)
|
||||
if (
|
||||
data.get("evals_per_epoch")
|
||||
and data.get("eval_strategy")
|
||||
and data.get("eval_strategy") != "steps"
|
||||
and data.get("evaluation_strategy")
|
||||
and data.get("evaluation_strategy") != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
|
||||
if data.get("do_bench_eval") and not (
|
||||
@@ -1315,26 +1289,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_qlora_zero3_w_use_reentrant(cls, data):
|
||||
if (
|
||||
data.get("adapter") == "qlora"
|
||||
and data.get("gradient_checkpointing_kwargs", {})
|
||||
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
||||
is False
|
||||
and data.get("deepspeed", "") is not None
|
||||
and "zero3" in data.get("deepspeed", "")
|
||||
):
|
||||
# may result in:
|
||||
# torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint:
|
||||
# Recomputed values for the following tensors have different metadata
|
||||
# than during the forward pass.
|
||||
LOG.warning(
|
||||
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_val_w_test_datasets(cls, data):
|
||||
@@ -1344,19 +1298,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_eval_strategy(cls, data):
|
||||
if (
|
||||
data.get("evaluation_strategy") is not None
|
||||
and data.get("eval_strategy") is None
|
||||
):
|
||||
LOG.info(
|
||||
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
|
||||
)
|
||||
data["eval_strategy"] = data.get("evaluation_strategy")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
||||
@@ -1435,6 +1376,21 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_unsloth_xformers_version(cls, data):
|
||||
if (
|
||||
data.get("unsloth_lora_mlp")
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
xformers_version = version("xformers")
|
||||
if xformers_version == "0.0.27":
|
||||
raise ValueError(
|
||||
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_torch_compile_deepspeed(cls, data):
|
||||
@@ -1444,46 +1400,11 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_npu_config(cls, data):
|
||||
if is_torch_npu_available():
|
||||
# check attention config
|
||||
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
|
||||
for attn in attn_list:
|
||||
if data.get(attn):
|
||||
raise NotImplementedError(
|
||||
f"{attn} is currently not supported in Ascend npu, please disable this configuration."
|
||||
)
|
||||
|
||||
# check quant config
|
||||
if data.get("optimizer") is not None and "bit" in data.get("optimizer"):
|
||||
optimizer = data.get("optimizer")
|
||||
raise NotImplementedError(
|
||||
f"{optimizer} is currently not supported in Ascend npu, choose another one please."
|
||||
)
|
||||
|
||||
quant_list = ["load_in_8bit", "load_in_4bit"]
|
||||
for quant in quant_list:
|
||||
if data.get(quant):
|
||||
raise NotImplementedError(
|
||||
f"Quantification is currently not supported in Ascend npu, please disable {quant}."
|
||||
)
|
||||
|
||||
# check dtype config
|
||||
if data.get("tf32"):
|
||||
raise NotImplementedError(
|
||||
"tf32 dtype is currently not supported in Ascend npu, please disable this configuration"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
capabilities: GPUCapabilities
|
||||
env_capabilities: EnvCapabilities
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_bf16(self):
|
||||
@@ -1558,21 +1479,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_adopt_torch_version(cls, data):
|
||||
if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")):
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
|
||||
if version.parse(torch_version) < version.parse("2.5.1"):
|
||||
raise ValueError(
|
||||
"ADOPT optimizer is incompatible with torch version < 2.5.1"
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -12,9 +12,3 @@ class GPUCapabilities(BaseModel):
|
||||
n_gpu: int = Field(default=1)
|
||||
n_node: int = Field(default=1)
|
||||
compute_capability: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class EnvCapabilities(BaseModel):
|
||||
"""model to manage the environment capabilities statically"""
|
||||
|
||||
torch_version: Optional[str] = Field(default=None)
|
||||
|
||||
@@ -13,7 +13,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.prompt_strategies.kto import load as load_kto
|
||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.data.utils import md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
@@ -64,57 +64,15 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
||||
|
||||
if isinstance(data_set, DatasetDict):
|
||||
data_set = data_set["train"]
|
||||
|
||||
data_set = data_set.map(
|
||||
ds_transform_fn,
|
||||
desc="Mapping RL Dataset",
|
||||
)
|
||||
|
||||
if isinstance(data_set, DatasetDict):
|
||||
data_set = data_set["train"]
|
||||
return data_set
|
||||
|
||||
|
||||
def drop_long_rl_seq(
|
||||
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
|
||||
):
|
||||
if rl in ("dpo", "ipo", "orpo", "simpo"):
|
||||
if not (
|
||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||
):
|
||||
raise ValueError(
|
||||
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
|
||||
)
|
||||
|
||||
prompt = sample["prompt"]
|
||||
chosen = sample["chosen"]
|
||||
rejected = sample["rejected"]
|
||||
|
||||
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||
|
||||
return (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
|
||||
if rl == "kto":
|
||||
if not (sample.get("prompt") and sample.get("completion")):
|
||||
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
||||
|
||||
prompt = sample["prompt"]
|
||||
completion = sample["completion"]
|
||||
|
||||
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
||||
len_completion = len(
|
||||
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
|
||||
def load_prepare_dpo_datasets(cfg):
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
@@ -136,7 +94,7 @@ def load_prepare_dpo_datasets(cfg):
|
||||
)
|
||||
split_datasets.insert(i, ds)
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
tokenizer = None
|
||||
|
||||
for i, data_set in enumerate(split_datasets):
|
||||
_type = dataset_cfgs[i]["type"]
|
||||
@@ -163,28 +121,7 @@ def load_prepare_dpo_datasets(cfg):
|
||||
# "prompt", "chosen" and "rejected" already preprocessed
|
||||
split_datasets[i] = data_set
|
||||
|
||||
drop_long = partial(
|
||||
drop_long_rl_seq,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
||||
|
||||
combined_datasets = concatenate_datasets(split_datasets)
|
||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
||||
|
||||
return combined_datasets
|
||||
return concatenate_datasets(split_datasets)
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
train_is_preprocessed = False
|
||||
@@ -208,9 +145,4 @@ def load_prepare_dpo_datasets(cfg):
|
||||
if eval_dataset and not eval_is_preprocessed:
|
||||
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
||||
|
||||
if cfg.dataset_exact_deduplication:
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=train_dataset, eval_dataset=eval_dataset
|
||||
)
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
@@ -44,7 +44,7 @@ from axolotl.prompters import (
|
||||
UnsupportedPrompter,
|
||||
)
|
||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.data.utils import md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_local_main_process, zero_first
|
||||
from axolotl.utils.trainer import (
|
||||
@@ -136,9 +136,8 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
if cfg.dataset_exact_deduplication:
|
||||
LOG.info("Deduplication not available for pretrained datasets")
|
||||
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
||||
|
||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||
if total_eval_steps == 0:
|
||||
@@ -261,7 +260,6 @@ def load_tokenized_prepared_datasets(
|
||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||
ds_from_hub = False
|
||||
ds_trust_remote_code = config_dataset.trust_remote_code
|
||||
try:
|
||||
# this is just a basic check to see if the path is a
|
||||
# valid HF dataset that's loadable
|
||||
@@ -271,7 +269,6 @@ def load_tokenized_prepared_datasets(
|
||||
streaming=True,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=ds_trust_remote_code,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||
@@ -351,15 +348,7 @@ def load_tokenized_prepared_datasets(
|
||||
split=None,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
ds = load_from_disk(config_dataset.path)
|
||||
except FileNotFoundError:
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
ds = load_from_disk(config_dataset.path)
|
||||
elif local_path.is_file():
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
|
||||
@@ -377,7 +366,7 @@ def load_tokenized_prepared_datasets(
|
||||
elif ds_from_hub:
|
||||
load_ds_kwargs = {}
|
||||
if config_dataset.split:
|
||||
load_ds_kwargs["split"] = config_dataset.split
|
||||
load_ds_kwargs = {"split": config_dataset.split}
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
@@ -385,7 +374,6 @@ def load_tokenized_prepared_datasets(
|
||||
data_files=config_dataset.data_files,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif ds_from_cloud and remote_file_system:
|
||||
@@ -403,7 +391,6 @@ def load_tokenized_prepared_datasets(
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
)
|
||||
elif config_dataset.path.startswith("https://"):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
@@ -414,7 +401,6 @@ def load_tokenized_prepared_datasets(
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
@@ -585,8 +571,7 @@ def load_prepare_datasets(
|
||||
)
|
||||
train_fingerprint = md5(to_hash_train)
|
||||
test_fingerprint = md5(to_hash_test)
|
||||
if cfg.dataset_exact_deduplication:
|
||||
_, _, dataset = deduplicate_and_log_datasets(dataset=dataset)
|
||||
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=val_set_size,
|
||||
shuffle=False,
|
||||
@@ -598,17 +583,12 @@ def load_prepare_datasets(
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
elif split == "test":
|
||||
if cfg.dataset_exact_deduplication:
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||
else:
|
||||
eval_dataset = dataset
|
||||
train_dataset = None
|
||||
eval_dataset = dataset
|
||||
else:
|
||||
if cfg.dataset_exact_deduplication:
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||
else:
|
||||
train_dataset = dataset
|
||||
train_dataset = dataset
|
||||
eval_dataset = None
|
||||
|
||||
return train_dataset, eval_dataset, prompters
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
"""data handling helpers"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
@@ -13,96 +8,3 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
||||
except TypeError:
|
||||
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||
|
||||
|
||||
def sha256(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
return hashlib.sha256(to_hash.encode(encoding)).hexdigest()
|
||||
|
||||
|
||||
def deduplicate_dataset(
|
||||
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
|
||||
) -> Dataset:
|
||||
unique_indices = []
|
||||
|
||||
for idx, row in enumerate(dataset):
|
||||
row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
|
||||
if row_hash not in seen_hashes:
|
||||
seen_hashes[row_hash] = [idx]
|
||||
unique_indices.append(idx)
|
||||
else:
|
||||
# Check for collision by looking up the original dataset indices
|
||||
original_indices = seen_hashes[row_hash]
|
||||
is_duplicate = False
|
||||
for original_idx in original_indices:
|
||||
if (
|
||||
not idx == original_idx
|
||||
and original_idx < len(dataset)
|
||||
and str(dataset[original_idx]) == str(row)
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
# Check in the other dataset if provided
|
||||
if other_dataset is not None:
|
||||
if original_idx < len(other_dataset) and str(
|
||||
other_dataset[original_idx]
|
||||
) == str(row):
|
||||
is_duplicate = True
|
||||
break
|
||||
if not is_duplicate:
|
||||
seen_hashes[row_hash].append(idx)
|
||||
unique_indices.append(idx)
|
||||
continue
|
||||
return dataset.select(unique_indices)
|
||||
|
||||
|
||||
def deduplicate_and_log_datasets(
|
||||
*,
|
||||
train_dataset: Dataset = None,
|
||||
eval_dataset: Dataset = None,
|
||||
dataset: Dataset = None,
|
||||
) -> tuple[Dataset, Dataset, Dataset]:
|
||||
"""
|
||||
Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.
|
||||
|
||||
Returns:
|
||||
tuple: Deduplicated train, eval, and additional datasets.
|
||||
"""
|
||||
seen_hashes: dict[str, list[int]] = {}
|
||||
|
||||
# Handle cases where datasets are None
|
||||
if train_dataset is not None:
|
||||
LOG.info(
|
||||
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
|
||||
)
|
||||
train_dataset = deduplicate_dataset(
|
||||
dataset=train_dataset, seen_hashes=seen_hashes
|
||||
)
|
||||
LOG.info(
|
||||
f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
|
||||
)
|
||||
else:
|
||||
LOG.info("Train dataset is None. Skipping deduplication.")
|
||||
|
||||
if eval_dataset is not None:
|
||||
LOG.info(
|
||||
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
|
||||
)
|
||||
eval_dataset = deduplicate_dataset(
|
||||
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
|
||||
)
|
||||
LOG.info(
|
||||
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
|
||||
)
|
||||
else:
|
||||
LOG.info("Eval dataset is None. Skipping deduplication.")
|
||||
|
||||
if dataset is not None and (eval_dataset is None and train_dataset is None):
|
||||
LOG.info(
|
||||
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
|
||||
)
|
||||
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
|
||||
LOG.info(
|
||||
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
|
||||
)
|
||||
|
||||
return train_dataset, eval_dataset, dataset
|
||||
|
||||
@@ -9,44 +9,10 @@ from datetime import timedelta
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import PartialState
|
||||
from transformers.utils.import_utils import (
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
distributed_state = None # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_device_type():
|
||||
device = torch.device("cpu")
|
||||
if is_torch_cuda_available():
|
||||
device = torch.device("cuda")
|
||||
elif is_torch_mps_available():
|
||||
device = torch.device("mps")
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device("npu")
|
||||
return device
|
||||
|
||||
|
||||
def get_device_count():
|
||||
cur_device = get_device_type()
|
||||
if "cuda" in str(cur_device):
|
||||
return torch.cuda.device_count()
|
||||
if "npu" in str(cur_device):
|
||||
return torch.npu.device_count()
|
||||
return 1
|
||||
|
||||
|
||||
def get_current_device():
|
||||
cur_device = get_device_type()
|
||||
if "cuda" in str(cur_device):
|
||||
return torch.cuda.current_device()
|
||||
if "npu" in str(cur_device):
|
||||
return torch.npu.current_device()
|
||||
return 0
|
||||
|
||||
|
||||
def is_distributed():
|
||||
"""
|
||||
Check if distributed training is initialized.
|
||||
@@ -125,7 +91,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
||||
if not is_distributed():
|
||||
return [value_scalar]
|
||||
value_tensor = torch.tensor(
|
||||
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
|
||||
value_scalar, device=torch.cuda.current_device()
|
||||
).float()
|
||||
|
||||
if not is_main_process():
|
||||
@@ -149,14 +115,13 @@ def broadcast_dict(vals: dict):
|
||||
if not is_distributed():
|
||||
return vals
|
||||
|
||||
cur_device = get_device_type()
|
||||
if is_main_process():
|
||||
data_byte = pickle.dumps(vals)
|
||||
data_tensor = torch.ByteTensor(list(data_byte)).to(cur_device)
|
||||
data_size = torch.IntTensor([len(data_byte)]).to(cur_device)
|
||||
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
|
||||
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
|
||||
else:
|
||||
data_tensor = torch.empty([1024], dtype=torch.uint8, device=cur_device)
|
||||
data_size = torch.IntTensor([0]).to(cur_device)
|
||||
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
|
||||
data_size = torch.IntTensor([0]).to("cuda")
|
||||
|
||||
dist.broadcast(data_size, 0)
|
||||
if not is_main_process():
|
||||
@@ -185,15 +150,14 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
|
||||
Returns:
|
||||
- The computed value (int or float).
|
||||
"""
|
||||
cur_device = f"{get_device_type()}:{get_current_device()}"
|
||||
if is_main_process():
|
||||
value_scalar = fn()
|
||||
value_tensor = torch.tensor(
|
||||
value_scalar, device=cur_device, dtype=torch.float32
|
||||
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
value_tensor = torch.tensor(
|
||||
0.0, device=cur_device, dtype=torch.float32
|
||||
0.0, device=torch.cuda.current_device(), dtype=torch.float32
|
||||
) # Placeholder tensor
|
||||
|
||||
# Broadcast the tensor to all processes.
|
||||
@@ -220,7 +184,7 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
||||
"""
|
||||
value_scalar = fn()
|
||||
value_tensor = torch.tensor(
|
||||
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
|
||||
value_scalar, device=torch.cuda.current_device()
|
||||
).float()
|
||||
|
||||
# Placeholder tensor for gathering results
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
"""
|
||||
utils to get GPU info for the current environment
|
||||
"""
|
||||
from accelerate.utils.environment import (
|
||||
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
||||
)
|
||||
from accelerate.utils.environment import get_gpu_info
|
||||
|
||||
|
||||
def check_cuda_p2p_ib_support():
|
||||
if not accelerate_check_cuda_p2p_ib_support():
|
||||
return False
|
||||
unsupported_devices = {"RTX 6000 Ada"}
|
||||
try:
|
||||
device_names, device_count = get_gpu_info()
|
||||
if 1 < device_count < 8:
|
||||
if any(
|
||||
unsupported_device in device_name
|
||||
for device_name in device_names
|
||||
for unsupported_device in unsupported_devices
|
||||
):
|
||||
return False
|
||||
except Exception: # pylint: disable=broad-except # nosec
|
||||
pass
|
||||
return True
|
||||
@@ -14,16 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
torch_version = version.parse(torch.__version__)
|
||||
|
||||
if torch_version < version.parse("2.4.0"):
|
||||
torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||
else:
|
||||
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
|
||||
|
||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
@@ -35,7 +25,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch_cuda_amp_custom_fwd
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(ctx, forward_function, hidden_states, *args):
|
||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
@@ -46,7 +36,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch_cuda_amp_custom_bwd
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx, dY):
|
||||
(hidden_states,) = ctx.saved_tensors
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
import gc
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import types
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||
|
||||
import addict
|
||||
@@ -30,7 +28,6 @@ from transformers import ( # noqa: F401
|
||||
AddedToken,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
@@ -58,7 +55,7 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
|
||||
from axolotl.utils.distributed import zero_only
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
@@ -92,11 +89,7 @@ def get_module_class_from_name(module, name):
|
||||
|
||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||
if cfg.is_multimodal:
|
||||
try:
|
||||
model_config = model_config.text_config
|
||||
except AttributeError:
|
||||
# for qwen2_vl
|
||||
model_config = model_config.get_text_config()
|
||||
model_config = model_config.text_config
|
||||
|
||||
quant_config_exists = (
|
||||
hasattr(model_config, "quantization_config")
|
||||
@@ -245,7 +238,6 @@ def load_tokenizer(cfg):
|
||||
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||
)
|
||||
)
|
||||
and k != "pad_token"
|
||||
):
|
||||
lora_modules_to_save = ", ".join(
|
||||
[f"`{x}`" for x in lora_modules_to_save]
|
||||
@@ -372,11 +364,7 @@ class ModelLoader:
|
||||
# init model config
|
||||
self.model_config = load_model_config(cfg)
|
||||
if cfg.is_multimodal:
|
||||
try:
|
||||
self.text_model_config = self.model_config.text_config
|
||||
except AttributeError:
|
||||
# for qwen2_vl
|
||||
self.text_model_config = self.model_config.get_text_config()
|
||||
self.text_model_config = self.model_config.text_config
|
||||
else:
|
||||
self.text_model_config = self.model_config
|
||||
|
||||
@@ -406,21 +394,14 @@ class ModelLoader:
|
||||
and self.cfg.flash_attention
|
||||
and self.cfg.sample_packing
|
||||
):
|
||||
has_remote_code = (
|
||||
"auto_map" in self.model_config
|
||||
and "AutoModelForCausalLM" in self.model_config["auto_map"]
|
||||
)
|
||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
||||
has_remote_code = self.cfg.trust_remote_code
|
||||
patch_for_multipack(
|
||||
self.cfg.model_config_type,
|
||||
model_name=self.cfg.base_model,
|
||||
has_remote_code=has_remote_code,
|
||||
is_remote_code=self.cfg.trust_remote_code,
|
||||
)
|
||||
|
||||
if self.cfg.is_llama_derived_model:
|
||||
self.patch_loss_llama()
|
||||
self.patch_loss()
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
@@ -462,34 +443,27 @@ class ModelLoader:
|
||||
|
||||
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
|
||||
|
||||
@cached_property
|
||||
def has_flash_attn(self) -> bool:
|
||||
"""Check if flash attention is installed"""
|
||||
return importlib.util.find_spec("flash_attn") is not None
|
||||
|
||||
def patch_loss_llama(self) -> None:
|
||||
def patch_loss(self) -> None:
|
||||
"""
|
||||
Patch loss functions
|
||||
"""
|
||||
if self.has_flash_attn:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_fa_llama_cross_entropy,
|
||||
patch_llama_rms_norm,
|
||||
)
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_llama_cross_entropy,
|
||||
patch_llama_rms_norm,
|
||||
)
|
||||
|
||||
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
|
||||
patch_fa_llama_cross_entropy()
|
||||
elif self.cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
|
||||
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
|
||||
if self.cfg.flash_attn_cross_entropy:
|
||||
patch_llama_cross_entropy()
|
||||
if self.cfg.flash_attn_rms_norm:
|
||||
patch_llama_rms_norm()
|
||||
elif self.cfg.unsloth_rms_norm:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
||||
|
||||
patch_unsloth_layernorm()
|
||||
if self.cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
@@ -499,7 +473,6 @@ class ModelLoader:
|
||||
"""
|
||||
Modify all llama derived models in one block
|
||||
"""
|
||||
self.patch_loss_llama()
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
@@ -547,6 +520,16 @@ class ModelLoader:
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
)
|
||||
|
||||
if self.cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
def set_auto_model_loader(self) -> None:
|
||||
"""set self.AutoModelLoader
|
||||
- default value: AutoModelForCausalLM (set at __init__)
|
||||
@@ -562,10 +545,6 @@ class ModelLoader:
|
||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||
MllamaForConditionalGeneration
|
||||
)
|
||||
elif self.model_config.model_type == "qwen2_vl":
|
||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||
AutoModelForImageTextToText
|
||||
)
|
||||
else:
|
||||
self.AutoModelLoader = (
|
||||
AutoModelForVision2Seq # pylint: disable=invalid-name
|
||||
@@ -583,8 +562,7 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
max_memory = {}
|
||||
num_device = get_device_count()
|
||||
for i in range(num_device):
|
||||
for i in range(torch.cuda.device_count()):
|
||||
max_memory[i] = gpu_memory_limit
|
||||
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
|
||||
|
||||
@@ -609,11 +587,8 @@ class ModelLoader:
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
||||
|
||||
cur_device = get_device_type()
|
||||
if "mps" in str(cur_device):
|
||||
if torch.backends.mps.is_available():
|
||||
self.model_kwargs["device_map"] = "mps:0"
|
||||
elif "npu" in str(cur_device):
|
||||
self.model_kwargs["device_map"] = "npu:0"
|
||||
|
||||
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
|
||||
# if cfg.rl:
|
||||
@@ -1058,9 +1033,7 @@ class ModelLoader:
|
||||
and self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||
):
|
||||
resize_kwargs = {}
|
||||
if self.cfg.mean_resizing_embeddings is not None and not (
|
||||
self.model_config.model_type == "llava"
|
||||
):
|
||||
if self.cfg.mean_resizing_embeddings is not None:
|
||||
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
||||
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||
else:
|
||||
@@ -1069,11 +1042,7 @@ class ModelLoader:
|
||||
self.ajust_model_config()
|
||||
|
||||
# log device memory usage
|
||||
if hasattr(self.model, "device") and self.model.device.type in (
|
||||
"cuda",
|
||||
"mps",
|
||||
"npu",
|
||||
):
|
||||
if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"):
|
||||
log_gpu_memory_usage(LOG, "after model load", self.model.device)
|
||||
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
@@ -1099,17 +1068,14 @@ class ModelLoader:
|
||||
|
||||
self.prepare_model(qlora_fsdp)
|
||||
|
||||
should_convert = (
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
|
||||
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
|
||||
)
|
||||
|
||||
if should_convert:
|
||||
LOG.info("Converting modules to %s", self.cfg.torch_dtype)
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if (needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp:
|
||||
LOG.info(
|
||||
"converting modules to %s for flash attention", self.cfg.torch_dtype
|
||||
)
|
||||
self.convert_embedding_modules_dtype(
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_modules,
|
||||
dist_dtype=self.cfg.torch_dtype,
|
||||
before_kbit_train_or_finetune=False,
|
||||
)
|
||||
@@ -1144,9 +1110,9 @@ class ModelLoader:
|
||||
and not skip_move_to_device
|
||||
):
|
||||
# TODO revaldate this conditional
|
||||
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
|
||||
self.model.to(f"cuda:{self.cfg.local_rank}")
|
||||
|
||||
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||
setattr(self.model, "is_parallelizable", True)
|
||||
setattr(self.model, "model_parallel", True)
|
||||
|
||||
|
||||
@@ -1,539 +0,0 @@
|
||||
"""
|
||||
Copied from https://github.com/iShohei220/adopt
|
||||
|
||||
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
|
||||
Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
|
||||
"""
|
||||
# mypy: ignore-errors
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Callable, List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim.optimizer import ( # DeviceDict,; _capturable_doc,; _differentiable_doc,; _foreach_doc,; _fused_doc,; _maximize_doc,; _stack_if_compiling,
|
||||
DeviceDict,
|
||||
Optimizer,
|
||||
ParamsT,
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_device_dtype_check_for_fused,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_foreach_doc,
|
||||
_fused_doc,
|
||||
_get_capturable_supported_devices,
|
||||
_get_scalar_dtype,
|
||||
_get_value,
|
||||
_maximize_doc,
|
||||
_stack_if_compiling,
|
||||
_use_grad_for_differentiable,
|
||||
_view_as_real,
|
||||
)
|
||||
|
||||
__all__ = ["ADOPT", "adopt"]
|
||||
|
||||
|
||||
class ADOPT(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||
eps: float = 1e-6,
|
||||
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
||||
weight_decay: float = 0.0,
|
||||
decouple: bool = False,
|
||||
*,
|
||||
foreach: Optional[bool] = None,
|
||||
maximize: bool = False,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
fused: Optional[bool] = None,
|
||||
):
|
||||
if isinstance(lr, Tensor):
|
||||
if foreach and not capturable:
|
||||
raise ValueError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
if lr.numel() != 1:
|
||||
raise ValueError("Tensor lr must be 1-element")
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
self.clip_lambda = clip_lambda
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
decouple=decouple,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
fused=fused,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
if fused:
|
||||
# TODO: support fused
|
||||
raise RuntimeError("`fused` is not currently supported")
|
||||
|
||||
if differentiable:
|
||||
raise RuntimeError("`fused` does not support `differentiable`")
|
||||
self._step_supports_amp_scaling = True
|
||||
# TODO(crcrpar): [low prec params & their higher prec copy]
|
||||
# Support AMP with FP16/BF16 model params which would need
|
||||
# higher prec copy of params to do update math in higher prec to
|
||||
# alleviate the loss of information.
|
||||
if foreach:
|
||||
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault("maximize", False)
|
||||
group.setdefault("foreach", None)
|
||||
group.setdefault("capturable", False)
|
||||
group.setdefault("differentiable", False)
|
||||
fused = group.setdefault("fused", None)
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||
step_val = float(p_state["step"])
|
||||
p_state["step"] = (
|
||||
torch.tensor(
|
||||
step_val,
|
||||
dtype=_get_scalar_dtype(is_fused=fused),
|
||||
device=p.device,
|
||||
)
|
||||
if group["capturable"] or group["fused"]
|
||||
else torch.tensor(step_val, dtype=_get_scalar_dtype())
|
||||
)
|
||||
|
||||
def _init_group(
|
||||
self,
|
||||
group,
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
):
|
||||
has_complex = False
|
||||
for p in group["params"]:
|
||||
if p.grad is not None:
|
||||
has_complex |= torch.is_complex(p)
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError("ADOPT does not support sparse gradients")
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
# Lazy state initialization
|
||||
if len(state) == 0:
|
||||
if group["fused"]:
|
||||
_device_dtype_check_for_fused(p)
|
||||
# note(crcrpar): [special device hosting for step]
|
||||
# Deliberately host `step` on CPU if both capturable and fused are off.
|
||||
# This is because kernel launches are costly on CUDA and XLA.
|
||||
state["step"] = (
|
||||
torch.zeros(
|
||||
(),
|
||||
dtype=_get_scalar_dtype(is_fused=group["fused"]),
|
||||
device=p.device,
|
||||
)
|
||||
if group["capturable"] or group["fused"]
|
||||
else torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||
)
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
exp_avgs.append(state["exp_avg"])
|
||||
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||
|
||||
if group["differentiable"] and state["step"].requires_grad:
|
||||
raise RuntimeError(
|
||||
"`requires_grad` is not supported for `step` in differentiable mode"
|
||||
)
|
||||
|
||||
# Foreach without capturable does not support a tensor lr
|
||||
if (
|
||||
group["foreach"]
|
||||
and torch.is_tensor(group["lr"])
|
||||
and not group["capturable"]
|
||||
):
|
||||
raise RuntimeError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
|
||||
state_steps.append(state["step"])
|
||||
return has_complex
|
||||
|
||||
@_use_grad_for_differentiable
|
||||
def step(self, closure=None):
|
||||
"""Perform a single optimization step.
|
||||
|
||||
Args:
|
||||
closure (Callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
self._cuda_graph_capture_health_check()
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
exp_avgs: List[Tensor] = []
|
||||
exp_avg_sqs: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
has_complex = self._init_group(
|
||||
group,
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
)
|
||||
|
||||
adopt(
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
has_complex=has_complex,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group["lr"],
|
||||
clip_lambda=self.clip_lambda,
|
||||
weight_decay=group["weight_decay"],
|
||||
decouple=group["decouple"],
|
||||
eps=group["eps"],
|
||||
maximize=group["maximize"],
|
||||
foreach=group["foreach"],
|
||||
capturable=group["capturable"],
|
||||
differentiable=group["differentiable"],
|
||||
fused=group["fused"],
|
||||
grad_scale=getattr(self, "grad_scale", None),
|
||||
found_inf=getattr(self, "found_inf", None),
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _single_tensor_adopt(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
has_complex: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
clip_lambda: Optional[Callable[[int], float]],
|
||||
weight_decay: float,
|
||||
decouple: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
differentiable: bool,
|
||||
):
|
||||
assert grad_scale is None and found_inf is None
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
# this assert is due to JIT being dumb and not realizing that the ops below
|
||||
# have overloads to handle both float and Tensor lrs, so we just assert it's
|
||||
# a float since most people using JIT are using floats
|
||||
assert isinstance(lr, float)
|
||||
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i] if not maximize else -grads[i]
|
||||
exp_avg = exp_avgs[i]
|
||||
exp_avg_sq = exp_avg_sqs[i]
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
||||
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||
|
||||
if weight_decay != 0 and not decouple:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
if torch.is_complex(param):
|
||||
grad = torch.view_as_real(grad)
|
||||
if exp_avg is not None:
|
||||
exp_avg = torch.view_as_real(exp_avg)
|
||||
if exp_avg_sq is not None:
|
||||
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||
param = torch.view_as_real(param)
|
||||
|
||||
if step == 0:
|
||||
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||
# update step
|
||||
step_t += 1
|
||||
continue
|
||||
|
||||
if weight_decay != 0 and decouple:
|
||||
param.add_(param, alpha=-lr * weight_decay)
|
||||
|
||||
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||
normed_grad = grad.div(denom)
|
||||
if clip_lambda is not None:
|
||||
clip = clip_lambda(step)
|
||||
normed_grad.clamp_(-clip, clip)
|
||||
|
||||
exp_avg.lerp_(normed_grad, 1 - beta1)
|
||||
|
||||
param.add_(exp_avg, alpha=-lr)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
|
||||
|
||||
def _multi_tensor_adopt(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
has_complex: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
clip_lambda: Optional[Callable[[int], float]],
|
||||
weight_decay: float,
|
||||
decouple: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
differentiable: bool,
|
||||
):
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
||||
if isinstance(lr, Tensor) and not capturable:
|
||||
raise RuntimeError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
assert all(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
||||
assert grad_scale is None and found_inf is None
|
||||
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||
[params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
|
||||
)
|
||||
for (
|
||||
device_params_,
|
||||
device_grads_,
|
||||
device_exp_avgs_,
|
||||
device_exp_avg_sqs_,
|
||||
device_state_steps_,
|
||||
), _ in grouped_tensors.values():
|
||||
device_params = cast(List[Tensor], device_params_)
|
||||
device_grads = cast(List[Tensor], device_grads_)
|
||||
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
|
||||
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
|
||||
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||
|
||||
# Handle complex parameters
|
||||
if has_complex:
|
||||
_view_as_real(
|
||||
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
|
||||
)
|
||||
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||
|
||||
if weight_decay != 0 and not decouple:
|
||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||
if maximize:
|
||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||
else:
|
||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||
device_grads, device_params, alpha=weight_decay
|
||||
)
|
||||
|
||||
if device_state_steps[0] == 0:
|
||||
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
else:
|
||||
torch._foreach_add_(device_state_steps, 1)
|
||||
|
||||
continue
|
||||
|
||||
if weight_decay != 0 and decouple:
|
||||
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
|
||||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
|
||||
|
||||
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
|
||||
if clip_lambda is not None:
|
||||
clip = clip_lambda(device_state_steps[0])
|
||||
torch._foreach_maximum_(normed_grad, -clip)
|
||||
torch._foreach_minimum_(normed_grad, clip)
|
||||
|
||||
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
|
||||
|
||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
torch._foreach_addcmul_(
|
||||
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||
)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
else:
|
||||
torch._foreach_add_(device_state_steps, 1)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
||||
def adopt(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
fused: Optional[bool] = None,
|
||||
grad_scale: Optional[Tensor] = None,
|
||||
found_inf: Optional[Tensor] = None,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
clip_lambda: Optional[Callable[[int], float]],
|
||||
weight_decay: float,
|
||||
decouple: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
):
|
||||
r"""Functional API that performs ADOPT algorithm computation."""
|
||||
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||
# the default when neither have been user-specified. Note that we default to foreach
|
||||
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||
# bake-in time before making it the default, even if it is typically faster.
|
||||
if fused is None and foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
|
||||
if foreach and isinstance(lr, Tensor) and not capturable:
|
||||
foreach = False
|
||||
if fused is None:
|
||||
fused = False
|
||||
if foreach is None:
|
||||
foreach = False
|
||||
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
if fused and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with fused optimizers")
|
||||
|
||||
# if fused and not torch.jit.is_scripting():
|
||||
# func = _fused_adopt
|
||||
# elif foreach and not torch.jit.is_scripting():
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_adopt
|
||||
else:
|
||||
func = _single_tensor_adopt
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
has_complex=has_complex,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
clip_lambda=clip_lambda,
|
||||
weight_decay=weight_decay,
|
||||
decouple=decouple,
|
||||
eps=eps,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
grad_scale=grad_scale,
|
||||
found_inf=found_inf,
|
||||
)
|
||||
250
src/axolotl/utils/optimizers/shampoo.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.optim import Optimizer
|
||||
from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
|
||||
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
|
||||
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
|
||||
|
||||
|
||||
class _ShampooBase(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-1,
|
||||
momentum=0.0,
|
||||
weight_decay=0.0,
|
||||
eps=1e-4,
|
||||
update_freq=1,
|
||||
*,
|
||||
block_size,
|
||||
quantization_bits,
|
||||
optimizer_state_class,
|
||||
):
|
||||
if lr <= 0.0:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if momentum < 0.0:
|
||||
raise ValueError(f"Invalid momentum value: {momentum}")
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
if eps < 0.0:
|
||||
raise ValueError(f"Invalid eps value: {eps}")
|
||||
if update_freq < 1:
|
||||
raise ValueError(f"Invalid update_freq value: {update_freq}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay,
|
||||
eps=eps,
|
||||
update_freq=update_freq,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
self.block_size = block_size
|
||||
self.quantization_bits = quantization_bits
|
||||
self.optimizer_state_class = optimizer_state_class
|
||||
|
||||
def step(self, closure: Optional[callable] = None) -> Optional[float]:
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
state["momentum_buffer"] = self._new_buffer(grad, True)
|
||||
state["preconds"] = []
|
||||
state["inv_preconds"] = []
|
||||
for dim in grad.size():
|
||||
state["preconds"].append(
|
||||
self.optimizer_state_class.zeros(
|
||||
(dim, dim),
|
||||
signed=False,
|
||||
block_size=self.block_size,
|
||||
device=grad.device,
|
||||
)
|
||||
)
|
||||
state["inv_preconds"].append(
|
||||
torch.zeros((dim, dim), device=grad.device)
|
||||
)
|
||||
|
||||
state["step"] += 1
|
||||
beta = group["momentum"]
|
||||
weight_decay = group["weight_decay"]
|
||||
lr = group["lr"]
|
||||
eps = group["eps"]
|
||||
update_freq = group["update_freq"]
|
||||
|
||||
# Apply momentum
|
||||
if beta > 0:
|
||||
state["momentum_buffer"].mul_(beta).add_(grad, alpha=1 - beta)
|
||||
grad = state["momentum_buffer"]
|
||||
|
||||
# Apply weight decay
|
||||
if weight_decay > 0:
|
||||
grad = grad.add(p.data, alpha=weight_decay)
|
||||
|
||||
# Preconditioning
|
||||
order = grad.ndimension()
|
||||
original_size = grad.size()
|
||||
for dim_id, dim in enumerate(grad.size()):
|
||||
precond = state["preconds"][dim_id]
|
||||
inv_precond = state["inv_preconds"][dim_id]
|
||||
|
||||
# Reshape grad
|
||||
grad = grad.transpose(0, dim_id).contiguous()
|
||||
transposed_size = grad.size()
|
||||
grad = grad.view(dim, -1)
|
||||
|
||||
grad_t = grad.t()
|
||||
|
||||
# Update preconditioner
|
||||
precond_fp32 = precond.dequantize()
|
||||
precond_update = grad @ grad_t
|
||||
precond_fp32.add_(precond_update)
|
||||
|
||||
# Quantize preconditioner back
|
||||
precond.copy_(precond_fp32)
|
||||
|
||||
# Update inverse preconditioner
|
||||
if state["step"] % update_freq == 0:
|
||||
inv_precond.copy_(
|
||||
self._compute_inv_precond(precond_fp32, eps, order)
|
||||
)
|
||||
|
||||
# Precondition grad
|
||||
if dim_id == order - 1:
|
||||
# Last dimension
|
||||
grad = grad_t @ inv_precond
|
||||
grad = grad.view(original_size)
|
||||
else:
|
||||
grad = inv_precond @ grad
|
||||
grad = grad.view(transposed_size)
|
||||
|
||||
# Update parameter
|
||||
p.data.add_(grad, alpha=-lr)
|
||||
|
||||
return loss
|
||||
|
||||
def _compute_inv_precond(self, precond: Tensor, eps: float, order: int):
|
||||
# Add eps for numerical stability
|
||||
precond = precond + torch.eye(precond.size(0), device=precond.device) * eps
|
||||
|
||||
# Compute matrix power
|
||||
inv_precond = self._matrix_power(precond, -1.0 / (2 * order))
|
||||
|
||||
return inv_precond
|
||||
|
||||
def _matrix_power(self, matrix: Tensor, power: float) -> Tensor:
|
||||
# Compute matrix power using SVD
|
||||
u, s, v = torch.svd(matrix)
|
||||
s_pow = s.pow(power)
|
||||
return u @ torch.diag(s_pow) @ v.t()
|
||||
|
||||
# bring your own function to create zero-filled subclass
|
||||
@staticmethod
|
||||
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
|
||||
raise NotImplementedError
|
||||
|
||||
# follow bitsandbytes, only quantize tensors >= 4096 values
|
||||
# also wrap subclass in DTensor when needed
|
||||
def _new_buffer(self, p: Tensor, signed: bool):
|
||||
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
|
||||
if isinstance(p, DTensor):
|
||||
out = DTensor.from_local(
|
||||
local_tensor=self._subclass_zeros(
|
||||
p.to_local(), signed, self.block_size
|
||||
),
|
||||
device_mesh=p.device_mesh,
|
||||
placements=p.placements,
|
||||
run_check=False,
|
||||
)
|
||||
else:
|
||||
out = self._subclass_zeros(p, signed, self.block_size)
|
||||
else:
|
||||
out = torch.zeros_like(p)
|
||||
return out
|
||||
|
||||
|
||||
class Shampoo8bit(_ShampooBase):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-1,
|
||||
momentum=0.0,
|
||||
weight_decay=0.0,
|
||||
eps=1e-4,
|
||||
update_freq=1,
|
||||
*,
|
||||
block_size=256,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
lr,
|
||||
momentum,
|
||||
weight_decay,
|
||||
eps,
|
||||
update_freq,
|
||||
block_size=block_size,
|
||||
quantization_bits=8,
|
||||
optimizer_state_class=OptimState8bit,
|
||||
)
|
||||
|
||||
|
||||
class Shampoo4bit(_ShampooBase):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-1,
|
||||
momentum=0.0,
|
||||
weight_decay=0.0,
|
||||
eps=1e-4,
|
||||
update_freq=1,
|
||||
*,
|
||||
block_size=128,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
lr,
|
||||
momentum,
|
||||
weight_decay,
|
||||
eps,
|
||||
update_freq,
|
||||
block_size=block_size,
|
||||
quantization_bits=4,
|
||||
optimizer_state_class=OptimState4bit,
|
||||
)
|
||||
|
||||
|
||||
class ShampooFp8(_ShampooBase):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-1,
|
||||
momentum=0.0,
|
||||
weight_decay=0.0,
|
||||
eps=1e-4,
|
||||
update_freq=1,
|
||||
*,
|
||||
block_size=256,
|
||||
):
|
||||
super().__init__(
|
||||
params,
|
||||
lr,
|
||||
momentum,
|
||||
weight_decay,
|
||||
eps,
|
||||
update_freq,
|
||||
block_size=block_size,
|
||||
quantization_bits=8, # FP8 uses 8 bits
|
||||
optimizer_state_class=OptimStateFp8,
|
||||
)
|
||||